diff --git a/msm/encoding.canoto.go b/msm/encoding.canoto.go index 30f82706..918ab9c8 100644 --- a/msm/encoding.canoto.go +++ b/msm/encoding.canoto.go @@ -28,12 +28,16 @@ const ( canotoNumber_StateMachineMetadata__SimplexBlacklist = 3 canotoNumber_StateMachineMetadata__PChainHeight = 4 canotoNumber_StateMachineMetadata__Timestamp = 5 + canotoNumber_StateMachineMetadata__AuxiliaryInfo = 6 + canotoNumber_StateMachineMetadata__ICMEpochInfo = 7 canotoTag_StateMachineMetadata__SimplexEpochInfo = "\x0a" // canoto.Tag(canotoNumber_StateMachineMetadata__SimplexEpochInfo, canoto.Len) canotoTag_StateMachineMetadata__SimplexProtocolMetadata = "\x12" // canoto.Tag(canotoNumber_StateMachineMetadata__SimplexProtocolMetadata, canoto.Len) canotoTag_StateMachineMetadata__SimplexBlacklist = "\x1a" // canoto.Tag(canotoNumber_StateMachineMetadata__SimplexBlacklist, canoto.Len) canotoTag_StateMachineMetadata__PChainHeight = "\x20" // canoto.Tag(canotoNumber_StateMachineMetadata__PChainHeight, canoto.Varint) canotoTag_StateMachineMetadata__Timestamp = "\x28" // canoto.Tag(canotoNumber_StateMachineMetadata__Timestamp, canoto.Varint) + canotoTag_StateMachineMetadata__AuxiliaryInfo = "\x32" // canoto.Tag(canotoNumber_StateMachineMetadata__AuxiliaryInfo, canoto.Len) + canotoTag_StateMachineMetadata__ICMEpochInfo = "\x3a" // canoto.Tag(canotoNumber_StateMachineMetadata__ICMEpochInfo, canoto.Len) ) type canotoData_StateMachineMetadata struct { @@ -81,6 +85,26 @@ func (*StateMachineMetadata) CanotoSpec(types ...reflect.Type) *canoto.Spec { OneOf: "", TypeUint: canoto.SizeOf(zero.Timestamp), }, + canoto.FieldTypeFromField( + /*type inference:*/ (zero.AuxiliaryInfo), + /*FieldNumber: */ canotoNumber_StateMachineMetadata__AuxiliaryInfo, + /*Name: */ "AuxiliaryInfo", + /*FixedLength: */ 0, + /*Repeated: */ false, + /*OneOf: */ "", + /*Pointer: */ true, + /*types: */ types, + ), + canoto.FieldTypeFromField( + /*type inference:*/ (&zero.ICMEpochInfo), + /*FieldNumber: */ canotoNumber_StateMachineMetadata__ICMEpochInfo, + /*Name: */ "ICMEpochInfo", + /*FixedLength: */ 0, + /*Repeated: */ false, + /*OneOf: */ "", + /*Pointer: */ false, + /*types: */ types, + ), }, } s.CalculateCanotoCache() @@ -187,6 +211,52 @@ func (c *StateMachineMetadata) UnmarshalCanotoFrom(r canoto.Reader) error { if canoto.IsZero(c.Timestamp) { return canoto.ErrZeroValue } + case canotoNumber_StateMachineMetadata__AuxiliaryInfo: + if wireType != canoto.Len { + return canoto.ErrUnexpectedWireType + } + + // Read the bytes for the field. + originalUnsafe := r.Unsafe + r.Unsafe = true + var msgBytes []byte + if err := canoto.ReadBytes(&r, &msgBytes); err != nil { + return err + } + r.Unsafe = originalUnsafe + + // Unmarshal the field from the bytes. + remainingBytes := r.B + r.B = msgBytes + c.AuxiliaryInfo = canoto.MakePointer(c.AuxiliaryInfo) + if err := (c.AuxiliaryInfo).UnmarshalCanotoFrom(r); err != nil { + return err + } + r.B = remainingBytes + case canotoNumber_StateMachineMetadata__ICMEpochInfo: + if wireType != canoto.Len { + return canoto.ErrUnexpectedWireType + } + + // Read the bytes for the field. + originalUnsafe := r.Unsafe + r.Unsafe = true + var msgBytes []byte + if err := canoto.ReadBytes(&r, &msgBytes); err != nil { + return err + } + if len(msgBytes) == 0 { + return canoto.ErrZeroValue + } + r.Unsafe = originalUnsafe + + // Unmarshal the field from the bytes. + remainingBytes := r.B + r.B = msgBytes + if err := (&c.ICMEpochInfo).UnmarshalCanotoFrom(r); err != nil { + return err + } + r.B = remainingBytes default: return canoto.ErrUnknownField } @@ -207,6 +277,12 @@ func (c *StateMachineMetadata) ValidCanoto() bool { if !(&c.SimplexEpochInfo).ValidCanoto() { return false } + if c.AuxiliaryInfo != nil && !(c.AuxiliaryInfo).ValidCanoto() { + return false + } + if !(&c.ICMEpochInfo).ValidCanoto() { + return false + } return true } @@ -232,6 +308,15 @@ func (c *StateMachineMetadata) CalculateCanotoCache() { if !canoto.IsZero(c.Timestamp) { size += uint64(len(canotoTag_StateMachineMetadata__Timestamp)) + canoto.SizeUint(c.Timestamp) } + if c.AuxiliaryInfo != nil { + (c.AuxiliaryInfo).CalculateCanotoCache() + fieldSize := (c.AuxiliaryInfo).CachedCanotoSize() + size += uint64(len(canotoTag_StateMachineMetadata__AuxiliaryInfo)) + canoto.SizeUint(fieldSize) + fieldSize + } + (&c.ICMEpochInfo).CalculateCanotoCache() + if fieldSize := (&c.ICMEpochInfo).CachedCanotoSize(); fieldSize != 0 { + size += uint64(len(canotoTag_StateMachineMetadata__ICMEpochInfo)) + canoto.SizeUint(fieldSize) + fieldSize + } atomic.StoreUint64(&c.canotoData.size, size) } @@ -291,6 +376,411 @@ func (c *StateMachineMetadata) MarshalCanotoInto(w canoto.Writer) canoto.Writer canoto.Append(&w, canotoTag_StateMachineMetadata__Timestamp) canoto.AppendUint(&w, c.Timestamp) } + if c.AuxiliaryInfo != nil { + fieldSize := (c.AuxiliaryInfo).CachedCanotoSize() + canoto.Append(&w, canotoTag_StateMachineMetadata__AuxiliaryInfo) + canoto.AppendUint(&w, fieldSize) + w = (c.AuxiliaryInfo).MarshalCanotoInto(w) + } + if fieldSize := (&c.ICMEpochInfo).CachedCanotoSize(); fieldSize != 0 { + canoto.Append(&w, canotoTag_StateMachineMetadata__ICMEpochInfo) + canoto.AppendUint(&w, fieldSize) + w = (&c.ICMEpochInfo).MarshalCanotoInto(w) + } + return w +} + +const ( + canotoNumber_ICMEpochInfo__EpochStartTime = 1 + canotoNumber_ICMEpochInfo__EpochNumber = 2 + canotoNumber_ICMEpochInfo__PChainEpochHeight = 3 + + canotoTag_ICMEpochInfo__EpochStartTime = "\x08" // canoto.Tag(canotoNumber_ICMEpochInfo__EpochStartTime, canoto.Varint) + canotoTag_ICMEpochInfo__EpochNumber = "\x10" // canoto.Tag(canotoNumber_ICMEpochInfo__EpochNumber, canoto.Varint) + canotoTag_ICMEpochInfo__PChainEpochHeight = "\x18" // canoto.Tag(canotoNumber_ICMEpochInfo__PChainEpochHeight, canoto.Varint) +) + +type canotoData_ICMEpochInfo struct { + size uint64 +} + +// CanotoSpec returns the specification of this canoto message. +func (*ICMEpochInfo) CanotoSpec(...reflect.Type) *canoto.Spec { + var zero ICMEpochInfo + s := &canoto.Spec{ + Name: "ICMEpochInfo", + Fields: []canoto.FieldType{ + { + FieldNumber: canotoNumber_ICMEpochInfo__EpochStartTime, + Name: "EpochStartTime", + OneOf: "", + TypeUint: canoto.SizeOf(zero.EpochStartTime), + }, + { + FieldNumber: canotoNumber_ICMEpochInfo__EpochNumber, + Name: "EpochNumber", + OneOf: "", + TypeUint: canoto.SizeOf(zero.EpochNumber), + }, + { + FieldNumber: canotoNumber_ICMEpochInfo__PChainEpochHeight, + Name: "PChainEpochHeight", + OneOf: "", + TypeUint: canoto.SizeOf(zero.PChainEpochHeight), + }, + }, + } + s.CalculateCanotoCache() + return s +} + +// UnmarshalCanoto unmarshals a Canoto-encoded byte slice into the struct. +// +// During parsing, the canoto cache is saved. +func (c *ICMEpochInfo) UnmarshalCanoto(bytes []byte) error { + r := canoto.Reader{ + B: bytes, + } + return c.UnmarshalCanotoFrom(r) +} + +// UnmarshalCanotoFrom populates the struct from a [canoto.Reader]. Most users +// should just use UnmarshalCanoto. +// +// During parsing, the canoto cache is saved. +// +// This function enables configuration of reader options. +func (c *ICMEpochInfo) UnmarshalCanotoFrom(r canoto.Reader) error { + // Zero the struct before unmarshaling. + *c = ICMEpochInfo{} + atomic.StoreUint64(&c.canotoData.size, uint64(len(r.B))) + + var minField uint32 + for canoto.HasNext(&r) { + field, wireType, err := canoto.ReadTag(&r) + if err != nil { + return err + } + if field < minField { + return canoto.ErrInvalidFieldOrder + } + + switch field { + case canotoNumber_ICMEpochInfo__EpochStartTime: + if wireType != canoto.Varint { + return canoto.ErrUnexpectedWireType + } + + if err := canoto.ReadUint(&r, &c.EpochStartTime); err != nil { + return err + } + if canoto.IsZero(c.EpochStartTime) { + return canoto.ErrZeroValue + } + case canotoNumber_ICMEpochInfo__EpochNumber: + if wireType != canoto.Varint { + return canoto.ErrUnexpectedWireType + } + + if err := canoto.ReadUint(&r, &c.EpochNumber); err != nil { + return err + } + if canoto.IsZero(c.EpochNumber) { + return canoto.ErrZeroValue + } + case canotoNumber_ICMEpochInfo__PChainEpochHeight: + if wireType != canoto.Varint { + return canoto.ErrUnexpectedWireType + } + + if err := canoto.ReadUint(&r, &c.PChainEpochHeight); err != nil { + return err + } + if canoto.IsZero(c.PChainEpochHeight) { + return canoto.ErrZeroValue + } + default: + return canoto.ErrUnknownField + } + + minField = field + 1 + } + return nil +} + +// ValidCanoto validates that the struct can be correctly marshaled into the +// Canoto format. +// +// Specifically, ValidCanoto ensures: +// 1. All OneOfs are specified at most once. +// 2. All strings are valid utf-8. +// 3. All custom fields are ValidCanoto. +func (c *ICMEpochInfo) ValidCanoto() bool { + return true +} + +// CalculateCanotoCache populates size and OneOf caches based on the current +// values in the struct. +// +// It is not safe to copy this struct concurrently. +func (c *ICMEpochInfo) CalculateCanotoCache() { + var size uint64 + if !canoto.IsZero(c.EpochStartTime) { + size += uint64(len(canotoTag_ICMEpochInfo__EpochStartTime)) + canoto.SizeUint(c.EpochStartTime) + } + if !canoto.IsZero(c.EpochNumber) { + size += uint64(len(canotoTag_ICMEpochInfo__EpochNumber)) + canoto.SizeUint(c.EpochNumber) + } + if !canoto.IsZero(c.PChainEpochHeight) { + size += uint64(len(canotoTag_ICMEpochInfo__PChainEpochHeight)) + canoto.SizeUint(c.PChainEpochHeight) + } + atomic.StoreUint64(&c.canotoData.size, size) +} + +// CachedCanotoSize returns the previously calculated size of the Canoto +// representation from CalculateCanotoCache. +// +// If CalculateCanotoCache has not yet been called, it will return 0. +// +// If the struct has been modified since the last call to CalculateCanotoCache, +// the returned size may be incorrect. +func (c *ICMEpochInfo) CachedCanotoSize() uint64 { + return atomic.LoadUint64(&c.canotoData.size) +} + +// MarshalCanoto returns the Canoto representation of this struct. +// +// It is assumed that this struct is ValidCanoto. +// +// It is not safe to copy this struct concurrently. +func (c *ICMEpochInfo) MarshalCanoto() []byte { + c.CalculateCanotoCache() + w := canoto.Writer{ + B: make([]byte, 0, c.CachedCanotoSize()), + } + w = c.MarshalCanotoInto(w) + return w.B +} + +// MarshalCanotoInto writes the struct into a [canoto.Writer] and returns the +// resulting [canoto.Writer]. Most users should just use MarshalCanoto. +// +// It is assumed that CalculateCanotoCache has been called since the last +// modification to this struct. +// +// It is assumed that this struct is ValidCanoto. +// +// It is not safe to copy this struct concurrently. +func (c *ICMEpochInfo) MarshalCanotoInto(w canoto.Writer) canoto.Writer { + if !canoto.IsZero(c.EpochStartTime) { + canoto.Append(&w, canotoTag_ICMEpochInfo__EpochStartTime) + canoto.AppendUint(&w, c.EpochStartTime) + } + if !canoto.IsZero(c.EpochNumber) { + canoto.Append(&w, canotoTag_ICMEpochInfo__EpochNumber) + canoto.AppendUint(&w, c.EpochNumber) + } + if !canoto.IsZero(c.PChainEpochHeight) { + canoto.Append(&w, canotoTag_ICMEpochInfo__PChainEpochHeight) + canoto.AppendUint(&w, c.PChainEpochHeight) + } + return w +} + +const ( + canotoNumber_AuxiliaryInfo__Info = 1 + canotoNumber_AuxiliaryInfo__PrevAuxInfoSeq = 2 + canotoNumber_AuxiliaryInfo__ApplicationID = 3 + + canotoTag_AuxiliaryInfo__Info = "\x0a" // canoto.Tag(canotoNumber_AuxiliaryInfo__Info, canoto.Len) + canotoTag_AuxiliaryInfo__PrevAuxInfoSeq = "\x10" // canoto.Tag(canotoNumber_AuxiliaryInfo__PrevAuxInfoSeq, canoto.Varint) + canotoTag_AuxiliaryInfo__ApplicationID = "\x18" // canoto.Tag(canotoNumber_AuxiliaryInfo__ApplicationID, canoto.Varint) +) + +type canotoData_AuxiliaryInfo struct { + size uint64 +} + +// CanotoSpec returns the specification of this canoto message. +func (*AuxiliaryInfo) CanotoSpec(...reflect.Type) *canoto.Spec { + var zero AuxiliaryInfo + s := &canoto.Spec{ + Name: "AuxiliaryInfo", + Fields: []canoto.FieldType{ + { + FieldNumber: canotoNumber_AuxiliaryInfo__Info, + Name: "Info", + OneOf: "", + TypeBytes: true, + }, + { + FieldNumber: canotoNumber_AuxiliaryInfo__PrevAuxInfoSeq, + Name: "PrevAuxInfoSeq", + OneOf: "", + TypeUint: canoto.SizeOf(zero.PrevAuxInfoSeq), + }, + { + FieldNumber: canotoNumber_AuxiliaryInfo__ApplicationID, + Name: "ApplicationID", + OneOf: "", + TypeUint: canoto.SizeOf(zero.ApplicationID), + }, + }, + } + s.CalculateCanotoCache() + return s +} + +// UnmarshalCanoto unmarshals a Canoto-encoded byte slice into the struct. +// +// During parsing, the canoto cache is saved. +func (c *AuxiliaryInfo) UnmarshalCanoto(bytes []byte) error { + r := canoto.Reader{ + B: bytes, + } + return c.UnmarshalCanotoFrom(r) +} + +// UnmarshalCanotoFrom populates the struct from a [canoto.Reader]. Most users +// should just use UnmarshalCanoto. +// +// During parsing, the canoto cache is saved. +// +// This function enables configuration of reader options. +func (c *AuxiliaryInfo) UnmarshalCanotoFrom(r canoto.Reader) error { + // Zero the struct before unmarshaling. + *c = AuxiliaryInfo{} + atomic.StoreUint64(&c.canotoData.size, uint64(len(r.B))) + + var minField uint32 + for canoto.HasNext(&r) { + field, wireType, err := canoto.ReadTag(&r) + if err != nil { + return err + } + if field < minField { + return canoto.ErrInvalidFieldOrder + } + + switch field { + case canotoNumber_AuxiliaryInfo__Info: + if wireType != canoto.Len { + return canoto.ErrUnexpectedWireType + } + + if err := canoto.ReadBytes(&r, &c.Info); err != nil { + return err + } + if len(c.Info) == 0 { + return canoto.ErrZeroValue + } + case canotoNumber_AuxiliaryInfo__PrevAuxInfoSeq: + if wireType != canoto.Varint { + return canoto.ErrUnexpectedWireType + } + + if err := canoto.ReadUint(&r, &c.PrevAuxInfoSeq); err != nil { + return err + } + if canoto.IsZero(c.PrevAuxInfoSeq) { + return canoto.ErrZeroValue + } + case canotoNumber_AuxiliaryInfo__ApplicationID: + if wireType != canoto.Varint { + return canoto.ErrUnexpectedWireType + } + + if err := canoto.ReadUint(&r, &c.ApplicationID); err != nil { + return err + } + if canoto.IsZero(c.ApplicationID) { + return canoto.ErrZeroValue + } + default: + return canoto.ErrUnknownField + } + + minField = field + 1 + } + return nil +} + +// ValidCanoto validates that the struct can be correctly marshaled into the +// Canoto format. +// +// Specifically, ValidCanoto ensures: +// 1. All OneOfs are specified at most once. +// 2. All strings are valid utf-8. +// 3. All custom fields are ValidCanoto. +func (c *AuxiliaryInfo) ValidCanoto() bool { + return true +} + +// CalculateCanotoCache populates size and OneOf caches based on the current +// values in the struct. +// +// It is not safe to copy this struct concurrently. +func (c *AuxiliaryInfo) CalculateCanotoCache() { + var size uint64 + if len(c.Info) != 0 { + size += uint64(len(canotoTag_AuxiliaryInfo__Info)) + canoto.SizeBytes(c.Info) + } + if !canoto.IsZero(c.PrevAuxInfoSeq) { + size += uint64(len(canotoTag_AuxiliaryInfo__PrevAuxInfoSeq)) + canoto.SizeUint(c.PrevAuxInfoSeq) + } + if !canoto.IsZero(c.ApplicationID) { + size += uint64(len(canotoTag_AuxiliaryInfo__ApplicationID)) + canoto.SizeUint(c.ApplicationID) + } + atomic.StoreUint64(&c.canotoData.size, size) +} + +// CachedCanotoSize returns the previously calculated size of the Canoto +// representation from CalculateCanotoCache. +// +// If CalculateCanotoCache has not yet been called, it will return 0. +// +// If the struct has been modified since the last call to CalculateCanotoCache, +// the returned size may be incorrect. +func (c *AuxiliaryInfo) CachedCanotoSize() uint64 { + return atomic.LoadUint64(&c.canotoData.size) +} + +// MarshalCanoto returns the Canoto representation of this struct. +// +// It is assumed that this struct is ValidCanoto. +// +// It is not safe to copy this struct concurrently. +func (c *AuxiliaryInfo) MarshalCanoto() []byte { + c.CalculateCanotoCache() + w := canoto.Writer{ + B: make([]byte, 0, c.CachedCanotoSize()), + } + w = c.MarshalCanotoInto(w) + return w.B +} + +// MarshalCanotoInto writes the struct into a [canoto.Writer] and returns the +// resulting [canoto.Writer]. Most users should just use MarshalCanoto. +// +// It is assumed that CalculateCanotoCache has been called since the last +// modification to this struct. +// +// It is assumed that this struct is ValidCanoto. +// +// It is not safe to copy this struct concurrently. +func (c *AuxiliaryInfo) MarshalCanotoInto(w canoto.Writer) canoto.Writer { + if len(c.Info) != 0 { + canoto.Append(&w, canotoTag_AuxiliaryInfo__Info) + canoto.AppendBytes(&w, c.Info) + } + if !canoto.IsZero(c.PrevAuxInfoSeq) { + canoto.Append(&w, canotoTag_AuxiliaryInfo__PrevAuxInfoSeq) + canoto.AppendUint(&w, c.PrevAuxInfoSeq) + } + if !canoto.IsZero(c.ApplicationID) { + canoto.Append(&w, canotoTag_AuxiliaryInfo__ApplicationID) + canoto.AppendUint(&w, c.ApplicationID) + } return w } diff --git a/msm/encoding.go b/msm/encoding.go index 5ed0162a..c6610eab 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -30,10 +30,69 @@ type StateMachineMetadata struct { PChainHeight uint64 `canoto:"uint,4"` // Timestamp is the time when the block is being built, in milliseconds since Unix epoch. Timestamp uint64 `canoto:"uint,5"` + // AuxiliaryInfo is application-specific information that the StateMachine doesn't need to understand, + // but can be used by applications that care about epoch changes, such as threshold distributed public key generation. + AuxiliaryInfo *AuxiliaryInfo `canoto:"pointer,6"` + // ICMEpochInfo is the metadata that the StateMachine uses for ICM epoching. + ICMEpochInfo ICMEpochInfo `canoto:"value,7"` canotoData canotoData_StateMachineMetadata } +// ICMEpochInfo is metadata used for the ICM protocol. +// The StateMachine maintains this metadata in a similar fashion to proposerVM. +type ICMEpochInfo struct { + EpochStartTime uint64 `canoto:"uint,1"` + EpochNumber uint64 `canoto:"uint,2"` + PChainEpochHeight uint64 `canoto:"uint,3"` + + canotoData canotoData_ICMEpochInfo +} + +func (ei *ICMEpochInfo) Equal(other *ICMEpochInfo) bool { + if ei == nil { + return other == nil + } + if other == nil { + return ei == nil + } + return ei.EpochStartTime == other.EpochStartTime && ei.EpochNumber == other.EpochNumber && ei.PChainEpochHeight == other.PChainEpochHeight +} + +// AppID is an identifier for applications that care about epoch changes. +type AppID uint32 + +// AuxiliaryInfo defines application-specific information for applications that might care about epoch change, +// such as threshold distributed public key generation. +type AuxiliaryInfo struct { + // Info is opaque bytes that can be used by applications to encode any information that describes + // the current state for the application. + Info []byte `canoto:"bytes,1"` + // PrevAuxInfoSeq is a sequence number that applications can use to find previous AuxiliaryInfo in the chain. + // It is zero if this is the first AuxiliaryInfo for this epoch. + PrevAuxInfoSeq uint64 `canoto:"uint,2"` + // ApplicationID is an identifier that identifies the application. + // Can be used for backward-compatibility and upgrade purposes. + ApplicationID AppID `canoto:"uint,3"` + + canotoData canotoData_AuxiliaryInfo +} + +func (ai *AuxiliaryInfo) IsZero() bool { + var zero AuxiliaryInfo + return ai.Equal(&zero) +} + +func (ai *AuxiliaryInfo) Equal(a *AuxiliaryInfo) bool { + if ai == nil { + return a == nil + } + if a == nil { + return ai == nil + } + return bytes.Equal(ai.Info, a.Info) && ai.PrevAuxInfoSeq == a.PrevAuxInfoSeq && ai.ApplicationID == a.ApplicationID +} + // SimplexEpochInfo is metadata used by the StateMachine. type SimplexEpochInfo struct { // PChainReferenceHeight is the P-Chain height that the StateMachine uses as a reference for the current epoch. diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index c5e03a24..7740a34c 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -45,7 +45,9 @@ func TestFakeNodeEpochChangesDespiteEmptyMempool(t *testing.T) { pChainHeight.Store(200) - for node.Epoch() == 1 { + firstEpoch := node.Epoch() + + for node.Epoch() == firstEpoch { node.buildAndNotarizeBlock() if node.canFinalize() { node.tryFinalizeNextBlock() @@ -229,6 +231,7 @@ type blockState struct { type fakeNode struct { t *testing.T + epoch uint64 sm *StateMachine mempoolEmpty bool // blocks holds notarized blocks in order. Finalized blocks always form a @@ -260,8 +263,9 @@ func newFakeNode(t *testing.T) *fakeNode { sm, _ := newStateMachine(t) fn := &fakeNode{ - t: t, - sm: sm, + t: t, + sm: sm, + epoch: 1, } fn.sm.BlockBuilder = fn @@ -293,17 +297,6 @@ func newFakeNode(t *testing.T) *fakeNode { return StateMachineBlock{}, nil, fmt.Errorf("block not found") } - fn.sm.FirstEverSimplexBlock = func() *StateMachineBlock { - for _, block := range fn.blocks { - if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { - continue - } - return &block.block - } - require.FailNow(t, "block not found") - return nil - } - return fn } @@ -378,6 +371,9 @@ func (fn *fakeNode) tryFinalizeNextBlock() { if block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil { fn.blocks = fn.blocks[:nextIndex+1] fn.t.Logf("Trimmed notarized blocks, new length: %d", len(fn.blocks)) + prevEpoch := fn.epoch + fn.epoch = md.Seq + fn.t.Logf("Epoch change from %d to %d", prevEpoch, fn.epoch) } } @@ -421,6 +417,7 @@ func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { block, err := fn.sm.BuildBlock(context.Background(), simplex.ProtocolMetadata{ Seq: lastMD.Seq + 1, Round: lastMD.Round + 1, + Epoch: fn.epoch, Prev: prevBlockDigest, }, nil) require.NoError(fn.t, err) @@ -439,7 +436,8 @@ func (fn *fakeNode) prepareMetadataAndPrevBlockDigest() (*simplex.ProtocolMetada require.NoError(fn.t, err) } else { lastMD = &simplex.ProtocolMetadata{ - Prev: lastBlockDigest, + Prev: lastBlockDigest, + Epoch: 1, } } return lastMD, lastBlockDigest diff --git a/msm/misc.go b/msm/misc.go index 62b1630c..c0f13239 100644 --- a/msm/misc.go +++ b/msm/misc.go @@ -5,6 +5,7 @@ package metadata import ( "context" + "errors" "fmt" "math" "math/big" @@ -15,9 +16,11 @@ import ( // but are not imported here to prevent us from importing the entire Avalanchego codebase. // Once we incorporate Simplex into Avalanchego, we can remove this file and import the relevant code from Avalanchego instead. +var errOverflow = errors.New("overflow") + func safeAdd(a, b uint64) (uint64, error) { if a > math.MaxUint64-b { - return 0, fmt.Errorf("overflow: %d + %d > maxuint64", a, b) + return 0, fmt.Errorf("%w: %d + %d > maxuint64", errOverflow, a, b) } return a + b, nil } @@ -45,11 +48,9 @@ type VMBlock interface { // // If nil is returned, it is guaranteed that either Accept or Reject will be // called on this block, unless the VM is shut down. - Verify(context.Context) error + Verify(ctx context.Context, pChainHeight uint64) error } -type UpgradeConfig = any - type bitmask big.Int func (bm *bitmask) Bytes() []byte { diff --git a/msm/misc_test.go b/msm/misc_test.go index b78d2cd3..60e31f61 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -4,19 +4,9 @@ package metadata import ( - "bytes" - "context" - "crypto/rand" - "crypto/sha256" - "encoding/asn1" - "fmt" - "maps" "math" "testing" - "time" - "github.com/ava-labs/simplex" - "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -25,7 +15,7 @@ func TestSafeAdd(t *testing.T) { name string a, b uint64 sum uint64 - err string + err error }{ { name: "zero plus zero", @@ -50,12 +40,12 @@ func TestSafeAdd(t *testing.T) { { name: "overflow by one", a: math.MaxUint64, b: 1, - err: "overflow", + err: errOverflow, }, { name: "overflow both large", a: math.MaxUint64 - 5, b: 10, - err: "overflow", + err: errOverflow, }, { name: "max uint64 boundary no overflow", @@ -65,8 +55,8 @@ func TestSafeAdd(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { result, err := safeAdd(tc.a, tc.b) - if tc.err != "" { - require.ErrorContains(t, err, tc.err) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) } else { require.NoError(t, err) require.Equal(t, tc.sum, result) @@ -152,345 +142,3 @@ func TestBitmask(t *testing.T) { require.False(t, cloned.Contains(7)) }) } - -// Test helpers - -type InnerBlock struct { - TS time.Time - BlockHeight uint64 - Bytes []byte -} - -func (i *InnerBlock) Digest() [32]byte { - return sha256.Sum256(i.Bytes) -} - -func (i *InnerBlock) Height() uint64 { - return i.BlockHeight -} - -func (i *InnerBlock) Timestamp() time.Time { - return i.TS -} - -func (i *InnerBlock) Verify(_ context.Context) error { - return nil -} - -// fakeVMBlock is a minimal VMBlock implementation for tests. -type fakeVMBlock struct { - height uint64 -} - -func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } -func (f *fakeVMBlock) Height() uint64 { return f.height } -func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } -func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } - -type outerBlock struct { - finalization *simplex.Finalization - block StateMachineBlock -} - -type blockStore map[uint64]*outerBlock - -func (bs blockStore) clone() blockStore { - newStore := make(blockStore) - maps.Copy(newStore, bs) - return newStore -} - -func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, exits := bs[seq] - if !exits { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) - } - return blk.block, blk.finalization, nil -} - -type approvalsRetriever struct { - result ValidatorSetApprovals -} - -func (a approvalsRetriever) Approvals() ValidatorSetApprovals { - return a.result -} - -type signatureVerifier struct { - err error -} - -func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { - return sv.err -} - -type signatureAggregator struct { - weightByNodeID map[string]uint64 - totalWeight uint64 -} - -type aggregatrdSignature struct { - Signatures [][]byte -} - -func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { - all := make([][]byte, 0, len(sigs)+1) - all = append(all, sigs...) - if len(existing) > 0 { - all = append(all, existing) - } - return asn1.Marshal(aggregatrdSignature{Signatures: all}) -} - -func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { - var sum uint64 - for _, signer := range signers { - sum += sv.weightByNodeID[string(signer)] - } - return sum*3 > sv.totalWeight*2 -} - -func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { - return func(weights []simplex.Node) simplex.SignatureAggregator { - s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} - for _, nw := range weights { - s.weightByNodeID[string(nw.Node)] = nw.Weight - s.totalWeight += nw.Weight - } - return s - } -} - -type noOpPChainListener struct{} - -func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { - <-ctx.Done() - return ctx.Err() -} - -type blockBuilder struct { - block VMBlock - err error -} - -func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { - // Block is always ready in tests. -} - -func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { - return bb.block, bb.err -} - -type validatorSetRetriever struct { - result NodeBLSMappings - resultMap map[uint64]NodeBLSMappings - err error -} - -func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { - if vsr.resultMap != nil { - if result, ok := vsr.resultMap[height]; ok { - return result, vsr.err - } - } - return vsr.result, vsr.err -} - -type keyAggregator struct{} - -func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - aggregated := make([]byte, 0) - for _, key := range keys { - aggregated = append(aggregated, key...) - } - return aggregated, nil -} - -var ( - genesisBlock = StateMachineBlock{ - // Genesis block metadata has all zero values - InnerBlock: &InnerBlock{ - TS: time.Now(), - Bytes: []byte{1, 2, 3}, - }, - } -) - -type dynamicApprovalsRetriever struct { - approvals *ValidatorSetApprovals -} - -func (d *dynamicApprovalsRetriever) Approvals() ValidatorSetApprovals { - return *d.approvals -} - -func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { - startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) - blocks := make([]StateMachineBlock, 0, endHeight+1) - var round, seq uint64 - for h := uint64(0); h <= endHeight; h++ { - index := len(blocks) - - if h == 0 { - blocks = append(blocks, genesisBlock) - continue - } - - if h < simplexStartHeight { - blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) - continue - } - - seq = uint64(index) - - blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) - round++ - } - return blocks -} - -func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { - content := make([]byte, 10) - _, err := rand.Read(content) - require.NoError(t, err) - - prev := genesisBlock.Digest() - if index > 0 { - prev = blocks[index-1].Digest() - } - - return StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: start.Add(time.Duration(h) * time.Second), - BlockHeight: h, - Bytes: []byte{1, 2, 3}, - }, - Metadata: StateMachineMetadata{ - PChainHeight: 100, - SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ - Round: round, - Seq: seq, - Epoch: 1, - Prev: prev, - }).Bytes(), - SimplexEpochInfo: SimplexEpochInfo{ - PrevSealingBlockHash: [32]byte{}, - PChainReferenceHeight: 100, - EpochNumber: 1, - PrevVMBlockSeq: uint64(index), - }, - }, - } -} - -func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { - content := make([]byte, 10) - _, err := rand.Read(content) - require.NoError(t, err) - - return StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: start.Add(time.Duration(h-startHeight) * time.Second), - BlockHeight: h, - Bytes: []byte{1, 2, 3}, - }, - } -} - -type testConfig struct { - blockStore blockStore - approvalsRetriever approvalsRetriever - signatureVerifier signatureVerifier - signatureAggregator signatureAggregator - blockBuilder blockBuilder - keyAggregator keyAggregator - validatorSetRetriever validatorSetRetriever -} - -func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { - bs := make(blockStore) - bs[0] = &outerBlock{block: genesisBlock} - - var testConfig testConfig - testConfig.blockStore = bs - testConfig.validatorSetRetriever.result = NodeBLSMappings{ - {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, - } - - smConfig := Config{ - GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, - LastNonSimplexBlockPChainHeight: 100, - FirstEverSimplexBlock: func() *StateMachineBlock { - var res *StateMachineBlock - min := uint64(math.MaxUint64) - for seq, block := range testConfig.blockStore { - if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { - continue - } - if seq < min { - min = seq - res = &block.block - } - } - return res - }, - GetTime: time.Now, - TimeSkewLimit: time.Second * 5, - Logger: testutil.MakeLogger(t), - GetBlock: testConfig.blockStore.getBlock, - MaxBlockBuildingWaitTime: time.Second, - ApprovalsRetriever: &testConfig.approvalsRetriever, - SignatureVerifier: &testConfig.signatureVerifier, - SignatureAggregatorCreator: newSignatureAggregatorCreator(), - BlockBuilder: &testConfig.blockBuilder, - KeyAggregator: &testConfig.keyAggregator, - GetPChainHeight: func() uint64 { - return 100 - }, - GetUpgrades: func() any { - return nil - }, - GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, - PChainProgressListener: &noOpPChainListener{}, - LastNonSimplexInnerBlock: genesisBlock.InnerBlock, - } - - sm, err := NewStateMachine(&smConfig) - require.NoError(t, err) - - return sm, &testConfig -} - -// concatAggregator concatenates signatures for easy verification in tests. -type concatAggregator struct{} - -func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { - result := bytes.Join(sigs, nil) - return append(result, existing...), nil -} - -func (concatAggregator) IsQuorum([]simplex.NodeID) bool { - return false -} - -type failingAggregator struct{} - -func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { - return nil, fmt.Errorf("aggregation failed") -} - -func (failingAggregator) IsQuorum([]simplex.NodeID) bool { - return false -} diff --git a/msm/msm.go b/msm/msm.go index 2b9c548b..8bb51cd9 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -6,13 +6,53 @@ package metadata import ( "context" "crypto/sha256" + "encoding/asn1" + "encoding/binary" + "errors" "fmt" + "math" "time" "github.com/ava-labs/simplex" "go.uber.org/zap" ) +const ( + maxSkew = 10 * time.Second +) + +var ( + errLastNonSimplexInnerBlockNil = errors.New("failed constructing zero block: last non-Simplex inner block is nil") + errInvalidProtocolMetadataSeq = errors.New("invalid ProtocolMetadata sequence number: should be > 0") + errInvalidProtocolMetadataEpoch = errors.New("invalid ProtocolMetadata epoch number") + errUnknownState = errors.New("unknown state") + errBuiltGenesisInnerBlock = errors.New("received a genesis block") + errZeroBlockParentNoInnerBlock = errors.New("zero block's parent has no inner block") + errNilBlock = errors.New("block is nil") + errInvalidPChainHeight = errors.New("invalid P-chain height") + errInvalidSimplexEpochInfo = errors.New("invalid SimplexEpochInfo") + errZeroBlockHasInnerBlock = errors.New("zero block must not have an inner block") + errZeroBlockInnerDigestMismatch = errors.New("zero block inner block digest does not match last non-Simplex inner block digest") + errZeroBlockTimestampMismatch = errors.New("zero block timestamp does not match last non-Simplex inner block timestamp") + errPrevSealingBlockNotFinalized = errors.New("previous sealing InnerBlock is not finalized") + errBlockDigestMismatch = errors.New("does not match proposed block digest") + errSealingBlockSeqUnset = errors.New("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + errEmptyNextEpochApprovals = errors.New("next epoch approvals are empty") + errPChainReferenceHeightMismatch = errors.New("unexpected P-chain reference height") + errPChainReferenceHeightDecreased = errors.New("P-chain reference height is decreasing") + errValidatorSetUnchanged = errors.New("validator set unchanged; next P-chain reference height should not have advanced") + errPChainHeightNotReached = errors.New("haven't reached referenced P-chain height yet") + errPChainHeightTooBig = errors.New("invalid P-chain height: greater than current") + errPChainHeightSmallerThanParent = errors.New("invalid P-chain height: smaller than parent block's") + errSignerSetShrunk = errors.New("some signers from parent block are missing from next epoch approvals of proposed block") + errNextEpochApprovalsShrunk = errors.New("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") + errTimestampTooBig = errors.New("invalid timestamp: exceeds maximum int64 value") + errTimestampDecreasing = errors.New("invalid timestamp: proposed timestamp is before parent block's timestamp") + errTimestampTooFarInFuture = errors.New("invalid timestamp: proposed timestamp is too far in the future compared to current time") + + signatureContext = "MSM approval" +) + // A StateMachineBlock is a representation of a parsed OuterBlock, containing the inner block and the metadata. type StateMachineBlock struct { // InnerBlock is the VM-level block, or nil if this is a block without an inner block (e.g., a Telock block). @@ -36,6 +76,41 @@ func (smb *StateMachineBlock) Digest() [32]byte { return sha256.Sum256(combined) } +// ICMEpoch defines the ICM epoch information that is maintained by the StateMachine and used for the ICM protocol. +// The Statemachine maintains this information identically to how the proposerVM maintains this information, +// and it does so by building the ICMEpochInput and then passing it into the StateMachine's ComputeICMEpoch function. +type ICMEpoch struct { + // EpochStartTime is the Unix timestamp when this ICM epoch started. + EpochStartTime uint64 + // EpochNumber is the sequential identifier of this ICM epoch. + EpochNumber uint64 + // PChainEpochHeight is the P-chain height associated with this ICM epoch. + PChainEpochHeight uint64 +} + +func (icme ICMEpoch) ToICMEpochInfo() ICMEpochInfo { + return ICMEpochInfo{ + PChainEpochHeight: icme.PChainEpochHeight, + EpochNumber: icme.EpochNumber, + EpochStartTime: icme.EpochStartTime, + } +} + +// ICMEpochInput defines the input for computing the ICM Epoch information for the next block. +type ICMEpochInput struct { + // ParentPChainHeight is the P-chain height recorded in the parent block. + ParentPChainHeight uint64 + // ParentTimestamp is the timestamp of the parent block. + ParentTimestamp time.Time + // ChildTimestamp is the timestamp of the block being built. + ChildTimestamp time.Time + // ParentEpoch is the ICM epoch information from the parent block. + ParentEpoch ICMEpoch +} + +// ICMEpochTransition computes the next ICM epoch given the current upgrade configuration and epoch input. +type ICMEpochTransition func(ICMEpochInput) ICMEpoch + // ApprovalsRetriever retrieves the approvals from validators of the next epoch for the epoch change. type ApprovalsRetriever interface { Approvals() ValidatorSetApprovals @@ -69,26 +144,8 @@ type BlockBuilder interface { WaitForPendingBlock(ctx context.Context) } -type verificationInput struct { - prevMD StateMachineMetadata - proposedBlockMD StateMachineMetadata - hasInnerBlock bool - innerBlockTimestamp time.Time // only set when hasInnerBlock is true - prevBlockSeq uint64 - nextBlockType BlockType - state state -} - -type verifier interface { - Verify(in verificationInput) error -} - // StateMachine manages block building and verification across epoch transitions. type StateMachine struct { - // verifiers is the list of verifiers used to verify proposed blocks. - // Each verifier is responsible for verifying a specific aspect of the block's metadata. - verifiers []verifier - *Config } @@ -105,8 +162,6 @@ type Config struct { GetTime func() time.Time // GetPChainHeight returns the latest known P-chain height. GetPChainHeight func() uint64 - // GetUpgrades returns the current upgrade configuration. - GetUpgrades func() UpgradeConfig // BlockBuilder builds new VM blocks. BlockBuilder BlockBuilder // Logger is used for logging state machine operations. @@ -125,8 +180,6 @@ type Config struct { SignatureVerifier SignatureVerifier // PChainProgressListener listens for changes in the P-chain height to trigger block building or epoch transitions. PChainProgressListener PChainProgressListener - // FirstEverSimplexBlock is the first block ever built by Simplex, or nil if Simplex has yet to build a block. - FirstEverSimplexBlock func() *StateMachineBlock // LastNonSimplexBlockPChainHeight is the P-chain height of the last block built by a non-Simplex proposer. // It is used to determine the validator set of the first ever Simplex epoch. LastNonSimplexBlockPChainHeight uint64 @@ -134,6 +187,8 @@ type Config struct { LastNonSimplexInnerBlock VMBlock // GenesisValidatorSet is the validator set used for the genesis block. GenesisValidatorSet NodeBLSMappings + // ComputeICMEpoch computes the ICM epoch information in order to know which P-chain height to encode. + ComputeICMEpoch ICMEpochTransition } type state uint8 @@ -148,7 +203,7 @@ const ( func NewStateMachine(config *Config) (*StateMachine, error) { if config.LastNonSimplexInnerBlock == nil { config.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") - return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") + return nil, errLastNonSimplexInnerBlockNil } sm := StateMachine{Config: config} return &sm, nil @@ -158,7 +213,7 @@ func NewStateMachine(config *Config) (*StateMachine, error) { func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.ProtocolMetadata, blacklist *simplex.Blacklist) (*StateMachineBlock, error) { // The zero sequence number is reserved for the genesis block, which should never be built. if metadata.Seq == 0 { - return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", metadata.Seq) + return nil, fmt.Errorf("%w: got %d", errInvalidProtocolMetadataSeq, metadata.Seq) } prevBlockSeq := metadata.Seq - 1 @@ -206,7 +261,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco case stateBuildBlockEpochSealed: return sm.buildBlockEpochSealed(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) default: - return nil, fmt.Errorf("unknown state %d", currentState) + return nil, fmt.Errorf("%w: %d", errUnknownState, currentState) } } @@ -214,7 +269,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco // and inner block against the previous block and the current state. func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { if block == nil { - return fmt.Errorf("InnerBlock is nil") + return errNilBlock } pmd, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) @@ -225,7 +280,9 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc seq := pmd.Seq if seq == 0 { - return fmt.Errorf("attempted to build a genesis inner block") + // This shouldn't happen, but in case we're asked to verify a block with a sequence of 0, + // we should reject it, because the zero sequence number is reserved for the genesis block, which should never be proposed. + return errBuiltGenesisInnerBlock } prevBlock, _, err := sm.GetBlock(seq-1, pmd.Prev) @@ -240,49 +297,70 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc case stateFirstSimplexBlock: err = sm.verifyBlockZero(block, prevBlock) default: - err = sm.verifyNonZeroBlock(ctx, block, prevBlock.Metadata, currentState, seq-1) + err = sm.verifyNonZeroBlock(ctx, block, &prevBlock, seq-1) } return err } -func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMachineBlock, prevBlockMD StateMachineMetadata, state state, prevSeq uint64) error { - blockType := IdentifyBlockType(block.Metadata, prevBlockMD, prevSeq) - sm.Logger.Debug("Identified block type", - zap.Stringer("blockType", blockType), - zap.Bool("nextHasBVD", block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil), - zap.Uint64("nextEpochNumber", block.Metadata.SimplexEpochInfo.EpochNumber), - zap.Bool("prevHasBVD", prevBlockMD.SimplexEpochInfo.BlockValidationDescriptor != nil), - zap.Uint64("prevEpochNumber", prevBlockMD.SimplexEpochInfo.EpochNumber), - zap.Uint64("prevNextPChainRefHeight", prevBlockMD.SimplexEpochInfo.NextPChainReferenceHeight), - zap.Uint64("prevSealingBlockSeq", prevBlockMD.SimplexEpochInfo.SealingBlockSeq), - zap.Uint64("prevSeq", prevSeq), - ) - - var innerBlockTimestamp time.Time - if block.InnerBlock != nil { - innerBlockTimestamp = block.InnerBlock.Timestamp() - } - - for _, verifier := range sm.verifiers { - if err := verifier.Verify(verificationInput{ - proposedBlockMD: block.Metadata, - nextBlockType: blockType, - prevMD: prevBlockMD, - state: state, - prevBlockSeq: prevSeq, - hasInnerBlock: block.InnerBlock != nil, - innerBlockTimestamp: innerBlockTimestamp, - }); err != nil { - sm.Logger.Debug("Invalid block", zap.Error(err)) - return err - } +func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block, prevBlock *StateMachineBlock, prevSeq uint64) error { + prevBlockMD := prevBlock.Metadata + currentState := prevBlockMD.SimplexEpochInfo.NextState() + + if err := verifyTimestamp(block, prevBlock, sm.GetTime()); err != nil { + return fmt.Errorf("failed to verify timestamp: %w", err) } - if block.InnerBlock == nil { - return nil + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := prevBlockMD.PChainHeight + proposedPChainHeight := block.Metadata.PChainHeight + + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + err := sm.verifyEpochNumber(block) + if err != nil { + return err + } + + switch currentState { + case stateBuildBlockNormalOp: + return sm.verifyNormalBlock(ctx, *prevBlock, block, prevSeq) + case stateBuildCollectingApprovals: + return sm.verifyCollectingApprovalsBlock(ctx, *prevBlock, block, prevSeq) + case stateBuildBlockEpochSealed: + return sm.verifyBlockEpochSealed(ctx, *prevBlock, block, prevSeq) + default: + return fmt.Errorf("%w: %d", errUnknownState, currentState) + } +} + +func verifyTimestamp(block *StateMachineBlock, prevBlock *StateMachineBlock, now time.Time) error { + if block.Metadata.Timestamp > math.MaxInt64 { + return fmt.Errorf("%w: timestamp %d exceeds maximum int64 value", errTimestampTooBig, block.Metadata.Timestamp) + } + + if block.Metadata.Timestamp < prevBlock.Metadata.Timestamp { + return fmt.Errorf("%w: proposed %d < parent %d", errTimestampDecreasing, block.Metadata.Timestamp, prevBlock.Metadata.Timestamp) } - return block.InnerBlock.Verify(ctx) + proposedTime := time.UnixMilli(int64(block.Metadata.Timestamp)) + + if now.Add(maxSkew).Before(proposedTime) { + return fmt.Errorf("%w: proposed timestamp %v, max skew: %v", errTimestampTooFarInFuture, proposedTime, maxSkew) + } + return nil +} + +func (sm *StateMachine) verifyEpochNumber(block *StateMachineBlock) error { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + if err != nil { + return fmt.Errorf("failed to parse ProtocolMetadata: %w", err) + } + if md.Epoch != block.Metadata.SimplexEpochInfo.EpochNumber { + return fmt.Errorf("%w: got %d, expected %d", errInvalidProtocolMetadataEpoch, md.Epoch, block.Metadata.SimplexEpochInfo.EpochNumber) + } + return nil } // buildBlockNormalOp builds a block while potentially also transitioning to a new epoch, depending on the P-chain. @@ -300,6 +378,19 @@ func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock Stat // buildBlockOrTransitionEpoch builds a block and decides whether to transition to a new epoch based on the P-chain height and validator set changes. func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, newSimplexEpochInfo SimplexEpochInfo) (*StateMachineBlock, error) { + var isSealingBlockFinalized bool + sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.EpochNumber + _, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return nil, fmt.Errorf("failed to retrieve sealing block for previous epoch (%d): %w", sealingBlockSeq, err) + } + if finalization != nil { + isSealingBlockFinalized = true + } else { + sm.Logger.Debug("Previous sealing block not finalized yet, "+ + "building normal block without epoch transition", zap.Uint64("sealingBlockSeq", sealingBlockSeq)) + } + blockBuildingDecider := sm.createBlockBuildingDecider(newSimplexEpochInfo.PChainReferenceHeight) decisionToBuildBlock, err := blockBuildingDecider.shouldBuildBlock(ctx) if err != nil { @@ -311,22 +402,207 @@ func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentB zap.Bool("transition epoch", decisionToBuildBlock.transitionEpoch), zap.Uint64("P-chain height", decisionToBuildBlock.pChainHeight)) - if decisionToBuildBlock.transitionEpoch { + if decisionToBuildBlock.transitionEpoch && isSealingBlockFinalized { sm.Logger.Debug("Transitioning epoch after building block", zap.Uint64("newPChainRefHeight", decisionToBuildBlock.pChainHeight)) newSimplexEpochInfo.NextPChainReferenceHeight = decisionToBuildBlock.pChainHeight } + now := sm.GetTime() + icmEpochInfo := computeICMEpochInfo(parentBlock, sm.ComputeICMEpoch, now) + var innerBlock VMBlock if decisionToBuildBlock.buildInnerBlock { - // TODO: This P-chain height should be taken from the ICM epoch - innerBlock, err = sm.BlockBuilder.BuildBlock(ctx, decisionToBuildBlock.pChainHeight) + innerBlock, err = sm.BlockBuilder.BuildBlock(ctx, icmEpochInfo.PChainEpochHeight) if err != nil { return nil, err } } - return sm.wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil + return wrapBlock(innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist, now, icmEpochInfo.ToICMEpochInfo()), nil +} + +func computeICMEpochInfo(parentBlock StateMachineBlock, computeICMEpoch ICMEpochTransition, childTimestamp time.Time) ICMEpoch { + parentTimestamp := time.UnixMilli(int64(parentBlock.Metadata.Timestamp)) + + icmEpochInfo := computeICMEpoch(ICMEpochInput{ + ParentPChainHeight: parentBlock.Metadata.PChainHeight, + ParentEpoch: ICMEpoch{ + PChainEpochHeight: parentBlock.Metadata.ICMEpochInfo.PChainEpochHeight, + EpochNumber: parentBlock.Metadata.ICMEpochInfo.EpochNumber, + EpochStartTime: parentBlock.Metadata.ICMEpochInfo.EpochStartTime, + }, + ChildTimestamp: childTimestamp, + ParentTimestamp: parentTimestamp, + }) + return icmEpochInfo +} + +func verifyAgainstExpected( + ctx context.Context, + innerBlock VMBlock, + expectedSimplexEpochInfo SimplexEpochInfo, + expectedPChainHeight uint64, + nextBlock *StateMachineBlock, + timestamp time.Time, + expectedIcmEpochInfo ICMEpochInfo, +) error { + if innerBlock != nil { + if err := innerBlock.Verify(ctx, expectedIcmEpochInfo.PChainEpochHeight); err != nil { + return err + } + } + expectedBlock := wrapBlock( + innerBlock, expectedSimplexEpochInfo, expectedPChainHeight, + nextBlock.Metadata.SimplexProtocolMetadata, nextBlock.Metadata.SimplexBlacklist, timestamp, expectedIcmEpochInfo) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s: %w", + expectedBlock.Digest(), + nextBlock.Digest(), + errBlockDigestMismatch) + } + return nil +} + +func (sm *StateMachine) verifyNormalBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + proposedPChainHeight := nextBlock.Metadata.PChainHeight + + timestamp := time.UnixMilli(int64(nextBlock.Metadata.Timestamp)) + + icmEpochInfo := computeICMEpochInfo(parentBlock, sm.ComputeICMEpoch, timestamp) + + if err := sm.verifyNextPChainRefHeightNormal(parentBlock.Metadata, nextBlock.Metadata.SimplexEpochInfo); err != nil { + return fmt.Errorf("failed to verify next P-chain reference height for normal block: %w", err) + } + newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight + + return verifyAgainstExpected(ctx, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock, timestamp, icmEpochInfo.ToICMEpochInfo()) +} + +func verifyPChainHeight(proposedPChainHeight uint64, currentPChainHeight uint64, prevPChainHeight uint64) error { + if proposedPChainHeight > currentPChainHeight { + return fmt.Errorf("%w: proposed %d, current %d", + errPChainHeightTooBig, proposedPChainHeight, currentPChainHeight) + } + + if prevPChainHeight > proposedPChainHeight { + return fmt.Errorf("%w: proposed %d, parent %d", + errPChainHeightSmallerThanParent, proposedPChainHeight, prevPChainHeight) + } + return nil +} + +func (sm *StateMachine) verifyNextPChainRefHeightNormal(prevMD StateMachineMetadata, next SimplexEpochInfo) error { + prev := prevMD.SimplexEpochInfo + // Next P-chain height can only increase, not decrease. + if next.NextPChainReferenceHeight > 0 && prev.PChainReferenceHeight > next.NextPChainReferenceHeight { + return fmt.Errorf("%w: previous P-chain reference height is %d and the proposed P-chain reference height is %d", errPChainReferenceHeightDecreased, prev.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // If the previous block already has a next P-chain reference height, + // we should keep the same next P-chain reference height until we reach it. + if prev.NextPChainReferenceHeight > 0 { + if next.NextPChainReferenceHeight != prev.NextPChainReferenceHeight { + return fmt.Errorf("%w: expected %d but got %d", errPChainReferenceHeightMismatch, prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) + } + return nil + } + + // If we reached here, then prev.NextPChainReferenceHeight == 0. + // If the previous block's next P-chain reference height is 0, and the new block's next P-chain reference height is > 0, + // we need to ensure that we have finalized the sealing block of the previous epoch. + if next.NextPChainReferenceHeight > 0 { + sealingBlockSeq := prev.EpochNumber + _, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return fmt.Errorf("failed to retrieve sealing block for previous epoch (%d): %w", sealingBlockSeq, err) + } + if finalization == nil { + return fmt.Errorf("%w: sealing block sequence %d", errPrevSealingBlockNotFinalized, sealingBlockSeq) + } + } + + // Make sure we have reached the next P-chain reference height, otherwise we won't be able to validate it. + pChainHeight := sm.GetPChainHeight() + + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + + // It might be that this block is the first block that has set the next P-chain reference height for the epoch, + // so check if it has done so correctly by observing whether the validator set has indeed changed. + + currentValidatorSet, err := sm.GetValidatorSet(prevMD.SimplexEpochInfo.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := sm.GetValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + // If the validator set doesn't change, we shouldn't have increased the next P-chain reference height. + if currentValidatorSet.Equal(newValidatorSet) && next.NextPChainReferenceHeight > 0 { + return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches previous block's P-chain reference height %d", + errValidatorSetUnchanged, next.NextPChainReferenceHeight, prev.PChainReferenceHeight) + } + + // Else, either the validator set has changed, or the next P-chain reference height is still 0. + // Both of these cases are fine. + + return nil +} + +// verifyNextPChainRefHeightForNewEpoch validates the proposed NextPChainReferenceHeight on the +// first block of a new epoch. +// This handles a corner case where the first block of an epoch initiates an epoch transition. +// We cannot reuse verifyNextPChainRefHeightNormal here — the baseline +// for the validator-set change check is the new epoch's PChainReferenceHeight, not the parent's, +// as in verifyNextPChainRefHeightNormal. +func (sm *StateMachine) verifyNextPChainRefHeightForNewEpoch(expectedEpochInfo SimplexEpochInfo, next SimplexEpochInfo) error { + // The first block of the epoch doesn't trigger an epoch change, we're all set. + if next.NextPChainReferenceHeight == 0 { + return nil + } + + // Next P-chain reference height cannot be smaller than the P-chain reference height, + // as the P-chain reference height itself cannot decrease, and the next P-chain reference height + // becomes the P-chain reference height when we change epochs. + if next.NextPChainReferenceHeight < expectedEpochInfo.PChainReferenceHeight { + return fmt.Errorf("%w: new epoch P-chain reference height is %d and the proposed next P-chain reference height is %d", + errPChainReferenceHeightDecreased, expectedEpochInfo.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // If we haven't reached this P-chain height yet, we cannot accept the next P-chain reference height, + // because there is no way of querying the validator set for the next P-chain reference height. + pChainHeight := sm.GetPChainHeight() + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + + currentValidatorSet, err := sm.GetValidatorSet(expectedEpochInfo.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := sm.GetValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + if currentValidatorSet.Equal(newValidatorSet) { + return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches new epoch's P-chain reference height %d", + errValidatorSetUnchanged, next.NextPChainReferenceHeight, expectedEpochInfo.PChainReferenceHeight) + } + + return nil } func (sm *StateMachine) createBlockBuildingDecider(pChainReferenceHeight uint64) blockBuildingDecider { @@ -389,9 +665,11 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet // We can only have blocks without inner blocks in Simplex blocks, but this is the first Simplex block. // Therefore, the parent block must have an inner block. sm.Logger.Error("Parent block has no inner block, cannot determine previous VM block sequence for zero block") - return nil, fmt.Errorf("failed constructing zero block: parent block has no inner block") + return nil, errZeroBlockParentNoInnerBlock } + // For the zero block, we set the timestamp to be the same as the last non-Simplex inner block's timestamp. + // We do it because we need to carry over a minimum timestamp from the non-Simplex blocks. timestamp := sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli() simplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, validatorSet, prevVMBlockSeq) @@ -415,29 +693,25 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock StateMachineBlock) error { if block == nil { - return fmt.Errorf("block is nil") + return errNilBlock } simplexEpochInfo := block.Metadata.SimplexEpochInfo - if simplexEpochInfo.EpochNumber != 1 { - return fmt.Errorf("invalid epoch number (%d), should be 1", simplexEpochInfo.EpochNumber) - } - if prevBlock.InnerBlock == nil { - return fmt.Errorf("parent inner block (%s) has no inner block", prevBlock.Digest()) + return fmt.Errorf("%w: parent digest %s", errZeroBlockParentNoInnerBlock, prevBlock.Digest()) } pChainHeight := sm.LastNonSimplexBlockPChainHeight prevVMBlockSeq := prevBlock.InnerBlock.Height() if block.Metadata.PChainHeight != pChainHeight { - return fmt.Errorf("invalid P-chain height (%d), expected to be %d", - block.Metadata.PChainHeight, pChainHeight) + return fmt.Errorf("%w: got %d, expected %d", + errInvalidPChainHeight, block.Metadata.PChainHeight, pChainHeight) } var expectedValidatorSet NodeBLSMappings - if prevBlock.InnerBlock.Height() == 0 { + if prevVMBlockSeq == 0 { expectedValidatorSet = sm.GenesisValidatorSet } else { var err error @@ -447,40 +721,156 @@ func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock Stat } } - if simplexEpochInfo.BlockValidationDescriptor == nil { - return fmt.Errorf("invalid BlockValidationDescriptor: should not be nil") - } - - membership := simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members - if !NodeBLSMappings(membership).Equal(expectedValidatorSet) { - return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", pChainHeight) - } - // If we have compared all fields so far, the rest of the fields we compare by constructing an explicit expected SimplexEpochInfo expectedSimplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, expectedValidatorSet, prevVMBlockSeq) if !expectedSimplexEpochInfo.Equal(&simplexEpochInfo) { - return fmt.Errorf("invalid SimplexEpochInfo: expected %v, got %v", expectedSimplexEpochInfo, simplexEpochInfo) + return fmt.Errorf("%w: expected %v, got %v", errInvalidSimplexEpochInfo, expectedSimplexEpochInfo, simplexEpochInfo) } // The InnerBlock must match the last non-Simplex inner block. if block.InnerBlock != nil { - return fmt.Errorf("zero block must not have an inner block") + return errZeroBlockHasInnerBlock } if prevBlock.InnerBlock.Digest() != sm.LastNonSimplexInnerBlock.Digest() { - return fmt.Errorf("zero block inner block digest does not match last non-Simplex inner block digest") + return errZeroBlockInnerDigestMismatch } // The timestamp must equal the last non-Simplex inner block's timestamp. + // We do it because we need to carry over a minimum timestamp from the non-Simplex blocks. expectedTimestamp := uint64(sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli()) if block.Metadata.Timestamp != expectedTimestamp { - return fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, block.Metadata.Timestamp) + return fmt.Errorf("%w: expected %d but got %d", errZeroBlockTimestampMismatch, expectedTimestamp, block.Metadata.Timestamp) } return nil } func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + newApprovals, err := sm.computeNewApprovals(parentBlock) + if err != nil { + return nil, err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock, prevBlockSeq, newApprovals) + + pChainHeight := parentBlock.Metadata.PChainHeight + + now := sm.GetTime() + icmEpochInfo := computeICMEpochInfo(parentBlock, sm.ComputeICMEpoch, now) + + // We might not have enough approvals to seal the current epoch, + // in which case we just carry over the approvals we have so far to the next block, + // so that eventually we'll have enough approvals to seal the epoch. + if !newApprovals.canSeal { + sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") + return sm.buildBlockImpatiently(ctx, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight, icmEpochInfo) + } + + sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") + + // Else, we have enough approvals to seal the epoch, so we create the sealing block. + return sm.createSealingBlock(ctx, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight, icmEpochInfo) +} + +func (sm *StateMachine) verifyCollectingApprovalsBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + nextMD := nextBlock.Metadata + newApprovals := nextMD.SimplexEpochInfo.NextEpochApprovals + + // The block builder should at least include its own approval in the block it builds, + // so we should have some approvals in the proposed block. + if newApprovals == nil || len(newApprovals.NodeIDs) == 0 || len(newApprovals.Signature) == 0 { + return errEmptyNextEpochApprovals + } + + prevEpochInfo := parentBlock.Metadata.SimplexEpochInfo + nextEpochInfo := nextBlock.Metadata.SimplexEpochInfo + + validators, err := sm.GetValidatorSet(prevEpochInfo.NextPChainReferenceHeight) + if err != nil { + return err + } + + err = sm.verifyNextEpochApprovalsSignature(prevEpochInfo, nextEpochInfo, validators) + if err != nil { + return err + } + + // A node cannot remove other nodes' approvals, only add its own approval if it wasn't included in the previous block. + // So the set of signers in next.NextEpochApprovals should be a superset of the set of signers in prev.NextEpochApprovals. + if err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prevEpochInfo, nextEpochInfo); err != nil { + return err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock, prevBlockSeq, &approvals{ + nodeIDs: newApprovals.NodeIDs, + signature: newApprovals.Signature, + }) + + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + approvals := bitmaskFromBytes(newApprovals.NodeIDs) + canSeal := sigAggr.IsQuorum(validators.SelectSubset(approvals)) + + // TODO: P-chain height should be taken from the ICM epoch. For now we pass the block proposer's P-chain height. + + if canSeal { + newSimplexEpochInfo, err = sm.computeSimplexEpochInfoForSealingBlock(newSimplexEpochInfo) + if err != nil { + return fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) + } + } + + timestamp := time.UnixMilli(int64(nextMD.Timestamp)) + + icmEpochInfo := computeICMEpochInfo(parentBlock, sm.ComputeICMEpoch, timestamp) + + return verifyAgainstExpected(ctx, nextBlock.InnerBlock, newSimplexEpochInfo, nextMD.PChainHeight, nextBlock, timestamp, icmEpochInfo.ToICMEpochInfo()) +} + +func (sm *StateMachine) verifyNextEpochApprovalsSignature(prev SimplexEpochInfo, next SimplexEpochInfo, validators NodeBLSMappings) error { + // First figure out which validators are approving the next epoch by looking at the bitmask of approving nodes, + // and then aggregate their public keys together to verify the signature. + + nodeIDsBitmask := next.NextEpochApprovals.NodeIDs + aggPK, err := sm.aggregatePubKeysForBitmask(nodeIDsBitmask, validators) + if err != nil { + return err + } + + pChainHeight := prev.NextPChainReferenceHeight + pChainHeightBuff := make([]byte, 8) + binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) + + signedMsg := simplex.SignedMessage{Payload: pChainHeightBuff, Context: signatureContext} + toBeSigned, err := asn1.Marshal(signedMsg) + if err != nil { + return err + } + + if err := sm.SignatureVerifier.VerifySignature(next.NextEpochApprovals.Signature, toBeSigned, aggPK); err != nil { + return fmt.Errorf("failed to verify signature: %w", err) + } + return nil +} + +func (sm *StateMachine) aggregatePubKeysForBitmask(nodeIDsBitmask []byte, validators NodeBLSMappings) ([]byte, error) { + approvingNodes := bitmaskFromBytes(nodeIDsBitmask) + publicKeys := make([][]byte, 0, len(validators)) + for i := range validators { + if !approvingNodes.Contains(i) { + continue + } + publicKeys = append(publicKeys, validators[i].BLSKey) + } + + aggPK, err := sm.KeyAggregator.AggregateKeys(publicKeys...) + if err != nil { + return nil, fmt.Errorf("failed to aggregate public keys: %w", err) + } + return aggPK, nil +} + +func computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock StateMachineBlock, prevBlockSeq uint64, newApprovals *approvals) SimplexEpochInfo { // The P-chain reference height and epoch number should remain the same until we transition to the new epoch. // The next P-chain reference height should have been set in the previous block, // which is the reason why we are collecting approvals in the first place. @@ -491,6 +881,18 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } + // This might be the first time we created approvals for the next epoch, + // so we need to initialize the NextEpochApprovals. + if newSimplexEpochInfo.NextEpochApprovals == nil { + newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} + } + // The node IDs and signature are aggregated across all past and present approvals. + newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs + newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature + return newSimplexEpochInfo +} + +func (sm *StateMachine) computeNewApprovals(parentBlock StateMachineBlock) (*approvals, error) { // We prepare information that is needed to compute the approvals for the new epoch, // such as the validator set for the next epoch, and the approvals from peers. validators, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight) @@ -498,56 +900,33 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren return nil, err } + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + // We retrieve approvals that validators have sent us for the next epoch. // These approvals are signed by validators of the next epoch. approvalsFromPeers := sm.ApprovalsRetriever.Approvals() sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) - nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight + nextPChainHeight := parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals - sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) - newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators, sm.Logger) if err != nil { return nil, err } - - // This might be the first time we created approvals for the next epoch, - // so we need to initialize the NextEpochApprovals. - if newSimplexEpochInfo.NextEpochApprovals == nil { - newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} - } - // The node IDs and signature are aggregated across all past and present approvals. - newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs - newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature - pChainHeight := parentBlock.Metadata.PChainHeight - - // We might not have enough approvals to seal the current epoch, - // in which case we just carry over the approvals we have so far to the next block, - // so that eventually we'll have enough approvals to seal the epoch. - if !newApprovals.canSeal { - sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) - } - - sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") - - // Else, we have enough approvals to seal the epoch, so we create the sealing block. - return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + return newApprovals, nil } // buildBlockImpatiently builds a block by waiting for the VM to build a block until MaxBlockBuildingWaitTime. // If the VM fails to build a block within that time, we build a block without an inner block, // so that we can continue making progress and not get stuck waiting for the VM. -func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { +func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64, icmEpochInfo ICMEpoch) (*StateMachineBlock, error) { impatientContext, cancel := context.WithTimeout(ctx, sm.MaxBlockBuildingWaitTime) defer cancel() - start := time.Now() + start := sm.GetTime() - // TODO: This P-chain height should be taken from the ICM epoch - childBlock, err := sm.BlockBuilder.BuildBlock(impatientContext, pChainHeight) + innerBlock, err := sm.BlockBuilder.BuildBlock(impatientContext, icmEpochInfo.PChainEpochHeight) if err != nil && impatientContext.Err() == nil { // If we got an error building the block, and we didn't time out, log the error but continue building the block without the inner block, // so that we can continue making progress and not get stuck on a single block. @@ -557,71 +936,72 @@ func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock S sm.Logger.Debug("Timed out waiting for block to be built, building block without inner block instead", zap.Duration("elapsed", time.Since(start)), zap.Duration("maxBlockBuildingWaitTime", sm.MaxBlockBuildingWaitTime)) } - return sm.wrapBlock(parentBlock, childBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + + now := sm.GetTime() + return wrapBlock(innerBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist, now, icmEpochInfo.ToICMEpochInfo()), nil +} + +func (sm *StateMachine) createSealingBlock(ctx context.Context, + simplexMetadata []byte, + simplexBlacklist []byte, + simplexEpochInfo SimplexEpochInfo, + pChainHeight uint64, + icmEpochInfo ICMEpoch) (*StateMachineBlock, error) { + simplexEpochInfo, err := sm.computeSimplexEpochInfoForSealingBlock(simplexEpochInfo) + if err != nil { + return nil, fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) + } + return sm.buildBlockImpatiently(ctx, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight, icmEpochInfo) } -func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { +func (sm *StateMachine) computeSimplexEpochInfoForSealingBlock(simplexEpochInfo SimplexEpochInfo) (SimplexEpochInfo, error) { validators, err := sm.GetValidatorSet(simplexEpochInfo.NextPChainReferenceHeight) if err != nil { - return nil, err + return SimplexEpochInfo{}, err } if simplexEpochInfo.BlockValidationDescriptor == nil { simplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} } simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members = validators - // If this is not the first epoch, and this is the sealing block, we set the hash of the previous sealing block. - if simplexEpochInfo.EpochNumber > 1 { - prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) - if err != nil { - sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) - return nil, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) - } - if finalization == nil { - sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) - return nil, fmt.Errorf("previous sealing InnerBlock at epoch %d is not finalized", simplexEpochInfo.EpochNumber-1) - } - simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() - } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. - firstSimplexBlock := sm.FirstEverSimplexBlock() - if firstSimplexBlock == nil { - return nil, fmt.Errorf("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") - } - simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlock.Digest() + prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) + if err != nil { + sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) + return SimplexEpochInfo{}, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber, err) } + if finalization == nil { + sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) + return SimplexEpochInfo{}, fmt.Errorf("%w: epoch %d", errPrevSealingBlockNotFinalized, simplexEpochInfo.EpochNumber) + } + simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) + return simplexEpochInfo, nil } // wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. -func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { - timestamp := parentBlock.Metadata.Timestamp - - hasChildBlock := childBlock != nil - - var newTimestamp time.Time - if hasChildBlock { - newTimestamp = childBlock.Timestamp() - timestamp = uint64(newTimestamp.UnixMilli()) - } +func wrapBlock( + childBlock VMBlock, + newSimplexEpochInfo SimplexEpochInfo, + pChainHeight uint64, + simplexMetadata, + simplexBlacklist []byte, + timestamp time.Time, + icmEpochInfo ICMEpochInfo) *StateMachineBlock { return &StateMachineBlock{ InnerBlock: childBlock, Metadata: StateMachineMetadata{ - Timestamp: timestamp, + Timestamp: uint64(timestamp.UnixMilli()), SimplexProtocolMetadata: simplexMetadata, SimplexBlacklist: simplexBlacklist, SimplexEpochInfo: newSimplexEpochInfo, PChainHeight: pChainHeight, + ICMEpochInfo: icmEpochInfo, }, } } -// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. -func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { - // We check if the sealing block has already been finalized. - // If not, we build a Telock block. - +func (sm *StateMachine) isSealingBlockFinalized(parentBlock StateMachineBlock, prevBlockSeq uint64) (bool, uint64, StateMachineBlock, error) { sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.SealingBlockSeq // If the sealing block sequence is still 0, it means previous block was the sealing block. @@ -630,30 +1010,46 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S } if sealingBlockSeq == 0 { - return nil, fmt.Errorf("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + return false, 0, StateMachineBlock{}, errSealingBlockSeqUnset } - newSimplexEpochInfo := SimplexEpochInfo{ - PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, - EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, - NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, - SealingBlockSeq: sealingBlockSeq, - PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + sealingBlock, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return false, 0, StateMachineBlock{}, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) } - sealingBlock, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + return finalization != nil, sealingBlockSeq, sealingBlock, nil +} + +// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. +func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // We check if the sealing block has already been finalized. + // If not, we build a Telock block. + isSealingBlockFinalized, sealingBlockSeq, sealingBlock, err := sm.isSealingBlockFinalized(parentBlock, prevBlockSeq) if err != nil { - return nil, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) + return nil, err } - isSealingBlockFinalized := finalization != nil + newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) + + now := sm.GetTime() + icmEpochInfo := computeICMEpochInfo(parentBlock, sm.ComputeICMEpoch, now).ToICMEpochInfo() if !isSealingBlockFinalized { pChainHeight := parentBlock.Metadata.PChainHeight - return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + return wrapBlock(nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist, now, icmEpochInfo), nil } // Else, we build a block for the new epoch. + newSimplexEpochInfo = computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo, parentBlock, sealingBlockSeq, prevBlockSeq) + + // TODO: This P-chain height should be taken from the ICM epoch + + return sm.buildBlockOrTransitionEpoch(ctx, sealingBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo) + +} + +func computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo SimplexEpochInfo, parentBlock StateMachineBlock, sealingBlockSeq uint64, prevBlockSeq uint64) SimplexEpochInfo { newSimplexEpochInfo = SimplexEpochInfo{ // P-chain reference height is previous block's NextPChainReferenceHeight. PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, @@ -661,16 +1057,64 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S EpochNumber: sealingBlockSeq, PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } + return newSimplexEpochInfo +} + +func computeSimplexEpochInfoForTelock(parentBlock StateMachineBlock, sealingBlockSeq uint64, prevBlockSeq uint64) SimplexEpochInfo { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + SealingBlockSeq: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + return newSimplexEpochInfo +} + +func (sm *StateMachine) verifyBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + isSealingBlockFinalized, sealingBlockSeq, _, err := sm.isSealingBlockFinalized(parentBlock, prevBlockSeq) + if err != nil { + return err + } + + timestamp := time.UnixMilli(int64(nextBlock.Metadata.Timestamp)) + + now := sm.GetTime() + icmEpochInfo := computeICMEpochInfo(parentBlock, sm.ComputeICMEpoch, now).ToICMEpochInfo() + + newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) + + if !isSealingBlockFinalized { + return verifyAgainstExpected(ctx, nil, newSimplexEpochInfo, nextBlock.Metadata.PChainHeight, nextBlock, timestamp, icmEpochInfo) + } + + // Else, it's a new epoch. + newSimplexEpochInfo = computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo, parentBlock, sealingBlockSeq, prevBlockSeq) + + // The first block of the new epoch may itself transition again, so trust and validate + // the proposed pchain height and (optional) next pchain reference height, mirroring + // what buildBlockOrTransitionEpoch does on the build side. + proposedPChainHeight := nextBlock.Metadata.PChainHeight + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := parentBlock.Metadata.PChainHeight + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + if err := sm.verifyNextPChainRefHeightForNewEpoch(newSimplexEpochInfo, nextBlock.Metadata.SimplexEpochInfo); err != nil { + return fmt.Errorf("failed to verify next P-chain reference height for new epoch block: %w", err) + } + newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight // TODO: This P-chain height should be taken from the ICM epoch - return sm.buildBlockOrTransitionEpoch(ctx, sealingBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo) + return verifyAgainstExpected(ctx, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock, timestamp, icmEpochInfo) } // constructSimplexZeroBlockSimplexEpochInfo constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. func constructSimplexZeroBlockSimplexEpochInfo(pChainHeight uint64, newValidatorSet NodeBLSMappings, prevVMBlockSeq uint64) SimplexEpochInfo { newSimplexEpochInfo := SimplexEpochInfo{ PChainReferenceHeight: pChainHeight, - EpochNumber: 1, + EpochNumber: prevVMBlockSeq + 1, // We treat the zero block as a special case, and we encode in it the block validation descriptor, // despite it not actually being a sealing block. This is because the zero block is the first block that introduces the validator set. BlockValidationDescriptor: &BlockValidationDescriptor{ @@ -738,7 +1182,7 @@ func computeNewApproverSignaturesAndSigners( logger simplex.Logger, ) ([]byte, bitmask, error) { if nextEpochApprovals == nil { - return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") + return nil, bitmask{}, errEmptyNextEpochApprovals } // Prepare the new signatures from the new approvals that haven't approved yet and that agree with our candidate auxiliary info digest and P-Chain height. newSignatures := make([][]byte, 0, len(approvalsFromPeers)+1) @@ -816,6 +1260,8 @@ func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes b } } +// computePrevVMBlockSeq computes the block sequence of the previous VM block (inner block). +// The block sequence of the previous VM block is the number of VM blocks that have been built since genesis. func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) uint64 { // Either our parent block has no inner block, in which case we just inherit its previous VM block sequence, if parentBlock.InnerBlock == nil { @@ -825,22 +1271,13 @@ func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) u return prevBlockSeq } -var ( - errSignerSetShrunk = fmt.Errorf("some signers from parent block are missing from next epoch approvals of proposed block") - errNextEpochApprovalsShrunk = fmt.Errorf("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") -) - -func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { - if prev.NextEpochApprovals == nil { - // Condition satisfied vacuously. +func areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { + if prev.NextEpochApprovals == nil || len(prev.NextEpochApprovals.NodeIDs) == 0 { return nil } - // Else, prev.NextEpochApprovals is not nil. - // If next.NextEpochApprovals is nil, condition is not satisfied. if next.NextEpochApprovals == nil { - return errNextEpochApprovalsShrunk + return fmt.Errorf("%w: previous block has next epoch approvals but proposed block doesn't have next epoch approvals", errNextEpochApprovalsShrunk) } - // Make sure that previous signers are still there. prevSigners := bitmaskFromBytes(prev.NextEpochApprovals.NodeIDs) nextSigners := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) diff --git a/msm/msm_test.go b/msm/msm_test.go index eff624da..26ad5cfa 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rand" "fmt" + "math" "testing" "time" @@ -26,7 +27,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { for _, testCase := range []struct { name string md simplex.ProtocolMetadata - err string + err error configure func(*StateMachine, *testConfig) mutateBlock func(*StateMachineBlock) }{ @@ -43,7 +44,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { md.Seq = 0 block.Metadata.SimplexProtocolMetadata = md.Bytes() }, - err: "attempted to build a genesis inner block", + err: errBuiltGenesisInnerBlock, }, { name: "previous block not found", @@ -51,7 +52,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { configure: func(_ *StateMachine, tc *testConfig) { delete(tc.blockStore, 0) }, - err: "failed to retrieve previous (0) inner block", + err: simplex.ErrBlockNotFound, }, { name: "parent has no inner block", @@ -61,7 +62,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { block: StateMachineBlock{}, } }, - err: "parent inner block (", + err: errZeroBlockParentNoInnerBlock, }, { name: "wrong epoch number", @@ -69,7 +70,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.EpochNumber = 2 }, - err: "invalid epoch number (2), should be 1", + err: errInvalidSimplexEpochInfo, }, { name: "P-chain height too big", @@ -77,7 +78,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.PChainHeight = 110 }, - err: "invalid P-chain height (110), expected to be 100", + err: errInvalidPChainHeight, }, { name: "P-chain height smaller than parent", @@ -85,7 +86,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { configure: func(sm *StateMachine, tc *testConfig) { sm.LastNonSimplexBlockPChainHeight = 99 }, - err: "invalid P-chain height (100), expected to be 99", + err: errInvalidPChainHeight, }, { name: "nil BlockValidationDescriptor", @@ -93,7 +94,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = nil }, - err: "invalid BlockValidationDescriptor: should not be nil", + err: errInvalidSimplexEpochInfo, }, { name: "membership mismatch", @@ -103,7 +104,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { {BLSKey: []byte{1}, Weight: 1}, } }, - err: "invalid BlockValidationDescriptor: should match validator set", + err: errInvalidSimplexEpochInfo, }, { name: "SimplexEpochInfo mismatch", @@ -111,7 +112,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 }, - err: "invalid SimplexEpochInfo", + err: errInvalidSimplexEpochInfo, }, } { t.Run(testCase.name, func(t *testing.T) { @@ -131,8 +132,8 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { } err = sm2.VerifyBlock(context.Background(), block) - if testCase.err != "" { - require.ErrorContains(t, err, testCase.err) + if testCase.err != nil { + require.ErrorIs(t, err, testCase.err) return } require.NoError(t, err) @@ -155,7 +156,7 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { md := simplex.ProtocolMetadata{ Round: 0, Seq: 43, - Epoch: 1, + Epoch: 43, Prev: preSimplexParent.Digest(), } @@ -191,7 +192,7 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: 100, - EpochNumber: 1, + EpochNumber: 43, PrevVMBlockSeq: 42, BlockValidationDescriptor: &BlockValidationDescriptor{ AggregatedMembership: AggregatedMembership{ @@ -212,12 +213,92 @@ func TestMSMNormalOp(t *testing.T) { for _, testCase := range []struct { name string setup func(*StateMachine, *testConfig) + mutateBlock func(*StateMachineBlock) + err error expectedPChainHeight uint64 expectedNextPChainRefHeight uint64 + expectedICMEpochInfo ICMEpochInfo }{ { name: "correct information", expectedPChainHeight: 100, + expectedICMEpochInfo: ICMEpochInfo{PChainEpochHeight: 100, EpochNumber: 1}, + }, + { + name: "trying to build a genesis block", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 0 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: errBuiltGenesisInnerBlock, + }, + { + name: "previous block not found", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 999 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: simplex.ErrBlockNotFound, + }, + { + name: "P-chain height too big", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 110 + }, + err: errPChainHeightTooBig, + }, + { + name: "P-chain height smaller than parent", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 0 + }, + err: errPChainHeightSmallerThanParent, + }, + { + name: "wrong epoch number", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.EpochNumber = 2 + }, + err: errInvalidProtocolMetadataEpoch, + }, + { + name: "non-nil BlockValidationDescriptor", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} + }, + err: errBlockDigestMismatch, + }, + { + name: "non-zero sealing block seq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.SealingBlockSeq = 5 + }, + err: errBlockDigestMismatch, + }, + { + name: "wrong PChainReferenceHeight", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PChainReferenceHeight = 50 + }, + err: errBlockDigestMismatch, + }, + { + name: "non-empty PrevSealingBlockHash", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevSealingBlockHash = [32]byte{1, 2, 3} + }, + err: errBlockDigestMismatch, + }, + { + name: "wrong PrevVMBlockSeq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 + }, + err: errBlockDigestMismatch, }, { name: "validator set change detected", @@ -229,14 +310,17 @@ func TestMSMNormalOp(t *testing.T) { }, expectedPChainHeight: newPChainHeight, expectedNextPChainRefHeight: newPChainHeight, + expectedICMEpochInfo: ICMEpochInfo{PChainEpochHeight: 100, EpochNumber: 1}, }, } { t.Run(testCase.name, func(t *testing.T) { chain := makeChain(t, 5, 10) sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) for i, block := range chain { - testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} + testConfig1.blockStore[uint64(i)] = &outerBlock{block: block, finalization: &simplex.Finalization{}} + testConfig2.blockStore[uint64(i)] = &outerBlock{block: block, finalization: &simplex.Finalization{}} } lastBlock := chain[len(chain)-1] @@ -252,6 +336,10 @@ func TestMSMNormalOp(t *testing.T) { blockTime := lastBlock.InnerBlock.Timestamp().Add(time.Second) + fixedTime := func() time.Time { return blockTime } + sm1.GetTime = fixedTime + sm2.GetTime = fixedTime + content := make([]byte, 10) _, err = rand.Read(content) require.NoError(t, err) @@ -264,13 +352,25 @@ func TestMSMNormalOp(t *testing.T) { if testCase.setup != nil { testCase.setup(sm1, testConfig1) + testCase.setup(sm2, testConfig2) } block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) require.NoError(t, err) require.NotNil(t, block1) - require.Equal(t, &StateMachineBlock{ + if testCase.mutateBlock != nil { + testCase.mutateBlock(block1) + } + + err = sm2.VerifyBlock(context.Background(), block1) + if testCase.err != nil { + require.ErrorIs(t, err, testCase.err) + return + } + require.NoError(t, err) + + expected := &StateMachineBlock{ InnerBlock: &InnerBlock{ TS: blockTime, BlockHeight: lastBlock.InnerBlock.Height(), @@ -287,8 +387,10 @@ func TestMSMNormalOp(t *testing.T) { PrevVMBlockSeq: lastBlock.InnerBlock.Height(), NextPChainReferenceHeight: testCase.expectedNextPChainRefHeight, }, + ICMEpochInfo: testCase.expectedICMEpochInfo, }, - }, block1) + } + require.Equal(t, expected.Digest(), block1.Digest()) }) } } @@ -342,14 +444,17 @@ func TestMSMFullEpochLifecycle(t *testing.T) { for _, testCase := range []struct { name string firstBlockBeforeSimplex StateMachineBlock + epochNum uint64 }{ { name: "building on top of genesis", firstBlockBeforeSimplex: genesis, + epochNum: 1, }, { name: "upgrading to Simplex from pre-Simplex blocks", firstBlockBeforeSimplex: notGenesis, + epochNum: notGenesis.InnerBlock.Height() + 1, }, } { t.Run(testCase.name, func(t *testing.T) { @@ -366,10 +471,41 @@ func TestMSMFullEpochLifecycle(t *testing.T) { return currentPChainHeight } + // Since we explicitly compare the built block with an expected value, + // we need the timestamps to be deterministic. So instead of using time.Now(), we use a fixed + // startTime and add offsets to it for each block. + currentTime := startTime + fixedTime := func() time.Time { return currentTime } + + // We exercise an ICM epoch transition by jumping block3's timestamp + // past the 1-second ICM-epoch window. + // ComputeICMEpoch transitions when the parent block's timestamp has + // crossed the current ICM epoch's start + 1s, so block4 (and every + // block after it) lands in ICM epoch 2. + // + // block3 is also the block where the validator set change is first + // observed, so its Metadata.PChainHeight = pChainHeight2. Since the + // transition takes input.ParentPChainHeight as the new epoch's + // PChainEpochHeight, icmEpoch2.PChainEpochHeight = pChainHeight2. + // block2, block3: ICM epoch 1, started at startTime. + // block4 onward: ICM epoch 2, started at block3's timestamp, + // PChainEpochHeight = pChainHeight2. + icmEpoch1 := ICMEpochInfo{ + PChainEpochHeight: pChainHeight1, + EpochNumber: 1, + EpochStartTime: uint64(startTime.Unix()), + } + icmEpoch2 := ICMEpochInfo{ + PChainEpochHeight: pChainHeight2, + EpochNumber: 2, + EpochStartTime: uint64(startTime.Unix()) + 1, + } + // Create fresh state machine instances for each iteration. sm, tc := newStateMachine(t) sm.GetValidatorSet = getValidatorSet sm.GetPChainHeight = getPChainHeight + sm.GetTime = fixedTime tc.blockStore[0] = &outerBlock{block: genesis} tc.blockStore[42] = &outerBlock{block: notGenesis} @@ -380,6 +516,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { smVerify, tcVerify := newStateMachine(t) smVerify.GetValidatorSet = getValidatorSet smVerify.GetPChainHeight = getPChainHeight + smVerify.GetTime = fixedTime smVerify.LastNonSimplexInnerBlock = testCase.firstBlockBeforeSimplex.InnerBlock smVerify.GenesisValidatorSet = validatorSet1 @@ -401,7 +538,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { md := simplex.ProtocolMetadata{ Seq: baseSeq + 1, Round: 0, - Epoch: 1, + Epoch: testCase.epochNum, Prev: testCase.firstBlockBeforeSimplex.Digest(), } @@ -414,7 +551,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq, BlockValidationDescriptor: &BlockValidationDescriptor{ AggregatedMembership: AggregatedMembership{ @@ -424,7 +561,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { }, }, }, block1) - addBlock(md.Seq, *block1, nil) + addBlock(md.Seq, *block1, &simplex.Finalization{}) require.NoError(t, smVerify.VerifyBlock(context.Background(), block1)) @@ -433,21 +570,23 @@ func TestMSMFullEpochLifecycle(t *testing.T) { smVerify.LatestPersistedHeight = baseSeq + 1 // ----- Step 2: Build a normal block (no validator set change) ----- + currentTime = startTime.Add(2 * time.Millisecond) tc.blockBuilder.block = nextBlock(2) - md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: 1, Prev: block1.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: testCase.epochNum, Prev: block1.Digest()} block2, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(2), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(2 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight1, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq, }, + ICMEpochInfo: icmEpoch1, }, }, block2) addBlock(md.Seq, *block2, nil) @@ -458,22 +597,27 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // Advance P-chain height so that GetValidatorSet returns a different set. currentPChainHeight = pChainHeight2 + // Jump block3's timestamp past the 1-second ICM-epoch window so + // block4 (whose parent is block3) sees parentTimestamp >= + // epochStart + 1s and transitions ICM to epoch 2. + currentTime = startTime.Add(time.Second + 3*time.Millisecond) tc.blockBuilder.block = nextBlock(3) - md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: 1, Prev: block2.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: testCase.epochNum, Prev: block2.Digest()} block3, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(3), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(3 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight2, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 2, NextPChainReferenceHeight: pChainHeight2, }, + ICMEpochInfo: icmEpoch1, }, }, block3) addBlock(md.Seq, *block3, nil) @@ -499,19 +643,20 @@ func TestMSMFullEpochLifecycle(t *testing.T) { sig, err := aggr.AppendSignatures(nil, []byte("sig1")) require.NoError(t, err) + currentTime = startTime.Add(time.Second + 4*time.Millisecond) tc.blockBuilder.block = nextBlock(4) - md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: 1, Prev: block3.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: testCase.epochNum, Prev: block3.Digest()} block4, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(4), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(4 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight2, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 3, NextPChainReferenceHeight: pChainHeight2, NextEpochApprovals: &NextEpochApprovals{ @@ -519,6 +664,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { Signature: sig, }, }, + ICMEpochInfo: icmEpoch2, }, }, block4) addBlock(md.Seq, *block4, nil) @@ -539,19 +685,20 @@ func TestMSMFullEpochLifecycle(t *testing.T) { require.NoError(t, err) bitmask = []byte{3} + currentTime = startTime.Add(time.Second + 5*time.Millisecond) tc.blockBuilder.block = nextBlock(5) - md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: 1, Prev: block4.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: testCase.epochNum, Prev: block4.Digest()} block5, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(5), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(5 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight2, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 4, NextPChainReferenceHeight: pChainHeight2, NextEpochApprovals: &NextEpochApprovals{ @@ -559,6 +706,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { Signature: sig, }, }, + ICMEpochInfo: icmEpoch2, }, }, block5) addBlock(md.Seq, *block5, nil) @@ -579,19 +727,20 @@ func TestMSMFullEpochLifecycle(t *testing.T) { require.NoError(t, err) bitmask = []byte{7} + currentTime = startTime.Add(time.Second + 6*time.Millisecond) tc.blockBuilder.block = nextBlock(6) - md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: 1, Prev: block5.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: testCase.epochNum, Prev: block5.Digest()} block6, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(6), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(6 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight2, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 5, NextPChainReferenceHeight: pChainHeight2, SealingBlockSeq: 0, @@ -606,6 +755,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { Signature: sig6, }, }, + ICMEpochInfo: icmEpoch2, }, }, block6) addBlock(md.Seq, *block6, nil) @@ -644,28 +794,31 @@ func TestMSMFullEpochLifecycle(t *testing.T) { subTestCase.setup() tc.blockBuilder.block = nextBlock(7) - md = simplex.ProtocolMetadata{Seq: baseSeq + 7, Round: 6, Epoch: 1, Prev: block6.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 7, Round: 6, Epoch: testCase.epochNum, Prev: block6.Digest()} // If the sealing block isn't finalized yet, we expect to build a Telock. // However, despite the fact that the block builder is willing to build a new block, // a Telock shouldn't contain an inner block. if tc.blockStore[sealingSeq].finalization == nil { + // Telock shares the sealing block's timestamp slot. + currentTime = startTime.Add(time.Second + 6*time.Millisecond) telock, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nil, Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(6 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight2, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, NextPChainReferenceHeight: pChainHeight2, PrevVMBlockSeq: baseSeq + 6, SealingBlockSeq: sealingSeq, }, + ICMEpochInfo: icmEpoch2, }, }, telock) @@ -675,12 +828,18 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // ----- Step 7: Build a new epoch block (sealing block is finalized) ----- + // The first block of the new epoch carries the new EpochNumber + // (= sealing block's sequence) in both SimplexEpochInfo.EpochNumber + // and the protocol metadata's Epoch field. + md.Epoch = sealingSeq + + currentTime = startTime.Add(time.Second + 7*time.Millisecond) block7, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(7), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(7 * time.Millisecond).UnixMilli()), + Timestamp: uint64(currentTime.UnixMilli()), PChainHeight: pChainHeight2, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ @@ -688,6 +847,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { EpochNumber: sealingSeq, PrevVMBlockSeq: baseSeq + 6, }, + ICMEpochInfo: icmEpoch2, }, }, block7) addBlock(md.Seq, *block7, nil) @@ -775,7 +935,7 @@ func TestAreNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(t *testing.T }, } { t.Run(tc.name, func(t *testing.T) { - err := ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(tc.prev, tc.next) + err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(tc.prev, tc.next) if tc.err != nil { require.ErrorIs(t, err, tc.err) } else { @@ -785,6 +945,154 @@ func TestAreNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(t *testing.T } } +func TestVerifyPChainHeight(t *testing.T) { + tests := []struct { + name string + proposed uint64 + current uint64 + prev uint64 + err error + }{ + { + name: "proposed equals current and parent", + proposed: 10, + current: 10, + prev: 10, + }, + { + name: "proposed equals current, above parent", + proposed: 10, + current: 10, + prev: 5, + }, + { + name: "proposed equals parent, below current", + proposed: 5, + current: 10, + prev: 5, + }, + { + name: "proposed strictly between parent and current", + proposed: 7, + current: 10, + prev: 5, + }, + { + name: "all zero", + proposed: 0, + current: 0, + prev: 0, + }, + { + name: "proposed greater than current", + proposed: 11, + current: 10, + prev: 5, + err: errPChainHeightTooBig, + }, + { + name: "proposed greater than current by one, current is zero", + proposed: 1, + current: 0, + prev: 0, + err: errPChainHeightTooBig, + }, + { + name: "parent greater than proposed", + proposed: 5, + current: 10, + prev: 6, + err: errPChainHeightSmallerThanParent, + }, + { + name: "proposed is zero, parent is non-zero", + proposed: 0, + current: 10, + prev: 1, + err: errPChainHeightSmallerThanParent, + }, + { + // When both checks would trigger, "too big" takes precedence. + name: "both checks would fire, too-big wins", + proposed: 20, + current: 10, + prev: 15, + err: errPChainHeightTooBig, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := verifyPChainHeight(tt.proposed, tt.current, tt.prev) + if tt.err == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tt.err) + }) + } +} + +func TestVerifyTimestamp(t *testing.T) { + now := time.Now() + nowMilli := uint64(now.UnixMilli()) + skewMilli := uint64(maxSkew / time.Millisecond) + + tests := []struct { + name string + proposed uint64 + prev uint64 + err error + }{ + { + name: "proposed equals parent", + proposed: nowMilli, + prev: nowMilli, + }, + { + name: "proposed after parent, well within skew", + proposed: nowMilli + 100, + prev: nowMilli - 100, + }, + { + name: "proposed exactly at now + maxSkew", + proposed: nowMilli + skewMilli, + prev: nowMilli, + }, + { + name: "proposed below parent", + proposed: nowMilli - 1, + prev: nowMilli, + err: errTimestampDecreasing, + }, + { + name: "proposed one millisecond past now + maxSkew", + proposed: nowMilli + skewMilli + 1, + prev: nowMilli, + err: errTimestampTooFarInFuture, + }, + { + name: "proposed exceeds math.MaxInt64", + proposed: uint64(math.MaxInt64) + 1, + prev: nowMilli, + err: errTimestampTooBig, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + block := &StateMachineBlock{Metadata: StateMachineMetadata{Timestamp: tt.proposed}} + prev := &StateMachineBlock{Metadata: StateMachineMetadata{Timestamp: tt.prev}} + err := verifyTimestamp(block, prev, now) + if tt.err == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tt.err) + }) + } +} + func TestComputePrevVMBlockSeq(t *testing.T) { t.Run("parent has no inner block", func(t *testing.T) { parent := StateMachineBlock{ @@ -983,6 +1291,6 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { } _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}, logger) - require.ErrorContains(t, err, "aggregation failed") + require.ErrorIs(t, err, errTestAggregationFailed) }) } diff --git a/msm/util_test.go b/msm/util_test.go new file mode 100644 index 00000000..002e3d0e --- /dev/null +++ b/msm/util_test.go @@ -0,0 +1,422 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/asn1" + "errors" + "fmt" + "maps" + "testing" + "time" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" +) + +// Test helpers + +type InnerBlock struct { + TS time.Time + BlockHeight uint64 + Bytes []byte +} + +func (i *InnerBlock) Digest() [32]byte { + return sha256.Sum256(i.Bytes) +} + +func (i *InnerBlock) Height() uint64 { + return i.BlockHeight +} + +func (i *InnerBlock) Timestamp() time.Time { + return i.TS +} + +func (i *InnerBlock) Verify(_ context.Context, _ uint64) error { + return nil +} + +// fakeVMBlock is a minimal VMBlock implementation for tests. +type fakeVMBlock struct { + height uint64 +} + +func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } +func (f *fakeVMBlock) Height() uint64 { return f.height } +func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } +func (f *fakeVMBlock) Verify(_ context.Context, _ uint64) error { return nil } + +type outerBlock struct { + finalization *simplex.Finalization + block StateMachineBlock +} + +type blockStore map[uint64]*outerBlock + +func (bs blockStore) clone() blockStore { + newStore := make(blockStore) + maps.Copy(newStore, bs) + return newStore +} + +func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, exits := bs[seq] + if !exits { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) + } + return blk.block, blk.finalization, nil +} + +type approvalsRetriever struct { + result ValidatorSetApprovals +} + +func (a approvalsRetriever) Approvals() ValidatorSetApprovals { + return a.result +} + +type signatureVerifier struct { + err error +} + +func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { + return sv.err +} + +type signatureAggregator struct { + weightByNodeID map[string]uint64 + totalWeight uint64 +} + +type aggregatrdSignature struct { + Signatures [][]byte +} + +func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + all := make([][]byte, 0, len(sigs)+1) + all = append(all, sigs...) + if len(existing) > 0 { + all = append(all, existing) + } + return asn1.Marshal(aggregatrdSignature{Signatures: all}) +} + +func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { + var sum uint64 + for _, signer := range signers { + sum += sv.weightByNodeID[string(signer)] + } + return sum*3 > sv.totalWeight*2 +} + +func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { + return func(weights []simplex.Node) simplex.SignatureAggregator { + s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} + for _, nw := range weights { + s.weightByNodeID[string(nw.Node)] = nw.Weight + s.totalWeight += nw.Weight + } + return s + } +} + +type noOpPChainListener struct{} + +func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { + <-ctx.Done() + return ctx.Err() +} + +type blockBuilder struct { + block VMBlock + err error +} + +func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { + // Block is always ready in tests. +} + +func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { + return bb.block, bb.err +} + +type validatorSetRetriever struct { + result NodeBLSMappings + resultMap map[uint64]NodeBLSMappings + err error +} + +func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { + if vsr.resultMap != nil { + if result, ok := vsr.resultMap[height]; ok { + return result, vsr.err + } + } + return vsr.result, vsr.err +} + +type keyAggregator struct{} + +func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + aggregated := make([]byte, 0) + for _, key := range keys { + aggregated = append(aggregated, key...) + } + return aggregated, nil +} + +var ( + genesisBlock = StateMachineBlock{ + // Genesis block metadata has all zero values + InnerBlock: &InnerBlock{ + TS: time.Now(), + Bytes: []byte{1, 2, 3}, + }, + } +) + +type dynamicApprovalsRetriever struct { + approvals *ValidatorSetApprovals +} + +func (d *dynamicApprovalsRetriever) Approvals() ValidatorSetApprovals { + return *d.approvals +} + +func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { + startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) + blocks := make([]StateMachineBlock, 0, endHeight+1) + var round, seq uint64 + for h := uint64(0); h <= endHeight; h++ { + index := len(blocks) + + if h == 0 { + blocks = append(blocks, genesisBlock) + continue + } + + if h < simplexStartHeight { + blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) + continue + } + + seq = uint64(index) + + blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) + round++ + } + return blocks +} + +func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + prev := genesisBlock.Digest() + if index > 0 { + prev = blocks[index-1].Digest() + } + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + Metadata: StateMachineMetadata{ + PChainHeight: 100, + SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ + Round: round, + Seq: seq, + Epoch: 1, + Prev: prev, + }).Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{}, + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: uint64(index), + }, + }, + } +} + +func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h-startHeight) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + } +} + +type testConfig struct { + blockStore blockStore + approvalsRetriever approvalsRetriever + signatureVerifier signatureVerifier + signatureAggregator signatureAggregator + blockBuilder blockBuilder + keyAggregator keyAggregator + validatorSetRetriever validatorSetRetriever +} + +func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { + bs := make(blockStore) + bs[0] = &outerBlock{block: genesisBlock} + + var testConfig testConfig + testConfig.blockStore = bs + testConfig.validatorSetRetriever.result = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, + } + + smConfig := Config{ + GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, + LastNonSimplexBlockPChainHeight: 100, + GetTime: time.Now, + TimeSkewLimit: time.Second * 5, + Logger: testutil.MakeLogger(t), + GetBlock: testConfig.blockStore.getBlock, + MaxBlockBuildingWaitTime: time.Second, + ApprovalsRetriever: &testConfig.approvalsRetriever, + SignatureVerifier: &testConfig.signatureVerifier, + SignatureAggregatorCreator: newSignatureAggregatorCreator(), + BlockBuilder: &testConfig.blockBuilder, + KeyAggregator: &testConfig.keyAggregator, + GetPChainHeight: func() uint64 { + return 100 + }, + GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, + PChainProgressListener: &noOpPChainListener{}, + LastNonSimplexInnerBlock: genesisBlock.InnerBlock, + ComputeICMEpoch: func(input ICMEpochInput) ICMEpoch { + // This is just the ACP-181 implementation from avalanchego + var zeroEpoch ICMEpoch + if input.ParentEpoch == zeroEpoch { + return ICMEpoch{ + PChainEpochHeight: input.ParentPChainHeight, + EpochNumber: 1, + EpochStartTime: uint64(input.ParentTimestamp.Unix()), + } + } + endTime := time.Unix(int64(input.ParentEpoch.EpochStartTime), 0).Add(time.Second) + if input.ParentTimestamp.Before(endTime) { + return input.ParentEpoch + } + return ICMEpoch{ + PChainEpochHeight: input.ParentPChainHeight, + EpochNumber: input.ParentEpoch.EpochNumber + 1, + EpochStartTime: uint64(input.ParentTimestamp.Unix()), + } + }, + } + + sm, err := NewStateMachine(&smConfig) + require.NoError(t, err) + + return sm, &testConfig +} + +// concatAggregator concatenates signatures for easy verification in tests. +type concatAggregator struct{} + +func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + result := bytes.Join(sigs, nil) + return append(result, existing...), nil +} + +func (concatAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} + +type failingAggregator struct{} + +func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +var errTestAggregationFailed = errors.New("aggregation failed") + +func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { + return nil, errTestAggregationFailed +} + +func (failingAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} + +type testBlockStore map[uint64]StateMachineBlock + +func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[seq] + if !ok { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) + } + return blk, nil, nil +} + +type testVMBlock struct { + bytes []byte + height uint64 +} + +func (b *testVMBlock) Digest() [32]byte { + return sha256.Sum256(b.bytes) +} + +func (b *testVMBlock) Height() uint64 { + return b.height +} + +func (b *testVMBlock) Timestamp() time.Time { + return time.Now() +} + +func (b *testVMBlock) Verify(_ context.Context, _ uint64) error { + return nil +} + +type testSigVerifier struct { + err error +} + +func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { + return sv.err +} + +type testKeyAggregator struct { + err error +} + +func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + if ka.err != nil { + return nil, ka.err + } + var agg []byte + for _, k := range keys { + agg = append(agg, k...) + } + return agg, nil +}