diff --git a/dot/parachain/collator-protocol/messages/protocol_messages.go b/dot/parachain/collator-protocol/messages/protocol_messages.go index 9747a4f94f..5c09cccf15 100644 --- a/dot/parachain/collator-protocol/messages/protocol_messages.go +++ b/dot/parachain/collator-protocol/messages/protocol_messages.go @@ -39,7 +39,7 @@ func (mvdt *CollationProtocol) SetValue(value any) (err error) { } } -func (mvdt CollationProtocol) IndexValue() (index uint, value any, err error) { +func (mvdt *CollationProtocol) IndexValue() (index uint, value any, err error) { switch mvdt.inner.(type) { case CollatorProtocolMessage: return 0, mvdt.inner, nil @@ -48,12 +48,12 @@ func (mvdt CollationProtocol) IndexValue() (index uint, value any, err error) { return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue } -func (mvdt CollationProtocol) Value() (value any, err error) { +func (mvdt *CollationProtocol) Value() (value any, err error) { _, value, err = mvdt.IndexValue() return } -func (mvdt CollationProtocol) ValueAt(index uint) (value any, err error) { +func (mvdt *CollationProtocol) ValueAt(index uint) (value any, err error) { switch index { case 0: return *new(CollatorProtocolMessage), nil @@ -68,7 +68,7 @@ func NewCollationProtocol() CollationProtocol { } type CollatorProtocolMessageValues interface { - Declare | AdvertiseCollation | CollationSeconded + Declare | AdvertiseCollation | AdvertiseCollationV2 | CollationSeconded } // CollatorProtocolMessage represents Network messages used by the collator protocol subsystem @@ -90,6 +90,10 @@ func (mvdt *CollatorProtocolMessage) SetValue(value any) (err error) { setCollatorProtocolMessage(mvdt, value) return + case AdvertiseCollationV2: + setCollatorProtocolMessage(mvdt, value) + return + case CollationSeconded: setCollatorProtocolMessage(mvdt, value) return @@ -99,7 +103,7 @@ func (mvdt *CollatorProtocolMessage) SetValue(value any) (err error) { } } -func (mvdt CollatorProtocolMessage) IndexValue() (index uint, value any, err error) { +func (mvdt *CollatorProtocolMessage) IndexValue() (index uint, value any, err error) { switch mvdt.inner.(type) { case Declare: return 0, mvdt.inner, nil @@ -107,6 +111,9 @@ func (mvdt CollatorProtocolMessage) IndexValue() (index uint, value any, err err case AdvertiseCollation: return 1, mvdt.inner, nil + case AdvertiseCollationV2: + return 2, mvdt.inner, nil + case CollationSeconded: return 4, mvdt.inner, nil @@ -114,12 +121,12 @@ func (mvdt CollatorProtocolMessage) IndexValue() (index uint, value any, err err return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue } -func (mvdt CollatorProtocolMessage) Value() (value any, err error) { +func (mvdt *CollatorProtocolMessage) Value() (value any, err error) { _, value, err = mvdt.IndexValue() return } -func (mvdt CollatorProtocolMessage) ValueAt(index uint) (value any, err error) { +func (mvdt *CollatorProtocolMessage) ValueAt(index uint) (value any, err error) { switch index { case 0: return *new(Declare), nil @@ -127,6 +134,9 @@ func (mvdt CollatorProtocolMessage) ValueAt(index uint) (value any, err error) { case 1: return *new(AdvertiseCollation), nil + case 2: + return *new(AdvertiseCollationV2), nil + case 4: return *new(CollationSeconded), nil @@ -153,6 +163,12 @@ type Declare struct { // It can only be sent once the peer has declared that they are a collator with given ID type AdvertiseCollation common.Hash +type AdvertiseCollationV2 struct { + RelayParent common.Hash `scale:"1"` + CandidateHash parachaintypes.CandidateHash `scale:"2"` + ParentHeadDataHash common.Hash `scale:"3"` +} + // CollationSeconded represents that a collation sent to a validator was seconded. type CollationSeconded struct { RelayParent common.Hash `scale:"1"` @@ -160,12 +176,12 @@ type CollationSeconded struct { } // Type returns CollationMsgType -func (CollationProtocol) Type() network.MessageType { +func (*CollationProtocol) Type() network.MessageType { return network.CollationMsgType } // Hash returns the hash of the CollationProtocolV1 -func (cp CollationProtocol) Hash() (common.Hash, error) { +func (cp *CollationProtocol) Hash() (common.Hash, error) { // scale encode each extrinsic encMsg, err := cp.Encode() if err != nil { @@ -176,7 +192,7 @@ func (cp CollationProtocol) Hash() (common.Hash, error) { } // Encode a collator protocol message using scale encode -func (cp CollationProtocol) Encode() ([]byte, error) { +func (cp *CollationProtocol) Encode() ([]byte, error) { enc, err := scale.Marshal(cp) if err != nil { return nil, err diff --git a/dot/parachain/collator-protocol/validator-side/collation_fetching.go b/dot/parachain/collator-protocol/validator-side/collation_fetching.go index 25cec0a8eb..6e079ba385 100644 --- a/dot/parachain/collator-protocol/validator-side/collation_fetching.go +++ b/dot/parachain/collator-protocol/validator-side/collation_fetching.go @@ -19,9 +19,11 @@ const ( collationFetchingMaxResponseSize = maxPoVSize + 10000 // 10MB ) +type CollationFetchingRequest = CollationFetchingRequestV1 + // CollationFetchingRequest represents a request to retrieve // the advertised collation at the specified relay chain block. -type CollationFetchingRequest struct { +type CollationFetchingRequestV1 struct { // Relay parent we want a collation for RelayParent common.Hash `scale:"1"` @@ -29,8 +31,23 @@ type CollationFetchingRequest struct { ParaID parachaintypes.ParaID `scale:"2"` } +// CollationFetchingRequestV2 represents the enhanced request format +// with candidate hash +type CollationFetchingRequestV2 struct { + // Relay parent we want a collation for + RelayParent common.Hash `scale:"1"` + // Parachain id of the collation + ParaID parachaintypes.ParaID `scale:"2"` + // Hash of the candidate we want a collation for + CandidateHash common.Hash `scale:"3"` +} + // Encode returns the SCALE encoding of the CollationFetchingRequest -func (c CollationFetchingRequest) Encode() ([]byte, error) { +func (c CollationFetchingRequestV1) Encode() ([]byte, error) { + return scale.Marshal(c) +} + +func (c CollationFetchingRequestV2) Encode() ([]byte, error) { return scale.Marshal(c) } @@ -58,7 +75,7 @@ func (mvdt *CollationFetchingResponse) SetValue(value any) (err error) { } } -func (mvdt CollationFetchingResponse) IndexValue() (index uint, value any, err error) { +func (mvdt *CollationFetchingResponse) IndexValue() (index uint, value any, err error) { switch mvdt.inner.(type) { case parachaintypes.Collation: return 0, mvdt.inner, nil @@ -67,12 +84,12 @@ func (mvdt CollationFetchingResponse) IndexValue() (index uint, value any, err e return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue } -func (mvdt CollationFetchingResponse) Value() (value any, err error) { +func (mvdt *CollationFetchingResponse) Value() (value any, err error) { _, value, err = mvdt.IndexValue() return } -func (mvdt CollationFetchingResponse) ValueAt(index uint) (value any, err error) { +func (mvdt *CollationFetchingResponse) ValueAt(index uint) (value any, err error) { switch index { case 0: return *new(parachaintypes.Collation), nil diff --git a/dot/parachain/collator-protocol/validator-side/message.go b/dot/parachain/collator-protocol/validator-side/message.go index e1b41dcd5e..3fde9b9653 100644 --- a/dot/parachain/collator-protocol/validator-side/message.go +++ b/dot/parachain/collator-protocol/validator-side/message.go @@ -17,7 +17,6 @@ import ( "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" - "github.com/ChainSafe/gossamer/pkg/scale" "github.com/libp2p/go-libp2p/core/peer" ) @@ -27,18 +26,6 @@ const ( CollationSeconded ) -//nolint:unused -func decodeCollationMessage(in []byte) (network.NotificationsMessage, error) { - collationMessage := collatorprotocolmessages.CollationProtocol{} - - err := scale.Unmarshal(in, &collationMessage) - if err != nil { - return nil, fmt.Errorf("cannot decode message: %w", err) - } - - return &collationMessage, nil -} - type ProspectiveCandidate struct { CandidateHash parachaintypes.CandidateHash ParentHeadDataHash common.Hash @@ -156,9 +143,17 @@ func (cpvs *CollatorProtocolValidatorSide) fetchCollation(pendingCollation Pendi return ErrNotAdvertised } - // TODO #4711 + // Convert parachaintypes.CandidateHash to *common.Hash for requestCollation + var candidateHashCommon *common.Hash + if candidateHash != nil { + candidateHashCommon = &candidateHash.Value // Extract the common.Hash from CandidateHash + } + // TODO: Add it to collation_fetch_timeouts if we can't process this in timeout time. + // state + // .collation_fetch_timeouts + // .push(timeout(id.clone(), candidate_hash, relay_parent).boxed()); collation, err := cpvs.requestCollation(pendingCollation.RelayParent, pendingCollation.ParaID, - pendingCollation.PeerID) + pendingCollation.PeerID, candidateHashCommon) if err != nil { return fmt.Errorf("requesting collation: %w", err) } @@ -423,8 +418,28 @@ func (cpvs *CollatorProtocolValidatorSide) processCollatorProtocolMessage(sender if err != nil { return fmt.Errorf("handling v1 advertisement: %w", err) } + // TODO: + // - tracks advertisements received and the source (peer id) of the advertisement + // - accept one advertisement per collator per source per relay-parent + case 2: // AdvertiseCollationV2 + advertiseCollationV2Message, ok := collatorProtocolMessageV.(collatorprotocolmessages.AdvertiseCollationV2) + if !ok { + return errors.New("expected message to be advertise collation v2") + } + prospectiveCandidate := &ProspectiveCandidate{ + CandidateHash: advertiseCollationV2Message.CandidateHash, + ParentHeadDataHash: advertiseCollationV2Message.ParentHeadDataHash, + } - case CollationSeconded: + err := cpvs.handleAdvertisement(advertiseCollationV2Message.RelayParent, sender, prospectiveCandidate) + if err != nil { + return fmt.Errorf("handling v2 advertisement: %w", err) + } + + logger.Debugf("Peer %s sent V2 advertisement, upgrading to ProtocolV2", sender) + cpvs.setPeerProtocolVersion(sender, ProtocolV2) + + case 4: // CollationSeconded logger.Errorf("unexpected collation seconded message from peer %s, decreasing its reputation", sender) cpvs.SubSystemToOverseer <- networkbridgemessages.ReportPeer{ PeerID: sender, @@ -438,20 +453,10 @@ func (cpvs *CollatorProtocolValidatorSide) processCollatorProtocolMessage(sender return nil } -//nolint:unused -func getCollatorHandshake() (network.Handshake, error) { - return &collatorHandshake{}, nil -} - func decodeCollatorHandshake(_ []byte) (network.Handshake, error) { return &collatorHandshake{}, nil } -//nolint:unused -func validateCollatorHandshake(_ peer.ID, _ network.Handshake) error { - return nil -} - type collatorHandshake struct{} // String formats a collatorHandshake as a string diff --git a/dot/parachain/collator-protocol/validator-side/validator_side.go b/dot/parachain/collator-protocol/validator-side/validator_side.go index 01be466b94..d635aabaae 100644 --- a/dot/parachain/collator-protocol/validator-side/validator_side.go +++ b/dot/parachain/collator-protocol/validator-side/validator_side.go @@ -75,13 +75,16 @@ func New(net Network, protocolID protocol.ID, overseerChan chan<- any, Keystore: ks, SubSystemToOverseer: overseerChan, collationFetchingReqResProtocol: collationFetchingReqResProtocol, + collationRequests: make(chan CollationRequestInfo, 100), peerData: make(map[peer.ID]PeerData), + peerVersions: make(map[peer.ID]PeerProtocolVersion), currentAssignments: make(map[parachaintypes.ParaID]uint), perRelayParent: make(map[common.Hash]PerRelayParent), BlockedAdvertisements: make(map[string][]blockedAdvertisement), implicitView: util.NewBackingImplicitView(blockState, nil), activeLeaves: make(map[common.Hash]parachaintypes.ProspectiveParachainsMode), fetchedCandidates: make(map[string]CollationEvent), + requestCompletions: make(chan string, 100), } } @@ -89,6 +92,10 @@ func (cpvs *CollatorProtocolValidatorSide) Run( ctx context.Context, overseerToSubSystem <-chan any) { inactivityTicker := time.NewTicker(activityPoll) + //Track active requests for timeout handling + requestCleanupTicker := time.NewTicker(10 * time.Millisecond) + activeRequests := make(map[string]CollationRequestInfo) + for { select { // TODO: #4697: use util.ReputationAggregator @@ -105,13 +112,47 @@ func (cpvs *CollatorProtocolValidatorSide) Run( case <-inactivityTicker.C: // TODO: disconnect inactive peers, Issue #4256 + case requestInfo := <-cpvs.collationRequests: + // For now, just log that we received a collation request + logger.Debugf("Tracking collation request: %s for para %d from peer %s", + requestInfo.RequestID, requestInfo.ParaID, requestInfo.PeerID) + activeRequests[requestInfo.RequestID] = requestInfo + + case <-requestCleanupTicker.C: + now := time.Now() + for requestID, requestInfo := range activeRequests { + // Check if request is older than maxUnsharedDownloadTime + if now.Sub(requestInfo.RequestTime) > maxUnsharedDownloadTime { + logger.Debugf("Request %s expired after %v, cancelling network request", + requestID, now.Sub(requestInfo.RequestTime)) + requestInfo.Cancel() + delete(activeRequests, requestID) + } + } + + case requestID := <-cpvs.requestCompletions: + // Remove completed request from tracking + delete(activeRequests, requestID) + logger.Debugf("Request %s completed successfully", requestID) + case unfetchedCollation := <-cpvs.unfetchedCollation: + // TODO: If we can't get the collation from given collator within MAX_UNSHARED_DOWNLOAD_TIME, + // we will start another one from the next collator. + var candidateHash *common.Hash + if unfetchedCollation.PendingCollation.ProspectiveCandidate != nil { + candidateHash = &unfetchedCollation.PendingCollation.ProspectiveCandidate.CandidateHash.Value + } + + var candidateHashParam *parachaintypes.CandidateHash + if candidateHash != nil { + candidateHashParam = ¶chaintypes.CandidateHash{Value: *candidateHash} + } // check if this peer id has advertised this relay parent peerData := cpvs.peerData[unfetchedCollation.PendingCollation.PeerID] - if peerData.HasAdvertised(unfetchedCollation.PendingCollation.RelayParent, nil) { + if peerData.HasAdvertised(unfetchedCollation.PendingCollation.RelayParent, candidateHashParam) { // if so request collation from this peer id collation, err := cpvs.requestCollation(unfetchedCollation.PendingCollation.RelayParent, - unfetchedCollation.PendingCollation.ParaID, unfetchedCollation.PendingCollation.PeerID) + unfetchedCollation.PendingCollation.ParaID, unfetchedCollation.PendingCollation.PeerID, candidateHash) if err != nil { logger.Errorf("fetching collation: %w", err) } @@ -288,6 +329,19 @@ func (cpvs *CollatorProtocolValidatorSide) assignIncoming(relayParent common.Has return nil } +func (cpvs *CollatorProtocolValidatorSide) getPeerProtocolVersion(peerID peer.ID) PeerProtocolVersion { + if version, exists := cpvs.peerVersions[peerID]; exists { + return version + } + // Default to V1 for backward compatibility + return ProtocolV1 +} + +// Add method to detect peer version during handshake or connection +func (cpvs *CollatorProtocolValidatorSide) setPeerProtocolVersion(peerID peer.ID, version PeerProtocolVersion) { + cpvs.peerVersions[peerID] = version +} + func findValidatorGroup(validatorIndex parachaintypes.ValidatorIndex, validatorGroups parachaintypes.ValidatorGroups, ) (parachaintypes.GroupIndex, bool) { for groupIndex, validatorGroup := range validatorGroups.Validators { @@ -352,23 +406,81 @@ func (*CollatorProtocolValidatorSide) Stop() { // - check if the requested collation is in our view // TODO: #4711 func (cpvs *CollatorProtocolValidatorSide) requestCollation(relayParent common.Hash, - paraID parachaintypes.ParaID, peerID peer.ID) (*parachaintypes.Collation, error) { + paraID parachaintypes.ParaID, peerID peer.ID, candidateHash *common.Hash) (*parachaintypes.Collation, error) { _, ok := cpvs.perRelayParent[relayParent] if !ok { return nil, ErrOutOfView } - // make collation fetching request - collationFetchingRequest := CollationFetchingRequest{ + ctx, cancel := context.WithTimeout(context.Background(), maxUnsharedDownloadTime) // MAX_UNSHARED_DOWNLOAD_TIME + defer cancel() + + requestInfo := CollationRequestInfo{ + PeerID: peerID, RelayParent: relayParent, ParaID: paraID, + RequestTime: time.Now(), + RequestID: fmt.Sprintf("%s-%d-%s", relayParent.String(), paraID, peerID.String()), + Cancel: cancel, + } + + // Try to send to channel (non-blocking) + select { + case cpvs.collationRequests <- requestInfo: + //Successfully sent + default: + // Channel full - cancel and return error + cancel() + return nil, fmt.Errorf("collation requests channel is full") + } + + peerVersion := cpvs.getPeerProtocolVersion(peerID) + + var requestMessage network.Message + + switch peerVersion { + case ProtocolV1: + requestMessage = CollationFetchingRequestV1{ + RelayParent: relayParent, + ParaID: paraID, + } + case ProtocolV2: + // For V2, we need the candidate hash - this should come from the advertisement + // For now, use zero hash as placeholder (you'll need to pass this as parameter) + if candidateHash != nil { + requestMessage = CollationFetchingRequestV2{ + RelayParent: relayParent, + ParaID: paraID, + CandidateHash: *candidateHash, // TODO: Get from advertisement + } + } else { + // Should this be an error instead? + return nil, fmt.Errorf("V2 peer requires candidate hash") + } + default: + // Fallback to V1 + requestMessage = CollationFetchingRequestV1{ + RelayParent: relayParent, + ParaID: paraID, + } } collationFetchingResponse := NewCollationFetchingResponse() - err := cpvs.collationFetchingReqResProtocol.Do(peerID, collationFetchingRequest, &collationFetchingResponse) - if err != nil { - return nil, fmt.Errorf("collation fetching request failed: %w", err) + + done := make(chan error, 1) + go func() { + err := cpvs.collationFetchingReqResProtocol.Do(peerID, requestMessage, &collationFetchingResponse) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return nil, fmt.Errorf("collation fetching request failed: %w", err) + } + case <-ctx.Done(): + return nil, fmt.Errorf("collation fetching request timed out after %v", maxUnsharedDownloadTime) } v, err := collationFetchingResponse.Value() @@ -380,6 +492,13 @@ func (cpvs *CollatorProtocolValidatorSide) requestCollation(relayParent common.H return nil, fmt.Errorf("collation fetching response value expected: CollationVDT, got: %T", v) } + // Try to notify completion (non-blocking) + select { + case cpvs.requestCompletions <- requestInfo.RequestID: + default: + // Channel full, but that's ok - cleanup will handle it + } + return &collation, nil } @@ -401,6 +520,13 @@ type PeerData struct { state PeerStateInfo } +type PeerProtocolVersion int + +const ( + ProtocolV1 PeerProtocolVersion = 1 + ProtocolV2 PeerProtocolVersion = 2 +) + func (peerData *PeerData) HasAdvertised( relayParent common.Hash, mayBeCandidateHash *parachaintypes.CandidateHash) bool { @@ -578,6 +704,14 @@ type CollatorProtocolValidatorSide struct { // track all active collators and their data peerData map[peer.ID]PeerData + // Track protocol versions for each peer + peerVersions map[peer.ID]PeerProtocolVersion + + // Channel that gets populated when new collation requests are sent + collationRequests chan CollationRequestInfo + + requestCompletions chan string + // Parachains we're currently assigned to. With async backing enabled // this includes assignments from the implicit view. currentAssignments map[parachaintypes.ParaID]uint @@ -614,6 +748,15 @@ type CollatorProtocolValidatorSide struct { fetchedCandidates map[string]CollationEvent } +type CollationRequestInfo struct { + PeerID peer.ID + RelayParent common.Hash + ParaID parachaintypes.ParaID + RequestTime time.Time + RequestID string + Cancel context.CancelFunc +} + // Identifier of a fetched collation type fetchedCollationInfo struct { // Candidate's relay parent @@ -718,6 +861,7 @@ func (cpvs *CollatorProtocolValidatorSide) handleNetworkBridgeEvents(msg any) er } case networkbridgeevents.PeerDisconnected: delete(cpvs.peerData, msg.PeerID) + delete(cpvs.peerVersions, msg.PeerID) case networkbridgeevents.NewGossipTopology: // NOTE: This won't happen case networkbridgeevents.PeerViewChange: diff --git a/dot/parachain/collator-protocol/validator-side/validator_side_test.go b/dot/parachain/collator-protocol/validator-side/validator_side_test.go index 135a6cf85e..5f7adebb41 100644 --- a/dot/parachain/collator-protocol/validator-side/validator_side_test.go +++ b/dot/parachain/collator-protocol/validator-side/validator_side_test.go @@ -6,6 +6,7 @@ package validatorside import ( "sync" "testing" + "time" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/parachain/backing" @@ -508,3 +509,174 @@ func TestPeerViewChange(t *testing.T) { _, ok := cpvs.peerData[peer.ID("peer1")].state.CollatingPeerState.advertisements[common.Hash{0x01}] require.False(t, ok) } + +type testRequestMaker struct { + delay time.Duration +} + +func (s *testRequestMaker) Do(peerID peer.ID, message network.Message, responseMessage network.ResponseMessage) error { + time.Sleep(s.delay) + return nil +} + +func TestRequestCollation_Timeout(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Use the existing MockNetwork + mockNet := NewMockNetwork(ctrl) + + // Create a mock RequestMaker that delays + mockRequestMaker := &testRequestMaker{ + delay: 2 * time.Second, + } + + // Set up the mock to return our slow RequestMaker + mockNet.EXPECT(). + GetRequestResponseProtocol(gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockRequestMaker). + AnyTimes() + + // Create the validator side with the mock + protocolID := "/test/collations/1" + subsystemToOverseer := make(chan any, 10) + + cpvs := New(mockNet, protocol.ID(protocolID), subsystemToOverseer, nil, nil) + + // Add a relay parent to pass the view check + cpvs.perRelayParent = map[common.Hash]PerRelayParent{ + {0x01}: {}, + } + + relayParent := common.Hash{0x01} + paraID := parachaintypes.ParaID(123) + peerID := peer.ID("test-peer") + + start := time.Now() + + // This should timeout after ~1 second + collation, err := cpvs.requestCollation(relayParent, paraID, peerID, nil) + + elapsed := time.Since(start) + + // Verify timeout behavior + require.Nil(t, collation) + require.Error(t, err) + require.Contains(t, err.Error(), "timed out") + require.Greater(t, elapsed, 90*time.Millisecond) + require.Less(t, elapsed, 1200*time.Millisecond) +} + +func TestRequestCollation_Success(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Use the existing MockNetwork + mockNet := NewMockNetwork(ctrl) + + // Create a mock RequestMaker that responds quickly + mockRequestMaker := &testRequestMaker{ + delay: 100 * time.Millisecond, // Much less than 1 second + } + + // Set up the mock to return our fast RequestMaker + mockNet.EXPECT(). + GetRequestResponseProtocol(gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockRequestMaker). + AnyTimes() + + // Create the validator side with the mock + protocolID := "/test/collations/1" + subsystemToOverseer := make(chan any, 10) + + cpvs := New(mockNet, protocol.ID(protocolID), subsystemToOverseer, nil, nil) + + // Add a relay parent to pass the view check + cpvs.perRelayParent = map[common.Hash]PerRelayParent{ + {0x01}: {}, + } + + relayParent := common.Hash{0x01} + paraID := parachaintypes.ParaID(123) + peerID := peer.ID("test-peer") + + start := time.Now() + _, err := cpvs.requestCollation(relayParent, paraID, peerID, nil) + elapsed := time.Since(start) + + // Test that it completed quickly (didn't timeout) + require.Less(t, elapsed, 500*time.Millisecond, "should complete quickly") + require.Error(t, err) + require.Contains(t, err.Error(), "getting value of collation fetching response") + require.NotContains(t, err.Error(), "timed out", "should not be a timeout error") +} + +func TestRequestCollation_OutOfView(t *testing.T) { + t.Parallel() + + cpvs := &CollatorProtocolValidatorSide{ + perRelayParent: map[common.Hash]PerRelayParent{}, // Empty - no relay parents + } + + relayParent := common.Hash{0x01} + paraID := parachaintypes.ParaID(123) + peerID := peer.ID("test-peer") + + collation, err := cpvs.requestCollation(relayParent, paraID, peerID, nil) + + require.Nil(t, collation) + require.Equal(t, ErrOutOfView, err) +} + +func TestPeerVersionManagement(t *testing.T) { + t.Parallel() + + cpvs := &CollatorProtocolValidatorSide{ + peerVersions: make(map[peer.ID]PeerProtocolVersion), + } + + peerID := peer.ID("test-peer") + + // Test default version + version := cpvs.getPeerProtocolVersion(peerID) + require.Equal(t, ProtocolV1, version, "New peer should default to V1") + + // Test setting V2 + cpvs.setPeerProtocolVersion(peerID, ProtocolV2) + version = cpvs.getPeerProtocolVersion(peerID) + require.Equal(t, ProtocolV2, version, "Peer should be upgraded to V2") +} + +func TestV2PeerRequiresCandidateHash(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockNet := NewMockNetwork(ctrl) + mockNet.EXPECT(). + GetRequestResponseProtocol(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&network.RequestResponseProtocol{}). + AnyTimes() + + cpvs := New(mockNet, protocol.ID("/test/collations/1"), make(chan any, 10), nil, nil) + + // Setup + relayParent := common.Hash{0x01} + cpvs.perRelayParent = map[common.Hash]PerRelayParent{relayParent: {}} + cpvs.collationRequests = make(chan CollationRequestInfo, 100) + + peerID := peer.ID("v2-peer") + + // Set peer to V2 + cpvs.setPeerProtocolVersion(peerID, ProtocolV2) + + // Test: V2 peer without candidate hash should fail + _, err := cpvs.requestCollation(relayParent, 123, peerID, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "V2 peer requires candidate hash") +}