diff --git a/cmd/bench/cmd/client.go b/cmd/bench/cmd/client.go index 4bd16441b..fc58d6211 100644 --- a/cmd/bench/cmd/client.go +++ b/cmd/bench/cmd/client.go @@ -115,7 +115,7 @@ func runClient() error { } return err } - rand.Read(txBytes) //nolint:gosec + rand.Read(txBytes) //nolint:gosec,staticcheck logger.Log(logging.LevelDebug, fmt.Sprintf("Submitting transaction #%d", i)) if err := client.SubmitTransaction(txBytes); err != nil { return err diff --git a/cmd/externalmodule-test-client/client.go b/cmd/externalmodule-test-client/client.go new file mode 100644 index 000000000..60ca2f879 --- /dev/null +++ b/cmd/externalmodule-test-client/client.go @@ -0,0 +1,54 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/filecoin-project/mir/pkg/externalmodule" + "github.com/filecoin-project/mir/stdevents" + "github.com/filecoin-project/mir/stdtypes" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + beroConn, err := externalmodule.Connect(ctx, "ws://localhost:8080/bero") + if err != nil { + panic(err) + } + + response, err := beroConn.Submit(ctx, stdtypes.ListOf( + stdevents.NewTestString("remote", "Ping"), + stdevents.NewRaw("remote", []byte{0, 1, 2, 3}), + )) + if err != nil { + panic(err) + } + + fmt.Printf("Bero received %d events in response.\n", response.Len()) + + cecoConn, err := externalmodule.Connect(ctx, "ws://localhost:8080/ceco") + if err != nil { + panic(err) + } + response, err = cecoConn.Submit(ctx, stdtypes.ListOf( + stdevents.NewTestString("remote", "Ping"), + stdevents.NewRaw("remote", []byte{0, 1, 2, 3}), + )) + if err != nil { + panic(err) + } + + fmt.Printf("Ceco received %d events in response.\n", response.Len()) + + err = beroConn.Close(ctx) + if err != nil { + panic(err) + } + err = cecoConn.Close(ctx) + if err != nil { + panic(err) + } +} diff --git a/cmd/externalmodule-test-server/server.go b/cmd/externalmodule-test-server/server.go new file mode 100644 index 000000000..b52637488 --- /dev/null +++ b/cmd/externalmodule-test-server/server.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "time" + + "github.com/filecoin-project/mir/pkg/externalmodule" + "github.com/filecoin-project/mir/stdevents" + "github.com/filecoin-project/mir/stdtypes" +) + +type EmptyModule struct { + prefix string +} + +func (e EmptyModule) ImplementsModule() {} + +func (e EmptyModule) ApplyEvents(events *stdtypes.EventList) (*stdtypes.EventList, error) { + fmt.Printf("%s: Received %d event(s).\n", e.prefix, events.Len()) + return stdtypes.ListOf(stdevents.NewTestString("anonymous-module", "Pong")), nil +} + +func main() { + s := externalmodule.NewServer( + externalmodule.NewHandler("bero", EmptyModule{"bero"}), + externalmodule.NewHandler("ceco", EmptyModule{"ceco"}), + ) + + time.AfterFunc(10*time.Second, func() { + err := s.Stop() + if err != nil { + fmt.Printf("Error stopping server: %v\n", err) + } + }) + + err := s.Serve("0.0.0.0:8080") + + if err != nil { + fmt.Println(err) + } else { + fmt.Println("Server stopped cleanly.") + } +} diff --git a/cmd/mircat/debug.go b/cmd/mircat/debug.go index 42073ebe9..0c5f7a509 100644 --- a/cmd/mircat/debug.go +++ b/cmd/mircat/debug.go @@ -111,7 +111,7 @@ func debug(args *arguments) error { for _, event := range entry.Events { // Set the index of the event in the event log. - metadata.index = uint64(index) + metadata.index = uint64(index) //nolint:gosec // If the event was selected by the user for inspection, pause before submitting it to the node. // The processing continues after the user's interactive confirmation. @@ -200,9 +200,6 @@ func debuggerNode(id stdtypes.NodeID, membership *trantorpbtypes.Membership) (*m "iss": protocol, "timer": timer.New(), } - if err != nil { - panic(fmt.Errorf("error initializing the Mir modules: %w", err)) - } node, err := mir.NewNode(id, mir.DefaultNodeConfig().WithLogger(logger), nodeModules, nil) if err != nil { diff --git a/cmd/mircat/display.go b/cmd/mircat/display.go index eab1de0dd..45fa494b1 100644 --- a/cmd/mircat/display.go +++ b/cmd/mircat/display.go @@ -69,7 +69,7 @@ func displayEvents(args *arguments) error { //nolint:gocognit } // getting events from entry for _, event := range entry.Events { - metadata.index = uint64(index) + metadata.index = uint64(index) //nolint:gosec _, validEvent := args.selectedEventNames[eventName(event)] _, validDest := args.selectedEventDests[event.DestModule] diff --git a/go.mod b/go.mod index 25b988e08..9aad394d0 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require github.com/golang/mock v1.6.0 require ( filippo.io/keygen v0.0.0-20230306160926-5201437acf8e + github.com/coder/websocket v1.8.12 github.com/dave/jennifer v1.5.1 github.com/drand/kyber v1.2.0 github.com/drand/kyber-bls12381 v0.3.0 diff --git a/go.sum b/go.sum index 10ba790bc..d6040dd52 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAu github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= github.com/cenkalti/backoff/v4 v4.0.0 h1:6VeaLF9aI+MAUQ95106HwWzYZgJJpZ4stumjj6RFYAU= github.com/cenkalti/backoff/v4 v4.0.0/go.mod h1:eEew/i+1Q6OrCDZh3WiXYv3+nJwBASZ8Bog/87DQnVg= +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/dave/jennifer v1.5.1 h1:AI8gaM02nCYRw6/WTH0W+S6UNck9YqPZ05xoIxQtuoE= github.com/dave/jennifer v1.5.1/go.mod h1:AxTG893FiZKqxy3FP1kL80VMshSMuz2G+EgvszgGRnk= diff --git a/node_test.go b/node_test.go index e3a0850d7..2603ea950 100644 --- a/node_test.go +++ b/node_test.go @@ -3,6 +3,7 @@ package mir import ( "context" "fmt" + "math" "sync" "sync/atomic" "testing" @@ -140,7 +141,7 @@ func TestNode_Backpressure(t *testing.T) { nodeConfig.Stats.Period = 100 * time.Millisecond // Set an input event rate that would fill the node's event buffers in one second in 10 batches. - blabberModule := newBlabber(uint64(nodeConfig.PauseInputThreshold/10), 100*time.Millisecond) + blabberModule := newBlabber(uint64(nodeConfig.PauseInputThreshold/10), 100*time.Millisecond) //nolint:gosec // Set the event consumption rate to 1/2 of the input rate (i.e., draining the buffer in 2 seconds) // and create the consumer module. @@ -181,8 +182,8 @@ func TestNode_Backpressure(t *testing.T) { fmt.Printf("Total submitted events: %d\n", atomic.LoadUint64(&blabberModule.totalSubmitted)) totalSubmitted := atomic.LoadUint64(&blabberModule.totalSubmitted) expectSubmitted := atomic.LoadUint64(&consumerModule.numProcessed) + - uint64(nodeConfig.PauseInputThreshold) + // Events left in the buffer - uint64(nodeConfig.MaxEventBatchSize) + // Events in the consumer's processing queue + uint64(nodeConfig.PauseInputThreshold) + //nolint:gosec // Events left in the buffer + uint64(nodeConfig.MaxEventBatchSize) + //nolint:gosec // Events in the consumer's processing queue 2*blabberModule.batchSize // one batch of overshooting, one batch waiting in the babbler's output channel. assert.LessOrEqual(t, totalSubmitted, expectSubmitted, "too many events submitted (node event buffer overflow)") } @@ -223,9 +224,12 @@ func (b *blabber) Go() { return default: } + if b.batchSize > math.MaxInt { + panic("batch size too big for int") + } evts := stdtypes.ListOf(sliceutil.Repeat( stdtypes.Event(stdevents.NewTestUint64("consumer", 0)), - int(b.batchSize), + int(b.batchSize), //nolint:gosec )...) select { case <-b.stop: diff --git a/pkg/availability/multisigcollector/multisigcollector.go b/pkg/availability/multisigcollector/multisigcollector.go index aa657b933..cecda6f14 100644 --- a/pkg/availability/multisigcollector/multisigcollector.go +++ b/pkg/availability/multisigcollector/multisigcollector.go @@ -1,6 +1,7 @@ package multisigcollector import ( + "fmt" "math" "google.golang.org/protobuf/proto" @@ -69,12 +70,17 @@ func NewReconfigurableModule(mc ModuleConfig, paramsTemplate ModuleParams, logge submc := mc submc.Self = mscID + // Check for integer overflow + if mscParams.MaxRequests > math.MaxInt { + return nil, fmt.Errorf("max requests too high for int type: %d", mscParams.MaxRequests) + } + // Fill in instance-specific parameters. moduleParams := paramsTemplate moduleParams.InstanceUID = []byte(mscID) moduleParams.EpochNr = mscParams.Epoch moduleParams.Membership = mscParams.Membership - moduleParams.MaxRequests = int(mscParams.MaxRequests) + moduleParams.MaxRequests = int(mscParams.MaxRequests) //nolint:gosec // TODO: Use InstanceUIDs properly. // (E.g., concatenate this with the instantiating protocol's InstanceUID when introduced.) diff --git a/pkg/checkpoint/chkpvalidator/conservativecv.go b/pkg/checkpoint/chkpvalidator/conservativecv.go index 3519af215..938b2d926 100644 --- a/pkg/checkpoint/chkpvalidator/conservativecv.go +++ b/pkg/checkpoint/chkpvalidator/conservativecv.go @@ -1,15 +1,16 @@ package chkpvalidator import ( - es "github.com/go-errors/errors" + "math" - t "github.com/filecoin-project/mir/stdtypes" + es "github.com/go-errors/errors" "github.com/filecoin-project/mir/pkg/checkpoint" "github.com/filecoin-project/mir/pkg/crypto" checkpointpbtypes "github.com/filecoin-project/mir/pkg/pb/checkpointpb/types" trantorpbtypes "github.com/filecoin-project/mir/pkg/pb/trantorpb/types" tt "github.com/filecoin-project/mir/pkg/trantor/types" + t "github.com/filecoin-project/mir/stdtypes" ) type ConservativeCV struct { @@ -55,14 +56,25 @@ func (ccv *ConservativeCV) Verify( return es.Errorf("nodeID not in membership") } + // Check if epoch is in integer bounds. + if sc.Epoch() > math.MaxInt || epochNr > math.MaxInt { + return es.Errorf("epoch number out of integer range") + } + // Check how far the received stable checkpoint is ahead of the local node's state. - chkpMembershipOffset := int(sc.Epoch()) - 1 - int(epochNr) + // Integer casting required here to prevent underflow. + chkpMembershipOffset := int(sc.Epoch()) - 1 - int(epochNr) //nolint:gosec if chkpMembershipOffset <= 0 { // Ignore stable checkpoints that are not far enough // ahead of the current state of the local node. return es.Errorf("checkpoint not far ahead enough") } + // Make sure ccv.configOffset is non-negative before conversion + if ccv.configOffset < 0 { + return es.Errorf("configOffset cannot be negative") + } + if chkpMembershipOffset > ccv.configOffset { // cannot verify checkpoint signatures, too far ahead return es.Errorf("checkpoint too far ahead") diff --git a/pkg/checkpoint/chkpvalidator/permissivecv.go b/pkg/checkpoint/chkpvalidator/permissivecv.go index d7527cbe1..e15d29896 100644 --- a/pkg/checkpoint/chkpvalidator/permissivecv.go +++ b/pkg/checkpoint/chkpvalidator/permissivecv.go @@ -1,17 +1,17 @@ package chkpvalidator import ( - es "github.com/go-errors/errors" - - t "github.com/filecoin-project/mir/stdtypes" + "math" - "github.com/filecoin-project/mir/pkg/logging" + es "github.com/go-errors/errors" "github.com/filecoin-project/mir/pkg/checkpoint" "github.com/filecoin-project/mir/pkg/crypto" + "github.com/filecoin-project/mir/pkg/logging" checkpointpbtypes "github.com/filecoin-project/mir/pkg/pb/checkpointpb/types" trantorpbtypes "github.com/filecoin-project/mir/pkg/pb/trantorpb/types" tt "github.com/filecoin-project/mir/pkg/trantor/types" + t "github.com/filecoin-project/mir/stdtypes" ) type PermissiveCV struct { @@ -50,13 +50,19 @@ func (pcv *PermissiveCV) Verify(chkp *checkpointpbtypes.StableCheckpoint, epochN return es.Errorf("nodeID not in membership") } + // Check if epoch is in integer bounds. + if sc.Epoch() > math.MaxInt || epochNr > math.MaxInt { + return es.Errorf("epoch number out of integer range") + } + // ATTENTION: We are using the membership contained in the checkpoint itself // as the one to verify its certificate against. // This is a vulnerability, since any the state of any node can be corrupted // simply by receiving a maliciously crafted checkpoint. // Thus, the permissive checker is a form of a stub and should not be used in production. chkpMembership := sc.PreviousMembership() - chkpMembershipOffset := int(sc.Epoch()) - 1 - int(epochNr) + // Integer casting required here to prevent underflow. + chkpMembershipOffset := int(sc.Epoch()) - 1 - int(epochNr) //nolint:gosec if chkpMembershipOffset > pcv.configOffset { // cannot verify checkpoint signatures, too far ahead diff --git a/pkg/deploytest/deployment.go b/pkg/deploytest/deployment.go index d61ee662e..33c4bf03b 100644 --- a/pkg/deploytest/deployment.go +++ b/pkg/deploytest/deployment.go @@ -179,8 +179,8 @@ func (d *Deployment) Run(ctx context.Context) (nodeErrors []error, heapObjects i <-ctx.Done() runtime.GC() runtime.ReadMemStats(&m2) - heapObjects = int64(m2.HeapObjects - m1.HeapObjects) - heapAlloc = int64(m2.HeapAlloc - m1.HeapAlloc) + heapObjects = int64(m2.HeapObjects - m1.HeapObjects) //nolint:gosec + heapAlloc = int64(m2.HeapAlloc - m1.HeapAlloc) //nolint:gosec cancel() }() diff --git a/pkg/deploytest/testreplica.go b/pkg/deploytest/testreplica.go index d037fc607..64d600de4 100644 --- a/pkg/deploytest/testreplica.go +++ b/pkg/deploytest/testreplica.go @@ -183,7 +183,7 @@ func (tr *TestReplica) submitFakeTransactions(ctx context.Context, node *mir.Nod destModule, []*trantorpbtypes.Transaction{{ ClientId: tt.NewClientIDFromInt(0), - TxNo: tt.TxNo(i), + TxNo: tt.TxNo(i), //nolint:gosec Data: []byte(fmt.Sprintf("Transaction %d", i)), }}, ).Pb()) diff --git a/pkg/dsl/dslmodule.go b/pkg/dsl/dslmodule.go index 9cfe75263..91d1d6b31 100644 --- a/pkg/dsl/dslmodule.go +++ b/pkg/dsl/dslmodule.go @@ -162,16 +162,19 @@ func (h Handle) RecoverAndCleanupContext(id ContextID) any { // The ImplementsModule method only serves the purpose of indicating that this is a Module and must not be called. func (m *dslModuleImpl) ImplementsModule() {} -// EmitEvent adds the event to the queue of output events -// NB: This function works with the (legacy) protoc-generated types and is likely to be -// removed in the future, with EmitMirEvent taking its place. +// EmitEvent adds the event to the queue of output events. func EmitEvent(m Module, ev stdtypes.Event) { m.DslHandle().impl.outputEvents.PushBack(ev) } -// EmitMirEvent adds the event to the queue of output events +// EmitEvents adds the events to the queue of output events. +func EmitEvents(m Module, events *stdtypes.EventList) { + m.DslHandle().impl.outputEvents.PushBackList(events) +} + +// EmitMirEvent adds the Mir-generated event to the queue of output events. // NB: this function works with the Mir-generated types. -// For the (legacy) protoc-generated types, EmitEvent can be used. +// For use with the general event type, see EmitEvent. func EmitMirEvent(m Module, ev *eventpbtypes.Event) { m.DslHandle().impl.outputEvents.PushBack(ev.Pb()) } diff --git a/pkg/dsl/test/dslmodule_test.go b/pkg/dsl/test/dslmodule_test.go index dc0da97d0..5415270a5 100644 --- a/pkg/dsl/test/dslmodule_test.go +++ b/pkg/dsl/test/dslmodule_test.go @@ -278,7 +278,7 @@ func newContextTestingModule(mc *contextTestingModuleModuleConfig) dsl.Module { // NB: avoid using primitive types as the context in the actual implementation, prefer named structs, // remember that the context type is used to match requests with responses. - cryptopbdsl.VerifySigs(m, mc.Crypto, sliceutil.Repeat(msg, int(u)), signatures, nodeIDs, &u) + cryptopbdsl.VerifySigs(m, mc.Crypto, sliceutil.Repeat(msg, int(u)), signatures, nodeIDs, &u) //nolint:gosec } return nil }) diff --git a/pkg/externalmodule/connection.go b/pkg/externalmodule/connection.go new file mode 100644 index 000000000..f956d0b94 --- /dev/null +++ b/pkg/externalmodule/connection.go @@ -0,0 +1,156 @@ +package externalmodule + +import ( + "context" + "fmt" + "sync" + + "github.com/coder/websocket" + + "github.com/filecoin-project/mir/stdevents" + "github.com/filecoin-project/mir/stdtypes" +) + +// Connection represents a connection to a particular module at a particular module server. +// It is used to send events to and receive events from it. +type Connection websocket.Conn + +// Connect establishes and returns a new connection +// to a module server at address addr (in the form of "ws://server:port/path"). +// The path component of the address is used to specify which module at the module server to connect to. +// When ctx is canceled before the connection is established, connecting aborts. +func Connect(ctx context.Context, addr string) (*Connection, error) { + + conn, _, err := websocket.Dial(ctx, addr, nil) + if err != nil { + return nil, err + } + + return (*Connection)(conn), nil +} + +// Submit sends the given events to the remote module, waits until the remote module processes them, and returns +// the resulting events produced by the remote module. +// One can see it as the proxy for the remote module's ApplyEvents method. +func (c *Connection) Submit(ctx context.Context, events *stdtypes.EventList) (*stdtypes.EventList, error) { + conn := (*websocket.Conn)(c) + ctx, cancel := context.WithCancel(ctx) + wg := sync.WaitGroup{} + wg.Add(1) + var sendErr error + + // We need to run sendEvents concurrently with receiveResponse to avoid a deadlock. + // If we first tried to send all the events and only then started receiving the response, the sending could be + // blocked by the server side blocked by the processing blocked by the sending of response events blocked by the + // client not having started receiving them. + go func() { + sendErr = sendEvents(ctx, conn, events) + if sendErr != nil { + cancel() // If sending fails, receiving of the response also must be aborted. + } + wg.Done() + }() + + response, err := receiveResponse(ctx, conn) + + // When reaching this line, sending events must have finished, as receiveResponse would otherwise not have returned. + // Waiting on the wait group is not necessary in this sense. + // Nevertheless, we still need to synchronize access to sendErr + // (and it's good practice to collect spawned goroutines before returning). + wg.Wait() + if sendErr != nil { + return nil, sendErr + } + + return response, err +} + +// Close closes the connection to the remote module. +func (c *Connection) Close(ctx context.Context) error { + conn := (*websocket.Conn)(c) + defer func() { _ = conn.CloseNow() }() + + err := conn.Write(ctx, websocket.MessageBinary, controlMessageClose().Bytes()) + if err != nil { + return err + } + + return conn.Close(websocket.StatusNormalClosure, "") +} + +// sendEvents writes a list of events to the raw websocket connection. +// All sent events are serialized and wrapped in a stdevents.Raw event. +// Thus, on the other side of the connection, only control messages and events of type stdevents.Raw can be expected. +func sendEvents(ctx context.Context, conn *websocket.Conn, events *stdtypes.EventList) error { + // Announce the number of events that will be sent. + err := conn.Write(ctx, websocket.MessageBinary, controlMessageEvents(events.Len()).Bytes()) + if err != nil { + return err + } + + // Send all the events, using one websocket message per event. + iter := events.Iterator() + for event := iter.Next(); event != nil; event = iter.Next() { + + rawEvent, err := stdevents.WrapInRaw(event) + if err != nil { + return err + } + + data, err := rawEvent.ToBytes() + if err != nil { + return err + } + + err = conn.Write(ctx, websocket.MessageBinary, data) + if err != nil { + return err + } + } + + return nil +} + +// receiveResponse reads the events the module server sends over the websocket and returns them in an EventList. +func receiveResponse(ctx context.Context, conn *websocket.Conn) (*stdtypes.EventList, error) { + // Read the number of resulting events returned from the remote module. + msgType, msgData, err := conn.Read(ctx) + if err != nil { + return nil, fmt.Errorf("could not read response data: %w", err) + } + + if msgType != websocket.MessageBinary { + return nil, fmt.Errorf("only binary message type is accepted for control messages") + } + command, err := controlMessageFromBytes(msgData) + if err != nil { + return nil, fmt.Errorf("could not load control message: %w", err) + } + if command.MsgType != MsgEvents { + return nil, fmt.Errorf("expected MSG_EVENTS control message type but got %v", command.MsgType) + } + + // Receive the resulting events. + resultEvents := stdtypes.EmptyList() + for i := 0; i < command.NumEvents; i++ { + + msgType, msgData, err := conn.Read(ctx) + if err != nil { + return nil, fmt.Errorf("could not read response data: %w", err) + } + if msgType != websocket.MessageBinary { + return nil, fmt.Errorf("only binary message type is accepted for events") + } + + // We can afford using stdevents.Deserialize because sendEvents (used on the other side of the websocket) + // only ever sends serialized events of type stdevent.RawEvent + event, err := stdevents.Deserialize(msgData) + if err != nil { + return nil, fmt.Errorf("could not deserialize event: %w", err) + } + + resultEvents.PushBack(event) + } + + return resultEvents, nil +} diff --git a/pkg/externalmodule/controlmessage.go b/pkg/externalmodule/controlmessage.go new file mode 100644 index 000000000..8eb64b40d --- /dev/null +++ b/pkg/externalmodule/controlmessage.go @@ -0,0 +1,40 @@ +package externalmodule + +import "encoding/json" + +type controlMessageType int + +const ( + MsgEvents = iota + MsgClose +) + +type ControlMessage struct { + MsgType controlMessageType + NumEvents int // Only used for EVENT_LIST type. +} + +func (cm *ControlMessage) Bytes() []byte { + data, err := json.Marshal(cm) + if err != nil { + panic(err) + } + return data +} + +func controlMessageFromBytes(data []byte) (*ControlMessage, error) { + var msg ControlMessage + err := json.Unmarshal(data, &msg) + if err != nil { + return nil, err + } + return &msg, nil +} + +func controlMessageEvents(numEvents int) *ControlMessage { + return &ControlMessage{MsgEvents, numEvents} +} + +func controlMessageClose() *ControlMessage { + return &ControlMessage{MsgClose, 0} +} diff --git a/pkg/externalmodule/modulehandler.go b/pkg/externalmodule/modulehandler.go new file mode 100644 index 000000000..d6e821bc9 --- /dev/null +++ b/pkg/externalmodule/modulehandler.go @@ -0,0 +1,147 @@ +package externalmodule + +import ( + "context" + "fmt" + "net/http" + "sync/atomic" + + "github.com/coder/websocket" + + "github.com/filecoin-project/mir/pkg/modules" + "github.com/filecoin-project/mir/stdevents" + "github.com/filecoin-project/mir/stdtypes" +) + +const ( + ConnActive = iota + ConnPending +) + +// ModuleHandler implements a handler function for an incoming connection at the module server for a PassiveModule. +type ModuleHandler struct { + + // URL path (after the domain) under which this module will be accessible + path string + + // The module logic + module modules.PassiveModule + + // Used to make sure there is only a single client talking to the handler. + // This is needed to prevent concurrent access to the module. + // The int32 type is somewhat arbitrary - it only needs to be supported by the CompareAndSwap + // family of functions in the atomic package. + connectionStatus int32 +} + +// NewHandler allocates and returns a pointer to a new ModuleHandler. +func NewHandler(path string, module modules.PassiveModule) *ModuleHandler { + return &ModuleHandler{ + path: path, + module: module, + connectionStatus: ConnPending, + } +} + +// handleConnection is the function that will be invoked by the HTTP server this handler is part of +// each time a connection to this handler's path is created. +// It reads websocket messages from the connection, passes them to the module logic, +// and writes back the generated events. +func (mh *ModuleHandler) handleConnection(writer http.ResponseWriter, request *http.Request) { + + // Only accept the first connection. + if !atomic.CompareAndSwapInt32(&mh.connectionStatus, ConnPending, ConnActive) { + writer.WriteHeader(http.StatusForbidden) + return + } + + // Only accept a websocket connection. + conn, err := websocket.Accept(writer, request, nil) + if err != nil { + writer.WriteHeader(http.StatusBadRequest) + return + } + + //TODO: Figure out a better way to deal with the context. + ctx := context.Background() + + // Main loop for reading incoming websocket messages. + var msgType websocket.MessageType + var msgData []byte + for msgType, msgData, err = conn.Read(ctx); err == nil; msgType, msgData, err = conn.Read(ctx) { + + // Only accept binary type messages. + if msgType != websocket.MessageBinary { + err = fmt.Errorf("only binary message type is accepted for control messages") + break + } + + // The first message must always be a control message, followed by a variable number of event messages. + command, loadingErr := controlMessageFromBytes(msgData) + if loadingErr != nil { + err = fmt.Errorf("could not load control message: %w", loadingErr) + break + } + + if command.MsgType == MsgEvents { + // If the control message announces the number of events that follow, + // process them all (mh.processEvents reads them from conn). + err = mh.processEvents(ctx, conn, command.NumEvents) + if err != nil { + break + } + } else if command.MsgType == MsgClose { + // If we received a closing message, stop processing. + break + } else { + err = fmt.Errorf("unknown control msg type: %v", command.MsgType) + break + } + + } + if err != nil { + fmt.Printf("Error processing incoming websocket message: %v\n", err) + } + + err = conn.Close(websocket.StatusNormalClosure, "") + if err != nil { + _ = conn.CloseNow() + } +} + +// processEvents reads messages from the connection, passes them to the module logic, +// and sends back the generated events. +func (mh *ModuleHandler) processEvents(ctx context.Context, conn *websocket.Conn, numEvents int) error { + resultEvents := stdtypes.EmptyList() + for ; numEvents > 0; numEvents-- { + newEvents, err := mh.processNextEvent(ctx, conn) + if err != nil { + return err + } + resultEvents.PushBackList(newEvents) + } + + return sendEvents(ctx, conn, resultEvents) +} + +// processNextEvent reads a single message from the given websocket connection, applies it to the module logic, +// and returns the resulting events. +func (mh *ModuleHandler) processNextEvent(ctx context.Context, conn *websocket.Conn) (*stdtypes.EventList, error) { + msgType, msgData, err := conn.Read(ctx) + if err != nil { + return nil, fmt.Errorf("could not read control message: %w", err) + } + + if msgType != websocket.MessageBinary { + return nil, fmt.Errorf("only binary message type is accepted for events") + } + + // We can afford using stdevents.Deserialize because sendEvents (used on the other side of the websocket) + // only ever sends serialized events of type stdevent.RawEvent + event, err := stdevents.Deserialize(msgData) + if err != nil { + return nil, fmt.Errorf("could not deserialize event: %w", err) + } + + return mh.module.ApplyEvents(stdtypes.ListOf(event)) +} diff --git a/pkg/externalmodule/proxymodule.go b/pkg/externalmodule/proxymodule.go new file mode 100644 index 000000000..bb22478e7 --- /dev/null +++ b/pkg/externalmodule/proxymodule.go @@ -0,0 +1,57 @@ +package externalmodule + +import ( + "context" + + "github.com/filecoin-project/mir/pkg/dsl" + "github.com/filecoin-project/mir/pkg/modules" + "github.com/filecoin-project/mir/stdevents" + "github.com/filecoin-project/mir/stdtypes" +) + +// NewProxyModule returns a new module that serves as a local proxy to an external module hosted on a module server. +// The addr parameter specifies the full URL (address and path) of the module at the server. +// The connection between the proxy and the module server is established when the module receives stdevents.Init, +// At which time the server must be running and accepting new connections. +func NewProxyModule(moduleID stdtypes.ModuleID, addr string) modules.PassiveModule { + m := dsl.NewModule(moduleID) + var connection *Connection + ctx := context.Background() + // TODO: Using a local context here might make the whole Mir node stuck if the connection gets stuck. + // There is no way to stop the module's operation from the outside - it can only stop by itself. + // This is a more general problem of passive modules. There is no way to force them to stop processing when the + // Mir node is shutting down. For most of the passive modules it is not an issue though, as they only locally + // process events and are guaranteed to eventually finish. For the proxy module, this is only the case if it can + // communicate with its corresponding server. + + // Upon Init, connect to the remote module and relay the Init event to it. + dsl.UponEvent(m, func(ev *stdevents.Init) error { + var err error + + // Create connection to module server. + connection, err = Connect(ctx, addr) + if err != nil { + return err + } + + // Relay Init event to remote module. + eventsOut, err := connection.Submit(ctx, stdtypes.ListOf(ev)) + if err != nil { + return err + } + dsl.EmitEvents(m, eventsOut) + return nil + }) + + // Simply relay all events (except for Init, which is handled separately) to the remote module. + dsl.UponOtherEvent(m, func(ev stdtypes.Event) error { + eventsOut, err := connection.Submit(ctx, stdtypes.ListOf(ev)) + if err != nil { + return err + } + dsl.EmitEvents(m, eventsOut) + return nil + }) + + return m +} diff --git a/pkg/externalmodule/server.go b/pkg/externalmodule/server.go new file mode 100644 index 000000000..15632f82e --- /dev/null +++ b/pkg/externalmodule/server.go @@ -0,0 +1,52 @@ +package externalmodule + +import ( + "net" + "net/http" + "time" +) + +// Server implements a HTTP server containing remote modules. +type Server struct { + serveMux http.ServeMux + httpServer *http.Server +} + +// NewServer returns a new Server containing the given moduleHandlers which define the modules operated by the server. +// Each handler associates a module with a path under which the module will be accessible. +func NewServer(moduleHandlers ...*ModuleHandler) *Server { + ms := Server{} + + for _, mh := range moduleHandlers { + ms.serveMux.HandleFunc("/"+mh.path, mh.handleConnection) + } + + return &ms +} + +// Serve starts the module server, making it listen to new connections +// at the given address and port (e.g., "0.0.0.0:8080"). +// Serve blocks until the Stop method is called. It is therefore expected to call Serve in a separate goroutine. +func (ms *Server) Serve(addrPort string) error { + l, err := net.Listen("tcp", addrPort) + if err != nil { + return err + } + + ms.httpServer = &http.Server{ + ReadHeaderTimeout: 2 * time.Second, + Handler: ms, + } + + return ms.httpServer.Serve(l) +} + +// Stop stops the server, making the call to the Serve method return. +func (ms *Server) Stop() error { + return ms.httpServer.Close() +} + +// ServeHTTP is a wrapper around the HTTP serveMux's method of the same name, so it can be used as a http.ServeMux. +func (ms *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ms.serveMux.ServeHTTP(w, r) +} diff --git a/pkg/iss/iss.go b/pkg/iss/iss.go index c933043f8..069b59c76 100644 --- a/pkg/iss/iss.go +++ b/pkg/iss/iss.go @@ -14,6 +14,7 @@ package iss import ( "encoding/binary" "fmt" + "math" es "github.com/go-errors/errors" "google.golang.org/protobuf/proto" @@ -342,7 +343,7 @@ func New( // Choose a leader for the new orderer instance. // TODO: Use the corresponding epoch's leader set to pick a leader, instead of just selecting one from all nodes. - leader := maputil.GetSortedKeys(membership.Nodes)[int(epoch)%len(membership.Nodes)] + leader := maputil.GetSortedKeys(membership.Nodes)[int(epoch)%len(membership.Nodes)] //nolint:gosec // Serialize checkpoint, so it can be proposed as a value. stableCheckpoint := checkpointpbtypes.StableCheckpoint{ @@ -408,7 +409,7 @@ func New( // that are not yet part of the system for those checkpoints. var delayed []stdtypes.NodeID for n := range membership.Nodes { - if epoch > iss.nodeEpochMap[n]+tt.EpochNr(iss.Params.RetainedEpochs) { + if epoch > iss.nodeEpochMap[n]+tt.EpochNr(iss.Params.RetainedEpochs) { //nolint:gosec delayed = append(delayed, n) } } @@ -438,7 +439,11 @@ func New( sc := checkpoint.StableCheckpointFromPb(chkp.Pb()) // Check how far the received stable checkpoint is ahead of the local node's state. - chkpMembershipOffset := int(sc.Epoch()) - 1 - int(iss.epoch.Nr()) + if sc.Epoch() > math.MaxInt || iss.epoch.Nr() > math.MaxInt { + return es.Errorf("epoch number out of integer range") + } + // Integer casting required here to prevent underflow. + chkpMembershipOffset := int(sc.Epoch()) - 1 - int(iss.epoch.Nr()) //nolint:gosec if chkpMembershipOffset <= 0 { // Ignore stable checkpoints that are not far enough // ahead of the current state of the local node. @@ -465,7 +470,11 @@ func New( } chkp := checkpoint.StableCheckpointFromPb(c.checkpoint.Pb()) - chkpMembershipOffset := int(chkp.Epoch()) - 1 - int(iss.epoch.Nr()) + if chkp.Epoch() > math.MaxInt || iss.epoch.Nr() > math.MaxInt { + return es.Errorf("epoch number out of integer range") + } + // Integer casting required here to prevent underflow. + chkpMembershipOffset := int(chkp.Epoch()) - 1 - int(iss.epoch.Nr()) //nolint:gosec if chkpMembershipOffset <= 0 { // Ignore stable checkpoints that have been lagged behind // during validation @@ -564,7 +573,7 @@ func InitialStateSnapshot( return nil, err } - firstEpochLength := uint64(params.SegmentLength * len(params.InitialMembership.Nodes)) + firstEpochLength := uint64(params.SegmentLength * len(params.InitialMembership.Nodes)) //nolint:gosec return &trantorpbtypes.StateSnapshot{ AppData: appState, EpochData: &trantorpbtypes.EpochData{ @@ -624,7 +633,7 @@ func (iss *ISS) initAvailability() { (*multisigcollector.InstanceParams)(&mscpbtypes.InstanceParams{ Epoch: iss.epoch.Nr(), Membership: iss.memberships[0], - MaxRequests: uint64(iss.Params.SegmentLength), + MaxRequests: uint64(iss.Params.SegmentLength), //nolint:gosec }), stdtypes.RetentionIndex(iss.epoch.Nr()), ) @@ -640,12 +649,12 @@ func (iss *ISS) initOrderers() error { // Create segment. // The sequence proposals are all set to nil, so that the orderer proposes new availability certificates. - proposals := freeProposals(iss.nextDeliveredSN+tt.SeqNr(i), tt.SeqNr(len(leaders)), iss.Params.SegmentLength) + proposals := freeProposals(iss.nextDeliveredSN+tt.SeqNr(i), tt.SeqNr(len(leaders)), iss.Params.SegmentLength) //nolint:gosec seg, err := common.NewSegment(leader, iss.epoch.Membership, proposals) if err != nil { return es.Errorf("error creating new segment: %w", err) } - iss.newEpochSN += tt.SeqNr(seg.Len()) + iss.newEpochSN += tt.SeqNr(seg.Len()) //nolint:gosec // Instantiate a new PBFT orderer. stddsl.NewSubmodule(iss.m, iss.moduleConfig.Ordering, @@ -792,7 +801,7 @@ func (iss *ISS) advanceEpoch() error { EpochConfig: &trantorpbtypes.EpochConfig{ // nolint:govet iss.epoch.Nr(), iss.epoch.FirstSN(), - uint64(iss.epoch.Len()), + uint64(iss.epoch.Len()), //nolint:gosec iss.memberships, }, }, @@ -904,8 +913,9 @@ func (iss *ISS) deliverCommonCheckpoint(chkpData []byte) error { // The state to prune is determined according to the retention index // which is derived from the epoch number the new // stable checkpoint is associated with. - pruneIndex := int(chkp.Epoch()) - iss.Params.RetainedEpochs - if pruneIndex > 0 { // "> 0" and not ">= 0", since only entries strictly smaller than the index are pruned. + // Integer casting required here to prevent underflow. + pruneIndex := int(chkp.Epoch()) - iss.Params.RetainedEpochs //nolint:gosec + if pruneIndex > 0 { // "> 0" and not ">= 0", since only entries strictly smaller than the index are pruned. // Prune timer, checkpointing, availability, orderers, and other modules. stddsl.GarbageCollect(iss.m, iss.moduleConfig.Timer, stdtypes.RetentionIndex(pruneIndex)) @@ -917,7 +927,7 @@ func (iss *ISS) deliverCommonCheckpoint(chkpData []byte) error { // Prune epoch state. for epoch := range iss.epochs { - if epoch < tt.EpochNr(pruneIndex) { + if epoch < tt.EpochNr(pruneIndex) { //nolint:gosec delete(iss.epochs, epoch) } } @@ -931,7 +941,7 @@ func (iss *ISS) deliverCommonCheckpoint(chkpData []byte) error { // Note that we are not using the current epoch number here, because it is not relevant for checkpoints. // Using pruneIndex makes sure that the re-transmission is stopped // on every stable checkpoint (when another one is started). - stdtypes.RetentionIndex(pruneIndex), + stdtypes.RetentionIndex(pruneIndex), //nolint:gosec isspbevents.PushCheckpoint(iss.moduleConfig.Self).Pb(), ) diff --git a/pkg/mempool/simplemempool/internal/parts/formbatchesint/formbatches.go b/pkg/mempool/simplemempool/internal/parts/formbatchesint/formbatches.go index 51dbb7c21..a7951d0b2 100644 --- a/pkg/mempool/simplemempool/internal/parts/formbatchesint/formbatches.go +++ b/pkg/mempool/simplemempool/internal/parts/formbatchesint/formbatches.go @@ -19,6 +19,9 @@ package formbatchesint import ( + "fmt" + "math" + "github.com/filecoin-project/mir/pkg/clientprogress" "github.com/filecoin-project/mir/pkg/dsl" "github.com/filecoin-project/mir/pkg/logging" @@ -156,7 +159,7 @@ func IncludeBatchCreation( // nolint:gocognit cutBatch(origin) } else { reqID := storePendingRequest(origin) - stddsl.TimerDelay(m, mc.Timer, params.BatchTimeout, mppbevents.BatchTimeout(mc.Self, uint64(reqID)).Pb()) + stddsl.TimerDelay(m, mc.Timer, params.BatchTimeout, mppbevents.BatchTimeout(mc.Self, uint64(reqID)).Pb()) //nolint:gosec } } @@ -266,6 +269,9 @@ func IncludeBatchCreation( // nolint:gocognit mppbdsl.UponBatchTimeout(m, func(batchReqID uint64) error { + if batchReqID > math.MaxInt { + return fmt.Errorf("batch request ID too big for an integer: %d", batchReqID) + } reqID := int(batchReqID) // Load the request origin. diff --git a/pkg/net/grpc/remoteconnection.go b/pkg/net/grpc/remoteconnection.go index 9c80869a9..c9ca66734 100644 --- a/pkg/net/grpc/remoteconnection.go +++ b/pkg/net/grpc/remoteconnection.go @@ -76,7 +76,7 @@ func (conn *remoteConnection) Send(msg *GrpcMessage) error { case conn.msgBuffer <- msg: return nil default: - return es.Errorf("send buffer full (" + conn.addr + ")") + return es.Errorf("send buffer full (%s)", conn.addr) } } diff --git a/pkg/orderers/common/segment.go b/pkg/orderers/common/segment.go index 2382a4f1a..b6d7135e5 100644 --- a/pkg/orderers/common/segment.go +++ b/pkg/orderers/common/segment.go @@ -1,6 +1,8 @@ package common import ( + "math" + es "github.com/go-errors/errors" "github.com/filecoin-project/mir/pkg/orderers/types" @@ -44,7 +46,10 @@ func (seg *Segment) NodeIDs() []t.NodeID { } func (seg *Segment) PrimaryNode(view types.ViewNr) t.NodeID { - return seg.NodeIDs()[(seg.LeaderIndex()+int(view))%len(seg.NodeIDs())] + if view > math.MaxInt { + panic("view number out of integer range") + } + return seg.NodeIDs()[(seg.LeaderIndex()+int(view))%len(seg.NodeIDs())] //nolint:gosec } func (seg *Segment) LeaderIndex() int { diff --git a/pkg/orderers/internal/common/common.go b/pkg/orderers/internal/common/common.go index ea28f26c2..f0e773eb6 100644 --- a/pkg/orderers/internal/common/common.go +++ b/pkg/orderers/internal/common/common.go @@ -97,8 +97,8 @@ type PbftProposalState struct { // ============================================================ // NumCommitted returns the number of slots that are already committed in the given view. -func (state *State) NumCommitted(view ot.ViewNr) int { - numCommitted := 0 +func (state *State) NumCommitted(view ot.ViewNr) uint64 { + numCommitted := uint64(0) for _, slot := range state.Slots[view] { if slot.Committed { numCommitted++ @@ -151,7 +151,7 @@ func (state *State) InitView( pbftpbevents.ViewChangeSNTimeout( moduleConfig.Self, view, - uint64(state.NumCommitted(view))).Pb(), + state.NumCommitted(view)).Pb(), //nolint:gosec ) stddsl.TimerDelay( m, @@ -168,7 +168,7 @@ func (state *State) InitView( // AllCommitted returns true if all slots of this pbftInstance in the current view are in the committed state // (i.e., have the committed flag set). func (state *State) AllCommitted() bool { - return state.NumCommitted(state.View) == len(state.Slots[state.View]) + return state.NumCommitted(state.View) == uint64(len(state.Slots[state.View])) } func (state *State) LookUpPreprepare(sn tt.SeqNr, digest []byte) *pbftpbtypes.Preprepare { diff --git a/pkg/orderers/internal/parts/goodcase/pbftgoodcase.go b/pkg/orderers/internal/parts/goodcase/pbftgoodcase.go index c69759546..16f4f2a11 100644 --- a/pkg/orderers/internal/parts/goodcase/pbftgoodcase.go +++ b/pkg/orderers/internal/parts/goodcase/pbftgoodcase.go @@ -2,6 +2,7 @@ package goodcase import ( "fmt" + "math" es "github.com/go-errors/errors" "google.golang.org/protobuf/proto" @@ -119,6 +120,9 @@ func IncludeGoodCase( }) pbftpbdsl.UponProposeTimeout(m, func(proposeTimeout uint64) error { + if proposeTimeout > math.MaxInt { + return es.Errorf("propose timeout too large (maximal allowed value: %d)", math.MaxInt) + } return applyProposeTimeout(m, state, params, moduleConfig, int(proposeTimeout), logger) }) @@ -255,7 +259,7 @@ func propose( // Set up a new timer for the next proposal. timeoutEvent := pbftpbevents.ProposeTimeout( moduleConfig.Self, - uint64(state.Proposal.ProposalsMade+1)) + uint64(state.Proposal.ProposalsMade+1)) //nolint:gosec stddsl.TimerDelay( m, @@ -548,7 +552,7 @@ func advanceSlotState( pbftpbevents.ViewChangeSNTimeout( moduleConfig.Self, state.View, - uint64(state.NumCommitted(state.View)), + state.NumCommitted(state.View), ).Pb(), ) } diff --git a/pkg/orderers/internal/parts/viewchange/pbftviewchange.go b/pkg/orderers/internal/parts/viewchange/pbftviewchange.go index 4567f2c6b..c610667ce 100644 --- a/pkg/orderers/internal/parts/viewchange/pbftviewchange.go +++ b/pkg/orderers/internal/parts/viewchange/pbftviewchange.go @@ -63,13 +63,17 @@ func IncludeViewChange( //nolint:gocognit return nil }) - cryptopbdsl.UponSigVerified(m, func(nodeID t.NodeID, error error, context *pbftpbtypes.SignedViewChange) error { + cryptopbdsl.UponSigVerified(m, func( + nodeID t.NodeID, + verificationError error, + context *pbftpbtypes.SignedViewChange, + ) error { // Ignore events with invalid signatures. - if error != nil { + if verificationError != nil { logger.Log(logging.LevelWarn, "Ignoring invalid signature, ignoring event.", "from", nodeID, - "error", error, + "error", verificationError, ) return nil } @@ -355,7 +359,7 @@ func applyViewChangeSNTimeout( // If the view is still the same as when the timer was set up, // if nothing has been committed since then, and if the segment-level checkpoint is not yet stable if view == state.View && - int(numCommitted) == state.NumCommitted(state.View) && + numCommitted == state.NumCommitted(state.View) && !state.SegmentCheckpoint.Stable(state.Segment.Membership) { // Start the view change sub-protocol. @@ -694,7 +698,7 @@ func latestPendingVCState(state *common.State) (*common.PbftViewChangeState, ot. // Find and return the view change state with the highest view number that received enough ViewChange messages. for v, s := range state.ViewChangeStates { - if s.EnoughViewChanges() && (state == nil || v > view) { + if s.EnoughViewChanges() && v > view { vcstate, view = s, v } } diff --git a/pkg/trantor/testing/smr_test.go b/pkg/trantor/testing/smr_test.go index 2819b4f39..75d72845d 100644 --- a/pkg/trantor/testing/smr_test.go +++ b/pkg/trantor/testing/smr_test.go @@ -198,7 +198,7 @@ func testIntegrationWithISS(tt *testing.T) { require.Error(tb, conf.ErrorExpected) for replica := range conf.NodeIDsWeight { app := deployment.TestConfig.FakeApps[replica] - require.Equal(tb, 0, int(app.TransactionsProcessed)) + require.Equal(tb, 0, int(app.TransactionsProcessed)) //nolint:gosec } }, }}, @@ -218,7 +218,7 @@ func testIntegrationWithISS(tt *testing.T) { require.Error(tb, conf.ErrorExpected) for replica := range conf.NodeIDsWeight { app := deployment.TestConfig.FakeApps[replica] - require.Equal(tb, conf.NumNetTXs+conf.NumFakeTXs, int(app.TransactionsProcessed)) + require.Equal(tb, conf.NumNetTXs+conf.NumFakeTXs, int(app.TransactionsProcessed)) //nolint:gosec } }, }}, @@ -379,7 +379,7 @@ func runIntegrationWithISSConfig(tb testing.TB, conf *TestConfig) (heapObjects i // Check if all transactions were delivered. for _, replica := range deployment.TestReplicas { app := deployment.TestConfig.FakeApps[replica.ID] - assert.Equal(tb, conf.NumNetTXs+conf.NumFakeTXs, int(app.TransactionsProcessed)) + assert.Equal(tb, conf.NumNetTXs+conf.NumFakeTXs, int(app.TransactionsProcessed)) //nolint:gosec } // If the test failed, keep the generated data. diff --git a/stdmodules/factory/factory_test.go b/stdmodules/factory/factory_test.go index b47b3de44..2fe172c46 100644 --- a/stdmodules/factory/factory_test.go +++ b/stdmodules/factory/factory_test.go @@ -102,7 +102,7 @@ func TestFactoryModule(t *testing.T) { }, "02 Instantiate many": func(t *testing.T) { - for i := 1; i <= 5; i++ { + for i := uint64(1); i <= 5; i++ { evOut, err := echoFactory.ApplyEvents(stdtypes.ListOf(stdevents.NewNewSubmodule( echoFactoryID, echoFactoryID.Then(stdtypes.ModuleID(fmt.Sprintf("inst%d", i))),