From 13cb79babb17c4eb0a56ed79a337629013ff69cb Mon Sep 17 00:00:00 2001 From: "Masih H. Derkani" Date: Mon, 23 Sep 2024 14:01:41 +0200 Subject: [PATCH] Refactor validation logic into a mutex-free pluggable struct Refactor the validation logic out of `Participant` and into its own dedicated struct that is mutex-free and listens to the progress made by the participant to infer the correct validation path. The change above significantly reduces the need for mutex control over current instance, which makes it easier to plug in extra conditional behaviour, e.g. #583. Fixes #561 --- gpbft/gpbft.go | 16 +- gpbft/gpbft_test.go | 2 +- gpbft/options.go | 14 +- gpbft/participant.go | 252 +-------------------- gpbft/participant_test.go | 6 +- gpbft/validator.go | 294 +++++++++++++++++++++++++ internal/caching/grouped_cache.go | 10 +- internal/caching/grouped_cache_test.go | 4 +- 8 files changed, 345 insertions(+), 253 deletions(-) create mode 100644 gpbft/validator.go diff --git a/gpbft/gpbft.go b/gpbft/gpbft.go index c9a091ce..2d98c473 100644 --- a/gpbft/gpbft.go +++ b/gpbft/gpbft.go @@ -211,7 +211,8 @@ type instance struct { // independently of protocol phases/rounds. decision *quorumState // tracer traces logic logs for debugging and simulation purposes. - tracer Tracer + tracer Tracer + progress ProgressObserver } func newInstance( @@ -221,7 +222,8 @@ func newInstance( data *SupplementalData, powerTable *PowerTable, aggregateVerifier Aggregate, - beacon []byte) (*instance, error) { + beacon []byte, + progress ProgressObserver) (*instance, error) { if input.IsZero() { return nil, fmt.Errorf("input is empty") } @@ -251,6 +253,7 @@ func newInstance( }, decision: newQuorumState(powerTable), tracer: participant.tracer, + progress: progress, }, nil } @@ -483,6 +486,7 @@ func (i *instance) beginQuality() error { } // Broadcast input value and wait to receive from others. i.phase = QUALITY_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.phaseTimeout = i.alarmAfterSynchrony() i.resetRebroadcastParams() i.broadcast(i.round, QUALITY_PHASE, i.proposal, false, nil) @@ -537,6 +541,7 @@ func (i *instance) beginConverge(justification *Justification) { } i.phase = CONVERGE_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.phaseTimeout = i.alarmAfterSynchrony() i.resetRebroadcastParams() @@ -599,6 +604,7 @@ func (i *instance) tryConverge() error { func (i *instance) beginPrepare(justification *Justification) { // Broadcast preparation of value and wait for everyone to respond. i.phase = PREPARE_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.phaseTimeout = i.alarmAfterSynchrony() i.resetRebroadcastParams() @@ -639,6 +645,7 @@ func (i *instance) tryPrepare() error { func (i *instance) beginCommit() { i.phase = COMMIT_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.phaseTimeout = i.alarmAfterSynchrony() i.resetRebroadcastParams() @@ -715,6 +722,7 @@ func (i *instance) tryCommit(round uint64) error { func (i *instance) beginDecide(round uint64) { i.phase = DECIDE_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.resetRebroadcastParams() var justification *Justification // Value cannot be empty here. @@ -740,10 +748,12 @@ func (i *instance) beginDecide(round uint64) { // The provided justification must justify the value being decided. func (i *instance) skipToDecide(value ECChain, justification *Justification) { i.phase = DECIDE_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.proposal = value i.value = i.proposal i.resetRebroadcastParams() i.broadcast(0, DECIDE_PHASE, i.value, false, justification) + metrics.phaseCounter.Add(context.TODO(), 1, metric.WithAttributes(attrDecidePhase)) metrics.currentPhase.Record(context.TODO(), int64(DECIDE_PHASE)) metrics.skipCounter.Add(context.TODO(), 1, metric.WithAttributes(attrSkipToDecide)) @@ -844,9 +854,11 @@ func (i *instance) addCandidate(c ECChain) bool { func (i *instance) terminate(decision *Justification) { i.log("✅ terminated %s during round %d", &i.value, i.round) i.phase = TERMINATED_PHASE + i.progress.NotifyProgress(i.instanceID, i.round, i.phase) i.value = decision.Vote.Value i.terminationValue = decision i.resetRebroadcastParams() + metrics.phaseCounter.Add(context.TODO(), 1, metric.WithAttributes(attrTerminatedPhase)) metrics.roundHistogram.Record(context.TODO(), int64(i.round)) metrics.currentPhase.Record(context.TODO(), int64(TERMINATED_PHASE)) diff --git a/gpbft/gpbft_test.go b/gpbft/gpbft_test.go index 4dd9d5f1..9fdd0458 100644 --- a/gpbft/gpbft_test.go +++ b/gpbft/gpbft_test.go @@ -281,7 +281,7 @@ func TestGPBFT_WithEvenPowerDistribution(t *testing.T) { t.Run("Queues future instance messages during current instance", func(t *testing.T) { instance, driver := newInstanceAndDriver(t) futureInstance := emulator.NewInstance(t, - 42, + 8, gpbft.PowerEntries{ gpbft.PowerEntry{ ID: 0, diff --git a/gpbft/options.go b/gpbft/options.go index c04fc7ea..a7b7a14e 100644 --- a/gpbft/options.go +++ b/gpbft/options.go @@ -8,11 +8,12 @@ import ( "time" ) -var ( +const ( defaultDelta = 3 * time.Second defaultDeltaBackOffExponent = 2.0 defaultMaxCachedInstances = 10 defaultMaxCachedMessagesPerInstance = 25_000 + defaultCommitteeLookback = 10 ) // Option represents a configurable parameter. @@ -22,6 +23,7 @@ type options struct { delta time.Duration deltaBackOffExponent float64 + committeeLookback uint64 maxLookaheadRounds uint64 rebroadcastAfter func(int) time.Duration @@ -36,6 +38,7 @@ func newOptions(o ...Option) (*options, error) { opts := &options{ delta: defaultDelta, deltaBackOffExponent: defaultDeltaBackOffExponent, + committeeLookback: defaultCommitteeLookback, rebroadcastAfter: defaultRebroadcastAfter, maxCachedInstances: defaultMaxCachedInstances, maxCachedMessagesPerInstance: defaultMaxCachedMessagesPerInstance, @@ -118,6 +121,15 @@ func WithMaxCachedMessagesPerInstance(v int) Option { } } +// WithCommitteeLookback sets the number of instances in the past from which the +// committee for the latest instance is derived. Defaults to 10 if unset. +func WithCommitteeLookback(lookback uint64) Option { + return func(o *options) error { + o.committeeLookback = lookback + return nil + } +} + var defaultRebroadcastAfter = exponentialBackoffer(1.3, 0.1, 3*time.Second, 30*time.Second) // WithRebroadcastBackoff sets the duration after the gPBFT timeout has elapsed, at diff --git a/gpbft/participant.go b/gpbft/participant.go index 147e4b5b..5b6436c4 100644 --- a/gpbft/participant.go +++ b/gpbft/participant.go @@ -1,17 +1,14 @@ package gpbft import ( - "bytes" "context" "errors" "fmt" - "math" "runtime/debug" "sort" "sync" "time" - "github.com/filecoin-project/go-f3/internal/caching" logging "github.com/ipfs/go-log/v2" "go.opentelemetry.io/otel/metric" ) @@ -51,10 +48,7 @@ type Participant struct { // protocol round for which a strong quorum of COMMIT messages was observed, // which may not be known to the participant. terminatedDuringRound uint64 - - // validationCache is a bounded cache of messages that have already been - // validated by the participant, grouped by instance. - validationCache *caching.GroupedSet + validator *cachingValidator } type validatedMessage struct { @@ -88,12 +82,13 @@ func NewParticipant(host Host, o ...Option) (*Participant, error) { if err != nil { return nil, err } + ccp := newCachedCommitteeProvider(host) return &Participant{ options: opts, host: host, - committeeProvider: newCachedCommitteeProvider(host), + committeeProvider: ccp, mqueue: newMessageQueue(opts.maxLookaheadRounds), - validationCache: caching.NewGroupedSet(opts.maxCachedInstances, opts.maxCachedMessagesPerInstance), + validator: newValidator(host, ccp, opts.committeeLookback, opts.maxCachedInstances, opts.maxCachedMessagesPerInstance), }, nil } @@ -147,223 +142,12 @@ func (p *Participant) CurrentInstance() uint64 { // returned. ErrValidationInvalid indicates that the message will never be valid // invalid and may be safely dropped. func (p *Participant) ValidateMessage(msg *GMessage) (valid ValidatedMessage, err error) { - // This method is not protected by the API mutex, it is intended for concurrent use. - // The instance mutex is taken when appropriate by inner methods. defer func() { if r := recover(); r != nil { err = newPanicError(r) } - if err != nil { - metrics.errorCounter.Add(context.TODO(), 1, metric.WithAttributes(metricAttributeFromError(err))) - } }() - - comt, err := p.fetchCommittee(msg.Vote.Instance, msg.Vote.Phase) - if err != nil { - return nil, err - } - - // TODO: Refactor validation into its own struct such that it encapsulates - // caching, metrics etc. - // See: https://github.com/filecoin-project/go-f3/issues/561 - - var buf bytes.Buffer - var cacheMessage bool - if err := msg.MarshalCBOR(&buf); err != nil { - log.Errorw("failed to marshal message for caching", "err", err) - } else if alreadyValidated, err := p.validationCache.Contains(msg.Vote.Instance, messageCacheNamespace, buf.Bytes()); err != nil { - log.Errorw("failed to check already validated messages", "err", err) - } else if alreadyValidated { - metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheHit, attrCacheKindMessage)) - return &validatedMessage{msg: msg}, nil - } else { - cacheMessage = true - metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheMiss, attrCacheKindMessage)) - } - - // Check sender is eligible. - senderPower, senderPubKey := comt.PowerTable.Get(msg.Sender) - if senderPower == 0 { - return nil, fmt.Errorf("sender %d with zero power or not in power table: %w", msg.Sender, ErrValidationInvalid) - } - - // Check that message value is a valid chain. - if err := msg.Vote.Value.Validate(); err != nil { - return nil, fmt.Errorf("invalid message vote value chain: %w: %w", err, ErrValidationInvalid) - } - - // Check phase-specific constraints. - switch msg.Vote.Phase { - case QUALITY_PHASE: - if msg.Vote.Round != 0 { - return nil, fmt.Errorf("unexpected round %d for quality phase: %w", msg.Vote.Round, ErrValidationInvalid) - } - if msg.Vote.Value.IsZero() { - return nil, fmt.Errorf("unexpected zero value for quality phase: %w", ErrValidationInvalid) - } - case CONVERGE_PHASE: - if msg.Vote.Round == 0 { - return nil, fmt.Errorf("unexpected round 0 for converge phase: %w", ErrValidationInvalid) - } - if msg.Vote.Value.IsZero() { - return nil, fmt.Errorf("unexpected zero value for converge phase: %w", ErrValidationInvalid) - } - if !VerifyTicket(p.host.NetworkName(), comt.Beacon, msg.Vote.Instance, msg.Vote.Round, senderPubKey, p.host, msg.Ticket) { - return nil, fmt.Errorf("failed to verify ticket from %v: %w", msg.Sender, ErrValidationInvalid) - } - case DECIDE_PHASE: - if msg.Vote.Round != 0 { - return nil, fmt.Errorf("unexpected non-zero round %d for decide phase: %w", msg.Vote.Round, ErrValidationInvalid) - } - if msg.Vote.Value.IsZero() { - return nil, fmt.Errorf("unexpected zero value for decide phase: %w", ErrValidationInvalid) - } - case PREPARE_PHASE, COMMIT_PHASE: - // No additional checks for PREPARE and COMMIT. - default: - return nil, fmt.Errorf("invalid vote phase: %d: %w", msg.Vote.Phase, ErrValidationInvalid) - } - - // Check vote signature. - sigPayload := p.host.MarshalPayloadForSigning(p.host.NetworkName(), &msg.Vote) - if err := p.host.Verify(senderPubKey, sigPayload, msg.Signature); err != nil { - return nil, fmt.Errorf("invalid signature on %v, %v: %w", msg, err, ErrValidationInvalid) - } - - // Check justification - needsJustification := !(msg.Vote.Phase == QUALITY_PHASE || - (msg.Vote.Phase == PREPARE_PHASE && msg.Vote.Round == 0) || - (msg.Vote.Phase == COMMIT_PHASE && msg.Vote.Value.IsZero())) - - if needsJustification { - if err := p.validateJustification(msg, comt); err != nil { - return nil, fmt.Errorf("%v: %w", err, ErrValidationInvalid) - } - } else if msg.Justification != nil { - return nil, fmt.Errorf("message %v has unexpected justification: %w", msg, ErrValidationInvalid) - } - - if cacheMessage { - if _, err := p.validationCache.Add(msg.Vote.Instance, messageCacheNamespace, buf.Bytes()); err != nil { - log.Warnw("failed to cache to already validated message", "err", err) - } - } - return &validatedMessage{msg: msg}, nil -} - -func (p *Participant) validateJustification(msg *GMessage, comt *Committee) error { - - if msg.Justification == nil { - return fmt.Errorf("message for phase %v round %v has no justification", msg.Vote.Phase, msg.Vote.Round) - } - - // Only cache the justification if: - // * marshalling it was successful - // * it is not already present in the cache. - var cacheJustification bool - var buf bytes.Buffer - if err := msg.Justification.MarshalCBOR(&buf); err != nil { - log.Errorw("failed to marshal justification for caching", "err", err) - } else if alreadyValidated, err := p.validationCache.Contains(msg.Vote.Instance, justificationCacheNamespace, buf.Bytes()); err != nil { - log.Warnw("failed to check if justification is already cached", "err", err) - } else if alreadyValidated { - metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheHit, attrCacheKindJustification)) - return nil - } else { - cacheJustification = true - metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheMiss, attrCacheKindJustification)) - } - - // Check that the justification is for the same instance. - if msg.Vote.Instance != msg.Justification.Vote.Instance { - return fmt.Errorf("message with instanceID %v has evidence from instanceID: %v", msg.Vote.Instance, msg.Justification.Vote.Instance) - } - if !msg.Vote.SupplementalData.Eq(&msg.Justification.Vote.SupplementalData) { - return fmt.Errorf("message and justification have inconsistent supplemental data: %v != %v", msg.Vote.SupplementalData, msg.Justification.Vote.SupplementalData) - } - // Check that justification vote value is a valid chain. - if err := msg.Justification.Vote.Value.Validate(); err != nil { - return fmt.Errorf("invalid justification vote value chain: %w", err) - } - - // Check every remaining field of the justification, according to the phase requirements. - // This map goes from the message phase to the expected justification phase(s), - // to the required vote values for justification by that phase. - // Anything else is disallowed. - expectations := map[Phase]map[Phase]struct { - Round uint64 - Value ECChain - }{ - // CONVERGE is justified by a strong quorum of COMMIT for bottom, - // or a strong quorum of PREPARE for the same value, from the previous round. - CONVERGE_PHASE: { - COMMIT_PHASE: {msg.Vote.Round - 1, ECChain{}}, - PREPARE_PHASE: {msg.Vote.Round - 1, msg.Vote.Value}, - }, - // PREPARE is justified by the same rules as CONVERGE (in rounds > 0). - PREPARE_PHASE: { - COMMIT_PHASE: {msg.Vote.Round - 1, ECChain{}}, - PREPARE_PHASE: {msg.Vote.Round - 1, msg.Vote.Value}, - }, - // COMMIT is justified by strong quorum of PREPARE from the same round with the same value. - COMMIT_PHASE: { - PREPARE_PHASE: {msg.Vote.Round, msg.Vote.Value}, - }, - // DECIDE is justified by strong quorum of COMMIT with the same value. - // The DECIDE message doesn't specify a round. - DECIDE_PHASE: { - COMMIT_PHASE: {math.MaxUint64, msg.Vote.Value}, - }, - } - - if expectedPhases, ok := expectations[msg.Vote.Phase]; ok { - if expected, ok := expectedPhases[msg.Justification.Vote.Phase]; ok { - if msg.Justification.Vote.Round != expected.Round && expected.Round != math.MaxUint64 { - return fmt.Errorf("message %v has justification from wrong round %d", msg, msg.Justification.Vote.Round) - } - if !msg.Justification.Vote.Value.Eq(expected.Value) { - return fmt.Errorf("message %v has justification for a different value: %v", msg, msg.Justification.Vote.Value) - } - } else { - return fmt.Errorf("message %v has justification with unexpected phase: %v", msg, msg.Justification.Vote.Phase) - } - } else { - return fmt.Errorf("message %v has unexpected phase for justification", msg) - } - - // Check justification power and signature. - var justificationPower int64 - signers := make([]int, 0) - if err := msg.Justification.Signers.ForEach(func(bit uint64) error { - if int(bit) >= len(comt.PowerTable.Entries) { - return fmt.Errorf("invalid signer index: %d", bit) - } - power := comt.PowerTable.ScaledPower[bit] - if power == 0 { - return fmt.Errorf("signer with ID %d has no power", comt.PowerTable.Entries[bit].ID) - } - justificationPower += power - signers = append(signers, int(bit)) - return nil - }); err != nil { - return fmt.Errorf("failed to iterate over signers: %w", err) - } - - if !IsStrongQuorum(justificationPower, comt.PowerTable.ScaledTotal) { - return fmt.Errorf("message %v has justification with insufficient power: %v", msg, justificationPower) - } - - payload := p.host.MarshalPayloadForSigning(p.host.NetworkName(), &msg.Justification.Vote) - if err := comt.AggregateVerifier.VerifyAggregate(signers, payload, msg.Justification.Signature); err != nil { - return fmt.Errorf("verification of the aggregate failed: %+v: %w", msg.Justification, err) - } - - if cacheJustification { - if _, err := p.validationCache.Add(msg.Vote.Instance, justificationCacheNamespace, buf.Bytes()); err != nil { - log.Warnw("failed to cache to already validated justification", "err", err) - } - } - return nil + return p.validator.ValidateMessage(msg) } // Receives a validated Granite message from some other participant. @@ -441,11 +225,11 @@ func (p *Participant) beginInstance() error { return fmt.Errorf("invalid canonical chain: %w", err) } - comt, err := p.fetchCommittee(p.currentInstance, INITIAL_PHASE) + comt, err := p.committeeProvider.GetCommittee(p.currentInstance) if err != nil { return err } - if p.gpbft, err = newInstance(p, p.currentInstance, chain, data, comt.PowerTable, comt.AggregateVerifier, comt.Beacon); err != nil { + if p.gpbft, err = newInstance(p, p.currentInstance, chain, data, comt.PowerTable, comt.AggregateVerifier, comt.Beacon, p.validator); err != nil { return fmt.Errorf("failed creating new gpbft instance: %w", err) } if err := p.gpbft.Start(); err != nil { @@ -465,25 +249,6 @@ func (p *Participant) beginInstance() error { return nil } -// Fetches the committee against which to validate messages for some instance. -func (p *Participant) fetchCommittee(instance uint64, phase Phase) (*Committee, error) { - p.instanceMutex.Lock() - defer p.instanceMutex.Unlock() - - switch { - // Accept all messages from the current and future instances. - case instance >= p.currentInstance: - // Accept messages from the previous instance, but only for decide messages. - case instance == p.currentInstance-1 && phase == DECIDE_PHASE: - // Reject all others as too old. - default: - return nil, fmt.Errorf("instance %d, current %d: %w", - instance, p.currentInstance, ErrValidationTooOld) - } - - return p.committeeProvider.GetCommittee(instance) -} - func (p *Participant) handleDecision() { if !p.terminated() { return @@ -504,9 +269,9 @@ func (p *Participant) finishCurrentInstance() *Justification { if p.gpbft != nil { decision = p.gpbft.terminationValue p.terminatedDuringRound = p.gpbft.round - p.validationCache.RemoveGroup(p.gpbft.instanceID) } p.gpbft = nil + p.validator.NotifyProgress(p.currentInstance, 0, INITIAL_PHASE) return decision } @@ -525,6 +290,7 @@ func (p *Participant) beginNextInstance(nextInstance uint64) { p.committeeProvider.EvictCommitteesBefore(nextInstance - 1) } p.currentInstance = nextInstance + p.validator.NotifyProgress(p.currentInstance, 0, INITIAL_PHASE) } func (p *Participant) terminated() bool { diff --git a/gpbft/participant_test.go b/gpbft/participant_test.go index 6206f7fb..d7d42dc0 100644 --- a/gpbft/participant_test.go +++ b/gpbft/participant_test.go @@ -82,6 +82,9 @@ func newParticipantTestSubject(t *testing.T, seed int64, instance uint64) *parti })) subject.host = gpbft.NewMockHost(t) + // Expect ad-hoc calls to getting network name as such calls bear no significance + // to correctness. + subject.host.On("NetworkName").Return(subject.networkName).Maybe() subject.Participant, err = gpbft.NewParticipant(subject.host, gpbft.WithTracer(subject), gpbft.WithDelta(delta), @@ -100,7 +103,6 @@ func (pt *participantTestSubject) expectBeginInstance() { pt.host.On("GetProposal", pt.instance).Return(pt.supplementalData, pt.canonicalChain, nil) pt.host.On("GetCommittee", pt.instance).Return(&gpbft.Committee{PowerTable: pt.powerTable, Beacon: pt.beacon}, nil).Once() pt.host.On("Time").Return(pt.time) - pt.host.On("NetworkName").Return(pt.networkName).Maybe() // We need to use `Maybe` here because `MarshalPayloadForSigning` may be called // an additional time for verification. // Without the `Maybe` the tests immediately fails here: @@ -517,7 +519,7 @@ func TestParticipant_ValidateMessage(t *testing.T) { }, } }, - wantErr: "committee not available", + wantErr: gpbft.ErrValidationNoCommittee.Error(), }, { name: "zero message is error", diff --git a/gpbft/validator.go b/gpbft/validator.go new file mode 100644 index 00000000..07d22156 --- /dev/null +++ b/gpbft/validator.go @@ -0,0 +1,294 @@ +package gpbft + +import ( + "bytes" + "context" + "fmt" + "math" + "sync/atomic" + + "github.com/filecoin-project/go-f3/internal/caching" + "go.opentelemetry.io/otel/metric" +) + +var ( + _ MessageValidator = (*cachingValidator)(nil) + _ ProgressObserver = (*cachingValidator)(nil) +) + +// ProgressObserver defines an interface for observing and being notified about +// the progress of a GPBFT instance as it advances through different instance, +// rounds or phases. +type ProgressObserver interface { + // NotifyProgress is called to notify the observer about the progress of GPBFT + // instance, round or phase. + NotifyProgress(instance, round uint64, phase Phase) +} + +type cachingValidator struct { + progress atomic.Pointer[progress] + // cache is a bounded cache of messages that have already been validated, grouped + // by instance. + cache *caching.GroupedSet + committeeLookback uint64 + committeeProvider CommitteeProvider + networkName NetworkName + signing Signatures +} + +type progress struct { + id uint64 + round uint64 + phase Phase +} + +func newValidator(host Host, cp CommitteeProvider, committeeLookback uint64, maxCachedInstances int, maxCachedMsgsPerInstance int) *cachingValidator { + validator := cachingValidator{ + cache: caching.NewGroupedSet(maxCachedInstances, maxCachedMsgsPerInstance), + committeeProvider: cp, + committeeLookback: committeeLookback, + networkName: host.NetworkName(), + signing: host, + } + // Default to instance 0, round 0 and INITIAL phase. + validator.progress.Store(&progress{}) + return &validator +} + +func (v *cachingValidator) NotifyProgress(instance uint64, round uint64, phase Phase) { + v.progress.Store(&progress{ + id: instance, + round: round, + phase: phase, + }) + if instance > 0 { + // Remove cache of validated messages for instance that are older than the + // previous instance. + v.cache.RemoveGroupsLessThan(instance - 1) + } +} + +// ValidateMessage checks if the given message is valid. If invalid, an error is +// returned. ErrValidationInvalid indicates that the message will never be valid +// invalid and may be safely dropped. +func (v *cachingValidator) ValidateMessage(msg *GMessage) (valid ValidatedMessage, err error) { + if msg == nil { + return nil, ErrValidationInvalid + } + + // Infer whether to proceed validating the message relative to the current instance. + switch currentInstance := v.progress.Load().id; { + case msg.Vote.Instance >= currentInstance+v.committeeLookback: + // Message is beyond current + committee lookback. + return nil, ErrValidationNoCommittee + case msg.Vote.Instance >= currentInstance, + msg.Vote.Instance == currentInstance-1 && msg.Vote.Phase == DECIDE_PHASE: + // Only proceed to validate the message if it: + // * belongs to an instance within the range of current to current + committee lookback, or + // * is a DECIDE message belonging to previous instance. + default: + // Message belongs to an instance older than the previous instance. + return nil, ErrValidationTooOld + } + + var buf bytes.Buffer + var cacheMessage bool + if err := msg.MarshalCBOR(&buf); err != nil { + log.Errorw("failed to marshal message for caching", "err", err) + } else if alreadyValidated, err := v.cache.Contains(msg.Vote.Instance, messageCacheNamespace, buf.Bytes()); err != nil { + log.Errorw("failed to check already validated messages", "err", err) + } else if alreadyValidated { + metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheHit, attrCacheKindMessage)) + return &validatedMessage{msg: msg}, nil + } else { + cacheMessage = true + metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheMiss, attrCacheKindMessage)) + } + + comt, err := v.committeeProvider.GetCommittee(msg.Vote.Instance) + if err != nil { + return nil, ErrValidationNoCommittee + } + // Check sender is eligible. + senderPower, senderPubKey := comt.PowerTable.Get(msg.Sender) + if senderPower == 0 { + return nil, fmt.Errorf("sender %d with zero power or not in power table: %w", msg.Sender, ErrValidationInvalid) + } + + // Check that message value is a valid chain. + if err := msg.Vote.Value.Validate(); err != nil { + return nil, fmt.Errorf("invalid message vote value chain: %w: %w", err, ErrValidationInvalid) + } + + // Check phase-specific constraints. + switch msg.Vote.Phase { + case QUALITY_PHASE: + if msg.Vote.Round != 0 { + return nil, fmt.Errorf("unexpected round %d for quality phase: %w", msg.Vote.Round, ErrValidationInvalid) + } + if msg.Vote.Value.IsZero() { + return nil, fmt.Errorf("unexpected zero value for quality phase: %w", ErrValidationInvalid) + } + case CONVERGE_PHASE: + if msg.Vote.Round == 0 { + return nil, fmt.Errorf("unexpected round 0 for converge phase: %w", ErrValidationInvalid) + } + if msg.Vote.Value.IsZero() { + return nil, fmt.Errorf("unexpected zero value for converge phase: %w", ErrValidationInvalid) + } + if !VerifyTicket(v.networkName, comt.Beacon, msg.Vote.Instance, msg.Vote.Round, senderPubKey, v.signing, msg.Ticket) { + return nil, fmt.Errorf("failed to verify ticket from %v: %w", msg.Sender, ErrValidationInvalid) + } + case DECIDE_PHASE: + if msg.Vote.Round != 0 { + return nil, fmt.Errorf("unexpected non-zero round %d for decide phase: %w", msg.Vote.Round, ErrValidationInvalid) + } + if msg.Vote.Value.IsZero() { + return nil, fmt.Errorf("unexpected zero value for decide phase: %w", ErrValidationInvalid) + } + case PREPARE_PHASE, COMMIT_PHASE: + // No additional checks for PREPARE and COMMIT. + default: + return nil, fmt.Errorf("invalid vote phase: %d: %w", msg.Vote.Phase, ErrValidationInvalid) + } + + // Check vote signature. + sigPayload := v.signing.MarshalPayloadForSigning(v.networkName, &msg.Vote) + if err := v.signing.Verify(senderPubKey, sigPayload, msg.Signature); err != nil { + return nil, fmt.Errorf("invalid signature on %v, %v: %w", msg, err, ErrValidationInvalid) + } + + // Check justification. + needsJustification := !(msg.Vote.Phase == QUALITY_PHASE || + (msg.Vote.Phase == PREPARE_PHASE && msg.Vote.Round == 0) || + (msg.Vote.Phase == COMMIT_PHASE && msg.Vote.Value.IsZero())) + + if needsJustification { + if err := v.validateJustification(msg, comt); err != nil { + return nil, fmt.Errorf("%v: %w", err, ErrValidationInvalid) + } + } else if msg.Justification != nil { + return nil, fmt.Errorf("message %v has unexpected justification: %w", msg, ErrValidationInvalid) + } + + if cacheMessage { + if _, err := v.cache.Add(msg.Vote.Instance, messageCacheNamespace, buf.Bytes()); err != nil { + log.Warnw("failed to cache to already validated message", "err", err) + } + } + return &validatedMessage{msg: msg}, nil +} + +func (v *cachingValidator) validateJustification(msg *GMessage, comt *Committee) error { + if msg.Justification == nil { + return fmt.Errorf("message for phase %v round %v has no justification", msg.Vote.Phase, msg.Vote.Round) + } + + // Only cache the justification if: + // * marshalling it was successful, and + // * it is not already present in the cache. + var cacheJustification bool + var buf bytes.Buffer + if err := msg.Justification.MarshalCBOR(&buf); err != nil { + log.Errorw("failed to marshal justification for caching", "err", err) + } else if alreadyValidated, err := v.cache.Contains(msg.Vote.Instance, justificationCacheNamespace, buf.Bytes()); err != nil { + log.Warnw("failed to check if justification is already cached", "err", err) + } else if alreadyValidated { + metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheHit, attrCacheKindJustification)) + return nil + } else { + cacheJustification = true + metrics.validationCache.Add(context.TODO(), 1, metric.WithAttributes(attrCacheMiss, attrCacheKindJustification)) + } + + // Check that the justification is for the same instance. + if msg.Vote.Instance != msg.Justification.Vote.Instance { + return fmt.Errorf("message with instanceID %v has evidence from instanceID: %v", msg.Vote.Instance, msg.Justification.Vote.Instance) + } + if !msg.Vote.SupplementalData.Eq(&msg.Justification.Vote.SupplementalData) { + return fmt.Errorf("message and justification have inconsistent supplemental data: %v != %v", msg.Vote.SupplementalData, msg.Justification.Vote.SupplementalData) + } + // Check that justification vote value is a valid chain. + if err := msg.Justification.Vote.Value.Validate(); err != nil { + return fmt.Errorf("invalid justification vote value chain: %w", err) + } + + // Check every remaining field of the justification, according to the phase requirements. + // This map goes from the message phase to the expected justification phase(s), + // to the required vote values for justification by that phase. + // Anything else is disallowed. + expectations := map[Phase]map[Phase]struct { + Round uint64 + Value ECChain + }{ + // CONVERGE is justified by a strong quorum of COMMIT for bottom, + // or a strong quorum of PREPARE for the same value, from the previous round. + CONVERGE_PHASE: { + COMMIT_PHASE: {msg.Vote.Round - 1, ECChain{}}, + PREPARE_PHASE: {msg.Vote.Round - 1, msg.Vote.Value}, + }, + // PREPARE is justified by the same rules as CONVERGE (in rounds > 0). + PREPARE_PHASE: { + COMMIT_PHASE: {msg.Vote.Round - 1, ECChain{}}, + PREPARE_PHASE: {msg.Vote.Round - 1, msg.Vote.Value}, + }, + // COMMIT is justified by strong quorum of PREPARE from the same round with the same value. + COMMIT_PHASE: { + PREPARE_PHASE: {msg.Vote.Round, msg.Vote.Value}, + }, + // DECIDE is justified by strong quorum of COMMIT with the same value. + // The DECIDE message doesn't specify a round. + DECIDE_PHASE: { + COMMIT_PHASE: {math.MaxUint64, msg.Vote.Value}, + }, + } + + if expectedPhases, ok := expectations[msg.Vote.Phase]; ok { + if expected, ok := expectedPhases[msg.Justification.Vote.Phase]; ok { + if msg.Justification.Vote.Round != expected.Round && expected.Round != math.MaxUint64 { + return fmt.Errorf("message %v has justification from wrong round %d", msg, msg.Justification.Vote.Round) + } + if !msg.Justification.Vote.Value.Eq(expected.Value) { + return fmt.Errorf("message %v has justification for a different value: %v", msg, msg.Justification.Vote.Value) + } + } else { + return fmt.Errorf("message %v has justification with unexpected phase: %v", msg, msg.Justification.Vote.Phase) + } + } else { + return fmt.Errorf("message %v has unexpected phase for justification", msg) + } + + // Check justification power and signature. + var justificationPower int64 + signers := make([]int, 0) + if err := msg.Justification.Signers.ForEach(func(bit uint64) error { + if int(bit) >= len(comt.PowerTable.Entries) { + return fmt.Errorf("invalid signer index: %d", bit) + } + power := comt.PowerTable.ScaledPower[bit] + if power == 0 { + return fmt.Errorf("signer with ID %d has no power", comt.PowerTable.Entries[bit].ID) + } + justificationPower += power + signers = append(signers, int(bit)) + return nil + }); err != nil { + return fmt.Errorf("failed to iterate over signers: %w", err) + } + + if !IsStrongQuorum(justificationPower, comt.PowerTable.ScaledTotal) { + return fmt.Errorf("message %v has justification with insufficient power: %v", msg, justificationPower) + } + + payload := v.signing.MarshalPayloadForSigning(v.networkName, &msg.Justification.Vote) + if err := comt.AggregateVerifier.VerifyAggregate(signers, payload, msg.Justification.Signature); err != nil { + return fmt.Errorf("verification of the aggregate failed: %+v: %w", msg.Justification, err) + } + + if cacheJustification { + if _, err := v.cache.Add(msg.Vote.Instance, justificationCacheNamespace, buf.Bytes()); err != nil { + log.Warnw("failed to cache to already validated justification", "err", err) + } + } + return nil +} diff --git a/internal/caching/grouped_cache.go b/internal/caching/grouped_cache.go index e67741c9..16c94ded 100644 --- a/internal/caching/grouped_cache.go +++ b/internal/caching/grouped_cache.go @@ -79,10 +79,16 @@ func (gs *GroupedSet) Add(g uint64, namespace, v []byte) (bool, error) { return !contained, nil } -func (gs *GroupedSet) RemoveGroup(group uint64) bool { +func (gs *GroupedSet) RemoveGroupsLessThan(group uint64) bool { gs.mu.Lock() defer gs.mu.Unlock() - return gs.evict(group) + var evictedAtLeastOne bool + for g := range gs.groups { + if g < group { + evictedAtLeastOne = gs.evict(g) || evictedAtLeastOne + } + } + return evictedAtLeastOne } func (gs *GroupedSet) evict(group uint64) bool { diff --git a/internal/caching/grouped_cache_test.go b/internal/caching/grouped_cache_test.go index 5d2af6e3..19665c92 100644 --- a/internal/caching/grouped_cache_test.go +++ b/internal/caching/grouped_cache_test.go @@ -101,10 +101,10 @@ func TestGroupedSet(t *testing.T) { t.Run("explicit group removal is removed", func(t *testing.T) { // Assert group 1 exists and removed - require.True(t, subject.RemoveGroup(1)) + require.True(t, subject.RemoveGroupsLessThan(2)) // Assert group 1 is already removed - require.False(t, subject.RemoveGroup(1)) + require.False(t, subject.RemoveGroupsLessThan(2)) contains, err := subject.Contains(g1v1()) require.NoError(t, err)