diff --git a/lib/dartssh2.dart b/lib/dartssh2.dart index 956142a..031459f 100644 --- a/lib/dartssh2.dart +++ b/lib/dartssh2.dart @@ -1,4 +1,5 @@ export 'src/ssh_algorithm.dart' show SSHAlgorithms; +export 'src/ssh_agent.dart'; export 'src/ssh_client.dart'; export 'src/ssh_errors.dart'; export 'src/ssh_forward.dart'; diff --git a/lib/src/message/msg_channel.dart b/lib/src/message/msg_channel.dart index cab4111..14733cb 100644 --- a/lib/src/message/msg_channel.dart +++ b/lib/src/message/msg_channel.dart @@ -541,6 +541,7 @@ abstract class SSHChannelRequestType { static const shell = 'shell'; static const exec = 'exec'; static const subsystem = 'subsystem'; + static const authAgent = 'auth-agent-req@openssh.com'; static const windowChange = 'window-change'; static const xon = 'xon-xoff'; static const signal = 'signal'; diff --git a/lib/src/ssh_agent.dart b/lib/src/ssh_agent.dart new file mode 100644 index 0000000..b89317f --- /dev/null +++ b/lib/src/ssh_agent.dart @@ -0,0 +1,234 @@ +import 'dart:async'; +import 'dart:typed_data'; + +import 'package:dartssh2/src/hostkey/hostkey_rsa.dart'; +import 'package:dartssh2/src/ssh_channel.dart'; +import 'package:dartssh2/src/ssh_hostkey.dart'; +import 'package:dartssh2/src/ssh_key_pair.dart'; +import 'package:dartssh2/src/ssh_message.dart'; +import 'package:dartssh2/src/ssh_transport.dart'; +import 'package:pointycastle/api.dart' hide Signature; +import 'package:pointycastle/asymmetric/api.dart' as asymmetric; +import 'package:pointycastle/digests/sha1.dart'; +import 'package:pointycastle/digests/sha256.dart'; +import 'package:pointycastle/digests/sha512.dart'; +import 'package:pointycastle/signers/rsa_signer.dart'; + +abstract class SSHAgentHandler { + Future handleRequest(Uint8List request); +} + +class SSHKeyPairAgent implements SSHAgentHandler { + SSHKeyPairAgent(this._identities, {this.comment}); + + final List _identities; + final String? comment; + + @override + Future handleRequest(Uint8List request) async { + if (request.isEmpty) { + return _failure(); + } + final reader = SSHMessageReader(request); + final messageType = reader.readUint8(); + switch (messageType) { + case SSHAgentProtocol.requestIdentities: + return _handleRequestIdentities(); + case SSHAgentProtocol.signRequest: + return _handleSignRequest(reader); + default: + return _failure(); + } + } + + Uint8List _handleRequestIdentities() { + final writer = SSHMessageWriter(); + writer.writeUint8(SSHAgentProtocol.identitiesAnswer); + writer.writeUint32(_identities.length); + for (final identity in _identities) { + final publicKey = identity.toPublicKey().encode(); + writer.writeString(publicKey); + writer.writeUtf8(comment ?? ''); + } + return writer.takeBytes(); + } + + Uint8List _handleSignRequest(SSHMessageReader reader) { + final keyBlob = reader.readString(); + final data = reader.readString(); + final flags = reader.readUint32(); + + final identity = _findIdentity(keyBlob); + if (identity == null) { + return _failure(); + } + + final signature = _sign(identity, data, flags); + final writer = SSHMessageWriter(); + writer.writeUint8(SSHAgentProtocol.signResponse); + writer.writeString(signature.encode()); + return writer.takeBytes(); + } + + SSHSignature _sign(SSHKeyPair identity, Uint8List data, int flags) { + if (identity is OpenSSHRsaKeyPair || identity is RsaPrivateKey) { + final signatureType = _rsaSignatureTypeForFlags(flags); + return _signRsa(identity, data, signatureType); + } + return identity.sign(data); + } + + String _rsaSignatureTypeForFlags(int flags) { + if (flags & SSHAgentProtocol.rsaSha2_512 != 0) { + return SSHRsaSignatureType.sha512; + } + if (flags & SSHAgentProtocol.rsaSha2_256 != 0) { + return SSHRsaSignatureType.sha256; + } + return SSHRsaSignatureType.sha1; + } + + SSHRsaSignature _signRsa( + SSHKeyPair identity, + Uint8List data, + String signatureType, + ) { + final key = _rsaKeyFrom(identity); + if (key == null) { + return identity.sign(data) as SSHRsaSignature; + } + + final signer = _rsaSignerFor(signatureType); + signer.init(true, PrivateKeyParameter(key)); + return SSHRsaSignature(signatureType, signer.generateSignature(data).bytes); + } + + asymmetric.RSAPrivateKey? _rsaKeyFrom(SSHKeyPair identity) { + if (identity is OpenSSHRsaKeyPair) { + return asymmetric.RSAPrivateKey(identity.n, identity.d, identity.p, identity.q); + } + if (identity is RsaPrivateKey) { + return asymmetric.RSAPrivateKey(identity.n, identity.d, identity.p, identity.q); + } + return null; + } + + RSASigner _rsaSignerFor(String signatureType) { + switch (signatureType) { + case SSHRsaSignatureType.sha1: + return RSASigner(SHA1Digest(), '06052b0e03021a'); + case SSHRsaSignatureType.sha256: + return RSASigner(SHA256Digest(), '0609608648016503040201'); + case SSHRsaSignatureType.sha512: + return RSASigner(SHA512Digest(), '0609608648016503040203'); + default: + return RSASigner(SHA256Digest(), '0609608648016503040201'); + } + } + + SSHKeyPair? _findIdentity(Uint8List keyBlob) { + for (final identity in _identities) { + final publicKey = identity.toPublicKey().encode(); + if (_bytesEqual(publicKey, keyBlob)) { + return identity; + } + } + return null; + } + + Uint8List _failure() { + final writer = SSHMessageWriter(); + writer.writeUint8(SSHAgentProtocol.failure); + return writer.takeBytes(); + } + + bool _bytesEqual(Uint8List a, Uint8List b) { + if (a.length != b.length) return false; + for (var i = 0; i < a.length; i++) { + if (a[i] != b[i]) return false; + } + return true; + } +} + +class SSHAgentChannel { + SSHAgentChannel(this._channel, this._handler, {this.printDebug}) { + _subscription = _channel.stream.listen( + _handleData, + onDone: _handleDone, + onError: (_, __) => _handleDone(), + ); + } + + final SSHChannel _channel; + final SSHAgentHandler _handler; + final SSHPrintHandler? printDebug; + + StreamSubscription? _subscription; + Uint8List _buffer = Uint8List(0); + bool _processing = false; + + void _handleDone() { + _subscription?.cancel(); + } + + void _handleData(SSHChannelData data) { + _buffer = _appendBytes(_buffer, data.bytes); + _drainRequests(); + } + + void _drainRequests() { + if (_processing) return; + _processing = true; + _processQueue().whenComplete(() => _processing = false); + } + + Future _processQueue() async { + while (_buffer.length >= 4) { + final length = ByteData.sublistView(_buffer, 0, 4).getUint32(0); + if (_buffer.length < 4 + length) return; + final payload = _buffer.sublist(4, 4 + length); + _buffer = _buffer.sublist(4 + length); + Uint8List response; + try { + response = await _handler.handleRequest(payload); + } catch (error) { + printDebug?.call('SSH agent handler error: $error'); + response = _failureResponse(); + } + _sendResponse(response); + } + } + + Uint8List _failureResponse() { + final writer = SSHMessageWriter(); + writer.writeUint8(SSHAgentProtocol.failure); + return writer.takeBytes(); + } + + void _sendResponse(Uint8List payload) { + final writer = SSHMessageWriter(); + writer.writeUint32(payload.length); + writer.writeBytes(payload); + _channel.addData(writer.takeBytes()); + } + + Uint8List _appendBytes(Uint8List a, Uint8List b) { + if (a.isEmpty) return b; + if (b.isEmpty) return a; + final combined = Uint8List(a.length + b.length); + combined.setAll(0, a); + combined.setAll(a.length, b); + return combined; + } +} + +abstract class SSHAgentProtocol { + static const int failure = 5; + static const int requestIdentities = 11; + static const int identitiesAnswer = 12; + static const int signRequest = 13; + static const int signResponse = 14; + static const int rsaSha2_256 = 2; + static const int rsaSha2_512 = 4; +} diff --git a/lib/src/ssh_channel.dart b/lib/src/ssh_channel.dart index 20f0fa8..4dec940 100644 --- a/lib/src/ssh_channel.dart +++ b/lib/src/ssh_channel.dart @@ -120,6 +120,17 @@ class SSHChannelController { return await _requestReplyQueue.next; } + Future sendAgentForwardingRequest() async { + sendMessage( + SSH_Message_Channel_Request( + recipientChannel: remoteId, + requestType: SSHChannelRequestType.authAgent, + wantReply: true, + ), + ); + return await _requestReplyQueue.next; + } + Future sendSubsystem(String subsystem) async { sendMessage( SSH_Message_Channel_Request.subsystem( diff --git a/lib/src/ssh_client.dart b/lib/src/ssh_client.dart index ee841f5..7c4cfd8 100644 --- a/lib/src/ssh_client.dart +++ b/lib/src/ssh_client.dart @@ -5,6 +5,7 @@ import 'dart:typed_data'; import 'package:dartssh2/src/http/http_client.dart'; import 'package:dartssh2/src/sftp/sftp_client.dart'; import 'package:dartssh2/src/ssh_algorithm.dart'; +import 'package:dartssh2/src/ssh_agent.dart'; import 'package:dartssh2/src/ssh_channel.dart'; import 'package:dartssh2/src/ssh_channel_id.dart'; import 'package:dartssh2/src/ssh_errors.dart'; @@ -122,6 +123,9 @@ class SSHClient { /// Function called when authentication is complete. final SSHAuthenticatedHandler? onAuthenticated; + /// Optional handler for SSH agent forwarding requests. + final SSHAgentHandler? agentHandler; + /// The interval at which to send a keep-alive message through the [ping] /// method. Set this to null to disable automatic keep-alive messages. final Duration? keepAliveInterval; @@ -154,6 +158,7 @@ class SSHClient { this.onUserInfoRequest, this.onUserauthBanner, this.onAuthenticated, + this.agentHandler, this.keepAliveInterval = const Duration(seconds: 10), this.disableHostkeyVerification = false, }) { @@ -314,6 +319,10 @@ class SSHClient { } } + if (agentHandler != null) { + await channelController.sendAgentForwardingRequest(); + } + if (pty != null) { final ptyOk = await channelController.sendPtyReq( terminalType: pty.type, @@ -353,6 +362,10 @@ class SSHClient { } } + if (agentHandler != null) { + await channelController.sendAgentForwardingRequest(); + } + if (pty != null) { final ok = await channelController.sendPtyReq( terminalType: pty.type, @@ -676,6 +689,8 @@ class SSHClient { switch (message.channelType) { case 'forwarded-tcpip': return _handleForwardedTcpipChannelOpen(message); + case 'auth-agent@openssh.com': + return _handleAgentChannelOpen(message); } printDebug?.call('unknown channelType: ${message.channelType}'); @@ -740,6 +755,42 @@ class SSHClient { ); } + void _handleAgentChannelOpen(SSH_Message_Channel_Open message) { + final handler = agentHandler; + if (handler == null) { + final reply = SSH_Message_Channel_Open_Failure( + recipientChannel: message.senderChannel, + reasonCode: SSH_Message_Channel_Open_Failure.codeUnknownChannelType, + description: 'agent forwarding not enabled', + ); + _sendMessage(reply); + return; + } + + final localChannelId = _channelIdAllocator.allocate(); + final confirmation = SSH_Message_Channel_Confirmation( + recipientChannel: message.senderChannel, + senderChannel: localChannelId, + initialWindowSize: _initialWindowSize, + maximumPacketSize: _maximumPacketSize, + data: Uint8List(0), + ); + _sendMessage(confirmation); + + final channelController = _acceptChannel( + localChannelId: localChannelId, + remoteChannelId: message.senderChannel, + remoteInitialWindowSize: message.initialWindowSize, + remoteMaximumPacketSize: message.maximumPacketSize, + ); + + SSHAgentChannel( + channelController.channel, + handler, + printDebug: printDebug, + ); + } + /// Finds a remote forward that matches the given host and port. SSHRemoteForward? _findRemoteForward(String host, int port) { final result = _remoteForwards.where( diff --git a/test/src/ssh_agent_test.dart b/test/src/ssh_agent_test.dart new file mode 100644 index 0000000..2f36987 --- /dev/null +++ b/test/src/ssh_agent_test.dart @@ -0,0 +1,71 @@ +import 'dart:typed_data'; + +import 'package:dartssh2/dartssh2.dart'; +import 'package:dartssh2/src/hostkey/hostkey_rsa.dart'; +import 'package:dartssh2/src/ssh_message.dart'; +import 'package:test/test.dart'; + +import '../test_utils.dart'; + +void main() { + final rsaPrivate = fixture('ssh-rsa/id_rsa'); + + SSHKeyPair rsaIdentity() { + return SSHKeyPair.fromPem(rsaPrivate).single; + } + + Uint8List buildRequestIdentities() { + final writer = SSHMessageWriter(); + writer.writeUint8(SSHAgentProtocol.requestIdentities); + return writer.takeBytes(); + } + + Uint8List buildSignRequest(SSHKeyPair identity, Uint8List data, int flags) { + final writer = SSHMessageWriter(); + writer.writeUint8(SSHAgentProtocol.signRequest); + writer.writeString(identity.toPublicKey().encode()); + writer.writeString(data); + writer.writeUint32(flags); + return writer.takeBytes(); + } + + test('SSHKeyPairAgent returns identities', () async { + final identity = rsaIdentity(); + final agent = SSHKeyPairAgent([identity], comment: 'test-key'); + + final response = await agent.handleRequest(buildRequestIdentities()); + final reader = SSHMessageReader(response); + + expect(reader.readUint8(), SSHAgentProtocol.identitiesAnswer); + expect(reader.readUint32(), 1); + final keyBlob = reader.readString(); + final comment = reader.readUtf8(); + + expect(keyBlob, identity.toPublicKey().encode()); + expect(comment, 'test-key'); + }); + + test('SSHKeyPairAgent signs RSA with expected signature type', () async { + final identity = rsaIdentity(); + final agent = SSHKeyPairAgent([identity]); + final data = Uint8List.fromList('sign-me'.codeUnits); + + final cases = { + SSHAgentProtocol.rsaSha2_256: SSHRsaSignatureType.sha256, + SSHAgentProtocol.rsaSha2_512: SSHRsaSignatureType.sha512, + 0: SSHRsaSignatureType.sha1, + }; + + for (final entry in cases.entries) { + final response = await agent.handleRequest( + buildSignRequest(identity, data, entry.key), + ); + final reader = SSHMessageReader(response); + expect(reader.readUint8(), SSHAgentProtocol.signResponse); + + final signatureBlob = reader.readString(); + final signature = SSHRsaSignature.decode(signatureBlob); + expect(signature.type, entry.value); + } + }); +}