diff --git a/device/device.go b/device/device.go index 86dff0d7e..5b2348564 100644 --- a/device/device.go +++ b/device/device.go @@ -368,10 +368,10 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - device.ipcMutex.Lock() - defer device.ipcMutex.Unlock() device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() if device.isClosed() { return } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 2d8f98426..cb4dedb11 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -6,6 +6,7 @@ package device import ( + "encoding/binary" "errors" "fmt" "sync" @@ -115,6 +116,98 @@ type MessageCookieReply struct { Cookie [blake2s.Size128 + poly1305.TagSize]byte } +var errMessageLengthMismatch = errors.New("message length mismatch") + +func (msg *MessageInitiation) unmarshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Ephemeral[:], b[8:]) + copy(msg.Static[:], b[8+len(msg.Ephemeral):]) + copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) + copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):]) + copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageInitiation) marshal(b []byte) error { + if len(b) != MessageInitiationSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + copy(b[8:], msg.Ephemeral[:]) + copy(b[8+len(msg.Ephemeral):], msg.Static[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + +func (msg *MessageResponse) unmarshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Sender = binary.LittleEndian.Uint32(b[4:]) + msg.Receiver = binary.LittleEndian.Uint32(b[8:]) + copy(msg.Ephemeral[:], b[12:]) + copy(msg.Empty[:], b[12+len(msg.Ephemeral):]) + copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):]) + copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):]) + + return nil +} + +func (msg *MessageResponse) marshal(b []byte) error { + if len(b) != MessageResponseSize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Sender) + binary.LittleEndian.PutUint32(b[8:], msg.Receiver) + copy(b[12:], msg.Ephemeral[:]) + copy(b[12+len(msg.Ephemeral):], msg.Empty[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:]) + copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:]) + + return nil +} + +func (msg *MessageCookieReply) unmarshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + msg.Type = binary.LittleEndian.Uint32(b) + msg.Receiver = binary.LittleEndian.Uint32(b[4:]) + copy(msg.Nonce[:], b[8:]) + copy(msg.Cookie[:], b[8+len(msg.Nonce):]) + + return nil +} + +func (msg *MessageCookieReply) marshal(b []byte) error { + if len(b) != MessageCookieReplySize { + return errMessageLengthMismatch + } + + binary.LittleEndian.PutUint32(b, msg.Type) + binary.LittleEndian.PutUint32(b[4:], msg.Receiver) + copy(b[8:], msg.Nonce[:]) + copy(b[8+len(msg.Nonce):], msg.Cookie[:]) + + return nil +} + type Handshake struct { state handshakeState mutex sync.RWMutex diff --git a/device/receive.go b/device/receive.go index af2db44e8..bc37f915e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -6,7 +6,6 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -287,8 +286,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal packet var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) + err := reply.unmarshal(elem.packet) if err != nil { device.log.Verbosef("Failed to decode cookie reply") goto skip @@ -353,8 +351,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode initiation message") goto skip @@ -386,8 +383,7 @@ func (device *Device) RoutineHandshake(id int) { // unmarshal var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { device.log.Errorf("Failed to decode response message") goto skip @@ -447,6 +443,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { elemsContainer.Lock() validTailPacket := -1 dataPacketReceived := false + rxBytesLen := uint64(0) for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed @@ -463,7 +460,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) @@ -512,6 +509,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + + peer.rxBytes.Add(rxBytesLen) if validTailPacket >= 0 { peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) peer.keepKeyFreshReceiving() diff --git a/device/send.go b/device/send.go index 8ed2e5f6c..7900f577f 100644 --- a/device/send.go +++ b/device/send.go @@ -6,7 +6,6 @@ package device import ( - "bytes" "encoding/binary" "errors" "net" @@ -124,10 +123,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - var buf [MessageInitiationSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() + packet := make([]byte, MessageInitiationSize) + _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() @@ -155,10 +152,8 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - var buf [MessageResponseSize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() + packet := make([]byte, MessageResponseSize) + _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() @@ -189,11 +184,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - var buf [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buf[:0]) - binary.Write(writer, binary.LittleEndian, reply) + packet := make([]byte, MessageCookieReplySize) + _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) + return nil }