From 0347dbfcad9dad9a892f94fb3e48db048f4fc565 Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 15:11:46 -0500 Subject: [PATCH 1/8] Fingerprinting server, bugfixing microsoft telnet.exe and more (#112) 2.2.0 * bugfix: workaround for Microsoft Telnet client crash on ``SB NEW_ENVIRON SEND``, #24. Server now defers ``DO NEW_ENVIRON`` until TTYPE cycling identifies the client, skipping it entirely for MS Telnet (ANSI/VT100). * bugfix: in handling of LINEMODE FORWARDMASK command bytes. * bugfix: SLC fingerprinting byte handling. * bugfix: send IAC GA (Go-Ahead) after prompts when SGA is not negotiated. Fixes hanging for MUD clients like Mudlet. PTY shell uses a 500ms idle timer. Use ``--never-send-ga`` to suppress like old behavior. * performance: with 'smarter' negotiation, default ``connect_maxwait`` reduced from 4.0s to 1.5s. * performance: both client and server protocol data_received methods have approximately ~50x throughput improvement in bulk data transfers. * new: ``Server`` class returned by ``create_server()`` with ``wait_for_client()`` method and ``clients`` property for tracking connected clients. * new: ``TelnetWriter.wait_for()`` and ``wait_for_condition()`` methods for waiting on telnet option negotiation state. * new: ``telnetlib3.sync`` module with blocking (non-asyncio) APIs: ``TelnetConnection`` for clients, ``BlockingTelnetServer`` for servers. * new: ``pty_shell`` module and demonstrating ``telnetlib3-server --pty-exec`` CLI argument and related ``--pty-raw`` server CLI option for raw PTY mode, used by most programs that handle their own terminal I/O. * new: ``guard_shells`` module with ``--robot-check`` and ``--pty-fork-limit`` CLI arguments for connection limiting and bot detection. * new: ``fingerprinting`` module for telnet client identification and capability probing. * new: ``--send-environ`` client CLI option to control which environment variables are sent via NEW_ENVIRON. Default no longer includes HOME or SHELL. --- .github/workflows/ci.yml | 1 - README.rst | 43 +- bin/moderate_fingerprints.py | 284 ++++ bin/server_binary.py | 58 + docs/conf.py | 10 +- docs/guidebook.rst | 43 +- docs/history.rst | 24 +- pyproject.toml | 32 +- telnetlib3/__init__.py | 10 + telnetlib3/_types.py | 32 + telnetlib3/accessories.py | 51 +- telnetlib3/client.py | 289 ++-- telnetlib3/client_base.py | 138 +- telnetlib3/client_shell.py | 41 +- telnetlib3/fingerprinting.py | 979 +++++++++++++ telnetlib3/fingerprinting_display.py | 1383 ++++++++++++++++++ telnetlib3/guard_shells.py | 180 ++- telnetlib3/relay_server.py | 42 +- telnetlib3/server.py | 553 +++++-- telnetlib3/server_base.py | 116 +- telnetlib3/server_pty_shell.py | 320 ++-- telnetlib3/server_shell.py | 72 +- telnetlib3/slc.py | 104 +- telnetlib3/stream_reader.py | 102 +- telnetlib3/stream_writer.py | 601 ++++---- telnetlib3/sync.py | 83 +- telnetlib3/telnetlib.py | 6 +- telnetlib3/telopt.py | 15 +- telnetlib3/tests/accessories.py | 14 +- telnetlib3/tests/pty_helper.py | 10 +- telnetlib3/tests/test_accessories.py | 17 +- telnetlib3/tests/test_benchmarks.py | 1 + telnetlib3/tests/test_charset.py | 24 +- telnetlib3/tests/test_client_unit.py | 215 +++ telnetlib3/tests/test_core.py | 111 +- telnetlib3/tests/test_encoding.py | 91 +- telnetlib3/tests/test_environ.py | 132 +- telnetlib3/tests/test_fingerprinting.py | 1024 +++++++++++++ telnetlib3/tests/test_guard_integration.py | 24 +- telnetlib3/tests/test_linemode.py | 15 +- telnetlib3/tests/test_naws.py | 44 +- telnetlib3/tests/test_pty_shell.py | 235 +-- telnetlib3/tests/test_reader.py | 40 +- telnetlib3/tests/test_server_api.py | 5 +- telnetlib3/tests/test_server_cli.py | 37 +- telnetlib3/tests/test_server_shell_unit.py | 564 +++---- telnetlib3/tests/test_shell.py | 58 +- telnetlib3/tests/test_status_logger.py | 3 +- telnetlib3/tests/test_stream_reader_extra.py | 3 +- telnetlib3/tests/test_stream_writer_extra.py | 82 +- telnetlib3/tests/test_stream_writer_full.py | 57 +- telnetlib3/tests/test_sync.py | 40 +- telnetlib3/tests/test_telnetlib.py | 21 +- telnetlib3/tests/test_timeout.py | 22 +- telnetlib3/tests/test_tspeed.py | 14 +- telnetlib3/tests/test_ttype.py | 30 +- telnetlib3/tests/test_uvloop_integration.py | 6 +- telnetlib3/tests/test_writer.py | 102 +- telnetlib3/tests/test_xdisploc.py | 17 +- tox.ini | 18 +- 60 files changed, 6712 insertions(+), 1976 deletions(-) create mode 100755 bin/moderate_fingerprints.py create mode 100644 bin/server_binary.py create mode 100644 telnetlib3/_types.py create mode 100644 telnetlib3/fingerprinting.py create mode 100644 telnetlib3/fingerprinting_display.py create mode 100644 telnetlib3/tests/test_client_unit.py create mode 100644 telnetlib3/tests/test_fingerprinting.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b7dcdbcc..36c738f5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,7 +58,6 @@ jobs: - ubuntu-latest - windows-latest python-version: - - "3.8" - "3.9" - "3.10" - "3.11" diff --git a/README.rst b/README.rst index 4ae33047..82923b21 100644 --- a/README.rst +++ b/README.rst @@ -29,7 +29,7 @@ Introduction ============ -``telnetlib3`` is a full-featured Telnet Client and Server library for python3.8 and newer. +``telnetlib3`` is a full-featured Telnet Client and Server library for python3.9 and newer. Modern asyncio_ and legacy blocking API's are provided. @@ -74,6 +74,32 @@ program. telnetlib3-server 0.0.0.0 1984 --shell=bin.server_wargame.shell telnetlib3-server --pty-exec /bin/bash -- --login +Fingerprinting Server +--------------------- + +A built-in fingerprinting server shell is provided to uniquely identify telnet clients:: + + export TELNETLIB3_DATA_DIR=./data + telnetlib3-server --shell telnetlib3.fingerprinting_server_shell + +A public fingerprinting server you can try out yourself:: + + telnet 1984.ws 555 + +An optional post-fingerprint hook can process saved files. The hook is run as +``python -m ``. The built-in post-script pretty-prints the JSON +and integrates with ucs-detect_ for terminal capability probing:: + + export TELNETLIB3_DATA_DIR=./fingerprints + export TELNETLIB3_FINGERPRINT_POST_SCRIPT=telnetlib3.fingerprinting + telnetlib3-server --shell telnetlib3.fingerprinting_server_shell + +If ucs-detect_ is installed and available in PATH, the post-script automatically +runs it to probe terminal capabilities (colors, sixel, kitty graphics, etc.) and +adds the results to the fingerprint data as ``terminal-fingerprint-data``. + +.. _ucs-detect: https://github.com/jquast/ucs-detect + Legacy telnetlib ---------------- @@ -113,6 +139,21 @@ or CHARSET to negotiate about it. In this case, use ``--force-binary`` and ``--encoding`` when the encoding of the remote end is known. +Go-Ahead (GA) +-------------- + +When a client does not negotiate Suppress Go-Ahead (SGA), the server sends +``IAC GA`` after output to signal that the client may transmit. This is +correct behavior for MUD clients like Mudlet that expect prompt detection +via GA. + +If GA causes unwanted output for your use case, disable it:: + + telnetlib3-server --never-send-ga + +For PTY shells, GA is sent after 500ms of output idle time to avoid +injecting GA in the middle of streaming output. + Quick Example ============= diff --git a/bin/moderate_fingerprints.py b/bin/moderate_fingerprints.py new file mode 100755 index 00000000..06729cc7 --- /dev/null +++ b/bin/moderate_fingerprints.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python +# pylint: disable=cyclic-import +"""Moderate fingerprint name suggestions.""" + +# std imports +import os +import sys +import json +import shutil +import subprocess +import collections +from pathlib import Path + +_BAT = shutil.which("bat") or shutil.which("batcat") +_JQ = shutil.which("jq") +_UNKNOWN = "0" * 16 +_PROBES = { + "telnet-probe": ("telnet-client", "telnet-client-revision"), + "terminal-probe": ("terminal-emulator", "terminal-emulator-revision"), +} + + +def _iter_files(data_dir): + """Yield (path, data) for each client JSON file.""" + client_base = data_dir / "client" + if client_base.is_dir(): + for path in sorted(client_base.glob("*/*/*.json")): + try: + with open(path, encoding="utf-8") as f: + yield path, json.load(f) + except (OSError, json.JSONDecodeError): + continue + + +def _print_json(label, data): + """Print labeled JSON, colorized through bat or jq when available.""" + raw = json.dumps(data, indent=4, sort_keys=True) + if _BAT: + r = subprocess.run( + [_BAT, "-l", "json", "--style=plain", "--color=always"], + input=raw, + capture_output=True, + text=True, + check=False, + ) + if r.returncode == 0: + raw = r.stdout.rstrip("\n") + elif _JQ: + r = subprocess.run([_JQ, "-C", "."], input=raw, capture_output=True, text=True, check=False) + if r.returncode == 0: + raw = r.stdout.rstrip("\n") + print(f"{label} {raw}") + + +def _print_telnet_context(session_data): + """Print key telnet session fields for moderation context.""" + ttype_cycle = session_data.get("ttype_cycle", []) + if ttype_cycle: + print(f" ttype cycle: {' -> '.join(ttype_cycle)}") + + extra = session_data.get("extra", {}) + if extra: + for key in sorted(extra): + print(f" {key}: {extra[key]}") + + +def _print_terminal_context(session_data): + """Print key terminal session fields for moderation context.""" + software = session_data.get("software_name") + version = session_data.get("software_version") + if software: + sw_str = software + if version: + sw_str += f" {version}" + print(f" software: {sw_str}") + + aw = session_data.get("ambiguous_width") + if aw is not None: + print(f" ambiguous_width: {aw}") + + +def _print_paired(paired_hashes, label, names): + """Print paired fingerprint hashes with names when known.""" + if not paired_hashes: + return + other_label = "terminal" if label == "telnet" else "telnet" + parts = [] + for ph in sorted(paired_hashes): + name = names.get(ph) + if name: + parts.append(f"{name} ({ph[:8]})") + else: + parts.append(ph[:12]) + print(f" paired {other_label}: {', '.join(parts)}") + + +def _load_names(data_dir): + try: + with open(data_dir / "fingerprint_names.json", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return {} + + +def _save_names(data_dir, names): + path = data_dir / "fingerprint_names.json" + tmp = path.with_suffix(".json.new") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(names, f, indent=2, sort_keys=True) + os.rename(tmp, path) + print(f"\nSaved {path}") + + +def _scan(data_dir, names, revise=False): + """Return entries for review. + + Each entry is ``(label, hash, suggestions, fp_data, session, paired)``. + """ + suggestions = collections.defaultdict(list) + fp_data = {} + labels = {} + sessions = {} + paired = collections.defaultdict(set) + + for _, data in _iter_files(data_dir): + file_sug = data.get("suggestions", {}) + for probe_key, (sug_key, rev_key) in _PROBES.items(): + h = data.get(probe_key, {}).get("fingerprint") + if not h or h == _UNKNOWN: + continue + labels.setdefault(h, probe_key.split("-", maxsplit=1)[0]) + fp_data.setdefault(h, data.get(probe_key, {}).get("fingerprint-data", {})) + sessions.setdefault(h, data.get(probe_key, {}).get("session_data", {})) + other = "terminal-probe" if probe_key == "telnet-probe" else "telnet-probe" + other_h = data.get(other, {}).get("fingerprint") + if other_h and other_h != _UNKNOWN: + paired[h].add(other_h) + look = rev_key if revise else sug_key + if look in file_sug: + suggestions[h].append(file_sug[look]) + + return [ + ( + labels[h], + h, + suggestions.get(h, []), + fp_data[h], + sessions.get(h, {}), + paired.get(h, set()), + ) + for h in sorted(fp_data) + if (h in names) == revise + ] + + +def _review(entries, names): + """Interactive review loop. Return True if any names were added.""" + updated = False + for label, h, sug_list, fpd, session_data, paired_hashes in entries: + current = names.get(h) + print(f"\n{'=' * 60}\n {label}: {h}") + if current: + print(f" current name: {current}") + + if fpd: + _print_json(" fingerprint-data:", fpd) + + if label == "telnet" and session_data: + _print_telnet_context(session_data) + elif label == "terminal" and session_data: + _print_terminal_context(session_data) + _print_paired(paired_hashes, label, names) + + default = "" + if sug_list: + counter = collections.Counter(sug_list) + default = counter.most_common(1)[0][0] + print(f" {sum(counter.values())} suggestion(s):") + for name, count in counter.most_common(): + print(f" {count}x {name}") + else: + print(" (no client suggestions)") + + suffix = f"for '{default}'" if default else "to skip" + try: + raw = input(f" Name (return {suffix}): ").strip() + except EOFError: + print() + continue + except KeyboardInterrupt: + print("\nAborted.") + return updated + + chosen = raw or default + if chosen and chosen != current: + names[h] = chosen + updated = True + print(f" -> {h} = {chosen}") + + return updated + + +def _relocate(data_dir): + """Move misplaced JSON files to match their internal fingerprint hashes.""" + client_base = data_dir / "client" + moved = 0 + stale = set() + for path, data in _iter_files(data_dir): + th = data.get("telnet-probe", {}).get("fingerprint") + tmh = data.get("terminal-probe", {}).get("fingerprint", _UNKNOWN) + if not th: + continue + if path.parent.parent.name == th and path.parent.name == tmh: + continue + target = client_base / th / tmh / path.name + if target.exists(): + continue + target.parent.mkdir(parents=True, exist_ok=True) + os.rename(path, target) + moved += 1 + stale.add(path.parent) + + for d in stale: + try: + d.rmdir() + d.parent.rmdir() + except OSError: + pass + return moved + + +def _prune(data_dir, names): + """Remove named hashes that have no data files.""" + hashes = set() + for path, _ in _iter_files(data_dir): + hashes.update({path.parent.parent.name, path.parent.name}) + orphaned = {h: n for h, n in names.items() if h not in hashes} + if not orphaned: + return False + + print(f"Found {len(orphaned)} orphaned hash(es):") + for h, name in sorted(orphaned.items(), key=lambda x: x[1]): + print(f" {h} {name}") + try: + if input("\nRemove? [y/N] ").strip().lower() != "y": + return False + except (EOFError, KeyboardInterrupt): + print() + return False + + for h in orphaned: + del names[h] + return True + + +def main(): + """CLI entry point for moderating fingerprint name suggestions.""" + data_dir_env = os.environ.get("TELNETLIB3_DATA_DIR") + if not data_dir_env: + print("Error: TELNETLIB3_DATA_DIR not set", file=sys.stderr) + sys.exit(1) + data_dir = Path(data_dir_env) + if not data_dir.exists(): + print(f"Error: {data_dir} does not exist", file=sys.stderr) + sys.exit(1) + + revise = "--check-revise" in sys.argv + relocated = _relocate(data_dir) + if relocated: + print(f"Relocated {relocated} file(s).\n") + + names = _load_names(data_dir) + if "--no-prune" not in sys.argv and _prune(data_dir, names): + _save_names(data_dir, names) + + entries = _scan(data_dir, names, revise) + if entries and _review(entries, names): + _save_names(data_dir, names) + elif not entries: + print("Nothing to review.") + + +if __name__ == "__main__": + main() diff --git a/bin/server_binary.py b/bin/server_binary.py new file mode 100644 index 00000000..88b34743 --- /dev/null +++ b/bin/server_binary.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +""" +Telnet server using binary (raw bytes) mode. + +This example demonstrates using ``encoding=False`` for a server that works +with raw bytes instead of Unicode strings. This is useful for protocol +bridging, binary data transfer, or custom protocols over telnet. + +When encoding is set (the default), the shell callback receives +``TelnetReaderUnicode`` and ``TelnetWriterUnicode``, which read and write +``str``. When ``encoding=False``, the shell receives ``TelnetReader`` and +``TelnetWriter``, which read and write ``bytes``. + +Run this server, then connect with: telnet localhost 6023 + +Example session:: + + $ telnet localhost 6023 + Escape character is '^]'. + [binary echo server] type something: + hello + hex: 68 65 6c 6c 6f 0d 0a + Connection closed by foreign host. +""" + +# std imports +import asyncio + +# local +import telnetlib3 # pylint: disable=cyclic-import + + +async def shell(reader, writer): + """Echo client input back as hex bytes.""" + writer.write(b"[binary echo server] type something:\r\n") + await writer.drain() + + data = await reader.read(128) + if data: + hex_str = " ".join(f"{b:02x}" for b in data) + writer.write(f"hex: {hex_str}\r\n".encode("ascii")) + await writer.drain() + writer.close() + + +async def main(): + """Start the telnet server in binary mode.""" + server = await telnetlib3.create_server( + host="127.0.0.1", port=6023, shell=shell, encoding=False + ) + print("Binary telnet server running on localhost:6023") + print("Connect with: telnet localhost 6023") + print("Press Ctrl+C to stop") + await server.wait_closed() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/conf.py b/docs/conf.py index 6ae464d1..7073fa86 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,6 +42,7 @@ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", "github", ] @@ -66,10 +67,10 @@ # built documents. # # The short X.Y version. -version = "2.1" +version = "2.2" # The full version, including alpha/beta/rc tags. -release = "2.1.0" # keep in sync with setup.py and telnetlib3/accessories.py !! +release = "2.2.0" # keep in sync with setup.py and telnetlib3/accessories.py !! # The language for content auto-generated by Sphinx. Refer to documentation # for a list of supported languages. @@ -243,6 +244,11 @@ # Ignore these references that can't be resolved (internal asyncio paths, etc.) nitpick_ignore = [ ("py:class", "asyncio.events.AbstractEventLoop"), + ("py:class", "asyncio.transports.BaseTransport"), + ("py:class", "asyncio.protocols.Protocol"), + ("py:class", "_asyncio.Task"), + ("py:class", "_asyncio.Future"), + ("py:data", "typing.Union"), ] # Both the class’ and the __init__ method’s docstring are concatenated and diff --git a/docs/guidebook.rst b/docs/guidebook.rst index 653450f6..5a150ac0 100644 --- a/docs/guidebook.rst +++ b/docs/guidebook.rst @@ -111,7 +111,6 @@ questions. Demonstrates the client shell callback pattern. :language: python :lines: 18-41 - Server API Reference -------------------- @@ -175,6 +174,48 @@ The ``wait_for_condition()`` method waits for a custom condition:: lambda w: w.mode == "kludge" and w.remote_option.enabled(ECHO) ) +Encoding and Binary Mode +------------------------ + +By default, telnetlib3 uses ``encoding="utf8"``, which means the shell +callback receives ``TelnetReaderUnicode`` and ``TelnetWriterUnicode``. +These work with Python ``str`` -- you read and write strings:: + + async def shell(reader, writer): + writer.write("Hello, world!\r\n") # str + data = await reader.read(1) # returns str + +To work with raw bytes instead, pass ``encoding=False`` to +``create_server()`` or ``open_connection()``. The shell then receives +``TelnetReader`` and ``TelnetWriter``, which work with ``bytes``:: + + async def binary_shell(reader, writer): + writer.write(b"Hello, world!\r\n") # bytes + data = await reader.read(1) # returns bytes + + await telnetlib3.create_server( + host="127.0.0.1", port=6023, + shell=binary_shell, encoding=False + ) + +Binary mode is useful for specific low-level conditions, like performing +xmodem transfers, or working with legacy systems that predate unicode +and utf-8 support. + +The same applies to clients -- ``open_connection(..., encoding=False)`` +returns a ``(TelnetReader, TelnetWriter)`` pair that works with ``bytes``. + +server_binary.py +~~~~~~~~~~~~~~~~ + +https://github.com/jquast/telnetlib3/blob/master/bin/server_binary.py + +A telnet server in binary mode that echoes client input as hex bytes. +Demonstrates using ``encoding=False`` for raw byte I/O. + +.. literalinclude:: ../bin/server_binary.py + :language: python + :lines: 34-51 Blocking Interface ================== diff --git a/docs/history.rst b/docs/history.rst index 2d8eb33a..bc861e8e 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -1,6 +1,19 @@ History ======= 2.2.0 + * bugfix: workaround for Microsoft Telnet client crash on + ``SB NEW_ENVIRON SEND``, :ghissue:`24`. Server now defers ``DO + NEW_ENVIRON`` until TTYPE cycling identifies the client, skipping it + entirely for MS Telnet (ANSI/VT100). + * bugfix: in handling of LINEMODE FORWARDMASK command bytes. + * bugfix: SLC fingerprinting byte handling. + * bugfix: send IAC GA (Go-Ahead) after prompts when SGA is not negotiated. + Fixes hanging for MUD clients like Mudlet. PTY shell uses a 500ms idle + timer. Use ``--never-send-ga`` to suppress like old behavior. + * performance: with 'smarter' negotiation, default ``connect_maxwait`` + reduced from 4.0s to 1.5s. + * performance: both client and server protocol data_received methods + have approximately ~50x throughput improvement in bulk data transfers. * new: ``Server`` class returned by ``create_server()`` with ``wait_for_client()`` method and ``clients`` property for tracking connected clients. @@ -9,8 +22,15 @@ History * new: ``telnetlib3.sync`` module with blocking (non-asyncio) APIs: ``TelnetConnection`` for clients, ``BlockingTelnetServer`` for servers. * new: ``pty_shell`` module and demonstrating ``telnetlib3-server --pty-exec`` CLI argument - * performance: both client and server protocol data_received methods were - optimized for ~50x throughput improvement in bulk data transfers. + and related ``--pty-raw`` server CLI option for raw PTY mode, used by most + programs that handle their own terminal I/O. + * new: ``guard_shells`` module with ``--robot-check`` and ``--pty-fork-limit`` + CLI arguments for connection limiting and bot detection. + * new: ``fingerprinting`` module for telnet client identification and + capability probing. + * new: ``--send-environ`` client CLI option to control which environment + variables are sent via NEW_ENVIRON. Default no longer includes HOME or + SHELL. 2.0.8 * bugfix: object has no attribute '_extra' :ghissue:`100` diff --git a/pyproject.toml b/pyproject.toml index 496bb8c6..36d517ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "hatchling.build" [project] name = "telnetlib3" -version = "2.1.0" -description = "Python 3 asyncio Telnet server and client Protocol library" +version = "2.2.0" +description = " Python Telnet server and client CLI and Protocol library" readme = "README.rst" license = "ISC" license-files = ["LICENSE"] @@ -30,7 +30,6 @@ classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: ISC License (ISCL)", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -43,12 +42,18 @@ classifiers = [ "Topic :: System :: Shells", "Topic :: Terminals :: Telnet", ] -requires-python = ">=3.8" +requires-python = ">=3.9" [project.optional-dependencies] docs = [ "Sphinx>3", "sphinx_rtd_theme", + "sphinx-autodoc-typehints", +] +extras = [ + "ucs-detect>=2", + "prettytable", + "pyyaml", ] [project.scripts] @@ -127,6 +132,23 @@ min-similarity-lines = 8 reports = false msg-template = "{path}:{line}: [{msg_id}({symbol}), {obj}] {msg}" +[tool.mypy] +python_version = "3.9" +strict = true +disallow_subclassing_any = false +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["telnetlib3.tests.*"] +disallow_untyped_defs = false +disallow_incomplete_defs = false +disallow_untyped_calls = false +warn_return_any = false + +[[tool.mypy.overrides]] +module = ["telnetlib3.telnetlib"] +ignore_errors = true + [tool.black] line-length = 100 -target-version = ["py38", "py39", "py310", "py311", "py312", "py313"] +target-version = ["py39", "py310", "py311", "py312", "py313"] diff --git a/telnetlib3/__init__.py b/telnetlib3/__init__.py index ed8b65ef..e759b876 100644 --- a/telnetlib3/__init__.py +++ b/telnetlib3/__init__.py @@ -1,5 +1,8 @@ """telnetlib3: an asyncio Telnet Protocol implemented in python.""" +# std imports +import sys + # flake8: noqa: F405 # fmt: off # isort: off @@ -18,6 +21,9 @@ from . import slc from . import telnetlib from . import guard_shells +from . import fingerprinting +if sys.platform != "win32": + from . import fingerprinting_display # noqa: F401 from . import sync from .server_base import * # noqa from .server import * # noqa @@ -30,6 +36,9 @@ from .slc import * # noqa from .telnetlib import * # noqa from .guard_shells import * # noqa +from .fingerprinting import * # noqa +if sys.platform != "win32": + from .fingerprinting_display import * # noqa from .sync import * # noqa try: from . import server_pty_shell @@ -49,6 +58,7 @@ + server.__all__ + server_shell.__all__ + guard_shells.__all__ + + fingerprinting.__all__ + (server_pty_shell.__all__ if PTY_SUPPORT else ()) # client, + client_base.__all__ diff --git a/telnetlib3/_types.py b/telnetlib3/_types.py new file mode 100644 index 00000000..06558e16 --- /dev/null +++ b/telnetlib3/_types.py @@ -0,0 +1,32 @@ +"""Shared type aliases for telnetlib3 public API.""" + +from __future__ import annotations + +# std imports +from typing import Any, Dict, Tuple, Union, Literal, Callable, Coroutine + +# local +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode + +# Encoding parameter: str for Unicode mode, False for binary mode. +Encoding = Union[str, Literal[False]] + +# Shell callback: async def shell(reader, writer) -> None. +ShellCallback = Callable[ + [Union[TelnetReader, TelnetReaderUnicode], Union[TelnetWriter, TelnetWriterUnicode]], + Coroutine[Any, Any, None], +] + +# Reader/writer union types. +ReaderType = Union[TelnetReader, TelnetReaderUnicode] +WriterType = Union[TelnetWriter, TelnetWriterUnicode] + +# Environment mapping. +EnvironMapping = Dict[str, str] + +# NAWS window dimensions (cols, rows). +WindowSize = Tuple[int, int] + +# Terminal speed (rx, tx). +TerminalSpeed = Tuple[int, int] diff --git a/telnetlib3/accessories.py b/telnetlib3/accessories.py index 9f0c428e..9dc12872 100644 --- a/telnetlib3/accessories.py +++ b/telnetlib3/accessories.py @@ -1,10 +1,17 @@ """Accessory functions.""" +from __future__ import annotations + # std imports import shlex import asyncio import logging import importlib +from typing import TYPE_CHECKING, Any, Dict, Union, Mapping, Callable, Optional + +if TYPE_CHECKING: # pragma: no cover + # local + from .stream_reader import TelnetReader, TelnetReaderUnicode __all__ = ( "encoding_from_lang", @@ -17,29 +24,37 @@ ) -def get_version(): +def get_version() -> str: """Return the current version of telnetlib3.""" - return "2.1.0" # keep in sync with setup.py and docs/conf.py !! + return "2.2.0" # keep in sync with setup.py and docs/conf.py !! -def encoding_from_lang(lang): +def encoding_from_lang(lang: str) -> Optional[str]: """ Parse encoding from LANG environment value. + Returns the encoding portion if present, or None if the LANG value + does not contain an encoding suffix (no '.' separator). + + :param lang: LANG environment value (e.g., 'en_US.UTF-8@misc') + :returns: Encoding string (e.g., 'UTF-8') or None if no encoding found. + Example:: >>> encoding_from_lang('en_US.UTF-8@misc') 'UTF-8' + >>> encoding_from_lang('en_IL') + None """ - encoding = lang - if "." in lang: - _, encoding = lang.split(".", 1) + if "." not in lang: + return None + _, encoding = lang.split(".", 1) if "@" in encoding: encoding, _ = encoding.split("@", 1) return encoding -def name_unicode(ucs): +def name_unicode(ucs: str) -> str: """Return 7-bit ascii printable of any string.""" # more or less the same as curses.ascii.unctrl -- but curses # module is conditionally excluded from many python distributions! @@ -56,7 +71,7 @@ def name_unicode(ucs): return rep -def eightbits(number): +def eightbits(number: int) -> str: """ Binary representation of ``number`` padded to 8 bits. @@ -75,11 +90,16 @@ def eightbits(number): ) -def make_logger(name, loglevel="info", logfile=None, logfmt=_DEFAULT_LOGFMT): +def make_logger( + name: str, + loglevel: str = "info", + logfile: Optional[str] = None, + logfmt: str = _DEFAULT_LOGFMT, +) -> logging.Logger: """Create and return simple logger for given arguments.""" lvl = getattr(logging, loglevel.upper()) - _cfg = {"format": logfmt} + _cfg: Dict[str, Any] = {"format": logfmt} if logfile: _cfg["filename"] = logfile logging.basicConfig(**_cfg) @@ -88,20 +108,23 @@ def make_logger(name, loglevel="info", logfile=None, logfmt=_DEFAULT_LOGFMT): return logging.getLogger(name) -def repr_mapping(mapping): +def repr_mapping(mapping: Mapping[str, Any]) -> str: """Return printable string, 'key=value [key=value ...]' for mapping.""" return " ".join(f"{key}={shlex.quote(str(value))}" for key, value in mapping.items()) -def function_lookup(pymod_path): +def function_lookup(pymod_path: str) -> Callable[..., Any]: """Return callable function target from standard module.function path.""" module_name, func_name = pymod_path.rsplit(".", 1) module = importlib.import_module(module_name) - shell_function = getattr(module, func_name) + shell_function: Callable[..., Any] = getattr(module, func_name) assert callable(shell_function), shell_function return shell_function -def make_reader_task(reader, size=2**12): +def make_reader_task( + reader: "Union[TelnetReader, TelnetReaderUnicode, asyncio.StreamReader]", + size: int = 2**12, +) -> "asyncio.Task[Any]": """Return asyncio task wrapping coroutine of reader.read(size).""" return asyncio.ensure_future(reader.read(size)) diff --git a/telnetlib3/client.py b/telnetlib3/client.py index 84508b7c..61719ae7 100755 --- a/telnetlib3/client.py +++ b/telnetlib3/client.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """Telnet Client API for the 'telnetlib3' python package.""" +from __future__ import annotations + # std imports import os import sys @@ -8,10 +10,22 @@ import struct import asyncio import argparse +from typing import ( + Any, + Dict, + List, + Tuple, + Union, + Callable, + Optional, + Sequence, +) # local -# local imports from telnetlib3 import accessories, client_base +from telnetlib3._types import ShellCallback +from telnetlib3.stream_reader import TelnetReader, TelnetReaderUnicode +from telnetlib3.stream_writer import TelnetWriter, TelnetWriterUnicode __all__ = ("TelnetClient", "TelnetTerminalClient", "open_connection") @@ -30,21 +44,43 @@ class TelnetClient(client_base.BaseClient): #: full default LANG value of 'en_US.utf8' DEFAULT_LOCALE = "en_US" - def __init__( # pylint: disable=too-many-positional-arguments,keyword-arg-before-vararg + #: Default environment variables to send via NEW_ENVIRON + DEFAULT_SEND_ENVIRON = ("TERM", "LANG", "COLUMNS", "LINES", "COLORTERM") + + def __init__( # pylint: disable=too-many-positional-arguments self, - term="unknown", - cols=80, - rows=25, - tspeed=(38400, 38400), - xdisploc="", - *args, - **kwargs, - ): + term: str = "unknown", + cols: int = 80, + rows: int = 25, + tspeed: Tuple[int, int] = (38400, 38400), + xdisploc: str = "", + send_environ: Optional[Sequence[str]] = None, + shell: Optional[ShellCallback] = None, + encoding: Union[str, bool] = "utf8", + encoding_errors: str = "strict", + force_binary: bool = False, + connect_minwait: float = 1.0, + connect_maxwait: float = 4.0, + limit: Optional[int] = None, + waiter_closed: Optional[asyncio.Future[None]] = None, + _waiter_connected: Optional[asyncio.Future[None]] = None, + ) -> None: """Initialize TelnetClient with terminal parameters.""" - super().__init__(*args, **kwargs) + super().__init__( + shell=shell, + encoding=encoding, + encoding_errors=encoding_errors, + force_binary=force_binary, + connect_minwait=connect_minwait, + connect_maxwait=connect_maxwait, + limit=limit, + waiter_closed=waiter_closed, + _waiter_connected=_waiter_connected, + ) + self._send_environ = set(send_environ or self.DEFAULT_SEND_ENVIRON) self._extra.update( { - "charset": kwargs["encoding"] or "", + "charset": encoding or "", # for our purposes, we only send the second part (encoding) of our # 'lang' variable, CHARSET negotiation does not provide locale # negotiation; this is better left to the real LANG variable @@ -53,11 +89,7 @@ def __init__( # pylint: disable=too-many-positional-arguments,keyword-arg-befor # So which locale should we represent? Rather than using the # locale.getpreferredencoding() method, we provide a deterministic # class value DEFAULT_LOCALE (en_US), derive and modify as needed. - "lang": ( - "C" - if not kwargs["encoding"] - else self.DEFAULT_LOCALE + "." + kwargs["encoding"] - ), + "lang": ("C" if not encoding else self.DEFAULT_LOCALE + "." + str(encoding)), "cols": cols, "rows": rows, "term": term, @@ -66,7 +98,7 @@ def __init__( # pylint: disable=too-many-positional-arguments,keyword-arg-befor } ) - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Handle connection made to server. @@ -78,6 +110,7 @@ def connection_made(self, transport): from telnetlib3.telopt import NAWS, TTYPE, TSPEED, CHARSET, XDISPLOC, NEW_ENVIRON super().connection_made(transport) + assert self.writer is not None # Wire extended rfc callbacks for requests of # terminal attributes, environment values, etc. @@ -93,55 +126,70 @@ def connection_made(self, transport): # Override the default handle_will method to detect when both sides support CHARSET original_handle_will = self.writer.handle_will + writer = self.writer - def enhanced_handle_will(opt): + def enhanced_handle_will(opt: bytes) -> None: result = original_handle_will(opt) # If this was a WILL CHARSET from the server, and we also have WILL CHARSET enabled, # log that both sides support CHARSET. The server should initiate the actual REQUEST. if ( opt == CHARSET - and self.writer.remote_option.enabled(CHARSET) - and self.writer.local_option.enabled(CHARSET) + and writer.remote_option.enabled(CHARSET) + and writer.local_option.enabled(CHARSET) ): self.log.debug("Both sides support CHARSET, ready for server to initiate REQUEST") return result - self.writer.handle_will = enhanced_handle_will + self.writer.handle_will = enhanced_handle_will # type: ignore[method-assign] - def send_ttype(self): + def send_ttype(self) -> str: """Callback for responding to TTYPE requests.""" - return self._extra["term"] + result: str = self._extra["term"] + return result - def send_tspeed(self): + def send_tspeed(self) -> Tuple[int, int]: """Callback for responding to TSPEED requests.""" - return tuple(map(int, self._extra["tspeed"].split(","))) + parts = self._extra["tspeed"].split(",") + return (int(parts[0]), int(parts[1])) - def send_xdisploc(self): + def send_xdisploc(self) -> str: """Callback for responding to XDISPLOC requests.""" - return self._extra["xdisploc"] + result: str = self._extra["xdisploc"] + return result - def send_env(self, keys): + def send_env(self, keys: Sequence[str]) -> Dict[str, Any]: """ Callback for responding to NEW_ENVIRON requests. - :param dict keys: Values are requested for the keys specified. When empty, all environment + Only sends variables listed in ``_send_environ`` (set via ``send_environ`` + parameter or ``--send-environ`` CLI option). + + :param keys: Values are requested for the keys specified. When empty, all environment values that wish to be volunteered should be returned. - :returns: dictionary of environment values requested, or an empty string for keys not + :returns: Environment values requested, or an empty string for keys not available. A return value must be given for each key requested. - :rtype: dict """ - env = { + # All available values + all_env = { + # Terminal info from connection parameters "LANG": self._extra["lang"], "TERM": self._extra["term"], - "DISPLAY": self._extra["xdisploc"], "LINES": self._extra["rows"], "COLUMNS": self._extra["cols"], + # Environment variables from os.environ + "COLORTERM": os.environ.get("COLORTERM", ""), + "USER": os.environ.get("USER", ""), + "HOME": os.environ.get("HOME", ""), + "SHELL": os.environ.get("SHELL", ""), + # Note: DISPLAY intentionally not available (security) } + # Filter to only allowed variables + env = {k: v for k, v in all_env.items() if k in self._send_environ} return {key: env.get(key, "") for key in keys} or env - def send_charset(self, offered): + def send_charset(self, offered: List[str]) -> str: """ Callback for responding to CHARSET requests. @@ -159,13 +207,13 @@ def send_charset(self, offered): - If no viable encodings found, reject - :param list offered: list of CHARSET options offered by server. - :returns: character encoding agreed to be used, or "" to reject. - :rtype: str + :param offered: CHARSET options offered by server. + :returns: Character encoding agreed to be used, or empty string to reject. """ # Get client's desired encoding canonical name desired_name = None if self.default_encoding: + assert isinstance(self.default_encoding, str) try: desired_name = codecs.lookup(self.default_encoding).name except LookupError: @@ -230,35 +278,38 @@ def send_charset(self, offered): self.log.warning("No suitable encoding offered by server: %s", offered) return "" - def send_naws(self): + def send_naws(self) -> Tuple[int, int]: """ Callback for responding to NAWS requests. - :rtype: (int, int) - :returns: client window size as (rows, columns). + :returns: Client window size as (rows, columns). """ return (self._extra["rows"], self._extra["cols"]) - def encoding(self, outgoing=None, incoming=None): + def encoding( + self, + outgoing: Optional[bool] = None, + incoming: Optional[bool] = None, + ) -> str: """ Return encoding for the given stream direction. - :param bool outgoing: Whether the return value is suitable for + :param outgoing: Whether the return value is suitable for encoding bytes for transmission to server. - :param bool incoming: Whether the return value is suitable for + :param incoming: Whether the return value is suitable for decoding bytes received by the client. - :raises TypeError: when a direction argument, either ``outgoing`` + :raises TypeError: When a direction argument, either ``outgoing`` or ``incoming``, was not set ``True``. :returns: ``'US-ASCII'`` for the directions indicated, unless ``BINARY`` :rfc:`856` has been negotiated for the direction - indicated or :attr`force_binary` is set ``True``. - :rtype: str + indicated or ``force_binary`` is set ``True``. """ if not (outgoing or incoming): raise TypeError( "encoding arguments 'outgoing' and 'incoming' are required: toggle at least one." ) + assert self.writer is not None # may we encode in the direction indicated? _outgoing_only = outgoing and not incoming _incoming_only = not outgoing and incoming @@ -274,42 +325,41 @@ def encoding(self, outgoing=None, incoming=None): # default_encoding, may be re-negotiated later. Only the CHARSET # negotiation method allows the server to select an encoding, so # this value is reflected here by a single return statement. - return self._extra["charset"] + result: str = self._extra["charset"] + return result return "US-ASCII" class TelnetTerminalClient(TelnetClient): """Telnet client for sessions with a network virtual terminal (NVT).""" - def send_naws(self): + def send_naws(self) -> Tuple[int, int]: """ Callback replies to request for window size, NAWS :rfc:`1073`. - :rtype: (int, int) - :returns: window dimensions by lines and columns + :returns: Window dimensions by lines and columns. """ return self._winsize() - def send_env(self, keys): + def send_env(self, keys: Sequence[str]) -> Dict[str, Any]: """ Callback replies to request for env values, NEW_ENVIRON :rfc:`1572`. - :rtype: dict - :returns: super class value updated with window LINES and COLUMNS. + :returns: Super class value updated with window LINES and COLUMNS. """ env = super().send_env(keys) env["LINES"], env["COLUMNS"] = self._winsize() return env @staticmethod - def _winsize(): + def _winsize() -> Tuple[int, int]: try: # std imports import fcntl # pylint: disable=import-outside-toplevel import termios # pylint: disable=import-outside-toplevel fmt = "hhhh" - buf = "\x00" * struct.calcsize(fmt) + buf = b"\x00" * struct.calcsize(fmt) val = fcntl.ioctl(sys.stdin.fileno(), termios.TIOCGWINSZ, buf) rows, cols, _, _ = struct.unpack(fmt, val) return rows, cols @@ -322,93 +372,96 @@ def _winsize(): async def open_connection( # pylint: disable=too-many-locals - host=None, - port=23, + host: Optional[str] = None, + port: int = 23, *, - client_factory=None, - family=0, - flags=0, - local_addr=None, - encoding="utf8", - encoding_errors="replace", - force_binary=False, - term="unknown", - cols=80, - rows=25, - tspeed=(38400, 38400), - xdisploc="", - shell=None, - connect_minwait=2.0, - connect_maxwait=3.0, - waiter_closed=None, - _waiter_connected=None, - limit=None, -): + client_factory: Optional[Callable[..., client_base.BaseClient]] = None, + family: int = 0, + flags: int = 0, + local_addr: Optional[Tuple[str, int]] = None, + encoding: Union[str, bool] = "utf8", + encoding_errors: str = "replace", + force_binary: bool = False, + term: str = "unknown", + cols: int = 80, + rows: int = 25, + tspeed: Tuple[int, int] = (38400, 38400), + xdisploc: str = "", + shell: Optional[ShellCallback] = None, + connect_minwait: float = 2.0, + connect_maxwait: float = 3.0, + waiter_closed: Optional[asyncio.Future[None]] = None, + _waiter_connected: Optional[asyncio.Future[None]] = None, + limit: Optional[int] = None, + send_environ: Optional[Sequence[str]] = None, +) -> Tuple[ + Union[TelnetReader, TelnetReaderUnicode], + Union[TelnetWriter, TelnetWriterUnicode], +]: """ Connect to a TCP Telnet server as a Telnet client. - :param str host: Remote Internet TCP Server host. - :param int port: Remote Internet host TCP port. - :param client_base.BaseClient client_factory: Client connection class - factory. When ``None``, :class:`TelnetTerminalClient` is used when - *stdin* is attached to a terminal, :class:`TelnetClient` otherwise. - :param int family: Same meaning as + :param host: Remote Internet TCP Server host. + :param port: Remote Internet host TCP port. + :param client_factory: Client connection class factory. When ``None``, + :class:`TelnetTerminalClient` is used when *stdin* is attached to a + terminal, :class:`TelnetClient` otherwise. + :param family: Same meaning as :meth:`asyncio.loop.create_connection`. - :param int flags: Same meaning as + :param flags: Same meaning as :meth:`asyncio.loop.create_connection`. - :param tuple local_addr: Same meaning as + :param local_addr: Same meaning as :meth:`asyncio.loop.create_connection`. - :param str encoding: The default assumed encoding, or ``False`` to disable + :param encoding: The default assumed encoding, or ``False`` to disable unicode support. This value is used for decoding bytes received by and encoding bytes transmitted to the Server. These values are preferred in response to NEW_ENVIRON :rfc:`1572` as environment value ``LANG``, and by CHARSET :rfc:`2066` negotiation. The server's attached ``reader, writer`` streams accept and return - unicode, unless this value explicitly set ``False``. In that case, the - attached streams interfaces are bytes-only. - :param str encoding_errors: Same meaning as :meth:`codecs.Codec.encode`. + unicode, unless this value is explicitly set ``False``. In that case, + the attached streams interfaces are bytes-only. + :param encoding_errors: Same meaning as :meth:`codecs.Codec.encode`. - :param str term: Terminal type sent for requests of TTYPE, :rfc:`930` or as + :param term: Terminal type sent for requests of TTYPE, :rfc:`930` or as Environment value TERM by NEW_ENVIRON negotiation, :rfc:`1672`. - :param int cols: Client window dimension sent as Environment value COLUMNS + :param cols: Client window dimension sent as Environment value COLUMNS by NEW_ENVIRON negotiation, :rfc:`1672` or NAWS :rfc:`1073`. - :param int rows: Client window dimension sent as Environment value LINES by + :param rows: Client window dimension sent as Environment value LINES by NEW_ENVIRON negotiation, :rfc:`1672` or NAWS :rfc:`1073`. - :param tuple tspeed: Tuple of client BPS line speed in form ``(rx, tx``) - for receive and transmit, respectively. Sent when requested by TSPEED, - :rfc:`1079`. - :param str xdisploc: String transmitted in response for request of + :param tspeed: Client BPS line speed in form ``(rx, tx)`` for receive and + transmit, respectively. Sent when requested by TSPEED, :rfc:`1079`. + :param xdisploc: String transmitted in response for request of XDISPLOC, :rfc:`1086` by server (X11). - :param float connect_minwait: The client allows any additional telnet + :param connect_minwait: The client allows any additional telnet negotiations to be demanded by the server within this period of time before launching the shell. Servers should assert desired negotiation on-connect and in response to 1 or 2 round trips. A server that does not make any telnet demands, such as a TCP server - that is not a telnet server will delay the execution of ``shell`` for + that is not a telnet server, will delay the execution of ``shell`` for exactly this amount of time. - :param float connect_maxwait: If the remote end is not complaint, or + :param connect_maxwait: If the remote end is not compliant, or otherwise confused by our demands, the shell continues anyway after the greater of this value has elapsed. A client that is not answering option negotiation will delay the start of the shell by this amount. - :param bool force_binary: When ``True``, the encoding is used regardless + :param force_binary: When ``True``, the encoding is used regardless of BINARY mode negotiation. - :param asyncio.Future waiter_closed: Future that completes when the - connection is closed. + :param waiter_closed: Future that completes when the connection is closed. :param shell: An async function that is called after negotiation completes, receiving arguments ``(reader, writer)``. - :param int limit: The buffer limit for reader stream. - :return (reader, writer): The reader is a :class:`~.TelnetReader` - instance, the writer is a :class:`~.TelnetWriter` instance. + :param limit: The buffer limit for reader stream. + :return: The reader is a :class:`~.TelnetReader` instance, the writer is a + :class:`~.TelnetWriter` instance. """ if client_factory is None: client_factory = TelnetClient if sys.platform != "win32" and sys.stdin.isatty(): client_factory = TelnetTerminalClient - def connection_factory(): + def connection_factory() -> client_base.BaseClient: + assert client_factory is not None return client_factory( encoding=encoding, encoding_errors=encoding_errors, @@ -424,11 +477,12 @@ def connection_factory(): waiter_closed=waiter_closed, _waiter_connected=_waiter_connected, limit=limit, + send_environ=send_environ, ) _, protocol = await asyncio.get_event_loop().create_connection( connection_factory, - host, + host or "localhost", port, family=family, flags=flags, @@ -437,10 +491,12 @@ def connection_factory(): await protocol._waiter_connected # pylint: disable=protected-access + assert protocol.reader is not None + assert protocol.writer is not None return protocol.reader, protocol.writer -async def run_client(): +async def run_client() -> None: """Command-line 'telnetlib3-client' entry point, via setuptools.""" args = _transform_args(_get_argument_parser().parse_args()) config_msg = f"Client configuration: {accessories.repr_mapping(args)}" @@ -462,16 +518,19 @@ async def run_client(): "force_binary": args["force_binary"], "encoding_errors": args["encoding_errors"], "connect_minwait": args["connect_minwait"], + "send_environ": args["send_environ"], } # connect _, writer = await open_connection(args["host"], args["port"], **connection_kwargs) # repl loop + assert writer.protocol is not None + assert isinstance(writer.protocol, client_base.BaseClient) await writer.protocol.waiter_closed -def _get_argument_parser(): +def _get_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Telnet protocol client", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -505,10 +564,15 @@ def _get_argument_parser(): type=float, help="timeout for pending negotiation", ) + parser.add_argument( + "--send-environ", + default="TERM,LANG,COLUMNS,LINES,COLORTERM", + help="comma-separated environment variables to send (NEW_ENVIRON)", + ) return parser -def _transform_args(args): +def _transform_args(args: argparse.Namespace) -> Dict[str, Any]: return { "host": args.host, "port": args.port, @@ -522,13 +586,14 @@ def _transform_args(args): "force_binary": args.force_binary, "encoding_errors": args.encoding_errors, "connect_minwait": args.connect_minwait, + "send_environ": tuple(v.strip() for v in args.send_environ.split(",") if v.strip()), } -def main(): +def main() -> None: """Entry point for telnetlib3-client command.""" asyncio.run(run_client()) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/telnetlib3/client_base.py b/telnetlib3/client_base.py index dca0754c..a523bffb 100644 --- a/telnetlib3/client_base.py +++ b/telnetlib3/client_base.py @@ -1,15 +1,20 @@ """Module provides class BaseClient.""" +from __future__ import annotations + # std imports import sys +import types import asyncio import logging import weakref import datetime import traceback import collections +from typing import Any, Type, Union, Callable, Optional, cast # local +from ._types import ShellCallback from .telopt import theNULL, name_commands from .stream_reader import TelnetReader, TelnetReaderUnicode from .stream_writer import TelnetWriter, TelnetWriterUnicode @@ -23,28 +28,28 @@ class BaseClient(asyncio.streams.FlowControlMixin, asyncio.Protocol): """Base Telnet Client Protocol.""" - _when_connected = None - _last_received = None - _transport = None + _when_connected: Optional[datetime.datetime] = None + _last_received: Optional[datetime.datetime] = None + _transport: Optional[asyncio.Transport] = None _closing = False _reader_factory = TelnetReader _reader_factory_encoding = TelnetReaderUnicode _writer_factory = TelnetWriter _writer_factory_encoding = TelnetWriterUnicode - _check_later = None + _check_later: Optional[asyncio.Handle] = None def __init__( # pylint: disable=too-many-positional-arguments self, - shell=None, - encoding="utf8", - encoding_errors="strict", - force_binary=False, - connect_minwait=1.0, - connect_maxwait=4.0, - limit=None, - waiter_closed=None, - _waiter_connected=None, - ): + shell: Optional[ShellCallback] = None, + encoding: Union[str, bool] = "utf8", + encoding_errors: str = "strict", + force_binary: bool = False, + connect_minwait: float = 1.0, + connect_maxwait: float = 4.0, + limit: Optional[int] = None, + waiter_closed: Optional[asyncio.Future[None]] = None, + _waiter_connected: Optional[asyncio.Future[None]] = None, + ) -> None: """Class initializer.""" super().__init__() self.log = logging.getLogger("telnetlib3.client") @@ -53,24 +58,24 @@ def __init__( # pylint: disable=too-many-positional-arguments self.default_encoding = encoding self._encoding_errors = encoding_errors self.force_binary = force_binary - self._extra = {} + self._extra: dict[str, Any] = {} self.waiter_closed = waiter_closed or asyncio.Future() #: a future used for testing self._waiter_connected = _waiter_connected or asyncio.Future() - self._tasks = [] + self._tasks: list[Any] = [] self.shell = shell #: minimum duration for :meth:`check_negotiation`. self.connect_minwait = connect_minwait #: maximum duration for :meth:`check_negotiation`. self.connect_maxwait = connect_maxwait - self.reader = None - self.writer = None + self.reader: Optional[Union[TelnetReader, TelnetReaderUnicode]] = None + self.writer: Optional[Union[TelnetWriter, TelnetWriterUnicode]] = None self._limit = limit # High-throughput receive pipeline - self._rx_queue = collections.deque() + self._rx_queue: collections.deque[bytes] = collections.deque() self._rx_bytes = 0 - self._rx_task = None + self._rx_task: Optional[asyncio.Task[Any]] = None self._reading_paused = False # Apply backpressure to transport when our queue grows too large self._read_high = 512 * 1024 # pause_reading() above this many buffered bytes @@ -78,16 +83,16 @@ def __init__( # pylint: disable=too-many-positional-arguments # Base protocol methods - def eof_received(self): + def eof_received(self) -> None: """Called when the other end calls write_eof() or equivalent.""" self.log.debug("EOF from server, closing.") self.connection_lost(None) - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[Exception]) -> None: """ Called when the connection is lost or closed. - :param Exception exc: exception. ``None`` indicates + :param exc: Exception instance, or ``None`` to indicate a closing EOF sent by this end. """ if self._closing: @@ -95,6 +100,7 @@ def connection_lost(self, exc): self._closing = True # inform yielding readers about closed connection + assert self.reader is not None if exc is None: self.log.info("Connection closed to %s", self) self.reader.feed_eof() @@ -108,6 +114,7 @@ def connection_lost(self, exc): # close transport (may already be closed), set waiter_closed and # cancel Future _waiter_connected. + assert self._transport is not None self._transport.close() if not self._waiter_connected.done(): # strangely, for symmetry, our '_waiter_connected' must be set if @@ -124,21 +131,22 @@ def connection_lost(self, exc): # break circular references. self._transport = None - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Called when a connection is made. Ensure ``super().connection_made(transport)`` is called when derived. """ - self._transport = transport + _transport = cast(asyncio.Transport, transport) + self._transport = _transport self._when_connected = datetime.datetime.now() self._last_received = datetime.datetime.now() - reader_factory = self._reader_factory - writer_factory = self._writer_factory + reader_factory: type[TelnetReader] | type[TelnetReaderUnicode] = self._reader_factory + writer_factory: type[TelnetWriter] | type[TelnetWriterUnicode] = self._writer_factory - reader_kwds = {} - writer_kwds = {} + reader_kwds: dict[str, Any] = {} + writer_kwds: dict[str, Any] = {} if self.default_encoding: reader_kwds["fn_encoding"] = self.encoding @@ -154,13 +162,13 @@ def connection_made(self, transport): self.reader = reader_factory(**reader_kwds) # Attach transport so TelnetReader can apply pause_reading/resume_reading try: - self.reader.set_transport(transport) + self.reader.set_transport(_transport) except Exception: # pylint: disable=broad-exception-caught # Reader may not support transport coupling; ignore. pass self.writer = writer_factory( - transport=transport, + transport=_transport, protocol=self, reader=self.reader, client=True, @@ -172,12 +180,13 @@ def connection_made(self, transport): self._waiter_connected.add_done_callback(self.begin_shell) asyncio.get_event_loop().call_soon(self.begin_negotiation) - def begin_shell(self, future): + def begin_shell(self, future: asyncio.Future[None]) -> None: """Start the shell coroutine after negotiation completes.""" # Don't start shell if the connection was cancelled or errored if future.cancelled() or future.exception() is not None: return if self.shell is not None: + assert self.reader is not None and self.writer is not None coro = self.shell(self.reader, self.writer) if asyncio.iscoroutine(coro): # When a shell is defined as a coroutine, we must ensure @@ -199,7 +208,7 @@ def begin_shell(self, future): ) ) - def data_received(self, data): + def data_received(self, data: bytes) -> None: """ Process bytes received by transport. @@ -219,38 +228,41 @@ def data_received(self, data): # Pause reading if buffered bytes exceed high watermark if not self._reading_paused and self._rx_bytes >= self._read_high: - try: - self._transport.pause_reading() - self._reading_paused = True - except Exception: # pylint: disable=broad-exception-caught - # Some transports may not support pause_reading; ignore. - pass + if self._transport is not None: + try: + self._transport.pause_reading() + self._reading_paused = True + except Exception: # pylint: disable=broad-exception-caught + # Some transports may not support pause_reading; ignore. + pass # public properties @property - def duration(self): + def duration(self) -> float: """Time elapsed since client connected, in seconds as float.""" + assert self._when_connected is not None return (datetime.datetime.now() - self._when_connected).total_seconds() @property - def idle(self): + def idle(self) -> float: """Time elapsed since data last received, in seconds as float.""" + assert self._last_received is not None return (datetime.datetime.now() - self._last_received).total_seconds() # public protocol methods - def __repr__(self): + def __repr__(self) -> str: hostport = self.get_extra_info("peername", ["-", "closing"])[:2] return f"" - def get_extra_info(self, name, default=None): + def get_extra_info(self, name: str, default: Any = None) -> Any: """Get optional client protocol or transport information.""" if self._transport: default = self._transport.get_extra_info(name, default) return self._extra.get(name, default) - def begin_negotiation(self): + def begin_negotiation(self) -> None: """ Begin on-connect negotiation. @@ -264,7 +276,7 @@ def begin_negotiation(self): self._check_later = asyncio.get_event_loop().call_soon(self._check_negotiation_timer) self._tasks.append(self._check_later) - def encoding(self, outgoing=False, incoming=False): + def encoding(self, outgoing: bool = False, incoming: bool = False) -> Union[str, bool]: """ Encoding that should be used for the direction indicated. @@ -274,14 +286,13 @@ def encoding(self, outgoing=False, incoming=False): # pylint: disable=unused-argument return self.default_encoding or "US-ASCII" # pragma: no cover - def check_negotiation(self, final=False): + def check_negotiation(self, final: bool = False) -> bool: """ Callback, return whether negotiation is complete. - :param bool final: Whether this is the final time this callback + :param final: Whether this is the final time this callback will be requested to answer regarding protocol negotiation. :returns: Whether negotiation is over (client end is satisfied). - :rtype: bool Method is called on each new command byte processed until negotiation is considered final, or after :attr:`connect_maxwait` has elapsed, setting @@ -301,6 +312,7 @@ def check_negotiation(self, final=False): from .telopt import TTYPE, CHARSET, NEW_ENVIRON # First check if there are any pending options + assert self.writer is not None if any(self.writer.pending_option.values()): return False @@ -324,11 +336,13 @@ def check_negotiation(self, final=False): # private methods - def _process_chunk(self, data): # pylint: disable=too-many-branches,too-complex + def _process_chunk(self, data: bytes) -> bool: # pylint: disable=too-many-branches,too-complex """Process a chunk of received bytes; return True if any IAC/SB cmd observed.""" # This mirrors the previous optimized logic, but is called from an async task. self._last_received = datetime.datetime.now() + assert self.writer is not None + assert self.reader is not None writer = self.writer reader = self.reader @@ -349,7 +363,7 @@ def _process_chunk(self, data): # pylint: disable=too-many-branches,too-complex n = len(data) i = 0 out_start = 0 - feeding_oob = False + feeding_oob = bool(writer.is_oob) # Build set of special bytes for fast lookup special_bytes = frozenset({255} | (slc_vals or set())) @@ -399,7 +413,7 @@ def _process_chunk(self, data): # pylint: disable=too-many-branches,too-complex return cmd_received - async def _process_rx(self): + async def _process_rx(self) -> None: """Async processor for receive queue that yields control and applies backpressure.""" processed = 0 any_cmd = False @@ -414,11 +428,12 @@ async def _process_rx(self): # Resume reading when we've drained below low watermark if self._reading_paused and self._rx_bytes <= self._read_low: - try: - self._transport.resume_reading() - self._reading_paused = False - except Exception: # pylint: disable=broad-exception-caught - pass + if self._transport is not None: + try: + self._transport.resume_reading() + self._reading_paused = False + except Exception: # pylint: disable=broad-exception-caught + pass # Yield periodically to keep loop responsive without excessive context switching if processed >= 128 * 1024: @@ -430,7 +445,8 @@ async def _process_rx(self): if any_cmd and not self._waiter_connected.done(): self._check_negotiation_timer() - def _check_negotiation_timer(self): + def _check_negotiation_timer(self) -> None: + assert self._check_later is not None self._check_later.cancel() self._tasks.remove(self._check_later) @@ -442,6 +458,7 @@ def _check_negotiation_timer(self): self._waiter_connected.set_result(None) elif final: self.log.debug("negotiation failed after %1.2fs.", self.duration) + assert self.writer is not None _failed = [ name_commands(cmd_option) for (cmd_option, pending) in self.writer.pending_option.items() @@ -462,7 +479,12 @@ def _check_negotiation_timer(self): self._tasks.append(self._check_later) @staticmethod - def _log_exception(logger, e_type, e_value, e_tb): + def _log_exception( + logger: Callable[..., Any], + e_type: Optional[Type[BaseException]], + e_value: Optional[BaseException], + e_tb: Optional[types.TracebackType], + ) -> None: rows_tbk = [line for line in "\n".join(traceback.format_tb(e_tb)).split("\n") if line] rows_exc = [line.rstrip() for line in traceback.format_exception_only(e_type, e_value)] diff --git a/telnetlib3/client_shell.py b/telnetlib3/client_shell.py index 736af739..3d8f8e02 100644 --- a/telnetlib3/client_shell.py +++ b/telnetlib3/client_shell.py @@ -6,16 +6,22 @@ import sys import asyncio import collections +from typing import Any, Tuple, Union, Optional # local from . import accessories +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode __all__ = ("telnet_client_shell",) if sys.platform == "win32": - async def telnet_client_shell(telnet_reader, telnet_writer): + async def telnet_client_shell( + telnet_reader: Union[TelnetReader, TelnetReaderUnicode], + telnet_writer: Union[TelnetWriter, TelnetWriterUnicode], + ) -> None: """Win32 telnet client shell (not implemented).""" raise NotImplementedError("win32 not yet supported as telnet client. Please contribute!") @@ -37,33 +43,37 @@ class Terminal: "ModeDef", ["iflag", "oflag", "cflag", "lflag", "ispeed", "ospeed", "cc"] ) - def __init__(self, telnet_writer): + def __init__( + self, telnet_writer: Union[TelnetWriter, TelnetWriterUnicode] + ) -> None: self.telnet_writer = telnet_writer self._fileno = sys.stdin.fileno() self._istty = os.path.sameopenfile(0, 1) - self._save_mode = None + self._save_mode: Optional[Terminal.ModeDef] = None - def __enter__(self): + def __enter__(self) -> "Terminal": self._save_mode = self.get_mode() if self._istty: + assert self._save_mode is not None self.set_mode(self.determine_mode(self._save_mode)) return self - def __exit__(self, *_): + def __exit__(self, *_: Any) -> None: if self._istty: + assert self._save_mode is not None termios.tcsetattr(self._fileno, termios.TCSAFLUSH, list(self._save_mode)) - def get_mode(self): + def get_mode(self) -> Optional["Terminal.ModeDef"]: """Return current terminal mode if attached to a tty, otherwise None.""" if self._istty: return self.ModeDef(*termios.tcgetattr(self._fileno)) return None - def set_mode(self, mode): + def set_mode(self, mode: "Terminal.ModeDef") -> None: """Set terminal mode attributes.""" termios.tcsetattr(sys.stdin.fileno(), termios.TCSAFLUSH, list(mode)) - def determine_mode(self, mode): + def determine_mode(self, mode: "Terminal.ModeDef") -> "Terminal.ModeDef": """Return copy of 'mode' with changes suggested for telnet connection.""" if not self.telnet_writer.will_echo: # return mode as-is @@ -117,7 +127,9 @@ def determine_mode(self, mode): cc=cc, ) - async def make_stdio(self): + async def make_stdio( + self, + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """Return (reader, writer) pair for sys.stdin, sys.stdout.""" reader = asyncio.StreamReader() reader_protocol = asyncio.StreamReaderProtocol(reader) @@ -146,7 +158,10 @@ async def make_stdio(self): return reader, writer # pylint: disable=too-many-locals,too-many-branches,too-many-statements,too-many-nested-blocks - async def telnet_client_shell(telnet_reader, telnet_writer): + async def telnet_client_shell( + telnet_reader: Union[TelnetReader, TelnetReaderUnicode], + telnet_writer: Union[TelnetWriter, TelnetWriterUnicode], + ) -> None: """ Minimal telnet client shell for POSIX terminals. @@ -169,12 +184,12 @@ async def telnet_client_shell(telnet_reader, telnet_writer): # Setup SIGWINCH handler to send NAWS on terminal resize (POSIX only). # We debounce to avoid flooding on continuous resizes. loop = asyncio.get_event_loop() - winch_pending = {"h": None} + winch_pending: dict[str, Optional[asyncio.TimerHandle]] = {"h": None} remove_winch = False if term._istty: # pylint: disable=protected-access try: - def _send_naws(): + def _send_naws() -> None: # local from .telopt import NAWS # pylint: disable=import-outside-toplevel @@ -187,7 +202,7 @@ def _send_naws(): except Exception: # pylint: disable=broad-exception-caught pass - def _on_winch(): + def _on_winch() -> None: h = winch_pending.get("h") if h is not None and not h.cancelled(): try: diff --git a/telnetlib3/fingerprinting.py b/telnetlib3/fingerprinting.py new file mode 100644 index 00000000..454c655d --- /dev/null +++ b/telnetlib3/fingerprinting.py @@ -0,0 +1,979 @@ +""" +Fingerprint shell for telnet client identification. + +This module probes telnet protocol capabilities, collects session data, +and saves fingerprint files. Display, REPL, and post-script code live +in :mod:`telnetlib3.fingerprinting_display`. +""" + +from __future__ import annotations + +# std imports +import os +import sys +import json +import time +import asyncio +import hashlib +import logging +import datetime +from typing import Any, Dict, List, Tuple, Union, Callable, Optional, cast + +# local +from . import slc +from .telopt import ( + BM, + DO, + DET, + EOR, + RCP, + RSP, + SGA, + TLS, + DONT, + ECHO, + GMCP, + NAMS, + NAOL, + NAOP, + NAWS, + RCTE, + LFLOW, + TTYPE, + X3PAD, + XAUTH, + BINARY, + KERMIT, + NAOCRD, + NAOFFD, + NAOHTD, + NAOHTS, + NAOLFD, + NAOVTD, + NAOVTS, + SNDLOC, + STATUS, + SUPDUP, + TSPEED, + TTYLOC, + CHARSET, + ENCRYPT, + TN3270E, + LINEMODE, + SEND_URL, + XDISPLOC, + FORWARD_X, + SSPI_LOGON, + NEW_ENVIRON, + PRAGMA_LOGON, + SUPDUPOUTPUT, + VT3270REGIME, + AUTHENTICATION, + COM_PORT_OPTION, + PRAGMA_HEARTBEAT, + SUPPRESS_LOCAL_ECHO, + theNULL, +) +from .accessories import encoding_from_lang +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode + +# Data directory for saving fingerprint data - None when unset (no saves) +DATA_DIR: Optional[str] = ( + os.environ["TELNETLIB3_DATA_DIR"] if os.environ.get("TELNETLIB3_DATA_DIR") else None +) + +# Maximum files per protocol-fingerprint folder +FINGERPRINT_MAX_FILES = int(os.environ.get("TELNETLIB3_FINGERPRINT_MAX_FILES", "200")) + +# Maximum number of unique fingerprint folders +FINGERPRINT_MAX_FINGERPRINTS = int( + os.environ.get("TELNETLIB3_FINGERPRINT_MAX_FINGERPRINTS", "1000") +) + +# Post-fingerprint Python module to execute with saved file path +# Example: TELNETLIB3_FINGERPRINT_POST_SCRIPT=telnetlib3.fingerprinting_display +FINGERPRINT_POST_SCRIPT = os.environ.get("TELNETLIB3_FINGERPRINT_POST_SCRIPT", "") + + +# Terminal types that uniquely identify specific telnet clients +PROTOCOL_MATCHED_TERMINALS = { + "syncterm", # SyncTERM BBS client +} + +# Terminal types associated with MUD clients, matched case-insensitively. +# These clients are likely to support extended options like GMCP. +MUD_TERMINALS = { + "mudlet", + "cmud", + "zmud", + "mushclient", + "atlantis", + "tintin++", + "tt++", + "blowtorch", + "mudrammer", + "kildclient", + "portal", + "beip", + "savitar", +} + +__all__ = ( + "fingerprinting_server_shell", + "fingerprinting_post_script", + "get_client_fingerprint", + "probe_client_capabilities", +) + +logger = logging.getLogger("telnetlib3.fingerprint") + +# Timeout for probe_client_capabilities in _run_probe (seconds) +_PROBE_TIMEOUT = 0.5 + +# Telnet options to probe, grouped by category +# Each entry is (option_bytes, name, description) +CORE_OPTIONS = [ + (BINARY, "BINARY", "8-bit binary mode"), + (SGA, "SGA", "Suppress Go Ahead"), + (ECHO, "ECHO", "Echo mode"), + (STATUS, "STATUS", "Option status reporting"), + (TTYPE, "TTYPE", "Terminal type"), + (TSPEED, "TSPEED", "Terminal speed"), + (LFLOW, "LFLOW", "Local flow control"), + (XDISPLOC, "XDISPLOC", "X display location"), + (NAWS, "NAWS", "Window size"), + (NEW_ENVIRON, "NEW_ENVIRON", "Environment variables"), + (CHARSET, "CHARSET", "Character set"), + (LINEMODE, "LINEMODE", "Line mode with SLC"), + (EOR, "EOR", "End of Record"), + # LOGOUT omitted - BSD client times out on this + (SNDLOC, "SNDLOC", "Send location"), +] + +MUD_OPTIONS = [ + (COM_PORT_OPTION, "COM_PORT", "Serial port control (RFC 2217)"), +] + +# Options with non-standard byte values (> 140) that crash some clients. +# icy_term (icy_net) only accepts option bytes 0-49, 138-140, and 255, +# returning a hard error for anything else. GMCP-capable MUD clients +# typically self-announce via IAC WILL GMCP, so probing is unnecessary. +EXTENDED_OPTIONS = [ + (GMCP, "GMCP", "Generic MUD Communication Protocol"), +] + +LEGACY_OPTIONS = [ + (AUTHENTICATION, "AUTHENTICATION", "Telnet authentication"), + (ENCRYPT, "ENCRYPT", "Encryption option"), + (TN3270E, "TN3270E", "3270 terminal emulation"), + (XAUTH, "XAUTH", "X authentication"), + (RSP, "RSP", "Remote serial port"), + (SUPPRESS_LOCAL_ECHO, "SUPPRESS_LOCAL_ECHO", "Local echo suppression"), + (TLS, "TLS", "TLS negotiation"), + (KERMIT, "KERMIT", "Kermit file transfer"), + (SEND_URL, "SEND_URL", "URL sending"), + (FORWARD_X, "FORWARD_X", "X11 forwarding"), + (PRAGMA_LOGON, "PRAGMA_LOGON", "Pragma logon"), + (SSPI_LOGON, "SSPI_LOGON", "SSPI logon"), + (PRAGMA_HEARTBEAT, "PRAGMA_HEARTBEAT", "Heartbeat"), + (X3PAD, "X3PAD", "X.3 PAD"), + (VT3270REGIME, "VT3270REGIME", "VT3270 regime"), + (TTYLOC, "TTYLOC", "Terminal location"), + (SUPDUP, "SUPDUP", "SUPDUP protocol"), + (SUPDUPOUTPUT, "SUPDUPOUTPUT", "SUPDUP output"), + (DET, "DET", "Data entry terminal"), + (BM, "BM", "Byte macro"), + (RCP, "RCP", "Reconnection"), + (NAMS, "NAMS", "NAMS"), + (RCTE, "RCTE", "Remote controlled transmit/echo"), + (NAOL, "NAOL", "Output line width"), + (NAOP, "NAOP", "Output page size"), + (NAOCRD, "NAOCRD", "Output CR disposition"), + (NAOHTS, "NAOHTS", "Output horiz tab stops"), + (NAOHTD, "NAOHTD", "Output horiz tab disposition"), + (NAOFFD, "NAOFFD", "Output formfeed disposition"), + (NAOVTS, "NAOVTS", "Output vert tabstops"), + (NAOVTD, "NAOVTD", "Output vert tab disposition"), + (NAOLFD, "NAOLFD", "Output LF disposition"), +] + +ALL_PROBE_OPTIONS = CORE_OPTIONS + MUD_OPTIONS + LEGACY_OPTIONS + +# All known options including extended, for display/name lookup only +_ALL_KNOWN_OPTIONS = ALL_PROBE_OPTIONS + EXTENDED_OPTIONS + +# Build mapping from hex string (e.g., "0x03") to option name (e.g., "SGA") +_OPT_BYTE_TO_NAME = {f"0x{opt[0]:02x}": name for opt, name, _ in _ALL_KNOWN_OPTIONS} + + +async def probe_client_capabilities( + writer: Union[TelnetWriter, TelnetWriterUnicode], + options: Optional[List[Tuple[bytes, str, str]]] = None, + progress_callback: Optional[Callable[[str, int, int, str], None]] = None, + timeout: float = 0.5, +) -> Dict[str, Dict[str, Any]]: + """ + Actively probe client for telnet capability support. + + Sends IAC DO for ALL options at once, waits for responses, then collects results. + + :param writer: TelnetWriter instance. + :param options: List of (opt_bytes, name, description) tuples to probe. Defaults to + ALL_PROBE_OPTIONS. + :param progress_callback: Optional callback(name, idx, total, status) called during result + collection. + :param timeout: Timeout in seconds to wait for all responses. + :returns: Dict mapping option name to {"status": "WILL"|"WONT"|"timeout", "opt": bytes, + "description": str}. + """ + if options is None: + options = ALL_PROBE_OPTIONS + + results = {} + to_probe = [] + + for opt, name, description in options: + if writer.remote_option.enabled(opt): + results[name] = { + "status": "WILL", + "opt": opt, + "description": description, + "already_negotiated": True, + } + elif writer.remote_option.get(opt) is False: + results[name] = { + "status": "WONT", + "opt": opt, + "description": description, + "already_negotiated": True, + } + else: + to_probe.append((opt, name, description)) + + for opt, name, description in to_probe: + writer.iac(DO, opt) + + await writer.drain() + + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + all_responded = all( + writer.remote_option.get(opt) is not None + for opt, name, desc in to_probe + if name not in results + ) + if all_responded: + break + await asyncio.sleep(0.05) + + for idx, (opt, name, description) in enumerate(to_probe, 1): + if name in results: + continue + + if progress_callback: + progress_callback(name, idx, len(to_probe), "") + + if writer.remote_option.enabled(opt): + results[name] = { + "status": "WILL", + "opt": opt, + "description": description, + } + elif writer.remote_option.get(opt) is False: + results[name] = { + "status": "WONT", + "opt": opt, + "description": description, + } + else: + results[name] = { + "status": "timeout", + "opt": opt, + "description": description, + } + + return results + + +# Keys to collect from extra_info +_EXTRA_INFO_KEYS = ( + "TERM", + "term", + "cols", + "rows", + "COLUMNS", + "LINES", + "charset", + "LANG", + "COLORTERM", + "peername", + "sockname", + "tspeed", + "xdisploc", + "DISPLAY", + "encoding", +) + tuple(f"ttype{n}" for n in range(1, 9)) + + +def get_client_fingerprint( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, Any]: + """ + Collect all available client information from writer. + + :param writer: TelnetWriter instance. + :returns: Dictionary of all negotiated client attributes. + """ + fingerprint = {} + + for key in _EXTRA_INFO_KEYS: + value = writer.get_extra_info(key) + if value is not None and value: + fingerprint[key] = value + + for env_key in ("USER", "SHELL", "HOME", "PATH", "LOGNAME", "MAIL"): + value = writer.get_extra_info(env_key) + if value is not None and value: + fingerprint[env_key] = value + + return fingerprint + + +async def _run_probe( + writer: Union[TelnetWriter, TelnetWriterUnicode], verbose: bool = True +) -> Tuple[Dict[str, Dict[str, Any]], float]: + """Run active probe, optionally extending to MUD options.""" + if _is_maybe_ms_telnet(writer): + probe_options = [opt for opt in CORE_OPTIONS + MUD_OPTIONS if opt[0] != NEW_ENVIRON] + logger.info( + "reduced probe for suspected MS telnet (ttype1=%r, ttype2=%r)", + writer.get_extra_info("ttype1"), + writer.get_extra_info("ttype2"), + ) + else: + probe_options = ALL_PROBE_OPTIONS + + total = len(probe_options) + _writer = cast(TelnetWriterUnicode, writer) + if verbose: + _writer.write(f"\rProbing {total} telnet options...\x1b[J") + await _writer.drain() + + start_time = time.time() + results = await probe_client_capabilities(writer, options=probe_options, timeout=_PROBE_TIMEOUT) + + if _is_maybe_mud(writer) and EXTENDED_OPTIONS: + ext_results = await probe_client_capabilities( + writer, options=EXTENDED_OPTIONS, timeout=_PROBE_TIMEOUT + ) + results.update(ext_results) + + elapsed = time.time() - start_time + + if verbose: + _writer.write("\r\x1b[K") + + return results, elapsed + + +def _get_protocol( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Any: + """Return the protocol object from a writer.""" + return getattr(writer, "_protocol", None) or getattr(writer, "protocol", None) + + +def _opt_byte_to_name(opt: bytes) -> str: + """Convert option bytes to name or hex string.""" + if isinstance(opt, bytes) and len(opt) > 0: + hex_key = f"0x{opt[0]:02x}" + return _OPT_BYTE_TO_NAME.get(hex_key, hex_key) + return str(opt) + + +def _collect_option_states( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, Dict[str, Any]]: + """Collect all telnet option states from writer.""" + options = {} + for label, opt_dict in [("remote", writer.remote_option), ("local", writer.local_option)]: + entries = {_opt_byte_to_name(opt): enabled for opt, enabled in opt_dict.items()} + if entries: + options[label] = entries + return options + + +def _collect_rejected_options( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, List[str]]: + """Collect rejected option offers from writer.""" + result: Dict[str, List[str]] = {} + if getattr(writer, "rejected_will", None): + result["will"] = sorted(_opt_byte_to_name(opt) for opt in writer.rejected_will) + if getattr(writer, "rejected_do", None): + result["do"] = sorted(_opt_byte_to_name(opt) for opt in writer.rejected_do) + return result + + +def _collect_extra_info( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, Any]: + """Collect all extra_info from writer, including private _extra dict.""" + extra: Dict[str, Any] = {} + + protocol = _get_protocol(writer) + if protocol and hasattr(protocol, "_extra"): + for key, value in protocol._extra.items(): # pylint: disable=protected-access + if isinstance(value, tuple): + extra[key] = list(value) + elif isinstance(value, bytes): + extra[key] = value.hex() + else: + extra[key] = value + + # Transport-level keys not in protocol._extra + for key in ("peername", "sockname", "timeout"): + if key not in extra: + if (value := writer.get_extra_info(key)) is not None: + extra[key] = list(value) if isinstance(value, tuple) else value + + # Clean up: prefer uppercase over lowercase redundant keys + if "TERM" in extra and "term" in extra: + del extra["term"] + if "COLUMNS" in extra and "cols" in extra: + del extra["cols"] + if "LINES" in extra and "rows" in extra: + del extra["rows"] + + # Remove ttype1, ttype2, etc. - collected separately in ttype_cycle + for i in range(1, 20): + extra.pop(f"ttype{i}", None) + + return extra + + +def _collect_ttype_cycle( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> List[str]: + """Collect the full TTYPE cycle responses.""" + ttype_list = [] + + protocol = _get_protocol(writer) + extra_dict = getattr(protocol, "_extra", {}) if protocol else {} + + for i in range(1, 20): + if value := (extra_dict.get(f"ttype{i}") or writer.get_extra_info(f"ttype{i}")): + ttype_list.append(value) + else: + break + return ttype_list + + +def _collect_protocol_timing( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, Any]: + """Collect timing information from protocol.""" + timing = {} + protocol = _get_protocol(writer) + if protocol: + if hasattr(protocol, "duration"): + timing["duration"] = protocol.duration + if hasattr(protocol, "idle"): + timing["idle"] = protocol.idle + if hasattr(protocol, "_connect_time"): + timing["connect_time"] = protocol._connect_time # pylint: disable=protected-access + return timing + + +def _collect_slc_tab( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, Any]: + """Collect non-default SLC entries when LINEMODE was negotiated.""" + slctab = getattr(writer, "slctab", None) + if not slctab: + return {} + + if not (hasattr(writer, "remote_option") and writer.remote_option.enabled(LINEMODE)): + return {} + + defaults = slc.generate_slctab(slc.BSD_SLC_TAB) + + result: Dict[str, Any] = {} + slc_set: Dict[str, Any] = {} + slc_unset: list[str] = [] + slc_nosupport: list[str] = [] + + for slc_func, slc_def in slctab.items(): + default_def = defaults.get(slc_func) + if ( + default_def is not None + and slc_def.mask == default_def.mask + and slc_def.val == default_def.val + ): + continue + + name = slc.name_slc_command(slc_func) + if slc_def.nosupport: + slc_nosupport.append(name) + elif slc_def.val == theNULL: + slc_unset.append(name) + else: + slc_set[name] = slc_def.val[0] if isinstance(slc_def.val, bytes) else slc_def.val + + if slc_set: + result["set"] = slc_set + if slc_unset: + result["unset"] = sorted(slc_unset) + if slc_nosupport: + result["nosupport"] = sorted(slc_nosupport) + + return result + + +def _create_protocol_fingerprint( + writer: Union[TelnetWriter, TelnetWriterUnicode], + probe_results: Dict[str, Dict[str, Any]], +) -> Dict[str, Any]: + """ + Create anonymized/summarized protocol fingerprint from session data. + + Fields are only included if negotiated. Environment variables are summarized as "True" (non- + empty value) or "None" (empty string). + + :param writer: TelnetWriter instance. + :param probe_results: Probe results from capability probing. + :returns: Dict with anonymized protocol fingerprint data. + """ + fingerprint: Dict[str, Any] = { + "probed-protocol": "client", + } + + protocol = _get_protocol(writer) + extra_dict = getattr(protocol, "_extra", {}) if protocol else {} + + for key in ("HOME", "USER", "SHELL"): + if key in extra_dict: + fingerprint[key] = "True" if extra_dict[key] else "None" + + # Encoding extracted from LANG + if lang := writer.get_extra_info("LANG"): + encoding = encoding_from_lang(lang) + fingerprint["encoding"] = encoding if encoding else "None" + else: + fingerprint["encoding"] = "None" + + # TERM categorization (inlined) + term = writer.get_extra_info("TERM") or writer.get_extra_info("term") + if not term: + fingerprint["TERM"] = "None" + elif (term_lower := term.lower()) in PROTOCOL_MATCHED_TERMINALS: + fingerprint["TERM"] = term_lower.capitalize() + elif "ansi" in term_lower: + fingerprint["TERM"] = "Yes-ansi" + else: + fingerprint["TERM"] = "Yes" + + charset = writer.get_extra_info("charset") + fingerprint["charset"] = charset if charset else "None" + + ttype_cycle = _collect_ttype_cycle(writer) + fingerprint["ttype-count"] = len(ttype_cycle) + + supported: list[str] = sorted( + [name for name, info in probe_results.items() if info["status"] == "WILL"] + ) + refused: list[str] = sorted( + [name for name, info in probe_results.items() if info["status"] in ("WONT", "timeout")] + ) + fingerprint["supported-options"] = supported + fingerprint["refused-options"] = refused + + rejected = _collect_rejected_options(writer) + if rejected.get("will"): + fingerprint["rejected-will"] = rejected["will"] + if rejected.get("do"): + fingerprint["rejected-do"] = rejected["do"] + + linemode_probed = any( + name == "LINEMODE" and info["status"] == "WILL" for name, info in probe_results.items() + ) + if linemode_probed: + slc_tab = _collect_slc_tab(writer) + if slc_tab: + fingerprint["slc"] = slc_tab + + return fingerprint + + +def _hash_fingerprint(data: Dict[str, Any]) -> str: + """Create deterministic 16-char SHA256 hash of a fingerprint dict.""" + canonical = json.dumps(data, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16] + + +def _count_protocol_folder_files(protocol_dir: str) -> int: + """Count JSON files in protocol fingerprint directory.""" + if not os.path.exists(protocol_dir): + return 0 + return sum(1 for f in os.listdir(protocol_dir) if f.endswith(".json")) + + +def _count_fingerprint_folders(data_dir: Optional[str] = None) -> int: + """Count unique telnet fingerprint folders in ``DATA_DIR/client/``.""" + _dir = data_dir if data_dir is not None else DATA_DIR + if _dir is None: + return 0 + client_dir = os.path.join(_dir, "client") + if not os.path.exists(client_dir): + return 0 + return sum(1 for f in os.listdir(client_dir) if os.path.isdir(os.path.join(client_dir, f))) + + +_UNKNOWN_TERMINAL_HASH = "0" * 16 +AMBIGUOUS_WIDTH_UNKNOWN = -1 + + +def _create_session_fingerprint( + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Dict[str, Any]: + """Create session identity fingerprint from stable client fields.""" + identity: Dict[str, Any] = {} + + if peername := writer.get_extra_info("peername"): + identity["client-ip"] = peername[0] + + if term := (writer.get_extra_info("TERM") or writer.get_extra_info("term")): + identity["TERM"] = term + + for key in ("USER", "HOME", "SHELL", "LANG", "charset"): + if (value := writer.get_extra_info(key)) is not None and value: + identity[key] = value + + return identity + + +def _load_fingerprint_names(data_dir: Optional[str] = None) -> Dict[str, str]: + """Load fingerprint hash-to-name mapping from ``fingerprint_names.json``.""" + _dir = data_dir if data_dir is not None else DATA_DIR + if _dir is None: + return {} + names_file = os.path.join(_dir, "fingerprint_names.json") + if not os.path.exists(names_file): + return {} + with open(names_file, encoding="utf-8") as f: + result: Dict[str, str] = json.load(f) + return result + + +def _resolve_hash_name(hash_val: str, names: Dict[str, str]) -> str: + """Return human-readable name for a hash, falling back to the hash itself.""" + return names.get(hash_val, hash_val) + + +def _validate_suggestion(text: str) -> Optional[str]: + """Validate a user-submitted fingerprint name suggestion.""" + cleaned = text.strip() + if not cleaned: + return None + for c in cleaned: + if ord(c) < 32 or ord(c) == 127: + return None + return cleaned + + +def _cooked_input(prompt: str) -> str: + """Call :func:`input` with echo and canonical mode temporarily enabled.""" + # std imports + import termios # pylint: disable=import-outside-toplevel + + fd = sys.stdin.fileno() + old_attrs = termios.tcgetattr(fd) + new_attrs = list(old_attrs) + new_attrs[3] |= termios.ECHO | termios.ICANON + termios.tcsetattr(fd, termios.TCSANOW, new_attrs) + try: + return input(prompt) + except EOFError: + return "" + finally: + termios.tcsetattr(fd, termios.TCSANOW, old_attrs) + + +def _atomic_json_write(filepath: str, data: Dict[str, Any]) -> None: + """Atomically write JSON data to file via write-to-new + rename.""" + tmp_path = os.path.splitext(filepath)[0] + ".json.new" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, sort_keys=True) + os.replace(tmp_path, filepath) + + +def _build_session_fingerprint( + writer: Union[TelnetWriter, TelnetWriterUnicode], + probe_results: Dict[str, Dict[str, Any]], + probe_time: float, +) -> Dict[str, Any]: + """Build the session fingerprint dict (raw detailed data).""" + extra = _collect_extra_info(writer) + extra.pop("peername", None) + extra.pop("sockname", None) + + ttype_cycle = _collect_ttype_cycle(writer) + option_states = _collect_option_states(writer) + timing = _collect_protocol_timing(writer) + + linemode_probed = probe_results.get("LINEMODE", {}).get("status") + slc_tab = _collect_slc_tab(writer) if linemode_probed == "WILL" else {} + + probe_by_status: Dict[str, Dict[str, int]] = {} + for name, info in probe_results.items(): + status = info["status"] + opt_byte = info["opt"][0] if isinstance(info["opt"], bytes) else info["opt"] + if status not in probe_by_status: + probe_by_status[status] = {} + probe_by_status[status][name] = opt_byte + + timing["probe"] = probe_time + + result = { + "extra": extra, + "ttype_cycle": ttype_cycle, + "option_states": option_states, + "probe": probe_by_status, + "timing": timing, + } + if slc_tab: + result["slc_tab"] = slc_tab + rejected = _collect_rejected_options(writer) + if rejected: + result["rejected"] = rejected + return result + + +def _save_fingerprint_data( # pylint: disable=too-many-locals,too-many-branches,too-complex + writer: Union[TelnetWriter, TelnetWriterUnicode], + probe_results: Dict[str, Dict[str, Any]], + probe_time: float, + session_fp: Optional[Dict[str, Any]] = None, +) -> Optional[str]: + """ + Save comprehensive fingerprint data to a JSON file. + + Creates directory structure: DATA_DIR//uuid4.json + Respects FINGERPRINT_MAX_FILES and FINGERPRINT_MAX_FINGERPRINTS limits. + + :param writer: TelnetWriter instance with full protocol access. + :param probe_results: Probe results from capability probing. + :param probe_time: Time taken for probing. + :param session_fp: Pre-built session fingerprint, or None to build it. + :returns: Path to saved file, or None if save skipped/failed. + """ + if DATA_DIR is None: + return None + if not os.path.isdir(DATA_DIR): + os.makedirs(DATA_DIR, exist_ok=True) + + if session_fp is None: + session_fp = _build_session_fingerprint(writer, probe_results, probe_time) + + protocol_fp = _create_protocol_fingerprint(writer, probe_results) + telnet_hash = _hash_fingerprint(protocol_fp) + + session_identity = _create_session_fingerprint(writer) + session_hash = _hash_fingerprint(session_identity) + + telnet_dir = os.path.join(DATA_DIR, "client", telnet_hash) + probe_dir = None + if os.path.exists(telnet_dir): + for name in os.listdir(telnet_dir): + candidate = os.path.join(telnet_dir, name) + if os.path.isdir(candidate) and name != _UNKNOWN_TERMINAL_HASH: + probe_dir = candidate + break + if probe_dir is None: + probe_dir = os.path.join(telnet_dir, _UNKNOWN_TERMINAL_HASH) + is_new_dir = not os.path.exists(probe_dir) + + if is_new_dir: + if _count_fingerprint_folders() >= FINGERPRINT_MAX_FINGERPRINTS: + logger.warning( + "max fingerprints (%d) exceeded, not saving %s", + FINGERPRINT_MAX_FINGERPRINTS, + telnet_hash, + ) + return None + try: + os.makedirs(probe_dir, exist_ok=True) + except OSError as exc: + logger.warning("failed to create directory %s: %s", probe_dir, exc) + return None + logger.info("new fingerprint %s", telnet_hash) + else: + file_count = _count_protocol_folder_files(probe_dir) + if file_count >= FINGERPRINT_MAX_FILES: + logger.warning( + "fingerprint %s at file limit (%d), not saving", + telnet_hash, + FINGERPRINT_MAX_FILES, + ) + return None + logger.info("connection for fingerprint %s", telnet_hash) + + filepath = os.path.join(probe_dir, f"{session_hash}.json") + + peername = writer.get_extra_info("peername") + now = datetime.datetime.now(datetime.timezone.utc) + session_entry = { + "ip": str(peername[0]) if peername else None, + "connected": now.isoformat(), + } + + if os.path.exists(filepath): + try: + with open(filepath, encoding="utf-8") as f: + data = json.load(f) + data["telnet-probe"]["session_data"] = session_fp + data["sessions"].append(session_entry) + except (OSError, json.JSONDecodeError, KeyError) as exc: + logger.warning("failed to read existing %s: %s", filepath, exc) + data = None + + if data is not None: + try: + _atomic_json_write(filepath, data) + return filepath + except OSError as exc: + logger.warning("failed to update fingerprint: %s", exc) + return None + + data = { + "telnet-probe": { + "fingerprint": telnet_hash, + "fingerprint-data": protocol_fp, + "session_data": session_fp, + }, + "sessions": [session_entry], + } + + try: + _atomic_json_write(filepath, data) + return filepath + except OSError as exc: + logger.warning("failed to save fingerprint: %s", exc) + return None + + +def _is_maybe_mud(writer: Union[TelnetWriter, TelnetWriterUnicode]) -> bool: + """Return whether the client looks like a MUD client.""" + term = (writer.get_extra_info("TERM") or "").lower() + if term in MUD_TERMINALS: + return True + for key in ("ttype1", "ttype2", "ttype3"): + if (writer.get_extra_info(key) or "").lower() in MUD_TERMINALS: + return True + return False + + +def _is_maybe_ms_telnet(writer: Union[TelnetWriter, TelnetWriterUnicode]) -> bool: + """ + Return whether the client looks like Microsoft Windows telnet. + + Microsoft telnet reports ttype1="ANSI", ttype2="VT100", refuses CHARSET, and sends unsolicited + WILL NAWS. The ttype cycle stalls after VT100. Sending a large NEW_ENVIRON sub-negotiation or + a burst of legacy IAC DO commands crashes the client. + + :param writer: TelnetWriter instance. + """ + ttype1 = (writer.get_extra_info("ttype1") or "").upper() + if ttype1 != "ANSI": + return False + ttype2 = (writer.get_extra_info("ttype2") or "").upper() + if ttype2 and ttype2 != "VT100": + return False + return True + + +async def fingerprinting_server_shell( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> None: + """ + Shell that probes client telnet capabilities and runs post-script. + + Immediately probes all telnet options on connect. If DATA_DIR is configured, saves fingerprint + data and runs the post-script through a PTY so it can probe the client's terminal with ucs- + detect. + + :param reader: TelnetReader instance. + :param writer: TelnetWriter instance. + """ + # pylint: disable=import-outside-toplevel + # local + from .server_pty_shell import pty_shell + + writer = cast(TelnetWriterUnicode, writer) + probe_results, probe_time = await _run_probe(writer, verbose=False) + + # Switch syncterm to Topaz (Amiga) font, just for fun why not + if (writer.get_extra_info("TERM") or "").lower() == "syncterm": + writer.write("\x1b[0;40 D") + await writer.drain() + + # Collect fingerprint data BEFORE disabling LINEMODE, so that + # _collect_slc_tab sees remote_option[LINEMODE] as True. + session_fp = _build_session_fingerprint(writer, probe_results, probe_time) + filepath = _save_fingerprint_data(writer, probe_results, probe_time, session_fp) + + # Disable LINEMODE if it was negotiated - stay in kludge mode (SGA+ECHO) + # for PTY shell. LINEMODE causes echo loops with GNU telnet when running + # ucs-detect (client's LIT_ECHO + PTY echo = feedback loop). + if probe_results.get("LINEMODE", {}).get("status") == "WILL": + writer.iac(DONT, LINEMODE) + await writer.drain() + await asyncio.sleep(0.1) + + if filepath is not None: + post_script = FINGERPRINT_POST_SCRIPT or "telnetlib3.fingerprinting_display" + await pty_shell( + reader, + writer, + sys.executable, + ["-W", "ignore::RuntimeWarning:runpy", "-m", post_script, str(filepath)], + raw_mode=True, + ) + else: + writer.close() + + +def fingerprinting_post_script(filepath: str) -> None: + """ + Post-fingerprint script that optionally runs ucs-detect for terminal probing. + + If ucs-detect is available in PATH, runs it to collect terminal capabilities + and merges the results into the fingerprint data. + + Can be used as the TELNETLIB3_FINGERPRINT_POST_SCRIPT target:: + + TELNETLIB3_FINGERPRINT_POST_SCRIPT=telnetlib3.fingerprinting + TELNETLIB3_DATA_DIR=./data + telnetlib3-server --shell fingerprinting_server_shell + + :param filepath: Path to the saved fingerprint JSON file. + """ + # local + # pylint: disable-next=import-outside-toplevel,cyclic-import + from .fingerprinting_display import fingerprinting_post_script as _fps + + _fps(filepath) + + +def main() -> None: + """CLI entry point for fingerprinting post-processing.""" + if len(sys.argv) != 2: + print(f"Usage: python -m {__name__} ", file=sys.stderr) + sys.exit(1) + fingerprinting_post_script(sys.argv[1]) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/telnetlib3/fingerprinting_display.py b/telnetlib3/fingerprinting_display.py new file mode 100644 index 00000000..fca110da --- /dev/null +++ b/telnetlib3/fingerprinting_display.py @@ -0,0 +1,1383 @@ +""" +Display, REPL, and post-script functions for telnet fingerprinting. + +This module contains all terminal display (blessed/prettytable), ucs-detect +integration, and interactive REPL code split from :mod:`fingerprinting`. +""" + +# std imports +import os +import sys +import copy +import json +import random +import shutil +import logging +import termios +import tempfile +import textwrap +import warnings +import functools +import contextlib +import subprocess +from typing import Any, Dict, List, Tuple, Optional, Generator + +# local +from .fingerprinting import ( + DATA_DIR, + _UNKNOWN_TERMINAL_HASH, + AMBIGUOUS_WIDTH_UNKNOWN, + _cooked_input, + _hash_fingerprint, + _atomic_json_write, + _resolve_hash_name, + _validate_suggestion, + _load_fingerprint_names, +) + +__all__ = ("fingerprinting_post_script",) + +logger = logging.getLogger("telnetlib3.fingerprint") + +_BAT = shutil.which("bat") or shutil.which("batcat") +_JQ = shutil.which("jq") + +echo = functools.partial(print, end="", flush=True) + + +def _run_ucs_detect() -> Optional[Dict[str, Any]]: + """Run ucs-detect if available and return terminal fingerprint data.""" + ucs_detect = shutil.which("ucs-detect") + if not ucs_detect: + return None + + patience_msg = random.choice( + [ + "Contemplate the virtue of patience", + "Endure delays with fortitude", + "To wait calmly requires discipline", + "Suspend expectations of imminence", + "The tide hastens for no man", + "Cultivate a stoic calmness", + "The tranquil mind eschews impatience", + "Deliberation is preferable to haste", + ] + ) + echo(f"{patience_msg}...\r\n") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + tmp_path = tmp.name + + try: + try: + result = subprocess.run( + [ + ucs_detect, + "--limit-category-time=1", + "--limit-codepoints=10", + "--timeout-cps=2", + "--limit-errors=2", + "--probe-silently", + "--no-final-summary", + "--no-languages-test", + "--save-json", + tmp_path, + ], + timeout=20, + check=False, + ) + except subprocess.TimeoutExpired: + logger.warning("ucs-detect timed out (client unresponsive to probes)") + return None + + if result.returncode != 0: + return None + + if not os.path.exists(tmp_path): + logger.warning("ucs-detect did not create output file") + return None + + with open(tmp_path, encoding="utf-8") as f: + terminal_data = json.load(f) + + for key in ("python_version", "datetime", "system", "wcwidth_version"): + terminal_data.pop(key, None) + + parsed: Dict[str, Any] = terminal_data + return parsed + + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def _create_terminal_fingerprint(terminal_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Create anonymized terminal fingerprint for hashing. + + Distills static terminal-identity fields from ucs-detect output, excluding session-variable data + (colors, dimensions, timing). + """ + fingerprint: Dict[str, Any] = {} + + results = terminal_data.get("terminal_results", {}) + fingerprint["software_name"] = terminal_data.get("software_name", "unknown") + fingerprint["software_version"] = terminal_data.get("software_version", "unknown") + + fingerprint["number_of_colors"] = results.get("number_of_colors") + fingerprint["sixel"] = results.get("sixel", False) + fingerprint["iterm2_features"] = results.get("iterm2_features", {}) + + fingerprint["kitty_graphics"] = results.get("kitty_graphics", False) + fingerprint["kitty_clipboard_protocol"] = results.get("kitty_clipboard_protocol", False) + fingerprint["kitty_keyboard"] = results.get("kitty_keyboard", {}) + fingerprint["kitty_notifications"] = results.get("kitty_notifications", False) + fingerprint["kitty_pointer_shapes"] = results.get("kitty_pointer_shapes", False) + + fingerprint["text_sizing"] = results.get("text_sizing", {}) + + da = results.get("device_attributes", {}) + fingerprint["da_service_class"] = da.get("service_class") + fingerprint["da_extensions"] = sorted(da.get("extensions", [])) + + raw_modes = results.get("modes", {}) + distilled_modes = {} + for mode_num, mode_data in sorted(raw_modes.items(), key=lambda x: int(x[0])): + if isinstance(mode_data, dict): + distilled_modes[str(mode_num)] = { + "supported": mode_data.get("supported", False), + "changeable": mode_data.get("changeable", False), + "enabled": mode_data.get("enabled", False), + "value": mode_data.get("value", 0), + } + fingerprint["modes"] = distilled_modes + + fingerprint["xtgettcap"] = results.get("xtgettcap", {}) + fingerprint["ambiguous_width"] = terminal_data.get("ambiguous_width") + + raw_test_results = terminal_data.get("test_results", {}) + distilled_tests = {} + for category, versions in raw_test_results.items(): + if not versions or not isinstance(versions, dict): + continue + for ver, entry in versions.items(): + if isinstance(entry, dict): + distilled_tests[category] = { + "unicode_version": ver, + "n_errors": entry.get("n_errors", 0), + "n_total": entry.get("n_total", 0), + } + break + if distilled_tests: + fingerprint["test_results"] = distilled_tests + + return fingerprint + + +def _wrap_options(options: List[str], max_width: int = 30) -> str: + """Word-wrap a list of options to fit within max_width.""" + if not options: + return "" + return "\n".join(textwrap.wrap(", ".join(options), width=max_width)) + + +def _color_yes_no(term: Any, value: bool) -> str: + """Apply green/red coloring to boolean value.""" + if value: + return str(term.forestgreen("Yes")) + return str(term.firebrick1("No")) + + +def _format_ttype( + extra: Dict[str, Any], session_data: Dict[str, Any], wrap_width: int = 30 +) -> Optional[str]: + """Format terminal type from TTYPE cycle for compact display.""" + ttype_cycle = session_data.get("ttype_cycle", []) + term_type = extra.get("TERM") or extra.get("term") + if not term_type and not ttype_cycle: + return None + primary = ttype_cycle[0] if ttype_cycle else term_type + primary_lower = primary.lower() if primary else "" + others = [] + seen = {primary_lower} + for ttype_val in ttype_cycle[1:]: + t_lower = ttype_val.lower() + if t_lower not in seen: + seen.add(t_lower) + others.append(t_lower) + type_str = primary or "" + if others: + suffix = ", ".join(others) + if len(type_str) + len(suffix) + 3 > wrap_width: + wrapped = "\n".join(textwrap.wrap(suffix, width=wrap_width - 2)) + type_str += f" ({wrapped})" + else: + type_str += f" ({suffix})" + return type_str + + +def _is_utf8_charset(value: str) -> bool: + """Test whether a charset or encoding string refers to UTF-8.""" + return value.lower().replace("-", "").replace("_", "") in ( + "utf8", + "unicode11utf8", + ) + + +def _format_encoding( + extra: Dict[str, Any], + proto_data: Dict[str, Any], + ambiguous_width: Optional[int] = None, +) -> Optional[Tuple[str, str]]: + """Consolidate LANG, charset, and encoding into a single key-value pair.""" + lang_val = extra.get("LANG") + charset_val = extra.get("charset") + encoding_val = proto_data.get("encoding") + + no_unicode = ambiguous_width == AMBIGUOUS_WIDTH_UNKNOWN + + if charset_val and no_unicode and _is_utf8_charset(charset_val): + charset_val = "unknown (ascii-only)" + + if lang_val and charset_val: + return ("LANG (Charset)", f"{lang_val} ({charset_val})") + if lang_val: + return ("LANG", lang_val) + if charset_val: + return ("Charset", charset_val) + if encoding_val and encoding_val != "None": + return ("Encoding", encoding_val) + return None + + +# pylint: disable-next=too-complex,too-many-locals,too-many-branches,too-many-statements +def _build_terminal_rows(term: Any, data: Dict[str, Any]) -> List[Tuple[str, str]]: + """Build (key, value) tuples for terminal capabilities table.""" + pairs: List[Tuple[str, str]] = [] + terminal_probe = data.get("terminal-probe", {}) + terminal_data = terminal_probe.get("session_data", {}) + terminal_results = terminal_data.get("terminal_results", {}) + if not terminal_data: + return pairs + + if fp_hash := terminal_probe.get("fingerprint"): + pairs.append(("Fingerprint", fp_hash)) + + if software := terminal_data.get("software_name"): + if ver := terminal_data.get("software_version"): + software += f" {ver}" + if len(software) > 15: + software = software[:14] + ("\u2026" if _has_unicode(data) else "..") + pairs.append(("Software", software)) + + telnet_probe = data.get("telnet-probe", {}) + session_data = telnet_probe.get("session_data", {}) + extra = session_data.get("extra", {}) + cols = extra.get("cols") or extra.get("COLUMNS") + rows = extra.get("rows") or extra.get("LINES") + if cols and rows: + size_str = f"{cols}x{rows}" + cell_w = terminal_results.get("cell_width") + cell_h = terminal_results.get("cell_height") + if cell_w and cell_h: + size_str += f" (*{cell_w}x{cell_h})" + pairs.append(("Size", size_str)) + + if (n_colors := terminal_results.get("number_of_colors")) is not None: + if n_colors >= 16777216: + color_str = term.forestgreen("24-bit") + elif n_colors <= 256: + color_str = term.firebrick1(f"{n_colors}") + else: + color_str = term.darkorange(f"{n_colors}") + pairs.append(("Colors", color_str)) + + has_fg = terminal_results.get("foreground_color_hex") is not None + has_bg = terminal_results.get("background_color_hex") is not None + if has_fg or has_bg: + pairs.append(("fg/bg colors", _color_yes_no(term, has_fg and has_bg))) + + has_kitty_gfx = terminal_results.get("kitty_graphics", False) + has_iterm2_gfx = (terminal_results.get("iterm2_features") or {}).get("supported", False) + has_sixel = terminal_results.get("sixel", False) + if has_kitty_gfx or has_iterm2_gfx: + protocols = [] + if has_kitty_gfx: + protocols.append("Kitty") + if has_iterm2_gfx: + protocols.append("iTerm2") + if has_sixel: + protocols.append("Sixel") + pairs.append(("Graphics", term.forestgreen(", ".join(protocols)))) + elif has_sixel: + pairs.append(("Graphics", term.darkorange("Sixel"))) + elif any(k in terminal_results for k in ("sixel", "kitty_graphics", "iterm2_features")): + pairs.append(("Graphics", term.firebrick1("No"))) + + if da := terminal_results.get("device_attributes"): + if (sc := da.get("service_class")) is not None: + class_names = { + 1: "VT100", + 2: "VT200", + 18: "VT330", + 41: "VT420", + 61: "VT500", + 62: "VT500", + 64: "VT500", + 65: "VT500", + } + pairs.append(("Device Class", class_names.get(sc, f"Class {sc}"))) + + screen_ratio = terminal_results.get("screen_ratio") + if screen_ratio: + ratio_name = terminal_results.get("screen_ratio_name", "") + if ratio_name: + pairs.append(("Aspect Ratio", f"{screen_ratio} ({ratio_name})")) + else: + pairs.append(("Aspect Ratio", screen_ratio)) + + ambiguous_width = terminal_data.get("ambiguous_width") + if ambiguous_width == 2: + pairs.append(("Ambiguous Width", "wide (2)")) + + modes = terminal_results.get("modes", {}) + mode_2027 = modes.get(2027, modes.get("2027")) + if mode_2027 is not None: + gc_value = _color_yes_no(term, mode_2027.get("supported")) + pairs.append(("Graphemes(2027)", gc_value)) + elif modes: + pairs.append(("Graphemes(2027)", term.darkorange("N/A"))) + + test_results = terminal_data.get("test_results", {}) + _emoji_keys = ( + "unicode_wide_results", + "emoji_zwj_results", + "emoji_vs16_results", + "emoji_vs15_results", + ) + all_pcts = [] + for key in _emoji_keys: + for entry in test_results.get(key, {}).values(): + if (pct := entry.get("pct_success")) is not None: + all_pcts.append(pct) + if all_pcts: + avg = sum(all_pcts) / len(all_pcts) + if avg >= 99.0: + pairs.append(("Emoji", term.forestgreen("Yes"))) + elif avg >= 33.3: + pairs.append(("Emoji", term.darkorange("Partial"))) + else: + pairs.append(("Emoji", term.firebrick1("No"))) + + return pairs + + +def _build_telnet_rows( # pylint: disable=too-many-locals,unused-argument + term: Any, data: Dict[str, Any] +) -> List[Tuple[str, str]]: + """Build (key, value) tuples for telnet protocol table.""" + pairs: List[Tuple[str, str]] = [] + telnet_probe = data.get("telnet-probe", {}) + proto_data = telnet_probe.get("fingerprint-data", {}) + session_data = telnet_probe.get("session_data", {}) + extra = session_data.get("extra", {}) + + if fp_hash := telnet_probe.get("fingerprint"): + pairs.append(("Fingerprint", fp_hash)) + + wrap_width = 30 + if type_str := _format_ttype(extra, session_data, wrap_width): + pairs.append(("Terminal Type", type_str)) + + terminal_probe = data.get("terminal-probe", {}) + aw = terminal_probe.get("session_data", {}).get("ambiguous_width") + if encoding_pair := _format_encoding(extra, proto_data, aw): + pairs.append(encoding_pair) + + if supported := proto_data.get("supported-options"): + pairs.append(("Options", _wrap_options(supported, wrap_width))) + + if rejected_will := proto_data.get("rejected-will"): + pairs.append( + ( + "Rejected", + _wrap_options(rejected_will, wrap_width), + ) + ) + + slc_tab = session_data.get("slc_tab", {}) + if slc_tab: + slc_set = slc_tab.get("set", {}) + slc_unset = slc_tab.get("unset", []) + slc_nosupport = slc_tab.get("nosupport", []) + parts = [] + if slc_set: + parts.append(f"{len(slc_set)} set") + if slc_unset: + parts.append(f"{len(slc_unset)} unset") + if slc_nosupport: + parts.append(f"{len(slc_nosupport)} nosupport") + if parts: + pairs.append(("SLC", ", ".join(parts))) + + env_vars = [] + for key in ("USER", "HOME", "SHELL"): + if proto_data.get(key) == "True": + env_vars.append(key) + if env_vars: + pairs.append(("Environment", ", ".join(env_vars))) + + if tspeed := extra.get("tspeed"): + pairs.append(("Speed", tspeed)) + + return pairs + + +def _make_terminal(**kwargs: Any) -> Any: + """Create a blessed Terminal, falling back to ``ansi`` on setupterm failure.""" + # 3rd party + from blessed import Terminal # pylint: disable=import-outside-toplevel,import-error + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + term = Terminal(**kwargs) + if any("setupterm" in str(w.message) for w in caught): + kwargs["kind"] = "ansi" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + term = Terminal(**kwargs) + return term + + +@contextlib.contextmanager +def _disable_isig() -> Generator[None, None, None]: + """Disable ``ISIG`` so that ``^C`` and ``^Z`` are ignored.""" + fd = sys.stdin.fileno() + old = termios.tcgetattr(fd) + new = list(old) + new[3] &= ~termios.ISIG + termios.tcsetattr(fd, termios.TCSANOW, new) + try: + yield + finally: + termios.tcsetattr(fd, termios.TCSANOW, old) + + +def _has_unicode(data: Dict[str, Any]) -> bool: + """Return whether the terminal supports unicode rendering.""" + aw = ( + data.get("terminal-probe", {}) + .get("session_data", {}) + .get("ambiguous_width", AMBIGUOUS_WIDTH_UNKNOWN) + ) + return bool(aw >= 1) + + +def _sync_timeout(data: Dict[str, Any]) -> float: + """Return synchronized output timeout based on measured RTT.""" + cps = data.get("terminal-probe", {}).get("session_data", {}).get("cps_summary", {}) + if (rtt_max := cps.get("rtt_max_ms")) and rtt_max > 0: + return float(rtt_max * 1.1 / 1000.0) + return 1.0 + + +def _setup_term_environ(data: Dict[str, Any]) -> None: + """ + Set ``TERM`` and ``COLORTERM`` based on probe data. + + Overrides ``TERM`` to ``ansi`` for Microsoft telnet clients whose + ``vtnt`` terminfo contains ``$`` padding sequences displayed as + literal text. Sets ``COLORTERM=truecolor`` when 24-bit color was + confirmed by the terminal probe, removes it otherwise to prevent the + server's own stale value from leaking through. + """ + ttype_cycle = data.get("telnet-probe", {}).get("session_data", {}).get("ttype_cycle", []) + # Microsoft telnet cycles ANSI -> VT100 -> VT52 -> VTNT -> VTNT. + # The "vtnt" terminfo contains $ padding sequences that MS telnet + # displays as literal text. Override to "ansi" which has no padding. + if ttype_cycle == ["ANSI", "VT100", "VT52", "VTNT", "VTNT"]: + os.environ["TERM"] = "ansi" + + if _has_truecolor(data): + os.environ["COLORTERM"] = "truecolor" + else: + os.environ.pop("COLORTERM", None) + + +def _has_truecolor(data: Dict[str, Any]) -> bool: + """Return whether the terminal supports 24-bit color.""" + n = ( + data.get("terminal-probe", {}) + .get("session_data", {}) + .get("terminal_results", {}) + .get("number_of_colors") + ) + return n is not None and n >= 16777216 + + +def _hotkey(term: Any, key: str) -> str: + """Format a hotkey as ``key-`` with key and dash in magenta.""" + return f"{term.bold_magenta(key)}{term.bold_magenta('-')}" + + +def _bracket_key(term: Any, key: str) -> str: + """Format a hotkey as ``[key]`` with brackets in cyan, key in magenta.""" + return f"{term.cyan('[')}{term.bold_magenta(key)}{term.cyan(']')}" + + +def _apply_unicode_borders(tbl: Any) -> None: + """Apply double-line box-drawing characters to a PrettyTable.""" + tbl.horizontal_char = "\u2550" + tbl.vertical_char = "\u2551" + tbl.junction_char = "\u256c" + tbl.top_junction_char = "\u2566" + tbl.bottom_junction_char = "\u2569" + tbl.left_junction_char = "\u2560" + tbl.right_junction_char = "\u2563" + tbl.top_left_junction_char = "\u2554" + tbl.top_right_junction_char = "\u2557" + tbl.bottom_left_junction_char = "\u255a" + tbl.bottom_right_junction_char = "\u255d" + + +def _display_compact_summary( # pylint: disable=too-complex,too-many-branches + data: Dict[str, Any], term: Any = None +) -> bool: + """Display compact fingerprint summary using prettytable.""" + try: + # 3rd party + from ucs_detect import ( # pylint: disable=import-outside-toplevel + _collect_side_by_side_lines, + ) + from prettytable import PrettyTable # pylint: disable=import-outside-toplevel + except ImportError: + return False + + if term is None: + term = _make_terminal() + + has_unicode = _has_unicode(data) + + def make_table(title: str, pairs: List[Tuple[str, str]]) -> str: + tbl = PrettyTable() + if has_unicode: + _apply_unicode_borders(tbl) + tbl.title = term.magenta(title) + tbl.field_names = ["Attribute", "Value"] + tbl.align["Attribute"] = "r" + tbl.align["Value"] = "l" + tbl.header = False + tbl.max_table_width = max(40, (term.width or 80) - 1) + for key, value in pairs: + tbl.add_row([key or "", value]) + return str(tbl) + + table_strings = [] + + terminal_rows = _build_terminal_rows(term, data) + if terminal_rows: + table_strings.append(make_table("Terminal", terminal_rows)) + + telnet_rows = _build_telnet_rows(term, data) + if telnet_rows: + table_strings.append(make_table("Telnet", telnet_rows)) + + if not table_strings: + return False + + timeout = _sync_timeout(data) + + echo(term.normal) + + widths = [len(s.split("\n", 1)[0]) for s in table_strings] + side_by_side = len(widths) < 2 or sum(widths) + 1 < (term.width or 80) + + if side_by_side: + all_lines = _collect_side_by_side_lines(term, table_strings) + if has_unicode: + with term.synchronized_output(timeout=timeout): + for line in all_lines: + echo(line + "\n") + else: + for line in all_lines: + echo(line + "\n") + else: + for tbl in table_strings: + lines = tbl.split("\n") + if has_unicode: + with term.synchronized_output(timeout=timeout): + for line in lines: + echo(line + "\n") + echo("\n") + else: + for line in lines: + echo(line + "\n") + echo("\n") + return True + + +def _fingerprint_similarity(a: Dict[str, Any], b: Dict[str, Any]) -> float: + """ + Compute field-by-field similarity score between two fingerprint dicts. + + :returns: Similarity as a float 0.0-1.0. + """ + _skip = {"probed-protocol"} + all_keys = (set(a) | set(b)) - _skip + if not all_keys: + return 1.0 + + scores: List[float] = [] + for key in all_keys: + va, vb = a.get(key), b.get(key) + if va is None and vb is None: + continue + if va is None or vb is None: + scores.append(0.0) + continue + if va == vb: + scores.append(1.0) + continue + if isinstance(va, list) and isinstance(vb, list): + sa, sb = set(map(str, va)), set(map(str, vb)) + union = sa | sb + scores.append(len(sa & sb) / len(union) if union else 1.0) + elif isinstance(va, dict) and isinstance(vb, dict): + scores.append(_fingerprint_similarity(va, vb)) + else: + scores.append(0.0) + + return sum(scores) / len(scores) if scores else 1.0 + + +def _load_known_fingerprints( # pylint: disable=too-complex + probe_type: str, +) -> Dict[str, Dict[str, Any]]: + """ + Load one fingerprint-data dict per unique hash from the data directory. + + :param probe_type: ``"telnet-probe"`` or ``"terminal-probe"``. + :returns: Dict mapping hash string to fingerprint-data dict. + """ + if DATA_DIR is None: + return {} + client_dir = os.path.join(DATA_DIR, "client") + if not os.path.isdir(client_dir): + return {} + + seen: Dict[str, Dict[str, Any]] = {} + is_telnet = probe_type == "telnet-probe" + + for telnet_hash in os.listdir(client_dir): + telnet_path = os.path.join(client_dir, telnet_hash) + if not os.path.isdir(telnet_path): + continue + if is_telnet and telnet_hash in seen: + continue + for terminal_hash in os.listdir(telnet_path): + if terminal_hash == _UNKNOWN_TERMINAL_HASH: + continue + terminal_path = os.path.join(telnet_path, terminal_hash) + if not os.path.isdir(terminal_path): + continue + if not is_telnet and terminal_hash in seen: + continue + target_hash = telnet_hash if is_telnet else terminal_hash + if target_hash in seen: + continue + for fname in os.listdir(terminal_path): + if not fname.endswith(".json"): + continue + try: + with open(os.path.join(terminal_path, fname), encoding="utf-8") as f: + file_data = json.load(f) + fp_data = file_data.get(probe_type, {}).get("fingerprint-data") + if fp_data: + seen[target_hash] = fp_data + except (OSError, json.JSONDecodeError, KeyError): + pass + break + return seen + + +def _find_nearest_match( + fp_data: Dict[str, Any], + probe_type: str, + names: Dict[str, str], +) -> Optional[Tuple[str, float]]: + """ + Find the most similar named fingerprint. + + :returns: ``(name, similarity)`` tuple or None if no candidates or best < 50%. + """ + known = _load_known_fingerprints(probe_type) + best_name: Optional[str] = None + best_score = 0.0 + for h, known_fp in known.items(): + if h not in names: + continue + score = _fingerprint_similarity(fp_data, known_fp) + if score > best_score: + best_score = score + best_name = names[h] + if best_name is None or best_score < 0.50: + return None + return (best_name, best_score) + + +def _build_seen_counts( # pylint: disable=too-many-locals + data: Dict[str, Any], + names: Optional[Dict[str, str]] = None, + term: Any = None, +) -> str: + """Build friendly "seen before" text from folder and session counts.""" + if DATA_DIR is None or not os.path.exists(DATA_DIR): + return "" + + telnet_probe = data.get("telnet-probe", {}) + if not (telnet_hash := telnet_probe.get("fingerprint")): + return "" + + terminal_probe = data.get("terminal-probe", {}) + terminal_hash = terminal_probe.get("fingerprint", _UNKNOWN_TERMINAL_HASH) + + _names = names or {} + telnet_name = _resolve_hash_name(telnet_hash, _names) + terminal_known = terminal_hash != _UNKNOWN_TERMINAL_HASH + terminal_name = _resolve_hash_name(terminal_hash, _names) if terminal_known else None + + if term is not None: + g = term.forestgreen + telnet_name = g(telnet_name) + if terminal_name is not None: + terminal_name = g(terminal_name) + + telnet_dir = os.path.join(DATA_DIR, "client", telnet_hash) + like_count = 0 + if os.path.isdir(telnet_dir): + for sub in os.listdir(telnet_dir): + sub_path = os.path.join(telnet_dir, sub) + if os.path.isdir(sub_path): + like_count += sum(1 for f in os.listdir(sub_path) if f.endswith(".json")) + + visit_count = len(data.get("sessions", [])) + + extra = telnet_probe.get("session_data", {}).get("extra", {}) + username = extra.get("USER") or extra.get("LOGNAME") + + lines: List[str] = [] + if like_count > 1: + others = like_count - 1 + noun = "client" if others == 1 else "clients" + lines.append(f"I've seen {others} other {noun} with your configuration.") + if visit_count > 1: + times = "time" if visit_count - 1 == 1 else "times" + lines.append(f"I've seen your exact fingerprint {visit_count - 1} {times} before.") + + who = f" {username}" if username else "" + terminal_suffix = ( + f" and {terminal_name}" if terminal_name and terminal_name != telnet_name else "" + ) + if visit_count <= 1: + lines.append(f"Welcome{who}! Detected {telnet_name}{terminal_suffix}.") + else: + lines.append(f"Welcome back{who}! Detected {telnet_name}{terminal_suffix}.") + + telnet_unknown = telnet_hash not in _names + terminal_unknown = terminal_known and terminal_hash not in _names + if (telnet_unknown or terminal_unknown) and _names: + match_lines = _nearest_match_lines( + data, + _names, + term, + telnet_unknown=telnet_unknown, + terminal_unknown=terminal_unknown, + ) + if match_lines: + lines.extend(match_lines) + + if lines: + return "\n".join(lines) + "\n\n" + return "" + + +def _color_match(term: Any, name: str, score: float) -> str: + """ + Color a nearest-match result by confidence threshold. + + :param score: Similarity as a float 0.0-1.0. + """ + pct = score * 100 + label = f"{name} ({pct:.0f}%)" + if term is None: + return label + if pct >= 95: + return str(term.forestgreen(label)) + if pct >= 75: + return str(term.darkorange(label)) + return str(term.firebrick1(label)) + + +def _nearest_match_lines( + data: Dict[str, Any], + names: Dict[str, str], + term: Any, + telnet_unknown: bool = False, + terminal_unknown: bool = False, +) -> List[str]: + """Build nearest-match text lines for unknown fingerprints.""" + result_lines: List[str] = [] + if telnet_unknown: + fp_data = data.get("telnet-probe", {}).get("fingerprint-data") + if fp_data: + result = _find_nearest_match(fp_data, "telnet-probe", names) + if result: + result_lines.append(f"Nearest telnet match: {_color_match(term, *result)}") + else: + result_lines.append("Nearest telnet match: (none)") + + if terminal_unknown: + fp_data = data.get("terminal-probe", {}).get("fingerprint-data") + if fp_data: + result = _find_nearest_match(fp_data, "terminal-probe", names) + if result: + result_lines.append(f"Nearest terminal match: {_color_match(term, *result)}") + else: + result_lines.append("Nearest terminal match: (none)") + return result_lines + + +def _repl_prompt(term: Any) -> None: + """Write the REPL prompt with hotkey legend.""" + bk = _bracket_key + legend = ( + f"{bk(term, 't')}erminal or te{bk(term, 'l')}net details, " # codespell:ignore te + f"{bk(term, 's')}ummarize or {bk(term, 'u')}pdate database: " + ) + echo(f"\r{term.clear_eos}{term.normal}{legend}") + + +def _paginate(term: Any, text: str, **_kw: Any) -> None: # pylint: disable=unused-argument + """Display text.""" + for line in text.split("\n"): + echo(line + "\n") + + +def _colorize_json(data: Any, term: Any = None) -> str: + """ + Format JSON with color, preferring bat/batcat over jq. + + :param term: blessed Terminal instance for ``TERM`` kind. + """ + json_str = json.dumps(data, indent=2, sort_keys=True) + if _BAT: + result = subprocess.run( + [_BAT, "-l", "json", "--style=plain", "--color=always"], + input=json_str, + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + return result.stdout.rstrip("\n") + if _JQ: + env = { + "TERM": getattr(term, "kind", None) or "dumb", + "COLUMNS": str(term.width or 80), + "LINES": str(term.height or 25), + } + if term.number_of_colors == 1 << 24: + env["COLORTERM"] = "truecolor" + result = subprocess.run( + [_JQ, "-C", 'walk(if type=="number" then (.*100|round)/100 else . end)'], + input=json_str, + capture_output=True, + text=True, + env=env, + check=False, + ) + if result.returncode == 0: + return result.stdout.rstrip("\n") + return json_str + + +def _strip_empty_features(d: Dict[str, Any]) -> None: + """Remove empty kitty/iterm2 feature keys from a dict in-place.""" + for key in list(d): + if key.startswith(("kitty_", "iterm2_")): + val = d[key] + if not val or (isinstance(val, dict) and not any(val.values())): + del d[key] + + +def _normalize_color_hex(hex_color: str) -> str: + """Normalize X11 color hex to standard 6-digit format.""" + # 3rd party + from blessed.colorspace import ( # pylint: disable=import-outside-toplevel,import-error + hex_to_rgb, + rgb_to_hex, + ) + + r, g, b = hex_to_rgb(hex_color) + return rgb_to_hex(r, g, b) + + +def _filter_terminal_detail( # pylint: disable=too-complex,too-many-branches + detail: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + """Filter terminal session data for display.""" + if not detail: + return detail + result = dict(detail) + + for key in ("session_arguments", "height", "width"): + result.pop(key, None) + + aw = result.get("ambiguous_width") + if aw is not None and aw != 2: + del result["ambiguous_width"] + + if "text_sizing" in result: + result["kitty_text_sizing"] = result.pop("text_sizing") + + _strip_empty_features(result) + + terminal_results = result.get("terminal_results") + if terminal_results is not None: + terminal_results = dict(terminal_results) + if "text_sizing" in terminal_results: + terminal_results["kitty_text_sizing"] = terminal_results.pop("text_sizing") + for key in ("foreground_color_rgb", "background_color_rgb"): + terminal_results.pop(key, None) + _strip_empty_features(terminal_results) + modes = terminal_results.pop("modes", None) + if modes: + dec_modes = {} + for _num, mode in modes.items(): + if isinstance(mode, dict) and mode.get("supported"): + name = mode.get("mode_name", str(_num)) + dec_modes[name] = { + "changeable": mode.get("changeable", False), + "enabled": mode.get("enabled", False), + } + if dec_modes: + terminal_results["dec_private_modes"] = dec_modes + for key in ("foreground_color_hex", "background_color_hex"): + if key in terminal_results: + terminal_results[key] = _normalize_color_hex(terminal_results[key]) + result["terminal_results"] = terminal_results + + test_results = result.get("test_results") + if test_results is not None: + filtered = {} + for k, v in test_results.items(): + if not v: + continue + if isinstance(v, dict): + reduced = {} + for ver, data in v.items(): + if isinstance(data, dict): + reduced[ver] = { + sk: sv for sk, sv in data.items() if sk in ("pct_success", "n_total") + } + else: + reduced[ver] = data + if reduced: + filtered[k] = reduced + else: + filtered[k] = v + if filtered: + result["test_results"] = filtered + else: + del result["test_results"] + return result + + +def _filter_telnet_detail( + detail: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + """Filter telnet probe data for display.""" + if not detail: + return detail + result = copy.deepcopy(detail) + + if session_data := result.get("session_data"): + for key in ("probe", "option_states"): + session_data.pop(key, None) + + if fp_data := result.get("fingerprint-data"): + fp_data.pop("refused-options", None) + + return result + + +def _show_detail(term: Any, data: Dict[str, Any], section: str) -> None: + """Show detailed JSON for a fingerprint section with pagination.""" + if section == "terminal": + terminal_probe = data.get("terminal-probe", {}) + detail = _filter_terminal_detail(terminal_probe.get("session_data")) + title = "Terminal Probe Results" + else: + detail = _filter_telnet_detail(data.get("telnet-probe")) + title = "Telnet Probe Data" + + underline = term.cyan("=" * len(title)) + if detail: + text = f"{term.magenta(title)}\n" f"{underline}\n" f"\n" f"{_colorize_json(detail, term)}" + _paginate(term, text) + else: + echo(f"{term.magenta(title)}\n{underline}\n\n(no data)\n") + + +def _client_ip(data: Dict[str, Any]) -> str: + """Extract client IP from fingerprint data.""" + sessions = data.get("sessions", []) + if sessions: + ip = sessions[-1].get("ip") + if ip: + return str(ip) + return "unknown" + + +def _build_database_entries( # pylint: disable=too-many-locals + names: Optional[Dict[str, str]] = None, +) -> List[Tuple[str, str, int, int]]: + """ + Scan fingerprint directories and build sorted database entries. + + :returns: List of ``(type, display_name, file_count, session_count)`` tuples sorted by session + count descending. + """ + client_dir = os.path.join(DATA_DIR, "client") if DATA_DIR else None + if not client_dir or not os.path.isdir(client_dir): + return [] + + _names = names or {} + telnet_counts: Dict[str, List[int]] = {} + terminal_counts: Dict[str, List[int]] = {} + for telnet_hash in os.listdir(client_dir): + telnet_path = os.path.join(client_dir, telnet_hash) + if not os.path.isdir(telnet_path): + continue + for terminal_hash in os.listdir(telnet_path): + terminal_path = os.path.join(telnet_path, terminal_hash) + if not os.path.isdir(terminal_path): + continue + for fname in os.listdir(terminal_path): + if not fname.endswith(".json"): + continue + n_sessions = 1 + fpath = os.path.join(terminal_path, fname) + try: + with open(fpath, encoding="utf-8") as f: + fdata = json.load(f) + n_sessions = len(fdata.get("sessions", [1])) + except (OSError, json.JSONDecodeError): + pass + telnet_counts.setdefault(telnet_hash, [0, 0]) + telnet_counts[telnet_hash][0] += 1 + telnet_counts[telnet_hash][1] += n_sessions + if terminal_hash != _UNKNOWN_TERMINAL_HASH: + terminal_counts.setdefault(terminal_hash, [0, 0]) + terminal_counts[terminal_hash][0] += 1 + terminal_counts[terminal_hash][1] += n_sessions + + merged: Dict[Tuple[str, str], List[int]] = {} + for h, (files, sessions) in telnet_counts.items(): + key = ("Telnet", _resolve_hash_name(h, _names)) + prev = merged.get(key, [0, 0]) + merged[key] = [prev[0] + files, prev[1] + sessions] + for h, (files, sessions) in terminal_counts.items(): + key = ("Terminal", _resolve_hash_name(h, _names)) + prev = merged.get(key, [0, 0]) + merged[key] = [prev[0] + files, prev[1] + sessions] + + entries = [(kind, name, files, sessions) for (kind, name), (files, sessions) in merged.items()] + entries.sort(key=lambda e: e[3], reverse=True) + return entries + + +def _show_database( + term: Any, + data: Dict[str, Any], + entries: List[Tuple[str, str, int, int]], +) -> None: + """Display scrollable database of all known fingerprints.""" + try: + # 3rd party + from prettytable import PrettyTable # pylint: disable=import-outside-toplevel + except ImportError: + echo("prettytable not installed.\n") + return + + if not entries: + echo("No fingerprints in database.\n") + return + + has_unicode = _has_unicode(data) + + tbl = PrettyTable() + if has_unicode: + _apply_unicode_borders(tbl) + tbl.title = term.magenta(f"Database ({len(entries)} fingerprints)") + tbl.field_names = ["Type", "Fingerprint", "Clients", "Calls"] + tbl.align["Type"] = "l" + tbl.align["Fingerprint"] = "l" + tbl.align["Clients"] = "r" + tbl.align["Calls"] = "r" + tbl.max_table_width = max(40, (term.width or 80) - 1) + for kind, display_name, files, sessions in entries: + tbl.add_row( + [ + kind, + term.forestgreen(display_name), + str(files), + str(sessions), + ] + ) + + _paginate(term, str(tbl)) + + +def _fingerprint_repl( + term: Any, + data: Dict[str, Any], + seen_counts: str = "", + filepath: Optional[str] = None, + names: Optional[Dict[str, str]] = None, +) -> None: + """Interactive REPL for exploring fingerprint data.""" + ip = _client_ip(data) + _commands = { + "q": "logoff", + "t": "terminal-detail", + "l": "telnet-detail", + "s": "database", + "u": "update", + "\x0c": "refresh", + } + + db_cache = None + + while True: + _repl_prompt(term) + with term.cbreak(): + while term.inkey(timeout=0): + pass # drain pending input (e.g. \r\n after keypress) + key = term.inkey(timeout=None) + + key_str = key.name or str(key) + if key_str in _commands: + echo(str(key) + "\n") + logger.info("%s: repl %s", ip, _commands[key_str]) + elif key_str not in ("KEY_ENTER", "\r", "\n"): + logger.info("%s: repl unknown key %r", ip, key_str) + + if key == "q" or key.name == "KEY_ESCAPE" or not key: + logger.info("%s: repl logoff", ip) + echo(f"\n{term.normal}") + break + if key == "t": + _show_detail(term, data, "terminal") + elif key == "l": + _show_detail(term, data, "telnet") + elif key == "s": + if db_cache is None: + db_cache = _build_database_entries(names) + _show_database(term, data, db_cache) + elif key == "u" and filepath is not None: + _names = names if names is not None else {} + _prompt_fingerprint_identification(term, data, filepath, _names) + names = _load_fingerprint_names() + seen_counts = _build_seen_counts(data, names, term) + elif key == "\x0c": + echo(term.normal + term.clear) + _display_compact_summary(data, term) + if seen_counts: + echo(seen_counts) + + +def _has_unknown_hashes(data: Dict[str, Any], names: Dict[str, str]) -> bool: + """Return True if either telnet or terminal hash is not yet named.""" + telnet_hash = data.get("telnet-probe", {}).get("fingerprint", "") + terminal_hash = data.get("terminal-probe", {}).get("fingerprint", _UNKNOWN_TERMINAL_HASH) + if telnet_hash not in names: + return True + if terminal_hash != _UNKNOWN_TERMINAL_HASH and terminal_hash not in names: + return True + return False + + +def _prompt_fingerprint_identification( # pylint: disable=too-many-branches + term: Any, data: Dict[str, Any], filepath: str, names: Dict[str, str] +) -> None: + """Prompt user to identify unknown fingerprint hashes.""" + telnet_probe = data.get("telnet-probe", {}) + telnet_hash = telnet_probe.get("fingerprint", "") + terminal_probe = data.get("terminal-probe", {}) + terminal_hash = terminal_probe.get("fingerprint", _UNKNOWN_TERMINAL_HASH) + + telnet_known = telnet_hash in names + terminal_known = terminal_hash in names or terminal_hash == _UNKNOWN_TERMINAL_HASH + all_known = telnet_known and terminal_known + + if all_known: + echo(f"\n{term.bold_magenta}Suggest a revision{term.normal}\n") + else: + echo(f"{term.bold_magenta}Help our database!{term.normal}\n") + + suggestions: Dict[str, str] = data.get("suggestions", {}) + revised = False + + if terminal_hash != _UNKNOWN_TERMINAL_HASH: + if not terminal_known: + software_name = terminal_probe.get("session_data", {}).get("software_name") + default = software_name or "" + if default: + prompt = f"Terminal emulator name" f' (press return for "{default}"): ' + else: + prompt = f"Terminal emulator name for {terminal_hash}: " + raw = _cooked_input(prompt) + if not raw and default: + raw = default + validated = _validate_suggestion(raw) + if validated: + suggestions["terminal-emulator"] = validated + elif all_known: + current_name = names.get(terminal_hash) + prompt = f"Terminal emulator name" f' (press return for "{current_name}"): ' + raw = _cooked_input(prompt).strip() + validated = _validate_suggestion(raw) if raw else None + if validated and validated != current_name: + suggestions["terminal-emulator-revision"] = validated + revised = True + + if not telnet_known: + raw = _cooked_input(f"Telnet client name for {telnet_hash}: ") + validated = _validate_suggestion(raw) + if validated: + suggestions["telnet-client"] = validated + elif all_known: + current_name = names.get(telnet_hash) + prompt = f"Telnet client name" f' (press return for "{current_name}"): ' + raw = _cooked_input(prompt).strip() + validated = _validate_suggestion(raw) if raw else None + if validated and validated != current_name: + suggestions["telnet-client-revision"] = validated + revised = True + + if suggestions: + data["suggestions"] = suggestions + _atomic_json_write(filepath, data) + + if revised: + echo("Your submission is under review.\n") + + echo("\n") + + +def _client_requires_ga(data: Dict[str, Any]) -> bool: + """Return True when the client refused SGA (e.g. MUD clients like Mudlet).""" + probe = data.get("telnet-probe", {}).get("session_data", {}).get("probe", {}) + return "SGA" not in probe.get("WILL", {}) + + +def _process_client_fingerprint(filepath: str, data: Dict[str, Any]) -> None: + """Process client fingerprint: run ucs-detect if available, update file.""" + if _client_requires_ga(data): + logger.info("skipping ucs-detect: client requires GA (MUD client)") + terminal_data = None + else: + terminal_data = _run_ucs_detect() + + if terminal_data: + terminal_fp = _create_terminal_fingerprint(terminal_data) + terminal_hash = _hash_fingerprint(terminal_fp) + + data["terminal-probe"] = { + "fingerprint": terminal_hash, + "fingerprint-data": terminal_fp, + "session_data": terminal_data, + } + + old_dir = os.path.dirname(filepath) + if os.path.basename(old_dir) != terminal_hash: + new_dir = os.path.join(os.path.dirname(old_dir), terminal_hash) + try: + os.makedirs(new_dir, exist_ok=True) + new_filepath = os.path.join(new_dir, os.path.basename(filepath)) + os.rename(filepath, new_filepath) + filepath = new_filepath + if not os.listdir(old_dir): + os.rmdir(old_dir) + except OSError as exc: + logger.warning("failed to move %s -> %s: %s", filepath, new_dir, exc) + + _atomic_json_write(filepath, data) + + _setup_term_environ(data) + + try: + # 3rd party + import blessed # noqa: F401 # pylint: disable=import-outside-toplevel,unused-import + except ImportError: + print(json.dumps(data, indent=2, sort_keys=True)) + return + + term = _make_terminal() + names = _load_fingerprint_names() + seen_counts = _build_seen_counts(data, names, term) + if not _display_compact_summary(data, term): + print(json.dumps(data, indent=2, sort_keys=True)) + if seen_counts: + echo(seen_counts) + + if term.is_a_tty: + with term.cbreak(), _disable_isig(): + if _has_unknown_hashes(data, names): + _prompt_fingerprint_identification(term, data, filepath, names) + _fingerprint_repl(term, data, seen_counts, filepath, names) + + +def fingerprinting_post_script(filepath: str) -> None: + """ + Post-fingerprint script that optionally runs ucs-detect for terminal probing. + + If ucs-detect is available in PATH, runs it to collect terminal capabilities + and merges the results into the fingerprint data. + + Can be used as the TELNETLIB3_FINGERPRINT_POST_SCRIPT target:: + + export TELNETLIB3_FINGERPRINT_POST_SCRIPT=telnetlib3.fingerprinting_display + export TELNETLIB3_DATA_DIR=./data + telnetlib3-server --shell telnetlib3.fingerprinting_server_shell + + :param filepath: Path to the saved fingerprint JSON file. + """ + filepath = str(filepath) + if not os.path.exists(filepath): + logger.warning("Post-script file not found: %s", filepath) + return + + with open(filepath, encoding="utf-8") as f: + data = json.load(f) + + telnet_probe = data.get("telnet-probe", {}) + probed_protocol = telnet_probe.get("fingerprint-data", {}).get("probed-protocol") + + if probed_protocol == "client": + _process_client_fingerprint(filepath, data) + else: + logger.warning("Unknown probed-protocol: %s", probed_protocol) + + +def main() -> None: + """CLI entry point for fingerprinting display post-processing.""" + if len(sys.argv) != 2: + print(f"Usage: python -m {__name__} ", file=sys.stderr) + sys.exit(1) + fingerprinting_post_script(sys.argv[1]) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/telnetlib3/guard_shells.py b/telnetlib3/guard_shells.py index 24e2bba2..c3edcf94 100644 --- a/telnetlib3/guard_shells.py +++ b/telnetlib3/guard_shells.py @@ -1,14 +1,29 @@ """ Guard shells for connection limiting and robot detection. -These shells are used when normal shell access is denied due to connection limits or failed robot -checks. +When running a telnet server on a public IPv4 address, or even on large private networks, +various network scanners, scrapers, worms, bots, and other automatons will connect. + +The ``robot_check`` function can reliably detect whether the remote end is a real terminal +emulator by measuring the rendered width of a wide Unicode character. Real terminals +render it as width 2, while bots typically see width 1 or timeout. + +These shells are used when normal shell access is denied due to connection limits or +failed robot checks. """ +from __future__ import annotations + # std imports import re import asyncio import logging +from typing import Tuple, Union, Optional, cast + +# local +from .server_shell import readline2 +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode __all__ = ("robot_check", "robot_shell", "busy_shell", "ConnectionCounter") @@ -27,7 +42,7 @@ class ConnectionCounter: """Simple shared counter for limiting concurrent connections.""" - def __init__(self, limit): + def __init__(self, limit: int) -> None: """ Initialize connection counter. @@ -36,7 +51,7 @@ def __init__(self, limit): self.limit = limit self._count = 0 - def try_acquire(self): + def try_acquire(self) -> bool: """ Try to acquire a connection slot. @@ -47,22 +62,26 @@ def try_acquire(self): return True return False - def release(self): + def release(self) -> None: """Release a connection slot.""" if self._count > 0: self._count -= 1 @property - def count(self): + def count(self) -> int: """Current connection count.""" return self._count -async def _read_line_inner(reader, max_len): +async def _read_line_inner( + reader: Union[TelnetReader, TelnetReaderUnicode], + max_len: int, +) -> str: """Inner loop for _read_line, separated for wait_for compatibility.""" + _reader = cast(TelnetReaderUnicode, reader) buf = "" while len(buf) < max_len: - char = await reader.read(1) + char = await _reader.read(1) if not char: break if char in ("\r", "\n"): @@ -71,7 +90,11 @@ async def _read_line_inner(reader, max_len): return buf -async def _read_line(reader, timeout, max_len=_MAX_INPUT): +async def _read_line( + reader: Union[TelnetReader, TelnetReaderUnicode], + timeout: float, + max_len: int = _MAX_INPUT, +) -> Optional[str]: """Read a line with timeout and length limit.""" try: return await asyncio.wait_for(_read_line_inner(reader, max_len), timeout) @@ -79,11 +102,29 @@ async def _read_line(reader, timeout, max_len=_MAX_INPUT): return None -async def _read_cpr_response(reader): +async def _readline_with_echo( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + timeout: float, +) -> Optional[str]: + """Read a line with echo and timeout, using readline2 from server_shell.""" + try: + return await asyncio.wait_for(readline2(reader, writer), timeout) + except asyncio.TimeoutError: + return None + + +async def _read_cpr_response( + reader: Union[TelnetReader, TelnetReaderUnicode], +) -> Optional[Tuple[int, int]]: """Read CPR response bytes until 'R' terminator.""" buf = b"" while True: - data = await reader.read(1) + try: + data = await reader.read(1) + except UnicodeDecodeError: + # Bot sent garbage bytes that can't be decoded + return None if not data: return None if isinstance(data, str): @@ -95,14 +136,19 @@ async def _read_cpr_response(reader): return (int(match.group(1)), int(match.group(2))) -async def _get_cursor_position(reader, writer, timeout=2.0): +async def _get_cursor_position( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + timeout: float = 2.0, +) -> Tuple[Optional[int], Optional[int]]: """ Query cursor position using DSR/CPR. :returns: (row, col) tuple or (None, None) on timeout/failure. """ # Send Device Status Report request - writer.write("\x1b[6n") + _writer = cast(TelnetWriterUnicode, writer) + _writer.write("\x1b[6n") await writer.drain() # Read response: ESC [ row ; col R @@ -113,78 +159,118 @@ async def _get_cursor_position(reader, writer, timeout=2.0): return (None, None) -async def _measure_width(reader, writer, text, timeout=2.0): +async def _measure_width( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + text: str, + timeout: float = 2.0, +) -> Optional[int]: """ Measure rendered width of text using cursor position. :returns: Width in columns, or None on failure. """ + _writer = cast(TelnetWriterUnicode, writer) _, x1 = await _get_cursor_position(reader, writer, timeout) if x1 is None: return None - writer.write(text) - await writer.drain() + _writer.write(text) + await _writer.drain() _, x2 = await _get_cursor_position(reader, writer, timeout) if x2 is None: return None # Clear the test character - writer.write(f"\x1b[{x1}G" + " " * (x2 - x1) + f"\x1b[{x1}G") - await writer.drain() + _writer.write(f"\x1b[{x1}G" + " " * (x2 - x1) + f"\x1b[{x1}G") + await _writer.drain() return x2 - x1 -async def robot_check(reader, writer, timeout=5.0): +async def robot_check( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + timeout: float = 5.0, +) -> bool: """ Check if client can render wide characters. :returns: True if client passes (renders wide char as width 2). """ width = await _measure_width(reader, writer, _WIDE_TEST_CHAR, timeout) - return width == 2 + return bool(width == 2) -async def robot_shell(reader, writer): - """ - Shell for failed robot checks. - - Asks philosophical questions, logs responses, and disconnects. - """ - logger.info("robot_shell: connection from %s", writer.get_extra_info("peername")) - - writer.write("Do robots dream of electric sheep? [yn] ") - await writer.drain() - - line1 = await _read_line(reader, timeout=10.0) - if line1 is None: - logger.info("robot_shell: timeout waiting for response") - return - - logger.info("robot denied, line1=%r", line1) +async def _ask_question( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + prompt: str, + timeout: float = 10.0, +) -> Optional[str]: + """Ask a question, echoing input and repeating prompt on blank input.""" + _writer = cast(TelnetWriterUnicode, writer) + while True: + _writer.write(prompt) + await _writer.drain() - writer.write("\r\nHave you ever wondered, who are the windowmakers? ") - await writer.drain() + line = await _readline_with_echo(reader, writer, timeout) + if line is None: + return None - line2 = await _read_line(reader, timeout=10.0) - if line2 is None: - logger.info("robot_shell: timeout on second question") - return + if line.strip(): + return line + # Blank input - repeat prompt + _writer.write("\r\n") - logger.info("robot denied, line2=%r", line2) - writer.write("\r\n") - await writer.drain() +async def robot_shell( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> None: + """ + Shell for failed robot checks. + Asks philosophical questions, logs responses, and disconnects. + """ + writer = cast(TelnetWriterUnicode, writer) + peername = writer.get_extra_info("peername") + logger.info("robot_shell: connection from %s", peername) -async def busy_shell(reader, writer): + answers = [] + try: + line1 = await _ask_question(reader, writer, "Do robots dream of electric sheep? [yn] ") + if line1 is None: + logger.info("robot_shell: timeout waiting for response") + return + answers.append(line1) + + line2 = await _ask_question( + reader, writer, "\r\nHave you ever wondered, who are the windowmakers? " + ) + if line2 is None: + logger.info("robot_shell: timeout on second question") + return + answers.append(line2) + + writer.write("\r\n") + await writer.drain() + finally: + if answers: + logger.info("robot denied, answers=%r", answers) + + +async def busy_shell( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> None: """ Shell for when connection limit is reached. Displays busy message, logs any input, and disconnects. """ + writer = cast(TelnetWriterUnicode, writer) logger.info( "busy_shell: connection from %s (limit reached)", writer.get_extra_info("peername"), diff --git a/telnetlib3/relay_server.py b/telnetlib3/relay_server.py index 94f8e24e..91986fd6 100644 --- a/telnetlib3/relay_server.py +++ b/telnetlib3/relay_server.py @@ -3,11 +3,14 @@ # std imports import asyncio import logging +from typing import Any, Set, Union, cast # local from .client import open_connection from .accessories import make_reader_task from .server_shell import readline +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode CR, LF, NUL = ("\r", "\n", "\x00") @@ -16,7 +19,10 @@ # local -async def relay_shell(client_reader, client_writer): +async def relay_shell( + client_reader: Union[TelnetReader, TelnetReaderUnicode], + client_writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> None: """ Example telnet relay shell for use with telnetlib3.create_server. @@ -28,26 +34,28 @@ async def relay_shell(client_reader, client_writer): type and environment variable of value COLORTERM. """ log = logging.getLogger("relay_server") + _reader = cast(TelnetReaderUnicode, client_reader) + _writer = cast(TelnetWriterUnicode, client_writer) - password_prompt = readline(client_reader, client_writer) - password_prompt.send(None) + password_prompt = readline(_reader, _writer) + next(password_prompt) - client_writer.write("Telnet Relay shell ready." + CR + LF + CR + LF) + _writer.write("Telnet Relay shell ready." + CR + LF + CR + LF) client_passcode = "867-5309" num_tries = 3 next_host, next_port = "1984.ws", 23 passcode = None for _ in range(num_tries): - client_writer.write("Passcode: ") + _writer.write("Passcode: ") while passcode is None: - inp = await client_reader.read(1) + inp = await _reader.read(1) if not inp: log.info("EOF from client") return passcode = password_prompt.send(inp) await asyncio.sleep(1) - client_writer.write(CR + LF) + _writer.write(CR + LF) if passcode == client_passcode: log.info("passcode accepted") break @@ -56,21 +64,21 @@ async def relay_shell(client_reader, client_writer): # wrong passcode after 3 tires if passcode is None: log.info("passcode failed after %s tries", num_tries) - client_writer.close() + _writer.close() return # connect to another telnet server (next_host, next_port) - client_writer.write(f"Connecting to {next_host}:{next_port} ... ") + _writer.write(f"Connecting to {next_host}:{next_port} ... ") server_reader, server_writer = await open_connection( next_host, next_port, - cols=client_writer.get_extra_info("cols"), - rows=client_writer.get_extra_info("rows"), + cols=_writer.get_extra_info("cols"), + rows=_writer.get_extra_info("rows"), ) - client_writer.write("connected!" + CR + LF) + _writer.write("connected!" + CR + LF) - done = [] - client_stdin = make_reader_task(client_reader) + done: Set["asyncio.Task[Any]"] = set() + client_stdin = make_reader_task(_reader) server_stdout = make_reader_task(server_reader) wait_for = {client_stdin, server_stdout} while wait_for: @@ -82,7 +90,7 @@ async def relay_shell(client_reader, client_writer): inp = task.result() if inp: server_writer.write(inp) - client_stdin = make_reader_task(client_reader) + client_stdin = make_reader_task(_reader) wait_for.add(client_stdin) else: log.info("EOF from client") @@ -90,10 +98,10 @@ async def relay_shell(client_reader, client_writer): elif task == server_stdout: out = task.result() if out: - client_writer.write(out) + _writer.write(out) server_stdout = make_reader_task(server_reader) wait_for.add(server_stdout) else: log.info("EOF from server") - client_writer.close() + _writer.close() log.info("No more tasks: relay server complete") diff --git a/telnetlib3/server.py b/telnetlib3/server.py index c121fcc0..5485e540 100755 --- a/telnetlib3/server.py +++ b/telnetlib3/server.py @@ -11,16 +11,36 @@ after an idle period. """ +from __future__ import annotations + # std imports +import os +import sys +import codecs import signal +import socket import asyncio import logging import argparse -from typing import Callable, Optional, NamedTuple +from typing import ( + Any, + Dict, + List, + Type, + Tuple, + Union, + Callable, + Optional, + Sequence, + NamedTuple, +) # local from . import accessories, server_base +from ._types import ShellCallback from .telopt import name_commands +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode # Check if PTY support is available (Unix-only modules: pty, termios, fcntl) try: @@ -44,20 +64,21 @@ class CONFIG(NamedTuple): loglevel: str = "info" logfile: Optional[str] = None logfmt: str = accessories._DEFAULT_LOGFMT # pylint: disable=protected-access - shell: Callable = accessories.function_lookup("telnetlib3.telnet_server_shell") + shell: Callable[..., Any] = accessories.function_lookup("telnetlib3.telnet_server_shell") encoding: str = "utf8" force_binary: bool = False timeout: int = 300 - connect_maxwait: float = 4.0 + connect_maxwait: float = 1.5 pty_exec: Optional[str] = None - pty_args: Optional[str] = None + pty_args: Optional[List[str]] = None + pty_raw: bool = False robot_check: bool = False pty_fork_limit: int = 0 status_interval: int = 20 + never_send_ga: bool = False # Default config instance - use this to access default values -# (accessing CONFIG.field directly returns _tuplegetter in Python 3.8) _config = CONFIG() logger = logging.getLogger("telnetlib3.server") @@ -73,26 +94,56 @@ class TelnetServer(server_base.BaseServer): # Derived methods from base class - def __init__( - self, term="unknown", cols=80, rows=25, timeout=300, *args, **kwargs - ): # pylint: disable=keyword-arg-before-vararg + def __init__( # pylint: disable=too-many-positional-arguments + self, + term: str = "unknown", + cols: int = 80, + rows: int = 25, + timeout: int = 300, + shell: Optional[ShellCallback] = None, + _waiter_connected: Optional[asyncio.Future[None]] = None, + encoding: Union[str, bool] = "utf8", + encoding_errors: str = "strict", + force_binary: bool = False, + never_send_ga: bool = False, + connect_maxwait: float = 4.0, + limit: Optional[int] = None, + reader_factory: type = TelnetReader, + reader_factory_encoding: type = TelnetReaderUnicode, + writer_factory: type = TelnetWriter, + writer_factory_encoding: type = TelnetWriterUnicode, + ) -> None: """Initialize TelnetServer with terminal parameters.""" - super().__init__(*args, **kwargs) - self.waiter_encoding = asyncio.Future() + super().__init__( + shell=shell, + _waiter_connected=_waiter_connected, + encoding=encoding, + encoding_errors=encoding_errors, + force_binary=force_binary, + never_send_ga=never_send_ga, + connect_maxwait=connect_maxwait, + limit=limit, + reader_factory=reader_factory, + reader_factory_encoding=reader_factory_encoding, + writer_factory=writer_factory, + writer_factory_encoding=writer_factory_encoding, + ) + self._environ_requested = False + self.waiter_encoding: asyncio.Future[bool] = asyncio.Future() self._tasks.append(self.waiter_encoding) self._ttype_count = 1 - self._timer = None + self._timer: Optional[asyncio.TimerHandle] = None self._extra.update( { "term": term, - "charset": kwargs.get("encoding", ""), + "charset": encoding or "", "cols": cols, "rows": rows, "timeout": timeout, } ) - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle new connection and wire up telnet option callbacks.""" # local from .telopt import ( # pylint: disable=import-outside-toplevel @@ -105,20 +156,22 @@ def connection_made(self, transport): ) super().connection_made(transport) + assert self.writer is not None # begin timeout timer self.set_timeout() # Wire extended rfc callbacks for responses to # requests of terminal attributes, environment values, etc. - for tel_opt, callback_fn in [ + _ext_callbacks: List[Tuple[bytes, Callable[..., Any]]] = [ (NAWS, self.on_naws), (NEW_ENVIRON, self.on_environ), (TSPEED, self.on_tspeed), (TTYPE, self.on_ttype), (XDISPLOC, self.on_xdisploc), (CHARSET, self.on_charset), - ]: + ] + for tel_opt, callback_fn in _ext_callbacks: self.writer.set_ext_callback(tel_opt, callback_fn) # Wire up a callbacks that return definitions for requests. @@ -128,21 +181,28 @@ def connection_made(self, transport): ]: self.writer.set_ext_send_callback(tel_opt, callback_fn) - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Process received data and reset timeout timer.""" self.set_timeout() super().data_received(data) - def begin_negotiation(self): + def begin_negotiation(self) -> None: """Begin telnet negotiation by requesting terminal type.""" # local from .telopt import DO, TTYPE # pylint: disable=import-outside-toplevel super().begin_negotiation() + assert self.writer is not None self.writer.iac(DO, TTYPE) - def begin_advanced_negotiation(self): - """Request advanced telnet options from client.""" + def begin_advanced_negotiation(self) -> None: + """ + Request advanced telnet options from client. + + ``DO NEW_ENVIRON`` is deferred until the TTYPE cycle completes + so that Microsoft telnet (ANSI + VT100) can be detected first. + See ``_negotiate_environ`` and GitHub issue #24. + """ # local from .telopt import ( # pylint: disable=import-outside-toplevel DO, @@ -152,29 +212,45 @@ def begin_advanced_negotiation(self): WILL, BINARY, CHARSET, - NEW_ENVIRON, ) super().begin_advanced_negotiation() + assert self.writer is not None self.writer.iac(WILL, SGA) self.writer.iac(WILL, ECHO) self.writer.iac(WILL, BINARY) - self.writer.iac(DO, NEW_ENVIRON) + # DO NEW_ENVIRON is deferred -- see _negotiate_environ() self.writer.iac(DO, NAWS) if self.default_encoding: - # Request client capability to negotiate character set self.writer.iac(DO, CHARSET) - def check_negotiation(self, final=False): + def check_negotiation(self, final: bool = False) -> bool: """Check if negotiation is complete including encoding.""" # local from .telopt import ( # pylint: disable=import-outside-toplevel + DO, SB, TTYPE, CHARSET, NEW_ENVIRON, ) + assert self.writer is not None + # If TTYPE cycle stalled or client refused TTYPE, trigger + # deferred NEW_ENVIRON negotiation now. Only when advanced + # negotiation is active -- a raw TCP client that WONTs TTYPE + # should not be sent DO NEW_ENVIRON. + if not self._environ_requested and self._advanced: + ttype_refused = self.writer.remote_option.get(TTYPE) is False + ttype_do_pending = self.writer.pending_option.get(DO + TTYPE) + ttype_sb_pending = self.writer.pending_option.get(SB + TTYPE) + if ttype_refused or final: + self._negotiate_environ() + elif not ttype_do_pending and not ttype_sb_pending: + # TTYPE fully resolved but on_ttype never called + # _negotiate_environ (shouldn't happen, but be safe) + self._negotiate_environ() + # Debug log to see which options are still pending pending = [ (name_commands(opt), val) for opt, val in self.writer.pending_option.items() if val @@ -182,8 +258,7 @@ def check_negotiation(self, final=False): if pending: logger.debug("Pending options: %r", pending) - # Check if we're waiting for important subnegotiations -- environment or charset information - # These are critical for proper encoding determination + # Check if we're waiting for important subnegotiations waiting_for_environ = ( SB + NEW_ENVIRON in self.writer.pending_option and self.writer.pending_option[SB + NEW_ENVIRON] @@ -239,20 +314,23 @@ def check_negotiation(self, final=False): # new methods - def encoding(self, outgoing=None, incoming=None): + def encoding( + self, + outgoing: Optional[bool] = None, + incoming: Optional[bool] = None, + ) -> Union[str, bool]: """ Return encoding for the given stream direction. - :param bool outgoing: Whether the return value is suitable for + :param outgoing: Whether the return value is suitable for encoding bytes for transmission to client end. - :param bool incoming: Whether the return value is suitable for + :param incoming: Whether the return value is suitable for decoding bytes received from the client. :raises TypeError: when a direction argument, either ``outgoing`` or ``incoming``, was not set ``True``. :returns: ``'US-ASCII'`` for the directions indicated, unless ``BINARY`` :rfc:`856` has been negotiated for the direction - indicated or :attr`force_binary` is set ``True``. - :rtype: str + indicated or ``force_binary`` is set ``True``. """ if not (outgoing or incoming): raise TypeError( @@ -260,6 +338,7 @@ def encoding(self, outgoing=None, incoming=None): ) # may we encode in the direction indicated? + assert self.writer is not None _outgoing_only = outgoing and not incoming _incoming_only = not outgoing and incoming _bidirectional = outgoing and incoming @@ -275,18 +354,24 @@ def encoding(self, outgoing=None, incoming=None): # negotiation. _lang = self.get_extra_info("LANG", "") if _lang and _lang != "C": - return accessories.encoding_from_lang(_lang) + candidate = accessories.encoding_from_lang(_lang) + if candidate: + try: + codecs.lookup(candidate) + return candidate + except LookupError: + pass # fall through to charset or default # otherwise, less common CHARSET negotiation may be found in many # East-Asia BBS and Western MUD systems. return self.get_extra_info("charset") or self.default_encoding return "US-ASCII" - def set_timeout(self, duration=-1): + def set_timeout(self, duration: int = -1) -> None: """ Restart or unset timeout for client. - :param int duration: When specified as a positive integer, + :param duration: When specified as a positive integer, schedules Future for callback of :meth:`on_timeout`. When ``-1``, the value of ``self.get_extra_info('timeout')`` is used. When non-True, it is canceled. @@ -305,34 +390,33 @@ def set_timeout(self, duration=-1): # Callback methods - def on_timeout(self): + def on_timeout(self) -> None: """ Callback received on session timeout. Default implementation writes "Timeout." bound by CRLF and closes. This can be disabled by calling :meth:`set_timeout` with - ``duration` value of ``0``. + ``duration`` value of ``0``. """ logger.debug("Timeout after %1.2fs", self.idle) - # try to write timeout using encoding, - try: + assert self.writer is not None + if isinstance(self.writer, TelnetWriterUnicode): self.writer.write("\r\nTimeout.\r\n") - except TypeError: - # unless server was started with encoding=False, we must send as binary! + else: self.writer.write(b"\r\nTimeout.\r\n") self.timeout_connection() - def on_naws(self, rows, cols): + def on_naws(self, rows: int, cols: int) -> None: """ Callback receives NAWS response, :rfc:`1073`. - :param int rows: screen size, by number of cells in height. - :param int cols: screen size, by number of cells in width. + :param rows: screen size, by number of cells in height. + :param cols: screen size, by number of cells in width. """ self._extra.update({"rows": rows, "cols": cols}) - def on_request_environ(self): + def on_request_environ(self) -> List[Union[str, bytes]]: """ Definition for NEW_ENVIRON request of client, :rfc:`1572`. @@ -340,13 +424,12 @@ def on_request_environ(self): first entered on receipt of (WILL, NEW_ENVIRON) by server. The return value *defines the request made to the client* for environment values. - :rtype list: a list of unicode character strings of US-ASCII - characters, indicating the environment keys the server requests - of the client. If this list contains the special byte constants, - ``USERVAR`` or ``VAR``, the client is allowed to volunteer any - other additional user or system values. - - Any empty return value indicates that no request should be made. + :returns: A list of US-ASCII character strings indicating the + environment keys the server requests of the client. If this list + contains the special byte constants, ``USERVAR`` or ``VAR``, the + client is allowed to volunteer any other additional user or system + values. An empty return value indicates that no request should be + made. The default return value is:: @@ -356,18 +439,58 @@ def on_request_environ(self): # local from .telopt import VAR, USERVAR # pylint: disable=import-outside-toplevel + # Parse additional keys from environment variable (comma-delimited) + additional = os.environ.get("TELNETLIB3_FINGERPRINT_ENVIRON_ADDITIONAL", "") + additional_keys = [k.strip() for k in additional.split(",") if k.strip()] + return [ + # Well-known VAR (RFC 1572) + "USER", + "DISPLAY", + # USERVAR - common environment variables "LANG", "TERM", "COLUMNS", "LINES", - "DISPLAY", "COLORTERM", + "HOME", + "SHELL", + # SSH/remote connection info + "SSH_CLIENT", + "SSH_TTY", + # System info + "LOGNAME", + "HOSTNAME", + "HOSTTYPE", + "OSTYPE", + "PWD", + # Editor preferences + "EDITOR", + "VISUAL", + # Terminal multiplexers + "TMUX", + "STY", + # Locale settings + "LC_ALL", + "LC_CTYPE", + "LC_MESSAGES", + "LC_COLLATE", + "LC_TIME", + # Container/remote + "DOCKER_HOST", + # Shell history + "HISTFILE", + # Cloud + "AWS_PROFILE", + "AWS_REGION", + # Additional keys from TELNETLIB3_FINGERPRINT_ENVIRON_ADDITIONAL + *additional_keys, + # Request any other VAR/USERVAR the client wants to send VAR, USERVAR, ] - def on_environ(self, mapping): + def on_environ(self, mapping: Dict[str, str]) -> None: """Callback receives NEW_ENVIRON response, :rfc:`1572`.""" # A well-formed client responds with empty values for variables to # mean "no value". They might have it, they just may not wish to @@ -386,7 +509,7 @@ def on_environ(self, mapping): self._extra.update(u_mapping) - def on_request_charset(self): + def on_request_charset(self) -> List[str]: """ Definition for CHARSET request by client, :rfc:`2066`. @@ -394,11 +517,9 @@ def on_request_charset(self): first entered on receipt of (WILL, CHARSET) by server. The return value *defines the request made to the client* for encodings. - :rtype list: a list of unicode character strings of US-ASCII - characters, indicating the encodings offered by the server in - its preferred order. - - Any empty return value indicates that no encodings are offered. + :returns: A list of US-ASCII character strings indicating the + encodings offered by the server in its preferred order. An empty + return value indicates that no encodings are offered. The default return value includes common encodings for both Western and Eastern scripts:: @@ -425,16 +546,17 @@ def on_request_charset(self): "US-ASCII", # Basic ASCII ] - def on_charset(self, charset): + def on_charset(self, charset: str) -> None: """Callback for CHARSET response, :rfc:`2066`.""" self._extra["charset"] = charset - def on_tspeed(self, rx, tx): + def on_tspeed(self, rx: str, tx: str) -> None: """Callback for TSPEED response, :rfc:`1079`.""" self._extra["tspeed"] = f"{rx},{tx}" - def on_ttype(self, ttype): + def on_ttype(self, ttype: str) -> None: """Callback for TTYPE response, :rfc:`930`.""" + assert self.writer is not None # TTYPE may be requested multiple times, we honor this system and # attempt to cause the client to cycle, as their first response may # not be their most significant. All responses held as 'ttype{n}', @@ -449,38 +571,89 @@ def on_ttype(self, ttype): _lastval = self.get_extra_info(f"ttype{self._ttype_count - 1}") + # After ttype1: send DO NEW_ENVIRON now unless ttype1 is "ANSI", + # in which case we defer until ttype2 to detect Microsoft telnet + # (ANSI + VT100) which crashes on NEW_ENVIRON (issue #24). + if key == "ttype1" and ttype != "ANSI": + self._negotiate_environ() + elif key == "ttype2" and not self._environ_requested: + self._negotiate_environ() + if key != "ttype1" and ttype == self.get_extra_info("ttype1", None): # cycle has looped, stop logger.debug("ttype cycle stop at %s: %s, looped.", key, ttype) + self._negotiate_environ() elif not ttype or self._ttype_count > self.TTYPE_LOOPMAX: # empty reply string or too many responses! logger.warning("ttype cycle stop at %s: %s.", key, ttype) + self._negotiate_environ() elif self._ttype_count == 3 and ttype.upper().startswith("MTTS "): val = self.get_extra_info("ttype2") - logger.debug("ttype cycle stop at %s: %s, using %s from ttype2.", key, ttype, val) + logger.debug( + "ttype cycle stop at %s: %s, using %s from ttype2.", + key, + ttype, + val, + ) self._extra["TERM"] = val + self._negotiate_environ() elif ttype == _lastval: logger.debug("ttype cycle stop at %s: %s, repeated.", key, ttype) + self._negotiate_environ() else: logger.debug("ttype cycle cont at %s: %s.", key, ttype) self._ttype_count += 1 self.writer.request_ttype() - def on_xdisploc(self, xdisploc): + def on_xdisploc(self, xdisploc: str) -> None: """Callback for XDISPLOC response, :rfc:`1096`.""" self._extra["xdisploc"] = xdisploc # private methods - def _check_encoding(self): + def _negotiate_environ(self) -> None: + """ + Send ``DO NEW_ENVIRON`` unless the client is Microsoft telnet. + + Called from :meth:`on_ttype` as soon as we have enough information: + + - After ``ttype1`` when it is not ``"ANSI"``. + - After ``ttype2`` when ``ttype1`` *is* ``"ANSI"`` -- if ``ttype2`` + is ``"VT100"`` the client is Microsoft Windows telnet and + ``NEW_ENVIRON`` is skipped entirely (GitHub issue #24). + - From :meth:`check_negotiation` when TTYPE stalls or is refused. + """ + if self._environ_requested: + return + self._environ_requested = True + + # local + from .telopt import DO, NEW_ENVIRON # pylint: disable=import-outside-toplevel + + ttype1 = self.get_extra_info("ttype1") or "" + ttype2 = self.get_extra_info("ttype2") or "" + + if ttype1 == "ANSI" and ttype2 == "VT100": + logger.info( + "skipping NEW_ENVIRON for Microsoft telnet (ttype1=%r, ttype2=%r)", + ttype1, + ttype2, + ) + return + + assert self.writer is not None + self.writer.iac(DO, NEW_ENVIRON) + + def _check_encoding(self) -> bool: # Periodically check for completion of ``waiter_encoding``. # local from .telopt import DO, SB, BINARY, CHARSET # pylint: disable=import-outside-toplevel + assert self.writer is not None # Check if we need to request client to use BINARY mode for client-to-server communication if ( self.writer.outbinary @@ -513,14 +686,15 @@ class Server: Returned by :func:`create_server`. """ - def __init__(self, server): + def __init__(self, server: Optional[asyncio.Server]) -> None: """Initialize wrapper around asyncio.Server.""" - self._server = server - self._protocols = [] - self._new_client = asyncio.Queue() + self._server: Optional[asyncio.Server] = server + self._protocols: List[server_base.BaseServer] = [] + self._new_client: asyncio.Queue[server_base.BaseServer] = asyncio.Queue() - def close(self): + def close(self) -> None: """Close the server, stop accepting new connections, and close all clients.""" + assert self._server is not None self._server.close() # Close all connected client transports for protocol in list(self._protocols): @@ -528,8 +702,9 @@ def close(self): if hasattr(protocol, "_transport") and protocol._transport is not None: protocol._transport.close() - async def wait_closed(self): + async def wait_closed(self) -> None: """Wait until the server and all client connections are closed.""" + assert self._server is not None await self._server.wait_closed() # Yield to event loop for pending close callbacks await asyncio.sleep(0) @@ -537,16 +712,18 @@ async def wait_closed(self): self._protocols.clear() @property - def sockets(self): + def sockets(self) -> Optional[Tuple["socket.socket", ...]]: """Return list of socket objects the server is listening on.""" + assert self._server is not None return self._server.sockets - def is_serving(self): + def is_serving(self) -> bool: """Return True if the server is accepting new connections.""" + assert self._server is not None return self._server.is_serving() @property - def clients(self): + def clients(self) -> List[server_base.BaseServer]: """ List of connected client protocol instances. @@ -556,7 +733,7 @@ def clients(self): self._protocols = [p for p in self._protocols if not getattr(p, "_closing", False)] return list(self._protocols) - async def wait_for_client(self): + async def wait_for_client(self) -> server_base.BaseServer: r""" Wait for a client to connect and complete negotiation. @@ -570,10 +747,10 @@ async def wait_for_client(self): """ return await self._new_client.get() - def _register_protocol(self, protocol): + def _register_protocol(self, protocol: asyncio.Protocol) -> None: """Register a new protocol instance (called by factory).""" # pylint: disable=protected-access - self._protocols.append(protocol) + self._protocols.append(protocol) # type: ignore[arg-type] # Only register callbacks if protocol has the required waiters # (custom protocols like plain asyncio.Protocol won't have these) if hasattr(protocol, "_waiter_connected"): @@ -585,7 +762,7 @@ def _register_protocol(self, protocol): class StatusLogger: """Periodic status logger for connected clients.""" - def __init__(self, server, interval): + def __init__(self, server: Server, interval: int) -> None: """ Initialize status logger. @@ -594,10 +771,10 @@ def __init__(self, server, interval): """ self._server = server self._interval = interval - self._task = None - self._last_status = None + self._task: Optional["asyncio.Task[None]"] = None + self._last_status: Optional[Dict[str, Any]] = None - def _get_status(self): + def _get_status(self) -> Dict[str, Any]: """Get current status snapshot using IP:port pairs for change detection.""" clients = self._server.clients client_data = [] @@ -609,27 +786,29 @@ def _get_status(self): "port": peername[1], "rx": getattr(client, "rx_bytes", 0), "tx": getattr(client, "tx_bytes", 0), + "idle": int(getattr(client, "idle", 0)), } ) client_data.sort(key=lambda x: (x["ip"], x["port"])) return {"count": len(clients), "clients": client_data} - def _status_changed(self, current): + def _status_changed(self, current: Dict[str, Any]) -> bool: """Check if status differs from last logged.""" if self._last_status is None: - return current["count"] > 0 + return bool(current["count"] > 0) return current != self._last_status - def _format_status(self, status): + def _format_status(self, status: Dict[str, Any]) -> str: """Format status for logging.""" if status["count"] == 0: return "0 clients connected" client_info = ", ".join( - f"{c['ip']}:{c['port']} (rx={c['rx']}, tx={c['tx']})" for c in status["clients"] + f"{c['ip']}:{c['port']} (rx={c['rx']}, tx={c['tx']}, idle={c['idle']})" + for c in status["clients"] ) return f"{status['count']} client(s): {client_info}" - async def _run(self): + async def _run(self) -> None: """Run periodic status logging.""" while True: await asyncio.sleep(self._interval) @@ -638,46 +817,60 @@ async def _run(self): logger.info("Status: %s", self._format_status(status)) self._last_status = status - def start(self): + def start(self) -> None: """Start the status logging task.""" if self._interval > 0: self._task = asyncio.create_task(self._run()) - def stop(self): + def stop(self) -> None: """Stop the status logging task.""" if self._task: self._task.cancel() -async def create_server(host=None, port=23, protocol_factory=TelnetServer, **kwds): +async def create_server( # pylint: disable=too-many-positional-arguments + host: Optional[Union[str, Sequence[str]]] = None, + port: int = 23, + protocol_factory: Optional[Type[asyncio.Protocol]] = TelnetServer, + shell: Optional[ShellCallback] = None, + encoding: Union[str, bool] = "utf8", + encoding_errors: str = "strict", + force_binary: bool = False, + never_send_ga: bool = False, + connect_maxwait: float = 4.0, + limit: Optional[int] = None, + term: str = "unknown", + cols: int = 80, + rows: int = 25, + timeout: int = 300, +) -> Server: """ Create a TCP Telnet server. - :param str host: The host parameter can be a string, in that case the TCP + :param host: The host parameter can be a string, in that case the TCP server is bound to host and port. The host parameter can also be a sequence of strings, and in that case the TCP server is bound to all hosts of the sequence. - :param int port: listen port for TCP Server. - :param server_base.BaseServer protocol_factory: An alternate protocol - factory for the server, when unspecified, :class:`TelnetServer` is - used. + :param port: Listen port for TCP server. + :param protocol_factory: An alternate protocol factory for the server. + When unspecified, :class:`TelnetServer` is used. :param shell: An async function that is called after negotiation completes, receiving arguments ``(reader, writer)``. Default is :func:`~.telnet_server_shell`. The reader is a :class:`~.TelnetReader` instance, the writer is a :class:`~.TelnetWriter` instance. - :param str encoding: The default assumed encoding, or ``False`` to disable - unicode support. Encoding may be negotiation to another value by + :param encoding: The default assumed encoding, or ``False`` to disable + unicode support. Encoding may be negotiated to another value by the client through NEW_ENVIRON :rfc:`1572` by sending environment value of ``LANG``, or by any legal value for CHARSET :rfc:`2066` negotiation. The server's attached ``reader, writer`` streams accept and return - unicode, or natural strings, "hello world", unless this value explicitly - set ``False``. In that case, the attached streams interfaces are - bytes-only, b"hello world". - :param str encoding_errors: Same meaning as :meth:`codecs.Codec.encode`. + unicode, or natural strings, "hello world", unless this value is + explicitly set to ``False``. In that case, the attached stream + interfaces are bytes-only, b"hello world". + :param encoding_errors: Same meaning as :meth:`codecs.Codec.encode`. Default value is ``strict``. - :param bool force_binary: When ``True``, the encoding specified is + :param force_binary: When ``True``, the encoding specified is used for both directions even when BINARY mode, :rfc:`856`, is not negotiated for the direction specified. This parameter has no effect when ``encoding=False``. @@ -691,25 +884,24 @@ async def create_server(host=None, port=23, protocol_factory=TelnetServer, **kwd may be no problem at all. If an encoding is assumed, as in many MUD and BBS systems, the combination of ``force_binary`` with a default ``encoding`` is often preferred. - :param str term: Value returned for ``writer.get_extra_info('term')`` + :param term: Value returned for ``writer.get_extra_info('term')`` until negotiated by TTYPE :rfc:`930`, or NAWS :rfc:`1572`. Default value is ``'unknown'``. - :param int cols: Value returned for ``writer.get_extra_info('cols')`` + :param cols: Value returned for ``writer.get_extra_info('cols')`` until negotiated by NAWS :rfc:`1572`. Default value is 80 columns. - :param int rows: Value returned for ``writer.get_extra_info('rows')`` + :param rows: Value returned for ``writer.get_extra_info('rows')`` until negotiated by NAWS :rfc:`1572`. Default value is 25 rows. - :param int timeout: Causes clients to disconnect if idle for this duration, + :param timeout: Causes clients to disconnect if idle for this duration, in seconds. This ensures resources are freed on busy servers. When explicitly set to ``False``, clients will not be disconnected for timeout. Default value is 300 seconds (5 minutes). - :param float connect_maxwait: If the remote end is not complaint, or + :param connect_maxwait: If the remote end is not compliant, or otherwise confused by our demands, the shell continues anyway after the greater of this value has elapsed. A client that is not answering option negotiation will delay the start of the shell by this amount. - :param int limit: The buffer limit for the reader stream. - :param kwds: Additional keyword arguments passed to the protocol factory. + :param limit: The buffer limit for the reader stream. - :return Server: A :class:`Server` instance that wraps the asyncio.Server + :return: A :class:`Server` instance that wraps the asyncio.Server and provides access to connected client protocols via :meth:`Server.wait_for_client` and :attr:`Server.clients`. """ @@ -718,8 +910,34 @@ async def create_server(host=None, port=23, protocol_factory=TelnetServer, **kwd telnet_server = Server(None) - def factory(): - protocol = protocol_factory(**kwds) + def factory() -> asyncio.Protocol: + protocol: asyncio.Protocol + if issubclass(protocol_factory, TelnetServer): + protocol = protocol_factory( + shell=shell, + encoding=encoding, + encoding_errors=encoding_errors, + force_binary=force_binary, + never_send_ga=never_send_ga, + connect_maxwait=connect_maxwait, + limit=limit, + term=term, + cols=cols, + rows=rows, + timeout=timeout, + ) + elif issubclass(protocol_factory, server_base.BaseServer): + protocol = protocol_factory( + shell=shell, + encoding=encoding, + encoding_errors=encoding_errors, + force_binary=force_binary, + never_send_ga=never_send_ga, + connect_maxwait=connect_maxwait, + limit=limit, + ) + else: + protocol = protocol_factory() telnet_server._register_protocol(protocol) # pylint: disable=protected-access return protocol @@ -729,7 +947,7 @@ def factory(): return telnet_server -async def _sigterm_handler(server, _log): +async def _sigterm_handler(server: Server, _log: logging.Logger) -> None: logger.info("SIGTERM received, closing server.") # This signals the completion of the server.wait_closed() Future, @@ -737,11 +955,8 @@ async def _sigterm_handler(server, _log): server.close() -def parse_server_args(): +def parse_server_args() -> Dict[str, Any]: """Parse command-line arguments for telnet server.""" - # std imports - import sys # pylint: disable=import-outside-toplevel - # Extract arguments after '--' for PTY program before argparse sees them argv = sys.argv[1:] pty_args = [] @@ -793,6 +1008,13 @@ def parse_server_args(): default=_config.pty_fork_limit, help="limit concurrent PTY connections (0 disables)", ) + parser.add_argument( + "--pty-raw", + action="store_true", + default=_config.pty_raw, + help="raw mode for --pty-exec: disable PTY echo for programs that " + "handle their own terminal I/O (curses, blessed, ucs-detect)", + ) parser.add_argument( "--robot-check", action="store_true", @@ -809,31 +1031,42 @@ def parse_server_args(): "status only logged when connected clients has changed." ), ) + parser.add_argument( + "--never-send-ga", + action="store_true", + default=_config.never_send_ga, + help="never send IAC GA (Go-Ahead). Default sends GA when SGA is " + "not negotiated, which is correct for MUD clients but may " + "confuse some other clients.", + ) result = vars(parser.parse_args(argv)) result["pty_args"] = pty_args if PTY_SUPPORT else None if not PTY_SUPPORT: result["pty_exec"] = None result["pty_fork_limit"] = 0 + result["pty_raw"] = False return result async def run_server( # pylint: disable=too-many-positional-arguments,too-many-locals - host=_config.host, - port=_config.port, - loglevel=_config.loglevel, - logfile=_config.logfile, - logfmt=_config.logfmt, - shell=_config.shell, - encoding=_config.encoding, - force_binary=_config.force_binary, - timeout=_config.timeout, - connect_maxwait=_config.connect_maxwait, - pty_exec=_config.pty_exec, - pty_args=_config.pty_args, - robot_check=_config.robot_check, - pty_fork_limit=_config.pty_fork_limit, - status_interval=_config.status_interval, -): + host: str = _config.host, + port: int = _config.port, + loglevel: str = _config.loglevel, + logfile: Optional[str] = _config.logfile, + logfmt: str = _config.logfmt, + shell: Callable[..., Any] = _config.shell, + encoding: Union[str, bool] = _config.encoding, + force_binary: bool = _config.force_binary, + timeout: int = _config.timeout, + connect_maxwait: float = _config.connect_maxwait, + pty_exec: Optional[str] = _config.pty_exec, + pty_args: Optional[List[str]] = _config.pty_args, + pty_raw: bool = _config.pty_raw, + robot_check: bool = _config.robot_check, + pty_fork_limit: int = _config.pty_fork_limit, + status_interval: int = _config.status_interval, + never_send_ga: bool = _config.never_send_ga, +) -> None: """ Program entry point for server daemon. @@ -850,7 +1083,7 @@ async def run_server( # pylint: disable=too-many-positional-arguments,too-many- # local from .server_pty_shell import make_pty_shell # pylint: disable=import-outside-toplevel - shell = make_pty_shell(pty_exec, pty_args) + shell = make_pty_shell(pty_exec, pty_args, raw_mode=pty_raw) # Wrap shell with guards if enabled if robot_check or pty_fork_limit: @@ -863,31 +1096,40 @@ async def run_server( # pylint: disable=too-many-positional-arguments,too-many- counter = ConnectionCounter(pty_fork_limit) if pty_fork_limit else None inner_shell = shell - async def guarded_shell(reader, writer): - # Check connection limit first - if counter and not counter.try_acquire(): - try: - await busy_shell(reader, writer) - finally: - if not writer.is_closing(): - writer.close() - return - + async def guarded_shell( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + ) -> None: try: - # Check robot if enabled - if robot_check: - passed = await do_robot_check(reader, writer) - if not passed: - await robot_shell(reader, writer) + # Check connection limit first + if counter and not counter.try_acquire(): + try: + await busy_shell(reader, writer) + finally: if not writer.is_closing(): writer.close() - return + return - # Run actual shell - await inner_shell(reader, writer) - finally: - if counter: - counter.release() + try: + # Check robot if enabled + if robot_check: + passed = await do_robot_check(reader, writer) + if not passed: + await robot_shell(reader, writer) + if not writer.is_closing(): + writer.close() + return + + # Run actual shell + await inner_shell(reader, writer) + finally: + if counter: + counter.release() + except (ConnectionResetError, BrokenPipeError, EOFError): + logger.debug( + "Connection lost in guarded_shell: %s", + writer.get_extra_info("peername", "unknown"), + ) shell = guarded_shell @@ -905,6 +1147,7 @@ async def guarded_shell(reader, writer): shell=shell, encoding=encoding, force_binary=force_binary, + never_send_ga=never_send_ga, timeout=timeout, connect_maxwait=connect_maxwait, ) @@ -933,10 +1176,10 @@ async def guarded_shell(reader, writer): logger.info("Server stop.") -def main(): +def main() -> None: """Entry point for telnetlib3-server command.""" asyncio.run(run_server(**parse_server_args())) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/telnetlib3/server_base.py b/telnetlib3/server_base.py index 46c81d12..527ca9b0 100644 --- a/telnetlib3/server_base.py +++ b/telnetlib3/server_base.py @@ -1,13 +1,18 @@ """Module provides class BaseServer.""" +from __future__ import annotations + # std imports import sys +import types import asyncio import logging import datetime import traceback +from typing import Any, Type, Union, Callable, Optional # local +from ._types import ShellCallback from .telopt import theNULL from .stream_reader import TelnetReader, TelnetReaderUnicode from .stream_writer import TelnetWriter, TelnetWriterUnicode @@ -24,8 +29,8 @@ class BaseServer(asyncio.streams.FlowControlMixin, asyncio.Protocol): """Base Telnet Server Protocol.""" - _when_connected = None - _last_received = None + _when_connected: Optional[datetime.datetime] = None + _last_received: Optional[datetime.datetime] = None _transport = None _advanced = False _closing = False @@ -35,24 +40,26 @@ class BaseServer(asyncio.streams.FlowControlMixin, asyncio.Protocol): def __init__( # pylint: disable=too-many-positional-arguments self, - shell=None, - _waiter_connected=None, - encoding="utf8", - encoding_errors="strict", - force_binary=False, - connect_maxwait=4.0, - limit=None, - reader_factory=TelnetReader, - reader_factory_encoding=TelnetReaderUnicode, - writer_factory=TelnetWriter, - writer_factory_encoding=TelnetWriterUnicode, - ): + shell: Optional[ShellCallback] = None, + _waiter_connected: Optional[asyncio.Future[None]] = None, + encoding: Union[str, bool] = "utf8", + encoding_errors: str = "strict", + force_binary: bool = False, + never_send_ga: bool = False, + connect_maxwait: float = 4.0, + limit: Optional[int] = None, + reader_factory: type = TelnetReader, + reader_factory_encoding: type = TelnetReaderUnicode, + writer_factory: type = TelnetWriter, + writer_factory_encoding: type = TelnetWriterUnicode, + ) -> None: """Class initializer.""" super().__init__() self.default_encoding = encoding self._encoding_errors = encoding_errors self.force_binary = force_binary - self._extra = {} + self.never_send_ga = never_send_ga + self._extra: dict[str, Any] = {} self._reader_factory = reader_factory self._reader_factory_encoding = reader_factory_encoding @@ -61,22 +68,24 @@ def __init__( # pylint: disable=too-many-positional-arguments #: a future used for testing self._waiter_connected = _waiter_connected or asyncio.Future() - self._tasks = [self._waiter_connected] + self._tasks: list[Any] = [self._waiter_connected] self.shell = shell - self.reader = None - self.writer = None + self.reader: Optional[Union[TelnetReader, TelnetReaderUnicode]] = None + self.writer: Optional[Union[TelnetWriter, TelnetWriterUnicode]] = None #: maximum duration for :meth:`check_negotiation`. self.connect_maxwait = connect_maxwait self._limit = limit - def timeout_connection(self): + def timeout_connection(self) -> None: """Close the connection due to timeout.""" + assert self.reader is not None + assert self.writer is not None self.reader.feed_eof() self.writer.close() # Base protocol methods - def eof_received(self): + def eof_received(self) -> None: """ Called when the other end calls write_eof() or equivalent. @@ -85,15 +94,16 @@ def eof_received(self): logger.debug("EOF from client, closing.") self.connection_lost(None) - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[Exception]) -> None: """ Called when the connection is lost or closed. - :param Exception exc: exception. ``None`` indicates close by EOF. + :param exc: Exception instance, or ``None`` to indicate close by EOF. """ if self._closing: return self._closing = True + assert self.reader is not None # inform yielding readers about closed connection if exc is None: @@ -132,7 +142,7 @@ def connection_lost(self, exc): # for inspection by tests after close. self._transport = None - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Called when a connection is made. @@ -147,8 +157,8 @@ def connection_made(self, transport): reader_factory = self._reader_factory writer_factory = self._writer_factory - reader_kwds = {} - writer_kwds = {} + reader_kwds: dict[str, Any] = {} + writer_kwds: dict[str, Any] = {} if self.default_encoding: reader_kwds["fn_encoding"] = self.encoding @@ -176,18 +186,19 @@ def connection_made(self, transport): self._waiter_connected.add_done_callback(self.begin_shell) asyncio.get_event_loop().call_soon(self.begin_negotiation) - def begin_shell(self, future): + def begin_shell(self, future: asyncio.Future[None]) -> None: """Start the shell coroutine after negotiation completes.""" # Don't start shell if the connection was cancelled or errored if future.cancelled() or future.exception() is not None: return if self.shell is not None: + assert self.reader is not None and self.writer is not None coro = self.shell(self.reader, self.writer) if asyncio.iscoroutine(coro): loop = asyncio.get_event_loop() loop.create_task(coro) - def data_received(self, data): + def data_received(self, data: bytes) -> None: """ Process bytes received by transport. @@ -206,6 +217,8 @@ def data_received(self, data): # self._last_received = datetime.datetime.now() self._rx_bytes += len(data) + assert self.writer is not None + assert self.reader is not None writer = self.writer reader = self.reader @@ -220,7 +233,7 @@ def data_received(self, data): n = len(data) i = 0 out_start = 0 - feeding_oob = False + feeding_oob = bool(writer.is_oob) while i < n: if not feeding_oob: @@ -267,38 +280,40 @@ def data_received(self, data): # public properties @property - def duration(self): + def duration(self) -> float: """Time elapsed since client connected, in seconds as float.""" + assert self._when_connected is not None return (datetime.datetime.now() - self._when_connected).total_seconds() @property - def idle(self): + def idle(self) -> float: """Time elapsed since data last received, in seconds as float.""" + assert self._last_received is not None return (datetime.datetime.now() - self._last_received).total_seconds() @property - def rx_bytes(self): + def rx_bytes(self) -> int: """Total bytes received from client.""" return self._rx_bytes @property - def tx_bytes(self): + def tx_bytes(self) -> int: """Total bytes sent to client.""" return self._tx_bytes # public protocol methods - def __repr__(self): + def __repr__(self) -> str: hostport = self.get_extra_info("peername", ["-", "closing"])[:2] return f"" - def get_extra_info(self, name, default=None): + def get_extra_info(self, name: str, default: Any = None) -> Any: """Get optional server protocol or transport information.""" if self._transport: default = self._transport.get_extra_info(name, default) return self._extra.get(name, default) - def begin_negotiation(self): + def begin_negotiation(self) -> None: """ Begin on-connect negotiation. @@ -309,7 +324,7 @@ def begin_negotiation(self): self._check_later = asyncio.get_event_loop().call_soon(self._check_negotiation_timer) self._tasks.append(self._check_later) - def begin_advanced_negotiation(self): + def begin_advanced_negotiation(self) -> None: """ Begin advanced negotiation. @@ -322,7 +337,7 @@ def begin_advanced_negotiation(self): at least one negotiation option to be affirmatively acknowledged. """ - def encoding(self, outgoing=False, incoming=False): + def encoding(self, outgoing: bool = False, incoming: bool = False) -> Union[str, bool]: """ Encoding that should be used for the direction indicated. @@ -332,12 +347,11 @@ def encoding(self, outgoing=False, incoming=False): # pylint: disable=unused-argument return self.default_encoding or "US-ASCII" - def negotiation_should_advance(self): + def negotiation_should_advance(self) -> bool: """ Whether advanced negotiation should commence. - :rtype: bool - :returns: True if advanced negotiation should be permitted. + :returns: ``True`` if advanced negotiation should be permitted. The base implementation returns True if any negotiation options were affirmatively acknowledged by client, more than likely @@ -345,18 +359,18 @@ def negotiation_should_advance(self): """ # Generally, this separates a bare TCP connect() from a True # RFC-compliant telnet client with responding IAC interpreter. + assert self.writer is not None server_do = sum(enabled for _, enabled in self.writer.remote_option.items()) client_will = sum(enabled for _, enabled in self.writer.local_option.items()) return bool(server_do or client_will) - def check_negotiation(self, final=False): # pylint: disable=unused-argument + def check_negotiation(self, final: bool = False) -> bool: # pylint: disable=unused-argument """ Callback, return whether negotiation is complete. - :param bool final: Whether this is the final time this callback + :param final: Whether this is the final time this callback will be requested to answer regarding protocol negotiation. :returns: Whether negotiation is over (server end is satisfied). - :rtype: bool Method is called on each new command byte processed until negotiation is considered final, or after ``connect_maxwait`` has elapsed, setting @@ -372,13 +386,16 @@ def check_negotiation(self, final=False): # pylint: disable=unused-argument # negotiation is complete (returns True) when all negotiation options # that have been requested have been acknowledged. + assert self.writer is not None return not any(self.writer.pending_option.values()) # private methods - def _check_negotiation_timer(self): - self._check_later.cancel() - self._tasks.remove(self._check_later) + def _check_negotiation_timer(self) -> None: + if self._check_later is not None: + self._check_later.cancel() + if self._check_later in self._tasks: + self._tasks.remove(self._check_later) later = self.connect_maxwait - self.duration final = bool(later < 0) @@ -397,7 +414,12 @@ def _check_negotiation_timer(self): self._tasks.append(self._check_later) @staticmethod - def _log_exception(log, e_type, e_value, e_tb): + def _log_exception( + log: Callable[..., Any], + e_type: Optional[Type[BaseException]], + e_value: Optional[BaseException], + e_tb: Optional[types.TracebackType], + ) -> None: rows_tbk = [line for line in "\n".join(traceback.format_tb(e_tb)).split("\n") if line] rows_exc = [line.rstrip() for line in traceback.format_exception_only(e_type, e_value)] diff --git a/telnetlib3/server_pty_shell.py b/telnetlib3/server_pty_shell.py index 1a0821d1..c480a615 100644 --- a/telnetlib3/server_pty_shell.py +++ b/telnetlib3/server_pty_shell.py @@ -5,28 +5,37 @@ each telnet connection, with proper terminal negotiation forwarding. """ +from __future__ import annotations + # std imports import os -import pty import sys import time -import errno -import fcntl import codecs -import signal import struct import asyncio import logging -import termios +from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Awaitable, cast # local -from .telopt import NAWS +from .telopt import ECHO, NAWS, WONT +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode __all__ = ("make_pty_shell", "pty_shell", "PTYSpawnError") # Delay between termination signals (seconds) _TERMINATE_DELAY = 0.1 +# Debounce delay for NAWS updates (seconds) +_NAWS_DEBOUNCE = 0.2 + +# Idle delay before sending IAC GA (seconds) +_GA_IDLE = 0.5 + +# Polling interval for _wait_for_terminal_info (seconds) +_TERMINAL_INFO_POLL = 0.05 + class PTYSpawnError(Exception): """Raised when PTY child process fails to exec.""" @@ -40,7 +49,7 @@ class PTYSpawnError(Exception): _ESU = b"\x1b[?2026l" # End Synchronized Update -def _platform_check(): +def _platform_check() -> None: """Verify platform supports PTY operations.""" if sys.platform == "win32": raise NotImplementedError("PTY support is not available on Windows") @@ -49,7 +58,16 @@ def _platform_check(): class PTYSession: """Manages a PTY session lifecycle.""" - def __init__(self, reader, writer, program, args, *, preexec_fn=None): + def __init__( + self, + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + program: str, + args: Optional[List[str]], + *, + preexec_fn: Optional[Callable[[], None]] = None, + raw_mode: bool = False, + ) -> None: """ Initialize PTY session. @@ -60,28 +78,36 @@ def __init__(self, reader, writer, program, args, *, preexec_fn=None): :param preexec_fn: Optional callable to run in child before exec. Called with no arguments after fork but before _setup_child. Useful for test coverage tracking in the forked child process. + :param raw_mode: If True, disable PTY echo and canonical mode. Use for programs that handle + their own terminal I/O (e.g., blessed, curses, ucs-detect). """ self.reader = reader self.writer = writer self.program = program self.args = args or [] self.preexec_fn = preexec_fn - self.master_fd = None - self.child_pid = None + self.raw_mode = raw_mode + self.master_fd: Optional[int] = None + self.child_pid: Optional[int] = None self._closing = False self._output_buffer = b"" self._in_sync_update = False - self._decoder = None - self._decoder_charset = None - self._naws_pending = None - self._naws_timer = None + self._decoder: Optional[codecs.IncrementalDecoder] = None + self._decoder_charset: Optional[str] = None + self._naws_pending: Optional[Tuple[int, int]] = None + self._naws_timer: Optional[asyncio.TimerHandle] = None + self._ga_timer: Optional[asyncio.TimerHandle] = None - def start(self): + def start(self) -> None: """ Fork PTY, configure environment, and exec program. :raises PTYSpawnError: If the child process fails to exec. """ + # std imports + import pty + import fcntl + _platform_check() env = self._build_environment() @@ -128,14 +154,14 @@ def start(self): if pid: logger.warning("child already exited: status=%d", status) - def _write_exec_error(self, pipe_fd, exc): + def _write_exec_error(self, pipe_fd: int, exc: Exception) -> None: """Write exception info to pipe for parent to read.""" ename = type(exc).__name__ msg = f"{ename}:{getattr(exc, 'errno', 0)}:{exc}" os.write(pipe_fd, msg.encode("utf-8", errors="replace")) os.close(pipe_fd) - def _handle_exec_error(self, data): + def _handle_exec_error(self, data: bytes) -> None: """Parse exec error from child and raise appropriate exception.""" try: parts = data.decode("utf-8", errors="replace").split(":", 2) @@ -148,7 +174,7 @@ def _handle_exec_error(self, data): except Exception as exc: raise PTYSpawnError(f"Exec failed: {data!r}") from exc - def _build_environment(self): + def _build_environment(self) -> Dict[str, str]: """Build environment dict from negotiated values.""" env = os.environ.copy() @@ -180,27 +206,47 @@ def _build_environment(self): return env - def _get_window_size(self): + def _get_window_size(self) -> Tuple[int, int]: """Get window size from negotiated values.""" - rows = self.writer.get_extra_info("rows", 25) - cols = self.writer.get_extra_info("cols", 80) + rows: int = self.writer.get_extra_info("rows", 25) + cols: int = self.writer.get_extra_info("cols", 80) return rows, cols - def _setup_child(self, env, rows, cols, exec_err_pipe, *, child_cov=None): + def _setup_child( + self, + env: Dict[str, str], + rows: int, + cols: int, + exec_err_pipe: int, + *, + child_cov: Any = None, + ) -> None: """Child process setup before exec.""" # Note: pty.fork() already calls setsid() for the child, so we don't need to + # std imports + import fcntl + import termios if rows and cols: winsize = struct.pack("HHHH", rows, cols, 0, 0) fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize) - # Configure PTY for telnet's character-at-a-time mode (WILL SGA, WILL ECHO). - # Disable local echo and canonical mode, but keep output processing so - # newlines are translated to CR-LF properly. attrs = termios.tcgetattr(sys.stdin.fileno()) - # c_lflag: disable ECHO (telnet handles echo) and ICANON (char-at-a-time) - attrs[3] &= ~(termios.ECHO | termios.ICANON) - # Keep c_oflag intact - OPOST and ONLCR translate \n to \r\n + + if self.raw_mode: + # Raw mode: disable echo and canonical mode for programs that handle + # their own terminal I/O (blessed, curses, ucs-detect). This prevents + # terminal responses from being echoed back through the PTY. + attrs[3] &= ~(termios.ECHO | termios.ICANON) + else: + # Normal mode: Keep ECHO and ICANON enabled for proper input() + # behavior. We sent WONT ECHO to the client, so the PTY handles echo + # with proper output translation (ONLCR: \n → \r\n). + pass + + # Set VERASE to ^H (0x08) since many telnet clients send ^H for backspace + # (default PTY ERASE is often ^? which won't work for those clients). + attrs[6][termios.VERASE] = 8 # ^H termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, attrs) # Save coverage data before exec replaces the process @@ -215,26 +261,30 @@ def _setup_child(self, env, rows, cols, exec_err_pipe, *, child_cov=None): self._write_exec_error(exec_err_pipe, err) os._exit(os.EX_OSERR) - def _setup_parent(self): + def _setup_parent(self) -> None: """Parent process setup after fork.""" + # std imports + import fcntl + + assert self.master_fd is not None flags = fcntl.fcntl(self.master_fd, fcntl.F_GETFL) fcntl.fcntl(self.master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) self.writer.set_ext_callback(NAWS, self._on_naws) - def _on_naws(self, rows, cols): + def _on_naws(self, rows: int, cols: int) -> None: """Handle NAWS updates by resizing PTY with debouncing.""" self.writer.protocol.on_naws(rows, cols) self._schedule_naws_update(rows, cols) - def _schedule_naws_update(self, rows, cols): + def _schedule_naws_update(self, rows: int, cols: int) -> None: """Schedule debounced NAWS update to avoid signal storms during rapid resize.""" self._naws_pending = (rows, cols) if self._naws_timer is not None: self._naws_timer.cancel() loop = asyncio.get_event_loop() - self._naws_timer = loop.call_later(0.2, self._fire_naws_update) + self._naws_timer = loop.call_later(_NAWS_DEBOUNCE, self._fire_naws_update) - def _fire_naws_update(self): + def _fire_naws_update(self) -> None: """Fire the pending NAWS update after debounce delay.""" if self._naws_pending is not None: rows, cols = self._naws_pending @@ -242,8 +292,13 @@ def _fire_naws_update(self): self._naws_timer = None self._set_window_size(rows, cols) - def _set_window_size(self, rows, cols): + def _set_window_size(self, rows: int, cols: int) -> None: """Set PTY window size and send SIGWINCH to child.""" + # std imports + import fcntl + import signal + import termios + if self.master_fd is None or self.child_pid is None: return winsize = struct.pack("HHHH", rows, cols, 0, 0) @@ -253,26 +308,33 @@ def _set_window_size(self, rows, cols): except ProcessLookupError: pass - async def run(self): + async def run(self) -> None: """Bridge loop between telnet and PTY.""" + # std imports + import errno + loop = asyncio.get_event_loop() pty_read_event = asyncio.Event() - pty_data_queue = asyncio.Queue() + pty_data_queue: asyncio.Queue[bytes] = asyncio.Queue() + assert self.child_pid is not None + assert self.master_fd is not None pid, _ = os.waitpid(self.child_pid, os.WNOHANG) if pid: return - def pty_readable(): + master_fd = self.master_fd + + def pty_readable() -> None: """Callback when PTY has data to read.""" # Drain available data to reduce tearing, but cap at 256KB to avoid # buffering forever on continuous output (e.g., cat large_file) - chunks = [] + chunks: list[bytes] = [] total = 0 max_batch = 262144 # 256KB while total < max_batch: try: - data = os.read(self.master_fd, 65536) + data = os.read(master_fd, 65536) if data: chunks.append(data) total += len(data) @@ -292,21 +354,29 @@ def pty_readable(): pty_data_queue.put_nowait(b"".join(chunks)) pty_read_event.set() - loop.add_reader(self.master_fd, pty_readable) + loop.add_reader(master_fd, pty_readable) try: await self._bridge_loop(pty_read_event, pty_data_queue) finally: try: - loop.remove_reader(self.master_fd) + loop.remove_reader(master_fd) except (ValueError, KeyError): pass - async def _bridge_loop(self, pty_read_event, pty_data_queue): + async def _bridge_loop( + self, + pty_read_event: asyncio.Event, + pty_data_queue: asyncio.Queue[bytes], + ) -> None: """Main bridge loop transferring data between telnet and PTY.""" while not self._closing and not self.writer.is_closing(): - telnet_task = asyncio.create_task(self.reader.read(4096)) - pty_task = asyncio.create_task(pty_read_event.wait()) + telnet_task: asyncio.Task[Union[bytes, str]] = asyncio.create_task( + self.reader.read(4096) + ) + pty_task: asyncio.Task[bool] = asyncio.create_task( + pty_read_event.wait() + ) done, pending = await asyncio.wait( {telnet_task, pty_task}, @@ -320,43 +390,51 @@ async def _bridge_loop(self, pty_read_event, pty_data_queue): except asyncio.CancelledError: pass - for task in done: - try: - if task is telnet_task: - data = task.result() - if data: - self._write_to_pty(data) - else: - self._closing = True - break - - elif task is pty_task: - task.result() - while not pty_data_queue.empty(): - data = pty_data_queue.get_nowait() - self._write_to_telnet(data) - # EAGAIN was hit - flush any remaining partial line - self._flush_remaining() - pty_read_event.clear() - except Exception as e: # pylint: disable=broad-exception-caught - logger.debug("bridge loop error: %s", e) - self._closing = True - break + try: + if telnet_task in done: + telnet_data = telnet_task.result() + if telnet_data: + self._write_to_pty(telnet_data) + else: + self._closing = True + continue + + if pty_task in done: + pty_task.result() + while not pty_data_queue.empty(): + pty_data = pty_data_queue.get_nowait() + self._write_to_telnet(pty_data) + # EAGAIN was hit - flush any remaining partial line + self._flush_remaining() + pty_read_event.clear() + except Exception as e: # pylint: disable=broad-exception-caught + logger.debug("bridge loop error: %s", e) + self._closing = True + break + + def _write_to_pty(self, data: Union[str, bytes]) -> None: + """ + Write data from telnet to PTY. - def _write_to_pty(self, data): - """Write data from telnet to PTY.""" + Translates DEL (0x7F) to ``^H`` (0x08) so that both backspace encodings work with the PTY's + VERASE setting (``^H``). + """ if self.master_fd is None: return if isinstance(data, str): charset = self.writer.get_extra_info("charset") or "utf-8" data = data.encode(charset, errors="replace") + data = data.replace(b"\x7f", b"\x08") try: os.write(self.master_fd, data) except OSError: self._closing = True - def _write_to_telnet(self, data): + def _write_to_telnet(self, data: bytes) -> None: """Write data from PTY to telnet, respecting synchronized update boundaries.""" + if self._ga_timer is not None: + self._ga_timer.cancel() + self._ga_timer = None self._output_buffer += data # Process buffer, flushing on ESU or newline boundaries @@ -396,7 +474,7 @@ def _write_to_telnet(self, data): # next sync boundary, or when more data arrives with EAGAIN) break - def _flush_output(self, data, final=False): + def _flush_output(self, data: bytes, final: bool = False) -> None: """Send data to telnet client using incremental decoder.""" if not data: return @@ -410,15 +488,34 @@ def _flush_output(self, data, final=False): # Decode using incremental decoder - it buffers incomplete sequences text = self._decoder.decode(data, final) if text: - self.writer.write(text) + cast(TelnetWriterUnicode, self.writer).write(text) - def _flush_remaining(self): + def _flush_remaining(self) -> None: """Flush remaining buffer after EAGAIN (partial lines, prompts, etc.).""" if self._output_buffer and not self._in_sync_update: self._flush_output(self._output_buffer) self._output_buffer = b"" + self._schedule_ga() + + def _schedule_ga(self) -> None: + """Schedule IAC GA after 500ms idle, for clients that refuse SGA.""" + if self._ga_timer is not None: + self._ga_timer.cancel() + self._ga_timer = None + if self.raw_mode: + return + if getattr(self.writer.protocol, "never_send_ga", False): + return + loop = asyncio.get_event_loop() + self._ga_timer = loop.call_later(_GA_IDLE, self._fire_ga) - def _isalive(self): + def _fire_ga(self) -> None: + """Send IAC GA if writer is still open.""" + self._ga_timer = None + if not self.writer.is_closing(): + self.writer.send_ga() + + def _isalive(self) -> bool: """Check if child process is still running.""" if self.child_pid is None: return False @@ -428,7 +525,7 @@ def _isalive(self): except ChildProcessError: return False - def _terminate(self, force=False): + def _terminate(self, force: bool = False) -> bool: """ Terminate child with signal escalation (ptyprocess pattern). @@ -437,9 +534,13 @@ def _terminate(self, force=False): :param force: If True, use SIGKILL as last resort. :returns: True if child was terminated, False otherwise. """ + # std imports + import signal + if not self._isalive(): return True + assert self.child_pid is not None signals = [signal.SIGHUP, signal.SIGCONT, signal.SIGINT] if force: signals.append(signal.SIGKILL) @@ -455,9 +556,12 @@ def _terminate(self, force=False): return not self._isalive() - def cleanup(self): + def cleanup(self) -> None: """Kill child process and close PTY fd.""" - # Cancel any pending NAWS timer + # Cancel any pending timers + if self._ga_timer is not None: + self._ga_timer.cancel() + self._ga_timer = None if self._naws_timer is not None: self._naws_timer.cancel() self._naws_timer = None @@ -484,7 +588,10 @@ def cleanup(self): self.child_pid = None -async def _wait_for_terminal_info(writer, timeout=2.0): +async def _wait_for_terminal_info( + writer: Union[TelnetWriter, TelnetWriterUnicode], + timeout: float = 2.0, +) -> None: """ Wait for TERM and window size to be negotiated. @@ -499,24 +606,43 @@ async def _wait_for_terminal_info(writer, timeout=2.0): rows = writer.get_extra_info("rows") if term and rows: return - await asyncio.sleep(0.1) + await asyncio.sleep(_TERMINAL_INFO_POLL) -async def pty_shell(reader, writer, program, args=None, preexec_fn=None): +async def pty_shell( # pylint: disable=too-many-positional-arguments + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + program: str, + args: Optional[List[str]] = None, + preexec_fn: Optional[Callable[[], None]] = None, + raw_mode: bool = False, +) -> None: """ PTY shell callback for telnet server. - :param TelnetReader reader: TelnetReader instance. - :param TelnetWriter writer: TelnetWriter instance. - :param str program: Path to program to execute. - :param list args: List of arguments for the program. + :param reader: TelnetReader instance. + :param writer: TelnetWriter instance. + :param program: Path to program to execute. + :param args: List of arguments for the program. :param preexec_fn: Optional callable to run in child before exec. + :param raw_mode: If True, disable PTY echo and canonical mode. Use for programs that handle + their own terminal I/O (e.g., blessed, curses, ucs-detect). """ _platform_check() await _wait_for_terminal_info(writer, timeout=2.0) - session = PTYSession(reader, writer, program, args, preexec_fn=preexec_fn) + # Echo handling depends on raw_mode: + # - Normal mode: Send WONT ECHO so client does local echo, PTY handles + # echo with proper ONLCR translation (\n → \r\n) for input() display. + # - Raw mode: Keep WILL ECHO so client doesn't local-echo, but PTY echo + # is disabled. This prevents terminal responses (CPR, etc.) from being + # echoed back. The program handles its own output. + if not raw_mode and writer.will_echo: + writer.iac(WONT, ECHO) + await writer.drain() + + session = PTYSession(reader, writer, program, args, preexec_fn=preexec_fn, raw_mode=raw_mode) try: session.start() await session.run() @@ -526,14 +652,27 @@ async def pty_shell(reader, writer, program, args=None, preexec_fn=None): writer.close() -def make_pty_shell(program, args=None, preexec_fn=None): +def make_pty_shell( + program: str, + args: Optional[List[str]] = None, + preexec_fn: Optional[Callable[[], None]] = None, + raw_mode: bool = False, +) -> Callable[ + [ + Union[TelnetReader, TelnetReaderUnicode], + Union[TelnetWriter, TelnetWriterUnicode], + ], + Awaitable[None], +]: """ Factory returning a shell callback for PTY execution. - :param str program: Path to program to execute. - :param list args: List of arguments for the program. + :param program: Path to program to execute. + :param args: List of arguments for the program. :param preexec_fn: Optional callable to run in child before exec. Useful for test coverage tracking in the forked child process. + :param raw_mode: If True, disable PTY echo and canonical mode. Use for programs + that handle their own terminal I/O (e.g., blessed, curses, ucs-detect). :returns: Async shell callback suitable for use with create_server(). Example usage:: @@ -547,7 +686,10 @@ def make_pty_shell(program, args=None, preexec_fn=None): ) """ - async def shell(reader, writer): - await pty_shell(reader, writer, program, args, preexec_fn=preexec_fn) + async def shell( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], + ) -> None: + await pty_shell(reader, writer, program, args, preexec_fn=preexec_fn, raw_mode=raw_mode) return shell diff --git a/telnetlib3/server_shell.py b/telnetlib3/server_shell.py index 34a645de..1bb1944d 100644 --- a/telnetlib3/server_shell.py +++ b/telnetlib3/server_shell.py @@ -1,17 +1,25 @@ """Telnet server shell implementations.""" +from __future__ import annotations + # std imports import types import asyncio +from typing import Union, Optional, Generator, cast # local from . import slc, telopt, accessories +from .stream_reader import TelnetReader, TelnetReaderUnicode +from .stream_writer import TelnetWriter, TelnetWriterUnicode CR, LF, NUL = ("\r", "\n", "\x00") ESC = "\x1b" -async def filter_ansi(reader, _writer): +async def filter_ansi( + reader: TelnetReaderUnicode, + _writer: TelnetWriterUnicode, +) -> str: """ Read and return the next non-ANSI-escape character from reader. @@ -38,17 +46,20 @@ async def filter_ansi(reader, _writer): __all__ = ("telnet_server_shell",) -async def telnet_server_shell( - reader, writer -): # pylint: disable=too-complex,too-many-branches,too-many-statements +async def telnet_server_shell( # pylint: disable=too-complex,too-many-branches,too-many-statements + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> None: """ A default telnet shell, appropriate for use with telnetlib3.create_server. This shell provides a very simple REPL, allowing introspection and state toggling of the connected client session. """ - linereader = readline(reader, writer) - linereader.send(None) + _reader = cast(TelnetReaderUnicode, reader) + writer = cast(TelnetWriterUnicode, writer) + linereader = readline(_reader, writer) + next(linereader) writer.write("Ready." + CR + LF) @@ -57,12 +68,14 @@ async def telnet_server_shell( if command: writer.write(CR + LF) writer.write("tel:sh> ") + if not getattr(writer.protocol, "never_send_ga", False): + writer.send_ga() await writer.drain() command = None while command is None: await writer.drain() - inp = await reader.read(1) + inp = await _reader.read(1) if not inp: # close/eof by client at prompt return @@ -117,9 +130,8 @@ async def telnet_server_shell( do_close = command.split()[4].lower() == "close" except IndexError: do_close = False - writer.write( - f"kb_limit={kb_limit}, delay={delay}, drain={drain}, do_close={do_close}:\r\n" - ) + msg = f"kb_limit={kb_limit}, delay={delay}, drain={drain}, do_close={do_close}:\r\n" + writer.write(msg) for lineout in character_dump(kb_limit): if writer.is_closing(): break @@ -138,7 +150,7 @@ async def telnet_server_shell( writer.close() -def character_dump(kb_limit): +def character_dump(kb_limit: int) -> Generator[str, None, None]: """Generate character dump output up to kb_limit kilobytes.""" num_bytes = 0 while (num_bytes) < (kb_limit * 1024): @@ -149,11 +161,15 @@ def character_dump(kb_limit): yield "\033[1G" + "wrote " + str(num_bytes) + " bytes" -async def get_next_ascii(reader, writer): +async def get_next_ascii( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Optional[str]: """Accept the next non-ANSI-escape character from reader.""" + _reader = cast(TelnetReaderUnicode, reader) escape_sequence = False while not writer.is_closing(): - next_char = await reader.read(1) + next_char = await _reader.read(1) if next_char == "\x1b": escape_sequence = True elif escape_sequence: @@ -165,7 +181,10 @@ async def get_next_ascii(reader, writer): @types.coroutine -def readline(_reader, writer): +def readline( + _reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Generator[Optional[str], str, None]: """ A very crude readline coroutine interface. @@ -173,6 +192,7 @@ def readline(_reader, writer): designed for Python 3.4 and remains here for compatibility, superseded by :func:`~.readline2` """ + _writer = cast(TelnetWriterUnicode, writer) command, inp, last_inp = "", "", "" inp = yield None while True: @@ -190,19 +210,22 @@ def readline(_reader, writer): # backspace over input if command: command = command[:-1] - writer.echo("\b \b") + _writer.echo("\b \b") last_inp = inp inp = yield None else: # buffer and echo input command += inp - writer.echo(inp) + _writer.echo(inp) last_inp = inp inp = yield None -async def readline2(reader, writer): +async def readline2( + reader: Union[TelnetReader, TelnetReaderUnicode], + writer: Union[TelnetWriter, TelnetWriterUnicode], +) -> Optional[str]: """ Async readline interface that filters ANSI escape sequences. @@ -212,9 +235,11 @@ async def readline2(reader, writer): However, this function does not handle all possible types of carriage returns and so it is not used by default shell, :func:`telnet_server_shell`. """ + _reader = cast(TelnetReaderUnicode, reader) + _writer = cast(TelnetWriterUnicode, writer) command = "" while True: - next_char = await filter_ansi(reader, writer) + next_char = await filter_ansi(_reader, _writer) if next_char == CR: return command @@ -226,17 +251,17 @@ async def readline2(reader, writer): # backspace over input if len(command) > 0: command = command[:-1] - writer.echo("\b \b") + _writer.echo("\b \b") elif not next_char: return None else: command += next_char - writer.echo(next_char) + _writer.echo(next_char) -def get_slcdata(writer): +def get_slcdata(writer: Union[TelnetWriter, TelnetWriterUnicode]) -> str: """Display Special Line Editing (SLC) characters.""" _slcs = sorted( [ @@ -270,7 +295,10 @@ def get_slcdata(writer): ) -def do_toggle(writer, option): +def do_toggle( + writer: Union[TelnetWriter, TelnetWriterUnicode], + option: Optional[str], +) -> str: """Display or toggle telnet session parameters.""" tbl_opt = { "echo": writer.local_option.enabled(telopt.ECHO), diff --git a/telnetlib3/slc.py b/telnetlib3/slc.py index 3432c8cd..453bccea 100644 --- a/telnetlib3/slc.py +++ b/telnetlib3/slc.py @@ -1,5 +1,10 @@ """Special Line Character support for Telnet Linemode Option (:rfc:`1184`).""" +from __future__ import annotations + +# std imports +from typing import Any, Dict, List, Tuple, Union, Callable, Optional + # local from .telopt import theNULL from .accessories import eightbits, name_unicode @@ -37,6 +42,7 @@ "SLC_VARIABLE", "SLC_XON", "snoop", + "theNULL", ) SLC_NOSUPPORT, SLC_CANTCHANGE, SLC_VARIABLE, SLC_DEFAULT = ( @@ -91,7 +97,7 @@ class SLC: """Defines the willingness to support a Special Linemode Character.""" - def __init__(self, mask=SLC_DEFAULT, value=theNULL): + def __init__(self, mask: bytes = SLC_DEFAULT, value: bytes = theNULL) -> None: """ Initialize SLC with the given mask and value. @@ -107,61 +113,61 @@ def __init__(self, mask=SLC_DEFAULT, value=theNULL): self.val = value @property - def level(self): + def level(self) -> bytes: """Returns SLC level of support.""" return bytes([ord(self.mask) & SLC_LEVELBITS]) @property - def nosupport(self): + def nosupport(self) -> bool: """Returns True if SLC level is SLC_NOSUPPORT.""" return self.level == SLC_NOSUPPORT @property - def cantchange(self): + def cantchange(self) -> bool: """Returns True if SLC level is SLC_CANTCHANGE.""" return self.level == SLC_CANTCHANGE @property - def variable(self): + def variable(self) -> bool: """Returns True if SLC level is SLC_VARIABLE.""" return self.level == SLC_VARIABLE @property - def default(self): + def default(self) -> bool: """Returns True if SLC level is SLC_DEFAULT.""" return self.level == SLC_DEFAULT @property - def ack(self): + def ack(self) -> int: """Returns True if SLC_ACK bit is set.""" return ord(self.mask) & ord(SLC_ACK) @property - def flushin(self): + def flushin(self) -> int: """Returns True if SLC_FLUSHIN bit is set.""" return ord(self.mask) & ord(SLC_FLUSHIN) @property - def flushout(self): - """Returns True if SLC_FLUSHIN bit is set.""" + def flushout(self) -> int: + """Returns True if SLC_FLUSHOUT bit is set.""" return ord(self.mask) & ord(SLC_FLUSHOUT) - def set_value(self, value): + def set_value(self, value: bytes) -> None: """Set SLC keyboard ascii value to ``byte``.""" assert isinstance(value, bytes) and len(value) == 1, value self.val = value - def set_mask(self, mask): + def set_mask(self, mask: bytes) -> None: """Set SLC option mask, ``mask``.""" assert isinstance(mask, bytes) and len(mask) == 1 self.mask = mask - def set_flag(self, flag): + def set_flag(self, flag: bytes) -> None: """Set SLC option flag, ``flag``.""" assert isinstance(flag, bytes) and len(flag) == 1 self.mask = bytes([ord(self.mask) | ord(flag)]) - def __str__(self): + def __str__(self) -> str: """SLC definition as string '(value, flag(|s))'.""" flags = [] for flag in ( @@ -175,14 +181,16 @@ def __str__(self): ): if getattr(self, flag): flags.append(flag) - value_str = name_unicode(self.val) if self.val != _POSIX_VDISABLE else "(DISABLED:\\xff)" + value_str = ( + name_unicode(chr(self.val[0])) if self.val != _POSIX_VDISABLE else "(DISABLED:\\xff)" + ) return f"({value_str}, {'|'.join(flags)})" class SLC_nosupport(SLC): # pylint: disable=invalid-name """SLC definition inferring our unwillingness to support the option.""" - def __init__(self): + def __init__(self) -> None: """Initialize SLC_nosupport with NOSUPPORT level and disabled value.""" SLC.__init__(self, SLC_NOSUPPORT, _POSIX_VDISABLE) @@ -222,7 +230,9 @@ def __init__(self): } -def generate_slctab(tabset=None): +def generate_slctab( + tabset: Optional[Dict[bytes, SLC]] = None, +) -> Dict[bytes, SLC]: """ Returns full 'SLC Tab' for definitions found using ``tabset``. @@ -238,7 +248,11 @@ def generate_slctab(tabset=None): return _slctab -def generate_forwardmask(binary_mode, tabset, ack=False): +def generate_forwardmask( + binary_mode: bool, + tabset: Dict[bytes, SLC], + ack: bool = False, +) -> "Forwardmask": """ Generate a Forwardmask instance. @@ -256,7 +270,7 @@ def generate_forwardmask(binary_mode, tabset, ack=False): byte = theNULL for char in range(start, last + 1): func, _, slc_def = snoop(bytes([char]), tabset, {}) - if func is not None and not slc_def.nosupport: + if func is not None and slc_def is not None and not slc_def.nosupport: # set bit for this character, it is a supported slc char byte = bytes([ord(byte) | 1]) if char != last: @@ -267,7 +281,11 @@ def generate_forwardmask(binary_mode, tabset, ack=False): return Forwardmask(b"".join(mask32), ack) -def snoop(byte, slctab, slc_callbacks): +def snoop( + byte: bytes, + slctab: Dict[bytes, SLC], + slc_callbacks: Dict[bytes, Callable[..., Any]], +) -> Tuple[Optional[Callable[..., Any]], Optional[bytes], Optional[SLC]]: """ Scan ``slctab`` for matching ``byte`` values. @@ -289,7 +307,7 @@ class Linemode: that editing is performed on the remote side. """ - def __init__(self, mask=b"\x00"): + def __init__(self, mask: bytes = b"\x00") -> None: """ Initialize Linemode with the given mask. @@ -299,48 +317,50 @@ def __init__(self, mask=b"\x00"): assert isinstance(mask, bytes) and len(mask) == 1, (repr(mask), mask) self.mask = mask - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Compare by another Linemode (LMODE_MODE_ACK ignored).""" # the inverse OR(|) of acknowledge bit UNSET in comparator, # would be the AND OR(& ~) to compare modes without acknowledge # bit set. + if not isinstance(other, Linemode): + return NotImplemented return (ord(self.mask) | ord(LMODE_MODE_ACK)) == (ord(other.mask) | ord(LMODE_MODE_ACK)) @property - def local(self): + def local(self) -> bool: """True if linemode is local.""" return bool(ord(self.mask) & ord(LMODE_MODE_LOCAL)) @property - def remote(self): + def remote(self) -> bool: """True if linemode is remote.""" return not self.local @property - def trapsig(self): + def trapsig(self) -> bool: """True if signals are trapped by client.""" return bool(ord(self.mask) & ord(LMODE_MODE_TRAPSIG)) @property - def ack(self): + def ack(self) -> bool: """Returns True if mode has been acknowledged.""" return bool(ord(self.mask) & ord(LMODE_MODE_ACK)) @property - def soft_tab(self): + def soft_tab(self) -> bool: r"""Returns True if client will expand horizontal tab (``\x09``).""" return bool(ord(self.mask) & ord(LMODE_MODE_SOFT_TAB)) @property - def lit_echo(self): + def lit_echo(self) -> bool: """Returns True if non-printable characters are displayed as-is.""" return bool(ord(self.mask) & ord(LMODE_MODE_LIT_ECHO)) - def __str__(self): + def __str__(self) -> str: """Returns string representation of line mode, for debugging.""" return "remote" if self.remote else "local" - def __repr__(self): + def __repr__(self) -> str: props = ", ".join( f"{prop}:{getattr(self, prop)}" for prop in ("lit_echo", "soft_tab", "ack", "trapsig", "remote", "local") @@ -351,11 +371,11 @@ def __repr__(self): class Forwardmask: """Forwardmask object using the bytemask value received by server.""" - def __init__(self, value, ack=False): + def __init__(self, value: Union[bytes, bytearray], ack: bool = False) -> None: """ Initialize Forwardmask with the given value. - :param bytes value: bytemask ``value`` received by server after ``IAC SB + :param value: Bytemask ``value`` received by server after ``IAC SB LINEMODE DO FORWARDMASK``. It must be a bytearray of length 16 or 32. """ assert isinstance(value, (bytes, bytearray)), value @@ -363,19 +383,19 @@ def __init__(self, value, ack=False): self.value = value self.ack = ack - def description_table(self): + def description_table(self) -> List[str]: """Returns list of strings describing obj as a tabular ASCII map.""" - result = [] + result: List[str] = [] mrk_cont = "(...)" - def continuing(): - return len(result) and result[-1] == mrk_cont + def continuing() -> bool: + return bool(len(result) and result[-1] == mrk_cont) - def is_last(mask): + def is_last(mask: int) -> bool: return mask == len(self.value) - 1 - def same_as_last(row): - return len(result) and result[-1].endswith(row.split()[-1]) + def same_as_last(row: str) -> bool: + return bool(len(result) and result[-1].endswith(row.split()[-1])) for mask, byte in enumerate(self.value): if byte == 0: @@ -395,12 +415,12 @@ def same_as_last(row): result.append(f"[{mask:2d}] {eightbits(byte)} {characters}") return result - def __str__(self): + def __str__(self) -> str: """Returns single string of binary 0 and 1 describing obj.""" bits = "".join(value for (_, value) in [eightbits(byte).split("b") for byte in self.value]) return f"0b{bits}" - def __contains__(self, number): + def __contains__(self, number: int) -> bool: """Whether forwardmask contains keycode ``number``.""" mask, flag = number // 8, 2 ** (7 - (number % 8)) return bool(self.value[mask] & flag) @@ -446,6 +466,6 @@ def __contains__(self, number): } -def name_slc_command(byte): +def name_slc_command(byte: bytes) -> str: """Given an SLC ``byte``, return global mnemonic as string.""" return repr(byte) if byte not in _DEBUG_SLC_OPTS else _DEBUG_SLC_OPTS[byte] diff --git a/telnetlib3/stream_reader.py b/telnetlib3/stream_reader.py index 5edc31bd..bc0b9ba4 100644 --- a/telnetlib3/stream_reader.py +++ b/telnetlib3/stream_reader.py @@ -1,5 +1,7 @@ """Module provides class TelnetReader and TelnetReaderUnicode.""" +from __future__ import annotations + # std imports import re import sys @@ -7,6 +9,7 @@ import asyncio import logging import warnings +from typing import Callable, Optional from asyncio import format_helpers __all__ = ( @@ -26,7 +29,7 @@ class TelnetReader: _source_traceback = None - def __init__(self, limit=_DEFAULT_LIMIT): + def __init__(self, limit: int = _DEFAULT_LIMIT) -> None: """Initialize TelnetReader with optional buffer size limit.""" self.log = logging.getLogger(__name__) # The line length limit is a security feature; @@ -38,9 +41,9 @@ def __init__(self, limit=_DEFAULT_LIMIT): self._limit = limit self._buffer = bytearray() self._eof = False # Whether we're done. - self._waiter = None # A future used by _wait_for_data() - self._exception = None - self._transport = None + self._waiter: Optional[asyncio.Future[None]] = None + self._exception: Optional[Exception] = None + self._transport: Optional[asyncio.BaseTransport] = None self._paused = False try: loop = asyncio.get_running_loop() @@ -49,7 +52,7 @@ def __init__(self, limit=_DEFAULT_LIMIT): except RuntimeError: pass - def __repr__(self): + def __repr__(self) -> str: """Description of stream encoding state.""" info = [type(self).__name__] if self._buffer: @@ -69,11 +72,11 @@ def __repr__(self): info.append("encoding=False") return f"<{' '.join(info)}>" - def exception(self): + def exception(self) -> Optional[Exception]: """Return the exception if set, otherwise None.""" return self._exception - def set_exception(self, exc): + def set_exception(self, exc: Exception) -> None: """Set the exception and wake up any waiting coroutine.""" self._exception = exc @@ -83,7 +86,7 @@ def set_exception(self, exc): if not waiter.cancelled(): waiter.set_exception(exc) - def _wakeup_waiter(self): + def _wakeup_waiter(self) -> None: """Wakeup read*() functions waiting for data or EOF.""" waiter = self._waiter if waiter is not None: @@ -91,17 +94,18 @@ def _wakeup_waiter(self): if not waiter.cancelled(): waiter.set_result(None) - def set_transport(self, transport): + def set_transport(self, transport: asyncio.BaseTransport) -> None: """Set the transport for flow control.""" assert self._transport is None, "Transport already set" self._transport = transport - def _maybe_resume_transport(self): + def _maybe_resume_transport(self) -> None: if self._paused and len(self._buffer) <= self._limit: self._paused = False - self._transport.resume_reading() + assert self._transport is not None + self._transport.resume_reading() # type: ignore[attr-defined] - def feed_eof(self): + def feed_eof(self) -> None: """ Mark EOF on the reader and wake any pending readers. @@ -119,11 +123,11 @@ def feed_eof(self): self._eof = True self._wakeup_waiter() - def at_eof(self): + def at_eof(self) -> bool: """Return True if the buffer is empty and 'feed_eof' was called.""" return self._eof and not self._buffer - def feed_data(self, data): + def feed_data(self, data: bytes) -> None: """Feed data bytes to the reader buffer.""" assert not self._eof, "feed_data after feed_eof" @@ -135,7 +139,7 @@ def feed_data(self, data): if self._transport is not None and not self._paused and len(self._buffer) > 2 * self._limit: try: - self._transport.pause_reading() + self._transport.pause_reading() # type: ignore[attr-defined] except NotImplementedError: # The transport can't be paused. # We'll just have to buffer all data. @@ -144,7 +148,7 @@ def feed_data(self, data): else: self._paused = True - async def _wait_for_data(self, func_name): + async def _wait_for_data(self, func_name: str) -> None: """ Wait until feed_data() or feed_eof() is called. @@ -166,7 +170,8 @@ async def _wait_for_data(self, func_name): # This is essential for readexactly(n) for case when n > self._limit. if self._paused: self._paused = False - self._transport.resume_reading() + assert self._transport is not None + self._transport.resume_reading() # type: ignore[attr-defined] self._waiter = asyncio.get_running_loop().create_future() try: @@ -174,7 +179,7 @@ async def _wait_for_data(self, func_name): finally: self._waiter = None - async def readuntil(self, separator=b"\n"): + async def readuntil(self, separator: bytes = b"\n") -> bytes: """ Read data from the stream until ``separator`` is found. @@ -262,12 +267,12 @@ async def readuntil(self, separator=b"\n"): "Separator is found, but chunk is longer than limit", isep ) - chunk = self._buffer[: isep + seplen] + result = bytes(self._buffer[: isep + seplen]) del self._buffer[: isep + seplen] self._maybe_resume_transport() - return bytes(chunk) + return result - async def readuntil_pattern(self, pattern: re.Pattern) -> bytes: + async def readuntil_pattern(self, pattern: re.Pattern[bytes]) -> bytes: """ Read data from the stream until ``pattern`` is found. @@ -329,15 +334,15 @@ async def readuntil_pattern(self, pattern: re.Pattern) -> bytes: # raise an exception with the partial data. This is checked after # searching the buffer, as the last received chunk might complete the pattern. if self._eof: - chunk = bytes(self._buffer) + partial = bytes(self._buffer) self._buffer.clear() - raise asyncio.IncompleteReadError(chunk, None) + raise asyncio.IncompleteReadError(partial, None) # Wait for more data to arrive since the pattern was not found and # we are not at EOF. await self._wait_for_data("readuntil_pattern") - async def read(self, n=-1): + async def read(self, n: int = -1) -> bytes: """ Read up to `n` bytes from the stream. @@ -387,7 +392,7 @@ async def read(self, n=-1): self._maybe_resume_transport() return data - async def readexactly(self, n): + async def readexactly(self, n: int) -> bytes: """ Read exactly `n` bytes. @@ -432,10 +437,10 @@ async def readexactly(self, n): self._maybe_resume_transport() return data - def __aiter__(self): + def __aiter__(self) -> "TelnetReader": return self - async def __anext__(self): + async def __anext__(self) -> bytes: val = await self.readline() if val == b"": raise StopAsyncIteration @@ -445,7 +450,7 @@ async def __anext__(self): # instead of the commit 260dd63a that introduced a close() method on a # reader. @property - def connection_closed(self): + def connection_closed(self) -> bool: """Deprecated: use at_eof() instead.""" warnings.warn( "connection_closed property removed, use at_eof() instead", @@ -454,7 +459,7 @@ def connection_closed(self): ) return self._eof - def close(self): + def close(self) -> None: """ Deprecated: use feed_eof() instead. @@ -471,7 +476,7 @@ def close(self): ) self.feed_eof() - async def readline(self): + async def readline(self) -> bytes: r""" Read one line. @@ -567,7 +572,13 @@ class TelnetReaderUnicode(TelnetReader): #: practice, however. _decoder = None - def __init__(self, fn_encoding, *, limit=_DEFAULT_LIMIT, encoding_errors="replace"): + def __init__( + self, + fn_encoding: Callable[..., str], + *, + limit: int = _DEFAULT_LIMIT, + encoding_errors: str = "replace", + ) -> None: """ A Unicode StreamReader interface for Telnet protocol. @@ -581,7 +592,7 @@ def __init__(self, fn_encoding, *, limit=_DEFAULT_LIMIT, encoding_errors="replac self.fn_encoding = fn_encoding self.encoding_errors = encoding_errors - def decode(self, buf, final=False): + def decode(self, buf: bytes, final: bool = False) -> str: """Decode bytes ``buf`` using preferred encoding.""" if buf == b"": return "" # EOF @@ -589,14 +600,13 @@ def decode(self, buf, final=False): encoding = self.fn_encoding(incoming=True) # late-binding, - # pylint: disable=protected-access - if self._decoder is None or encoding != self._decoder._encoding: + if self._decoder is None or encoding != getattr(self._decoder, "_encoding", ""): self._decoder = codecs.getincrementaldecoder(encoding)(errors=self.encoding_errors) - self._decoder._encoding = encoding + setattr(self._decoder, "_encoding", encoding) return self._decoder.decode(buf, final) - async def readline(self): + async def readline(self) -> str: # type: ignore[override] """ Read one line. @@ -605,15 +615,14 @@ async def readline(self): buf = await super().readline() return self.decode(buf) - async def read(self, n=-1): + async def read(self, n: int = -1) -> str: # type: ignore[override] """ Read up to *n* bytes. If the EOF was received and the internal buffer is empty, return an empty string. - :param int n: If *n* is not provided, or set to -1, read until EOF and return all characters - as one large string. - :rtype: str + :param n: If *n* is not provided, or set to -1, read until EOF and return all characters as + one large string. """ if self._exception is not None: raise self._exception @@ -650,31 +659,32 @@ async def read(self, n=-1): self._maybe_resume_transport() return u_data - async def readexactly(self, n): + async def readexactly(self, n: int) -> str: # type: ignore[override] """ Read exactly *n* unicode characters. :raises asyncio.IncompleteReadError: if the end of the stream is - reached before *n* can be read. the + reached before *n* can be read. The :attr:`asyncio.IncompleteReadError.partial` attribute of the exception contains the partial read characters. - :rtype: str """ if self._exception is not None: raise self._exception - blocks = [] + blocks: list[str] = [] while n > 0: block = await self.read(n) if not block: partial = "".join(blocks) - raise asyncio.IncompleteReadError(partial, len(partial) + n) + raise asyncio.IncompleteReadError( + partial, len(partial) + n # type: ignore[arg-type] + ) blocks.append(block) n -= len(block) return "".join(blocks) - def __repr__(self): + def __repr__(self) -> str: """Description of stream encoding state.""" encoding = None if callable(self.fn_encoding): diff --git a/telnetlib3/stream_writer.py b/telnetlib3/stream_writer.py index a2e3a86a..704f9dba 100644 --- a/telnetlib3/stream_writer.py +++ b/telnetlib3/stream_writer.py @@ -1,16 +1,22 @@ """Module provides :class:`TelnetWriter` and :class:`TelnetWriterUnicode`.""" -# pylint: disable=too-many-lines -# pylint: disable=duplicate-code +from __future__ import annotations # std imports import struct import asyncio import logging import collections +from typing import TYPE_CHECKING, Any, Dict, Callable, Optional, Sequence + +# pylint: disable=too-many-lines +# pylint: disable=duplicate-code + + +if TYPE_CHECKING: # pragma: no cover + from .stream_reader import TelnetReader # local -# local imports from . import slc from .telopt import ( AO, @@ -105,7 +111,7 @@ class TelnetWriter: #: Whether the last byte received by :meth:`~.feed_byte` begins an IAC #: command sequence. - cmd_received = None + cmd_received: bytes | tuple[bytes, bytes] | bool | None = None #: Whether the last byte received by :meth:`~.feed_byte` is a matching #: special line character value, if negotiated. @@ -126,13 +132,13 @@ class TelnetWriter: def __init__( self, - transport, - protocol, + transport: asyncio.Transport, + protocol: Any, *, - client=False, - server=False, - reader=None, - ): + client: bool = False, + server: bool = False, + reader: Optional["TelnetReader"] = None, + ) -> None: """ Initialize TelnetWriter. @@ -151,9 +157,9 @@ def __init__( :meth:`~.feed_byte`, which returns True to indicate the given byte should be forwarded to a Protocol reader method. - :param bool client: Whether the IAC interpreter should react from + :param client: Whether the IAC interpreter should react from the client point of view. - :param bool server: Whether the IAC interpreter should react from + :param server: Whether the IAC interpreter should react from the server point of view. """ self._transport = transport @@ -166,7 +172,7 @@ def __init__( reader, ) self._reader = reader - self._closed_fut = None + self._closed_fut: Optional[asyncio.Future[None]] = None if not any((client, server)) or all((client, server)): raise TypeError("keyword arguments `client', and `server' are mutually exclusive.") @@ -174,7 +180,7 @@ def __init__( self.log = logging.getLogger(__name__) #: List of (predicate, future) tuples for wait_for functionality - self._waiters = [] + self._waiters: list[tuple[Callable[[], bool], asyncio.Future[bool]]] = [] #: Dictionary of telnet option byte(s) that follow an #: IAC-DO or IAC-DONT command, and contains a value of ``True`` @@ -191,11 +197,19 @@ def __init__( #: indicating state of remote capabilities. self.remote_option = Option("remote_option", self.log, on_change=self._check_waiters) + #: Set of option byte(s) for WILL received from remote end + #: that were rejected with DONT (unhandled options). + self.rejected_will: set[bytes] = set() + + #: Set of option byte(s) for DO received from remote end + #: that were rejected with WONT (unsupported options). + self.rejected_do: set[bytes] = set() + #: Sub-negotiation buffer - self._sb_buffer = collections.deque() + self._sb_buffer: collections.deque[bytes] = collections.deque() #: SLC buffer - self._slc_buffer = collections.deque() + self._slc_buffer: collections.deque[bytes] = collections.deque() #: SLC Tab (SLC Functions and their support level, and ascii value) self.slctab = slc.generate_slctab(self.default_slc_tab) @@ -209,7 +223,7 @@ def __init__( # Set default callback handlers to local methods. A base protocol # wishing not to wire any callbacks at all may simply allow our stream # to gracefully log and do nothing about in most cases. - self._iac_callback = {} + self._iac_callback: dict[bytes, Callable[..., Any]] = {} for iac_cmd, key in ( (BRK, "brk"), (IP, "ip"), @@ -228,7 +242,7 @@ def __init__( ): self.set_iac_callback(cmd=iac_cmd, func=getattr(self, f"handle_{key}")) - self._slc_callback = {} + self._slc_callback: dict[bytes, Callable[..., Any]] = {} for slc_cmd, key in ( (slc.SLC_SYNCH, "dm"), (slc.SLC_BRK, "brk"), @@ -249,7 +263,7 @@ def __init__( ): self.set_slc_callback(slc_byte=slc_cmd, func=getattr(self, f"handle_{key}")) - self._ext_callback = {} + self._ext_callback: dict[bytes, Callable[..., Any]] = {} for ext_cmd, key in ( (LOGOUT, "logout"), (SNDLOC, "sndloc"), @@ -262,7 +276,7 @@ def __init__( ): self.set_ext_callback(cmd=ext_cmd, func=getattr(self, f"handle_{key}")) - self._ext_send_callback = {} + self._ext_send_callback: dict[bytes, Callable[..., Any]] = {} for ext_cmd, key in ( (TTYPE, "ttype"), (TSPEED, "tspeed"), @@ -277,18 +291,18 @@ def __init__( self.set_ext_send_callback(cmd=ext_cmd, func=getattr(self, _cbname + key)) @property - def connection_closed(self): + def connection_closed(self) -> bool: """Return True if connection has been closed.""" return self._connection_closed # Base protocol methods @property - def transport(self): + def transport(self) -> Optional[asyncio.BaseTransport]: """Return the underlying transport.""" return self._transport - def close(self): + def close(self) -> None: """Close the connection and release resources.""" if self.connection_closed: return @@ -310,13 +324,13 @@ def close(self): self._slc_callback.clear() self._iac_callback.clear() self._protocol = None - self._transport = None + self._transport = None # type: ignore[assignment] self._connection_closed = True # Signal that the connection is closed if self._closed_fut is not None and not self._closed_fut.done(): self._closed_fut.set_result(None) - def is_closing(self): + def is_closing(self) -> bool: """Return True if the connection is closing or already closed.""" if self._transport is not None: if self._transport.is_closing(): @@ -325,7 +339,7 @@ def is_closing(self): return True return False - async def wait_closed(self): + async def wait_closed(self) -> None: """ Wait until the underlying connection has completed closing. @@ -340,26 +354,32 @@ async def wait_closed(self): self._closed_fut = asyncio.get_running_loop().create_future() await self._closed_fut - def _check_waiters(self): + def _check_waiters(self) -> None: """Check all registered waiters and resolve those whose conditions are met.""" for check, fut in self._waiters[:]: if not fut.done() and check(): fut.set_result(True) - def _cancel_waiters(self): + def _cancel_waiters(self) -> None: """Cancel all pending waiters, typically called on connection close.""" for _check, fut in self._waiters[:]: if not fut.done(): fut.cancel() self._waiters.clear() - async def wait_for(self, *, remote=None, local=None, pending=None): + async def wait_for( + self, + *, + remote: Optional[Dict[str, bool]] = None, + local: Optional[Dict[str, bool]] = None, + pending: Optional[Dict[str, bool]] = None, + ) -> bool: """ Wait for negotiation state conditions to be met. - :param dict remote: Dict of option_name -> bool for remote_option checks. - :param dict local: Dict of option_name -> bool for local_option checks. - :param dict pending: Dict of option_name -> bool for pending_option checks. + :param remote: Dict of option_name -> bool for remote_option checks. + :param local: Dict of option_name -> bool for local_option checks. + :param pending: Dict of option_name -> bool for pending_option checks. :returns: True when all conditions are met. :raises KeyError: If an option name is not recognized. :raises asyncio.CancelledError: If connection closes while waiting. @@ -383,7 +403,7 @@ async def wait_for(self, *, remote=None, local=None, pending=None): opt = option_from_name(name) conditions.append((option_dict, opt, expected)) - def check(): + def check() -> bool: for option_dict, opt, expected in conditions: if expected: if not option_dict.enabled(opt): @@ -395,15 +415,16 @@ def check(): if check(): return True - fut = asyncio.get_running_loop().create_future() + fut: asyncio.Future[bool] = asyncio.get_running_loop().create_future() self._waiters.append((check, fut)) try: - return await fut + result: bool = await fut + return result finally: self._waiters = [(c, f) for c, f in self._waiters if f is not fut] - async def wait_for_condition(self, predicate): + async def wait_for_condition(self, predicate: Callable[["TelnetWriter"], bool]) -> bool: """ Wait for a custom condition to be met. @@ -418,18 +439,19 @@ async def wait_for_condition(self, predicate): if predicate(self): return True - def check(): + def check() -> bool: return predicate(self) - fut = asyncio.get_running_loop().create_future() + fut: asyncio.Future[bool] = asyncio.get_running_loop().create_future() self._waiters.append((check, fut)) try: - return await fut + result: bool = await fut + return result finally: self._waiters = [(c, f) for c, f in self._waiters if f is not fut] - def __repr__(self): + def __repr__(self) -> str: """Description of stream encoding state.""" info = ["TelnetWriter"] if self.server: @@ -476,14 +498,14 @@ def __repr__(self): return f"<{' '.join(info)}>" - def write(self, data): + def write(self, data: bytes) -> None: """Write a bytes object to the protocol transport.""" if self.connection_closed: self.log.debug("write after close, ignored %s bytes", len(data)) return self._write(data) - def writelines(self, lines): + def writelines(self, lines: Sequence[bytes]) -> None: """ Write unicode strings to transport. @@ -492,15 +514,15 @@ def writelines(self, lines): """ self.write(b"".join(lines)) - def write_eof(self): + def write_eof(self) -> None: """Write EOF to the transport.""" return self._transport.write_eof() - def can_write_eof(self): + def can_write_eof(self) -> bool: """Return True if the transport supports write_eof().""" return self._transport.can_write_eof() - async def drain(self): + async def drain(self) -> None: """ Flush the write buffer. @@ -529,14 +551,14 @@ async def drain(self): # proprietary write helper # pylint: disable=too-many-branches,too-many-statements,too-complex - def feed_byte(self, byte): + def feed_byte(self, byte: bytes) -> bool: """ Feed a single byte into Telnet option state machine. - :param int byte: an 8-bit byte value as integer (0-255), or + :param byte: an 8-bit byte value as integer (0-255), or a bytes array. When a bytes array, it must be of length 1. - :rtype bool: Whether the given ``byte`` is "in band", that is, should + :returns: Whether the given ``byte`` is "in band", that is, should be duplicated to a connected terminal or device. ``False`` is returned for an ``IAC`` command for each byte until its completion. :raises ValueError: When an illegal IAC command is received. @@ -595,6 +617,7 @@ def feed_byte(self, byte): elif self.cmd_received: # parse 3rd and final byte of IAC DO, DONT, WILL, WONT. + assert isinstance(self.cmd_received, bytes) cmd, opt = self.cmd_received, byte self.log.debug("recv IAC %s %s", name_command(cmd), name_command(opt)) try: @@ -646,6 +669,7 @@ def feed_byte(self, byte): # Inform caller which SLC function occurred by this attribute. self.slc_received = slc_name if callback: + assert slc_name is not None self.log.debug( "slc.snoop(%r): %s, callback is %s.", byte, @@ -659,53 +683,59 @@ def feed_byte(self, byte): # Our protocol methods - def get_extra_info(self, name, default=None): + def get_extra_info(self, name: str, default: Any = None) -> Any: """Get optional server protocol information.""" # StreamWriter uses self._transport.get_extra_info, so we mix it in - # here, but _protocol has all of the interesting telnet effects - return self._protocol.get_extra_info(name, default) or self._transport.get_extra_info( - name, default - ) + # here, but _protocol has all of the interesting telnet effects. + # Handle case where protocol/transport may be None (connection closed). + _missing = object() + if self._protocol is not None: + result = self._protocol.get_extra_info(name, _missing) + if result is not _missing: + return result + if self._transport is not None: + return self._transport.get_extra_info(name, default) + return default @property - def protocol(self): + def protocol(self) -> Any: """The (Telnet) protocol attached to this stream.""" return self._protocol @property - def server(self): + def server(self) -> bool: """Whether this stream is of the server's point of view.""" return bool(self._server) @property - def client(self): + def client(self) -> bool: """Whether this stream is of the client's point of view.""" return bool(not self._server) @property - def inbinary(self): + def inbinary(self) -> bool: """Whether binary data is expected to be received on reader, :rfc:`856`.""" return self.remote_option.enabled(BINARY) @property - def outbinary(self): + def outbinary(self) -> bool: """Whether binary data may be written to the writer, :rfc:`856`.""" return self.local_option.enabled(BINARY) - def echo(self, data): + def echo(self, data: bytes) -> None: """ Conditionally write ``data`` to transport when "remote echo" enabled. - :param bytes data: string received as input, conditionally written. The default - implementation depends on telnet negotiation willingness for local echo, only an RFC- - compliant telnet client will correctly set or unset echo accordingly by demand. + :param data: bytes received as input, conditionally written. The default implementation + depends on telnet negotiation willingness for local echo, only an RFC- compliant telnet + client will correctly set or unset echo accordingly by demand. """ assert self.server, "Client never performs echo of input received." if self.will_echo: self.write(data=data) @property - def will_echo(self): + def will_echo(self) -> bool: """ Whether Server end is expected to echo back input sent by client. @@ -714,29 +744,29 @@ def will_echo(self): their input has been received. From client perspective: the server will not echo our input, we should - chose to duplicate our input to standard out ourselves. + choose to duplicate our input to standard out ourselves. """ return (self.server and self.local_option.enabled(ECHO)) or ( self.client and self.remote_option.enabled(ECHO) ) @property - def mode(self): + def mode(self) -> str: """ String describing NVT mode. - :rtype str: One of: + One of: - ``kludge``: Client acknowledges WILL-ECHO, WILL-SGA. character-at- - a-time and remote line editing may be provided. + ``kludge``: Client acknowledges WILL-ECHO, WILL-SGA. Character-at- + a-time and remote line editing may be provided. - ``local``: Default NVT half-duplex mode, client performs line - editing and transmits only after pressing send (usually CR) + ``local``: Default NVT half-duplex mode, client performs line + editing and transmits only after pressing send (usually CR). - ``remote``: Client supports advanced remote line editing, using - mixed-mode local line buffering (optionally, echoing) until - send, but also transmits buffer up to and including special - line characters (SLCs). + ``remote``: Client supports advanced remote line editing, using + mixed-mode local line buffering (optionally, echoing) until + send, but also transmits buffer up to and including special + line characters (SLCs). """ if self.remote_option.enabled(LINEMODE): if self._linemode.local: @@ -751,12 +781,12 @@ def mode(self): return "local" @property - def is_oob(self): + def is_oob(self) -> bool: """The previous byte should not be received by the API stream.""" - return self.iac_received or self.cmd_received + return bool(self.iac_received or self.cmd_received) @property - def linemode(self): + def linemode(self) -> slc.Linemode: """ Linemode instance for stream. @@ -769,13 +799,13 @@ def linemode(self): """ return self._linemode - def send_iac(self, buf): + def send_iac(self, buf: bytes) -> None: """ Send a command starting with IAC (base 10 byte value 255). No transformations of bytes are performed. Normally, if the byte value 255 is sent, it is escaped as ``IAC + IAC``. This - method ensures it is not escaped,. + method ensures it is not escaped. """ assert isinstance(buf, (bytes, bytearray)), buf assert buf and buf.startswith(IAC), buf @@ -784,7 +814,7 @@ def send_iac(self, buf): if hasattr(self._protocol, "_tx_bytes"): self._protocol._tx_bytes += len(buf) - def iac(self, cmd, opt=b""): + def iac(self, cmd: bytes, opt: bytes = b"") -> bool: """ Send Is-A-Command 3-byte negotiation command. @@ -840,7 +870,7 @@ def iac(self, cmd, opt=b""): # Public methods for transmission signaling # - def send_ga(self): + def send_ga(self) -> bool: """ Transmit IAC GA (Go-Ahead). @@ -855,7 +885,7 @@ def send_ga(self): self.send_iac(IAC + GA) return True - def send_eor(self): + def send_eor(self) -> bool: """ Transmit IAC CMD_EOR (End-of-Record), :rfc:`885`. @@ -873,7 +903,7 @@ def send_eor(self): # Public methods for notifying about, or soliciting state options. # - def request_status(self): + def request_status(self) -> bool: """ Send ``IAC-SB-STATUS-SEND`` sub-negotiation (:rfc:`859`). @@ -892,7 +922,7 @@ def request_status(self): self.log.info("cannot send SB STATUS SEND, request pending.") return False - def request_tspeed(self): + def request_tspeed(self) -> bool: """ Send IAC-SB-TSPEED-SEND sub-negotiation, :rfc:`1079`. @@ -912,7 +942,7 @@ def request_tspeed(self): self.log.debug("cannot send SB TSPEED SEND, request pending.") return False - def request_charset(self): + def request_charset(self) -> bool: """ Request sub-negotiation CHARSET, :rfc:`2066`. @@ -938,7 +968,7 @@ def request_charset(self): codepages = self._ext_send_callback[CHARSET]() sep = " " - response = collections.deque() + response: collections.deque[bytes] = collections.deque() response.extend([IAC, SB, CHARSET, REQUEST]) response.extend([bytes(sep, "ascii")]) response.extend([bytes(sep.join(codepages), "ascii")]) @@ -948,7 +978,7 @@ def request_charset(self): self.pending_option[SB + CHARSET] = True return True - def request_environ(self): + def request_environ(self) -> bool: """ Request sub-negotiation NEW_ENVIRON, :rfc:`1572`. @@ -972,7 +1002,7 @@ def request_environ(self): self.log.debug("cannot send SB NEW_ENVIRON SEND IS, request pending.") return False - response = collections.deque() + response: collections.deque[bytes] = collections.deque() response.extend([IAC, SB, NEW_ENVIRON, SEND]) for env_key in request_list: @@ -990,7 +1020,7 @@ def request_environ(self): self.send_iac(b"".join(response)) return True - def request_xdisploc(self): + def request_xdisploc(self) -> bool: """ Send XDISPLOC, SEND sub-negotiation, :rfc:`1086`. @@ -1010,7 +1040,7 @@ def request_xdisploc(self): self.log.debug("cannot send SB XDISPLOC SEND, request pending.") return False - def request_ttype(self): + def request_ttype(self) -> bool: """ Send TTYPE SEND sub-negotiation, :rfc:`930`. @@ -1029,7 +1059,7 @@ def request_ttype(self): self.log.debug("cannot send SB TTYPE SEND, request pending.") return False - def request_forwardmask(self, fmask=None): + def request_forwardmask(self, fmask: Optional[slc.Forwardmask] = None) -> bool: """ Request the client forward their terminal control characters. @@ -1066,7 +1096,7 @@ def request_forwardmask(self, fmask=None): return True return False - def send_lineflow_mode(self): + def send_lineflow_mode(self) -> Optional[bool]: """ Send LFLOW mode sub-negotiation, :rfc:`1372`. @@ -1086,7 +1116,7 @@ def send_lineflow_mode(self): return True return False - def send_linemode(self, linemode=None): + def send_linemode(self, linemode: Optional[slc.Linemode] = None) -> None: """ Set and Inform other end to agree to change to linemode, ``linemode``. @@ -1111,7 +1141,7 @@ def send_linemode(self, linemode=None): # Public is-a-command (IAC) callbacks # - def set_iac_callback(self, cmd, func): + def set_iac_callback(self, cmd: bytes, func: Callable[..., Any]) -> None: """ Register callable ``func`` as callback for IAC ``cmd``. @@ -1139,21 +1169,21 @@ def set_iac_callback(self, cmd, func): ), name_command(cmd) self._iac_callback[cmd] = func - def handle_nop(self, cmd): # pylint:disable=unused-argument + def handle_nop(self, cmd: bytes) -> None: # pylint:disable=unused-argument """Handle IAC No-Operation (NOP).""" self.log.debug("IAC NOP: Null Operation (unhandled).") - def handle_ga(self, cmd): # pylint:disable=unused-argument + def handle_ga(self, cmd: bytes) -> None: # pylint:disable=unused-argument """Handle IAC Go-Ahead (GA).""" self.log.debug("IAC GA: Go-Ahead (unhandled).") - def handle_dm(self, cmd): # pylint:disable=unused-argument + def handle_dm(self, cmd: bytes) -> None: # pylint:disable=unused-argument """Handle IAC Data-Mark (DM).""" self.log.debug("IAC DM: Data-Mark (unhandled).") # Public mixed-mode SLC and IAC callbacks # - def handle_el(self, _byte): + def handle_el(self, _byte: bytes) -> None: """ Handle IAC Erase Line (EL, SLC_EL). @@ -1162,11 +1192,11 @@ def handle_el(self, _byte): """ self.log.debug("IAC EL: Erase Line (unhandled).") - def handle_eor(self, _byte): + def handle_eor(self, _byte: bytes) -> None: """Handle IAC End of Record (CMD_EOR, SLC_EOR).""" self.log.debug("IAC EOR: End of Record (unhandled).") - def handle_abort(self, _byte): + def handle_abort(self, _byte: bytes) -> None: """ Handle IAC Abort (ABORT, SLC_ABORT). @@ -1175,11 +1205,11 @@ def handle_abort(self, _byte): """ self.log.debug("IAC ABORT: Abort (unhandled).") - def handle_eof(self, _byte): + def handle_eof(self, _byte: bytes) -> None: """Handle IAC End of Record (EOF, SLC_EOF).""" self.log.debug("IAC EOF: End of File (unhandled).") - def handle_susp(self, _byte): + def handle_susp(self, _byte: bytes) -> None: """ Handle IAC Suspend Process (SUSP, SLC_SUSP). @@ -1191,16 +1221,16 @@ def handle_susp(self, _byte): """ self.log.debug("IAC SUSP: Suspend (unhandled).") - def handle_brk(self, _byte): + def handle_brk(self, _byte: bytes) -> None: """ Handle IAC Break (BRK, SLC_BRK). Sent by clients to indicate BREAK keypress. This is not the same as IP (^c), but a means to - map sysystem-dependent break key such as found on an IBM Systems. + map system-dependent break key such as found on an IBM Systems. """ self.log.debug("IAC BRK: Break (unhandled).") - def handle_ayt(self, _byte): + def handle_ayt(self, _byte: bytes) -> None: """ Handle IAC Are You There (AYT, SLC_AYT). @@ -1209,11 +1239,11 @@ def handle_ayt(self, _byte): """ self.log.debug("IAC AYT: Are You There? (unhandled).") - def handle_ip(self, _byte): + def handle_ip(self, _byte: bytes) -> None: """Handle IAC Interrupt Process (IP, SLC_IP).""" self.log.debug("IAC IP: Interrupt Process (unhandled).") - def handle_ao(self, _byte): + def handle_ao(self, _byte: bytes) -> None: """ Handle IAC Abort Output (AO) or SLC_AO. @@ -1225,7 +1255,7 @@ def handle_ao(self, _byte): """ self.log.debug("IAC AO: Abort Output, unhandled.") - def handle_ec(self, _byte): + def handle_ec(self, _byte: bytes) -> None: """ Handle IAC Erase Character (EC, SLC_EC). @@ -1234,7 +1264,7 @@ def handle_ec(self, _byte): """ self.log.debug("IAC EC: Erase Character (unhandled).") - def handle_tm(self, cmd): + def handle_tm(self, cmd: bytes) -> None: """ Handle IAC (WILL, WONT, DO, DONT) Timing Mark (TM). @@ -1246,11 +1276,11 @@ def handle_tm(self, cmd): # public Special Line Mode (SLC) callbacks # - def set_slc_callback(self, slc_byte, func): + def set_slc_callback(self, slc_byte: bytes, func: Callable[..., Any]) -> None: """ Register ``func`` as callable for receipt of ``slc_byte``. - :param bytes slc_byte: any of SLC_SYNCH, SLC_BRK, SLC_IP, SLC_AO, + :param slc_byte: any of SLC_SYNCH, SLC_BRK, SLC_IP, SLC_AO, SLC_AYT, SLC_EOR, SLC_ABORT, SLC_EOF, SLC_SUSP, SLC_EC, SLC_EL, SLC_EW, SLC_RP, SLC_XON, SLC_XOFF ... :param func: Callback receiving a single argument: the SLC function byte @@ -1264,7 +1294,7 @@ def set_slc_callback(self, slc_byte, func): ), f"Unknown SLC byte: {slc_byte!r}" self._slc_callback[slc_byte] = func - def handle_ew(self, _slc): + def handle_ew(self, _slc: bytes) -> None: """ Handle SLC_EW (Erase Word). @@ -1273,31 +1303,31 @@ def handle_ew(self, _slc): """ self.log.debug("SLC EC: Erase Word (unhandled).") - def handle_rp(self, _slc): + def handle_rp(self, _slc: bytes) -> None: """Handle SLC Repaint (RP).""" self.log.debug("SLC RP: Repaint (unhandled).") - def handle_lnext(self, _slc): + def handle_lnext(self, _slc: bytes) -> None: """Handle SLC Literal Next (LNEXT) (Next character is received raw).""" self.log.debug("SLC LNEXT: Literal Next (unhandled)") - def handle_xon(self, _byte): + def handle_xon(self, _byte: bytes) -> None: """Handle SLC Transmit-On (XON).""" self.log.debug("SLC XON: Transmit On (unhandled).") - def handle_xoff(self, _byte): + def handle_xoff(self, _byte: bytes) -> None: """Handle SLC Transmit-Off (XOFF).""" self.log.debug("SLC XOFF: Transmit Off.") # public Telnet extension callbacks # - def set_ext_send_callback(self, cmd, func): + def set_ext_send_callback(self, cmd: bytes, func: Callable[..., Any]) -> None: """ - Register callback for inquires of sub-negotiation of ``cmd``. + Register callback for inquiries of sub-negotiation of ``cmd``. :param func: A callable function for the given ``cmd`` byte. Note that the return type must match those documented. - :param bytes cmd: These callbacks must return any number of arguments, + :param cmd: These callbacks must return any number of arguments, for each registered ``cmd`` byte, respectively: * SNDLOC: for clients, returning one argument: the string @@ -1327,16 +1357,16 @@ def set_ext_send_callback(self, cmd, func): assert callable(func), "Argument func must be callable" self._ext_send_callback[cmd] = func - def set_ext_callback(self, cmd, func): + def set_ext_callback(self, cmd: bytes, func: Callable[..., Any]) -> None: """ Register ``func`` as callback for receipt of ``cmd`` negotiation. - :param bytes cmd: One of the following listed bytes: + :param cmd: One of the following listed bytes: * ``LOGOUT``: for servers and clients, receiving one argument. Server end may receive DO or DONT as argument ``cmd``, indicating client's wish to disconnect, or a response to WILL, LOGOUT, - indicating it's wish not to be automatically disconnected. Client + indicating its wish not to be automatically disconnected. Client end may receive WILL or WONT, indicating server's wish to disconnect, or acknowledgment that the client will not be disconnected. @@ -1378,27 +1408,27 @@ def set_ext_callback(self, cmd, func): assert callable(func), "Argument func must be callable" self._ext_callback[cmd] = func - def handle_xdisploc(self, xdisploc): + def handle_xdisploc(self, xdisploc: str) -> None: """Receive XDISPLAY value ``xdisploc``, :rfc:`1096`.""" # xdisploc string format is ':[.]'. self.log.debug("X Display is %s", xdisploc) - def handle_send_xdisploc(self): + def handle_send_xdisploc(self) -> str: """Send XDISPLAY value ``xdisploc``, :rfc:`1096`.""" # xdisploc string format is ':[.]'. self.log.warning("X Display requested, sending empty string.") return "" - def handle_sndloc(self, location): + def handle_sndloc(self, location: str) -> None: """Receive LOCATION value ``location``, :rfc:`779`.""" self.log.debug("Location is %s", location) - def handle_send_sndloc(self): + def handle_send_sndloc(self) -> str: """Send LOCATION value ``location``, :rfc:`779`.""" self.log.warning("Location requested, sending empty response.") return "" - def handle_ttype(self, ttype): + def handle_ttype(self, ttype: str) -> None: """ Receive TTYPE value ``ttype``, :rfc:`1091`. @@ -1408,25 +1438,27 @@ def handle_ttype(self, ttype): """ self.log.debug("Terminal type is %r", ttype) - def handle_send_ttype(self): + def handle_send_ttype(self) -> str: """Send TTYPE value ``ttype``, :rfc:`1091`.""" self.log.warning("Terminal type requested, sending empty string.") return "" - def handle_naws(self, width, height): + def handle_naws(self, width: int, height: int) -> None: """Receive window size ``width`` and ``height``, :rfc:`1073`.""" self.log.debug("Terminal cols=%s, rows=%s", width, height) - def handle_send_naws(self): + def handle_send_naws(self) -> tuple[int, int]: """Send window size ``width`` and ``height``, :rfc:`1073`.""" self.log.warning("Terminal size requested, sending 80x24.") return 80, 24 - def handle_environ(self, env): + def handle_environ(self, env: dict[str, str]) -> None: """Receive environment variables as dict, :rfc:`1572`.""" self.log.debug("Environment values are %r", env) - def handle_send_client_environ(self, _keys): + def handle_send_client_environ( + self, _keys: Any + ) -> dict[str, str]: """ Send environment variables as dict, :rfc:`1572`. @@ -1437,25 +1469,25 @@ def handle_send_client_environ(self, _keys): self.log.debug("Environment values requested, sending {{}}.") return {} - def handle_send_server_environ(self): + def handle_send_server_environ(self) -> list[str]: """Server requests environment variables as list, :rfc:`1572`.""" self.log.debug("Environment values offered, requesting [].") return [] - def handle_tspeed(self, rx, tx): + def handle_tspeed(self, rx: int, tx: int) -> None: """Receive terminal speed from TSPEED as int, :rfc:`1079`.""" self.log.debug("Terminal Speed rx:%s, tx:%s", rx, tx) - def handle_send_tspeed(self): + def handle_send_tspeed(self) -> tuple[int, int]: """Send terminal speed from TSPEED as int, :rfc:`1079`.""" self.log.debug("Terminal Speed requested, sending 9600,9600.") return 9600, 9600 - def handle_charset(self, charset): + def handle_charset(self, charset: str) -> None: """Receive character set as string, :rfc:`2066`.""" self.log.debug("Character set: %s", charset) - def handle_send_client_charset(self, _charsets): + def handle_send_client_charset(self, _charsets: list[str]) -> str: """ Send character set selection as string, :rfc:`2066`. @@ -1466,12 +1498,12 @@ def handle_send_client_charset(self, _charsets): self.log.debug("Character Set requested") return "" - def handle_send_server_charset(self): + def handle_send_server_charset(self) -> list[str]: """Send character set (encodings) offered to client, :rfc:`2066`.""" assert self.server return ["UTF-8"] - def handle_logout(self, cmd): + def handle_logout(self, cmd: bytes) -> None: """ Handle (IAC, (DO | DONT | WILL | WONT), LOGOUT), :rfc:`727`. @@ -1497,7 +1529,7 @@ def handle_logout(self, cmd): # public derivable methods DO, DONT, WILL, and WONT negotiation # - def handle_do(self, opt): + def handle_do(self, opt: bytes) -> bool: """ Process byte 3 of series (IAC, DO, opt) received by remote end. @@ -1598,12 +1630,13 @@ def handle_do(self, opt): else: self.log.debug("DO %s not supported.", name_command(opt)) - if self.local_option.get(opt, None) is None: + self.rejected_do.add(opt) + if not self.local_option.enabled(opt): self.iac(WONT, opt) return False return True - def handle_dont(self, opt): + def handle_dont(self, opt: bytes) -> None: """ Process byte 3 of series (IAC, DONT, opt) received by remote end. @@ -1621,9 +1654,9 @@ def handle_dont(self, opt): # affirm in the negative. # pylint: disable=too-many-branches,too-complex - def handle_will(self, opt): + def handle_will(self, opt: bytes) -> None: """ - Process byte 3 of series (IAC, DONT, opt) received by remote end. + Process byte 3 of series (IAC, WILL, opt) received by remote end. The remote end requests we perform any number of capabilities. Most implementations require an answer in the affirmative with DO, unless @@ -1724,15 +1757,13 @@ def handle_will(self, opt): }[opt]() else: - # option value of -1 toggles opt.unsupported() self.iac(DONT, opt) - self.remote_option[opt] = -1 + self.rejected_will.add(opt) self.log.warning("Unhandled: WILL %s.", name_command(opt)) - self.local_option[opt] = -1 if self.pending_option.enabled(DO + opt): self.pending_option[DO + opt] = False - def handle_wont(self, opt): + def handle_wont(self, opt: bytes) -> None: """ Process byte 3 of series (IAC, WONT, opt) received by remote end. @@ -1761,7 +1792,7 @@ def handle_wont(self, opt): # public derivable Sub-Negotation parsing # - def handle_subnegotiation(self, buf): + def handle_subnegotiation(self, buf: collections.deque[bytes]) -> None: """ Callback for end of sub-negotiation buffer. @@ -1808,16 +1839,16 @@ def handle_subnegotiation(self, buf): # Our Private API methods @staticmethod - def _escape_iac(buf): + def _escape_iac(buf: bytes) -> bytes: r"""Replace bytes in buf ``IAC`` (``b'\xff'``) by ``IAC IAC``.""" return buf.replace(IAC, IAC + IAC) - def _write(self, buf, escape_iac=True): + def _write(self, buf: bytes, escape_iac: bool = True) -> None: """ Write bytes to transport, conditionally escaping IAC. - :param bytes buf: bytes to write to transport. - :param bool escape_iac: whether bytes in buffer ``buf`` should be + :param buf: bytes to write to transport. + :param escape_iac: whether bytes in buffer ``buf`` should be escaped of byte ``IAC``. This should be set ``False`` for direct writes of ``IAC`` commands. """ @@ -1837,7 +1868,7 @@ def _write(self, buf, escape_iac=True): # Private sub-negotiation (SB) routines - def _handle_sb_charset(self, buf): + def _handle_sb_charset(self, buf: collections.deque[bytes]) -> None: cmd = buf.popleft() assert cmd == CHARSET opt = buf.popleft() @@ -1853,7 +1884,7 @@ def _handle_sb_charset(self, buf): self.log.debug("send IAC SB CHARSET REJECTED IAC SE") self.send_iac(IAC + SB + CHARSET + REJECTED + IAC + SE) else: - response = collections.deque() + response: collections.deque[bytes] = collections.deque() response.extend([IAC, SB, CHARSET, ACCEPTED]) response.extend([bytes(selected, "ascii")]) response.extend([IAC, SE]) @@ -1872,7 +1903,7 @@ def _handle_sb_charset(self, buf): else: raise ValueError(f"Illegal option follows IAC SB CHARSET: {opt!r}.") - def _handle_sb_tspeed(self, buf): + def _handle_sb_tspeed(self, buf: collections.deque[bytes]) -> None: """Callback handles IAC-SB-TSPEED--SE.""" cmd = buf.popleft() opt = buf.popleft() @@ -1883,29 +1914,29 @@ def _handle_sb_tspeed(self, buf): if opt == IS: assert self.server, f"SE: cannot recv from server: {name_command(cmd)} {opt_kind}" - rx, tx = str(), str() + rx_str, tx_str = str(), str() while len(buf): value = buf.popleft() if value == b",": break - rx += value.decode("ascii") + rx_str += value.decode("ascii") while len(buf): value = buf.popleft() if value == b",": break - tx += value.decode("ascii") - self.log.debug("sb_tspeed: %s, %s", rx, tx) + tx_str += value.decode("ascii") + self.log.debug("sb_tspeed: %s, %s", rx_str, tx_str) try: - rx, tx = int(rx), int(tx) + rx_int, tx_int = int(rx_str), int(tx_str) except ValueError as err: self.log.error( "illegal TSPEED values received (rx=%r, tx=%r): %s", - rx, - tx, + rx_str, + tx_str, err, ) return - self._ext_callback[TSPEED](rx, tx) + self._ext_callback[TSPEED](rx_int, tx_int) elif opt == SEND: assert self.client, f"SE: cannot recv from client: {name_command(cmd)} {opt_kind}" rx, tx = self._ext_send_callback[TSPEED]() @@ -1924,8 +1955,8 @@ def _handle_sb_tspeed(self, buf): if self.pending_option.enabled(WILL + TSPEED): self.pending_option[WILL + TSPEED] = False - def _handle_sb_xdisploc(self, buf): - """Callback handles IAC-SB-XIDISPLOC--SE.""" + def _handle_sb_xdisploc(self, buf: collections.deque[bytes]) -> None: + """Callback handles IAC-SB-XDISPLOC--SE.""" cmd = buf.popleft() opt = buf.popleft() @@ -1935,12 +1966,12 @@ def _handle_sb_xdisploc(self, buf): self.log.debug("recv %s %s: %r", name_command(cmd), opt_kind, b"".join(buf)) if opt == IS: - assert self.server, f"SE: cannot recv from server: {name_command(cmd)} {opt}" + assert self.server, f"SE: cannot recv from server: {name_command(cmd)} {opt!r}" xdisploc_str = b"".join(buf).decode("ascii") self.log.debug("recv IAC SB XDISPLOC IS %r IAC SE", xdisploc_str) self._ext_callback[XDISPLOC](xdisploc_str) elif opt == SEND: - assert self.client, f"SE: cannot recv from client: {name_command(cmd)} {opt}" + assert self.client, f"SE: cannot recv from client: {name_command(cmd)} {opt!r}" xdisploc_str = self._ext_send_callback[XDISPLOC]().encode("ascii") response = [IAC, SB, XDISPLOC, IS, xdisploc_str, IAC, SE] self.log.debug("send IAC SB XDISPLOC IS %r IAC SE", xdisploc_str) @@ -1948,7 +1979,7 @@ def _handle_sb_xdisploc(self, buf): if self.pending_option.enabled(WILL + XDISPLOC): self.pending_option[WILL + XDISPLOC] = False - def _handle_sb_ttype(self, buf): + def _handle_sb_ttype(self, buf: collections.deque[bytes]) -> None: """Callback handles IAC-SB-TTYPE--SE.""" cmd = buf.popleft() opt = buf.popleft() @@ -1959,12 +1990,12 @@ def _handle_sb_ttype(self, buf): self.log.debug("recv %s %s: %r", name_command(cmd), opt_kind, b"".join(buf)) if opt == IS: - assert self.server, f"SE: cannot recv from server: {name_command(cmd)} {opt}" + assert self.server, f"SE: cannot recv from server: {name_command(cmd)} {opt!r}" ttype_str = b"".join(buf).decode("ascii") self.log.debug("recv IAC SB TTYPE IS %r", ttype_str) self._ext_callback[TTYPE](ttype_str) elif opt == SEND: - assert self.client, f"SE: cannot recv from client: {name_command(cmd)} {opt}" + assert self.client, f"SE: cannot recv from client: {name_command(cmd)} {opt!r}" ttype_str = self._ext_send_callback[TTYPE]().encode("ascii") response = [IAC, SB, TTYPE, IS, ttype_str, IAC, SE] self.log.debug("send IAC SB TTYPE IS %r IAC SE", ttype_str) @@ -1972,7 +2003,7 @@ def _handle_sb_ttype(self, buf): if self.pending_option.enabled(WILL + TTYPE): self.pending_option[WILL + TTYPE] = False - def _handle_sb_environ(self, buf): + def _handle_sb_environ(self, buf: collections.deque[bytes]) -> None: """ Callback handles (IAC, SB, NEW_ENVIRON, , SE), :rfc:`1572`. @@ -2020,13 +2051,13 @@ def _handle_sb_environ(self, buf): if self.pending_option.enabled(WILL + TTYPE): self.pending_option[WILL + TTYPE] = False - def _handle_sb_sndloc(self, buf): + def _handle_sb_sndloc(self, buf: collections.deque[bytes]) -> None: """Fire callback for IAC-SB-SNDLOC--SE (:rfc:`779`).""" assert buf.popleft() == SNDLOC location_str = b"".join(buf).decode("ascii") self._ext_callback[SNDLOC](location_str) - def _send_naws(self): + def _send_naws(self) -> None: """Fire callback for IAC-DO-NAWS from server.""" # Similar to the callback method order fired by _handle_sb_naws(), # we expect our parameters in order of (rows, cols), matching the @@ -2048,14 +2079,16 @@ def _send_naws(self): self.log.debug("send IAC SB NAWS (rows=%s, cols=%s) IAC SE", rows, cols) self.send_iac(b"".join(response)) - def _handle_sb_naws(self, buf): + def _handle_sb_naws(self, buf: collections.deque[bytes]) -> None: """Fire callback for IAC-SB-NAWS--SE (:rfc:`1073`).""" cmd = buf.popleft() assert cmd == NAWS, name_command(cmd) assert len(buf) == 4, f"bad NAWS length {len(buf)}: {buf!r}" - assert self.remote_option.enabled( - NAWS - ), "received IAC SB NAWS without receipt of IAC WILL NAWS" + if not self.remote_option.enabled(NAWS): + self.log.info( + "received IAC SB NAWS without receipt of IAC WILL NAWS -- assuming NAWS-enabled" + ) + self.remote_option[NAWS] = True # note a similar formula: # # cols, rows = ((256 * buf[0]) + buf[1], @@ -2069,7 +2102,7 @@ def _handle_sb_naws(self, buf): # structure, which also matches the terminfo(5) capability, 'cup'. self._ext_callback[NAWS](rows, cols) - def _handle_sb_lflow(self, buf): + def _handle_sb_lflow(self, buf: collections.deque[bytes]) -> None: """Callback responds to IAC SB LFLOW, :rfc:`1372`.""" buf.popleft() # LFLOW if not self.local_option.enabled(LFLOW): @@ -2088,7 +2121,7 @@ def _handle_sb_lflow(self, buf): else: raise ValueError(f"Unknown IAC SB LFLOW option received: {buf!r}") - def _handle_sb_status(self, buf): + def _handle_sb_status(self, buf: collections.deque[bytes]) -> None: """ Callback responds to IAC SB STATUS, :rfc:`859`. @@ -2104,15 +2137,15 @@ def _handle_sb_status(self, buf): else: raise ValueError(f"Illegal byte following IAC SB STATUS: {opt!r}, expected SEND or IS.") - def _receive_status(self, buf): + def _receive_status(self, buf: collections.deque[bytes]) -> None: """ Callback responds to IAC SB STATUS IS, :rfc:`859`. - :param bytes buf: sub-negotiation byte buffer containing status data. This implementation - does its best to analyze our perspective's state to the state options given. Any - discrepancies are reported to the error log, but no action is taken. This implementation - handles malformed STATUS data gracefully by skipping invalid command bytes and - continuing to process the remaining data. + :param buf: sub-negotiation byte buffer containing status data. This implementation does its + best to analyze our perspective's state to the state options given. Any discrepancies + are reported to the error log, but no action is taken. This implementation handles + malformed STATUS data gracefully by skipping invalid command bytes and continuing to + process the remaining data. """ # Convert deque to list for processing buf_list = list(buf) @@ -2174,12 +2207,12 @@ def _receive_status(self, buf): # Move to next pair i += 2 - def _send_status(self): + def _send_status(self) -> None: """Callback responds to IAC SB STATUS SEND, :rfc:`859`.""" if not (self.pending_option.enabled(WILL + STATUS) or self.local_option.enabled(STATUS)): raise ValueError("Only sender of IAC WILL STATUS may reply by IAC SB STATUS IS.") - response = collections.deque() + response: collections.deque[bytes] = collections.deque() response.extend([IAC, SB, STATUS, IS]) for opt, status in self.local_option.items(): # status is 'WILL' for local option states that are True, @@ -2209,7 +2242,7 @@ def _send_status(self): # Special Line Character and other LINEMODE functions. # - def _handle_sb_linemode(self, buf): + def _handle_sb_linemode(self, buf: collections.deque[bytes]) -> None: """Callback responds to bytes following IAC SB LINEMODE.""" buf.popleft() opt = buf.popleft() @@ -2225,19 +2258,21 @@ def _handle_sb_linemode(self, buf): "expected LMODE_FORWARDMASK." ) self.log.debug("recv IAC SB LINEMODE %s LMODE_FORWARDMASK,", name_command(opt)) - self._handle_sb_forwardmask(LINEMODE, buf) + self._handle_sb_forwardmask(opt, buf) else: raise ValueError(f"Illegal IAC SB LINEMODE option {opt!r}") - def _handle_sb_linemode_mode(self, mode): + def _handle_sb_linemode_mode(self, mode: collections.deque[bytes]) -> None: """ Callback handles mode following IAC SB LINEMODE LINEMODE_MODE. - :param bytes mode: a single byte + :param mode: a single byte Result of agreement to enter ``mode`` given applied by setting the value of ``self.linemode``, and sending acknowledgment if necessary. """ + if not mode: + raise ValueError("IAC SB LINEMODE LINEMODE-MODE: missing mode byte") suggest_mode = slc.Linemode(mode[0]) self.log.debug("recv IAC SB LINEMODE LINEMODE-MODE %r IAC SE", suggest_mode.mask) @@ -2290,7 +2325,7 @@ def _handle_sb_linemode_mode(self, mode): self._linemode = suggest_mode - def _handle_sb_linemode_slc(self, buf): + def _handle_sb_linemode_slc(self, buf: collections.deque[bytes]) -> None: """ Callback handles IAC-SB-LINEMODE-SLC-. @@ -2309,7 +2344,7 @@ def _handle_sb_linemode_slc(self, buf): self._slc_end() self.request_forwardmask() - def _slc_end(self): + def _slc_end(self) -> None: """Transmit SLC commands buffered by :meth:`_slc_send`.""" if len(self._slc_buffer): self.log.debug("send (slc_end): %r", b"".join(self._slc_buffer)) @@ -2320,16 +2355,16 @@ def _slc_end(self): self.log.debug("slc_end: [..] IAC SE") self.send_iac(IAC + SE) - def _slc_start(self): + def _slc_start(self) -> None: """Send IAC SB LINEMODE SLC header.""" self.log.debug("slc_start: IAC SB LINEMODE SLC [..]") self.send_iac(IAC + SB + LINEMODE + slc.LMODE_SLC) - def _slc_send(self, slctab=None): + def _slc_send(self, slctab: Optional[dict[bytes, slc.SLC]] = None) -> None: """ Send supported SLC characters of current tabset, or specified tabset. - :param dict slctab: SLC byte tabset as dictionary, such as slc.BSD_SLC_TAB. + :param slctab: SLC byte tabset as dictionary, such as slc.BSD_SLC_TAB. """ send_count = 0 slctab = slctab or self.slctab @@ -2347,7 +2382,7 @@ def _slc_send(self, slctab=None): send_count += 1 self.log.debug("slc_send: %s functions queued.", send_count) - def _slc_add(self, func, slc_def=None): + def _slc_add(self, func: bytes, slc_def: Optional[slc.SLC] = None) -> None: """ Prepare slc triplet response (function, flag, value) for transmission. @@ -2362,7 +2397,7 @@ def _slc_add(self, func, slc_def=None): raise ValueError("SLC: buffer full!") self._slc_buffer.extend([func, slc_def.mask, slc_def.val]) - def _slc_process(self, func, slc_def): + def _slc_process(self, func: bytes, slc_def: slc.SLC) -> None: """ Process an SLC definition provided by remote end. @@ -2412,7 +2447,7 @@ def _slc_process(self, func, slc_def): return self._slc_change(func, slc_def) - def _slc_change(self, func, slc_def): + def _slc_change(self, func: bytes, slc_def: slc.SLC) -> None: """ Update SLC tabset with SLC definition provided by remote end. @@ -2439,7 +2474,9 @@ def _slc_change(self, func, slc_def): self.slctab[func].set_mask(slc.SLC_NOSUPPORT) else: # set current flag to the flag indicated in default tab - self.slctab[func].set_mask(self.default_slc_tab.get(func).mask) + default_slc = self.default_slc_tab.get(func) + assert default_slc is not None + self.slctab[func].set_mask(default_slc.mask) # set current value to value indicated in default tab self.default_slc_tab.get(func, slc.SLC_nosupport()) self.slctab[func].set_value(slc_def.val) @@ -2476,37 +2513,43 @@ def _slc_change(self, func, slc_def): self.slctab[func].val = slc_def.val self._slc_add(func) - def _handle_sb_forwardmask(self, cmd, buf): + def _handle_sb_forwardmask(self, cmd: bytes, buf: collections.deque[bytes]) -> None: """ Callback handles request for LINEMODE LMODE_FORWARDMASK. - :param bytes cmd: one of DO, DONT, WILL, WONT. - :param bytes buf: bytes following IAC SB LINEMODE DO FORWARDMASK. + :param cmd: one of DO, DONT, WILL, WONT. + :param buf: bytes following IAC SB LINEMODE DO FORWARDMASK. """ # set and report about pending options by 2-byte opt, # not well tested, no known implementations exist ! if self.server: - assert self.remote_option.enabled(LINEMODE), ( - f"cannot recv LMODE_FORWARDMASK {cmd} ({buf!r}) " - "without first sending DO LINEMODE." - ) - assert cmd not in ( - DO, - DONT, - ), f"cannot recv {name_command(cmd)} LMODE_FORWARDMASK on server end" + if not self.remote_option.enabled(LINEMODE): + self.log.info( + "receive and accept LMODE_FORWARDMASK %s without LINEMODE enabled", + name_command(cmd), + ) + if cmd in (DO, DONT): + self.log.warning( + "cannot recv %s LMODE_FORWARDMASK on server end", name_command(cmd) + ) + return if self.client: - assert self.local_option.enabled(LINEMODE), ( - f"cannot recv {name_command(cmd)} LMODE_FORWARDMASK " - "without first sending WILL LINEMODE." - ) - assert cmd not in ( - WILL, - WONT, - ), f"cannot recv {name_command(cmd)} LMODE_FORWARDMASK on client end" - assert ( - cmd not in (DONT,) or len(buf) == 0 - ), f"Illegal bytes follow DONT LMODE_FORWARDMASK: {buf!r}" - assert cmd not in (DO,) and len(buf), "bytes must follow DO LMODE_FORWARDMASK" + if not self.local_option.enabled(LINEMODE): + self.log.info( + "receive and accept LMODE_FORWARDMASK %s without LINEMODE enabled", + name_command(cmd), + ) + if cmd in (WILL, WONT): + self.log.warning( + "cannot recv %s LMODE_FORWARDMASK on client end", name_command(cmd) + ) + return + if cmd == DONT and len(buf) > 0: + self.log.warning("Illegal bytes follow DONT LMODE_FORWARDMASK: %r", buf) + return + if cmd == DO and len(buf) == 0: + self.log.warning("bytes must follow DO LMODE_FORWARDMASK") + return opt = SB + LINEMODE + slc.LMODE_FORWARDMASK if cmd in ( @@ -2522,32 +2565,32 @@ def _handle_sb_forwardmask(self, cmd, buf): if cmd == DO: self._handle_do_forwardmask(buf) - def _handle_sb_comport(self, buf): + def _handle_sb_comport(self, buf: collections.deque[bytes]) -> None: """ Callback handles IAC-SB-COM-PORT-OPTION. This callback simply logs the subnegotiation but does not perform any action. - :param bytes buf: bytes following IAC SB LINEMODE DO FORWARDMASK. + :param buf: bytes following IAC SB COM-PORT-OPTION. """ self.log.debug("SB unhandled: cmd=%s, buf=%r", name_command(COM_PORT_OPTION), buf) - def _handle_sb_gmcp(self, buf): + def _handle_sb_gmcp(self, buf: collections.deque[bytes]) -> None: """ Callback handles request for Generic Mud Communication Protocol (GMCP). This callback simply logs the subnegotiation but does not perform any action. - :param bytes buf: bytes following IAC SB GMCP. + :param buf: bytes following IAC SB GMCP. """ self.log.debug("SB unhandled: cmd=%s, buf=%r", name_command(GMCP), b"".join(buf)) - def _handle_do_forwardmask(self, buf): + def _handle_do_forwardmask(self, buf: collections.deque[bytes]) -> None: """ Callback handles request for LINEMODE DO FORWARDMASK. - :param bytes buf: bytes following IAC SB LINEMODE DO FORWARDMASK. :raises - NotImplementedError + :param buf: bytes following IAC SB LINEMODE DO FORWARDMASK. + :raises NotImplementedError: """ raise NotImplementedError @@ -2566,22 +2609,40 @@ class TelnetWriterUnicode(TelnetWriter): # pylint: disable=abstract-method discovered by ``LANG`` environment variables by NEW_ENVIRON, :rfc:`1572`. """ - def __init__(self, transport, protocol, fn_encoding, *, encoding_errors="strict", **kwds): + def __init__( + self, + transport: asyncio.Transport, + protocol: Any, + fn_encoding: Callable[..., str], + *, + encoding_errors: str = "strict", + client: bool = False, + server: bool = False, + reader: Optional["TelnetReader"] = None, + ) -> None: """Initialize TelnetWriterUnicode with encoding callback.""" self.fn_encoding = fn_encoding self.encoding_errors = encoding_errors - super().__init__(transport, protocol, **kwds) + super().__init__( + transport, + protocol, + client=client, + server=server, + reader=reader, + ) - def encode(self, string, errors): + def encode(self, string: str, errors: Optional[str] = None) -> bytes: """ Encode ``string`` using protocol-preferred encoding. - :param str string: unicode string to encode. - :param str errors: same as meaning in :meth:`codecs.Codec.encode`, when + :param string: unicode string to encode. + :param errors: same as meaning in :meth:`codecs.Codec.encode`, when ``None`` (default), value of class initializer keyword argument, ``encoding_errors``. - .. note: though a unicode interface, when ``outbinary`` mode has not + .. note:: + + Though a unicode interface, when ``outbinary`` mode has not been protocol negotiated, ``fn_encoding`` strictly enforces 7-bit ASCII range (ordinal byte values less than 128), as a strict compliance of the telnet RFC. @@ -2589,17 +2650,19 @@ def encode(self, string, errors): encoding = self.fn_encoding(outgoing=True) return bytes(string, encoding, errors or self.encoding_errors) - def write(self, string, errors=None): # pylint: disable=arguments-renamed + def write( # type: ignore[override] # pylint: disable=arguments-renamed + self, string: str, errors: Optional[str] = None + ) -> None: """ Write unicode string to transport, using protocol-preferred encoding. If the connection is closed, nothing is done. - :param str string: unicode string text to write to endpoint using the + :param string: unicode string text to write to endpoint using the protocol's preferred encoding. When the protocol ``encoding`` keyword is explicitly set to ``False``, the given string should be only raw ``b'bytes'``. - :param str errors: same as meaning in :meth:`codecs.Codec.encode`, when + :param errors: same as meaning in :meth:`codecs.Codec.encode`, when ``None`` (default), value of class initializer keyword argument, ``encoding_errors``. """ @@ -2608,7 +2671,9 @@ def write(self, string, errors=None): # pylint: disable=arguments-renamed errors = errors or self.encoding_errors self._write(self.encode(string, errors)) - def writelines(self, lines, errors=None): + def writelines( # type: ignore[override] + self, lines: Sequence[str], errors: Optional[str] = None + ) -> None: """ Write unicode strings to transport. @@ -2617,12 +2682,14 @@ def writelines(self, lines, errors=None): """ self.write(string="".join(lines), errors=errors) - def echo(self, string, errors=None): # pylint: disable=arguments-renamed + def echo( # type: ignore[override] # pylint: disable=arguments-renamed + self, string: str, errors: Optional[str] = None + ) -> None: """ Conditionally write ``string`` to transport when "remote echo" enabled. - :param str string: string received as input, conditionally written. - :param str errors: same as meaning in :meth:`codecs.Codec.encode`. + :param string: string received as input, conditionally written. + :param errors: same as meaning in :meth:`codecs.Codec.encode`. This method may only be called from the server perspective. The default implementation depends on telnet negotiation willingness for @@ -2634,7 +2701,7 @@ def echo(self, string, errors=None): # pylint: disable=arguments-renamed self.write(string=string, errors=errors) -class Option(dict): +class Option(dict[bytes, bool]): """ Telnet option state negotiation helper class. @@ -2642,11 +2709,16 @@ class Option(dict): telnet option negotiation. """ - def __init__(self, name, log, on_change=None): + def __init__( + self, + name: str, + log: logging.Logger, + on_change: Optional[Callable[[], None]] = None, + ) -> None: """ Class initializer. - :param str name: decorated name representing option class, such as 'local', 'remote', or + :param name: decorated name representing option class, such as 'local', 'remote', or 'pending'. :param on_change: optional callback invoked when option state changes. """ @@ -2654,16 +2726,15 @@ def __init__(self, name, log, on_change=None): self._on_change = on_change dict.__init__(self) - def enabled(self, key): + def enabled(self, key: bytes) -> bool: """ Return True if option is enabled. - :param bytes key: telnet option - :rtype: bool + :param key: telnet option byte(s). """ return bool(self.get(key, None) is True) - def __setitem__(self, key, value): + def __setitem__(self, key: bytes, value: bool) -> None: # the real purpose of this class, tracking state negotiation. if value != dict.get(self, key, None): descr = " + ".join( @@ -2675,41 +2746,38 @@ def __setitem__(self, key, value): self._on_change() -def _escape_environ(buf): +def _escape_environ(buf: bytes) -> bytes: """ Return new buffer with VAR and USERVAR escaped, if present in ``buf``. - :param bytes buf: given bytes buffer - :returns: bytes buffer with escape characters inserted. - :rtype: bytes + :param buf: given bytes buffer. + :returns: buffer with escape characters inserted. """ return buf.replace(VAR, ESC + VAR).replace(USERVAR, ESC + USERVAR) -def _unescape_environ(buf): +def _unescape_environ(buf: bytes) -> bytes: """ Return new buffer with escape characters removed for VAR and USERVAR. - :param bytes buf: given bytes buffer - :returns: bytes buffer with escape characters removed. - :rtype: bytes + :param buf: given bytes buffer. + :returns: buffer with escape characters removed. """ return buf.replace(ESC + VAR, VAR).replace(ESC + USERVAR, USERVAR) -def _encode_env_buf(env): +def _encode_env_buf(env: dict[str, str]) -> bytes: """ Encode dictionary for transmission as environment variables, :rfc:`1572`. - :param bytes buf: dictionary of environment values. - :returns: bytes buffer meant to follow sequence IAC SB NEW_ENVIRON IS. + :param env: dictionary of environment values. + :returns: buffer meant to follow sequence IAC SB NEW_ENVIRON IS. It is not terminated by IAC SE. - :rtype: bytes Returns bytes array ``buf`` for use in sequence (IAC, SB, NEW_ENVIRON, IS, , IAC, SE) as set forth in :rfc:`1572`. """ - buf = collections.deque() + buf: collections.deque[bytes] = collections.deque() for key, value in env.items(): buf.append(VAR) buf.extend([_escape_environ(key.encode("ascii"))]) @@ -2718,14 +2786,13 @@ def _encode_env_buf(env): return b"".join(buf) -def _decode_env_buf(buf): +def _decode_env_buf(buf: bytes) -> dict[str, str]: """ Decode environment values to dictionary, :rfc:`1572`. - :param bytes buf: bytes array following sequence IAC SB NEW_ENVIRON + :param buf: bytes array following sequence IAC SB NEW_ENVIRON SEND or IS up to IAC SE. :returns: dictionary representing the environment values decoded from buf. - :rtype: dict This implementation does not distinguish between ``USERVAR`` and ``VAR``. """ diff --git a/telnetlib3/sync.py b/telnetlib3/sync.py index e77f6f88..201d8d37 100644 --- a/telnetlib3/sync.py +++ b/telnetlib3/sync.py @@ -27,6 +27,8 @@ def handler(conn): server.serve_forever() """ +from __future__ import annotations + # std imports import time import queue @@ -53,10 +55,10 @@ class TelnetConnection: Wraps async ``telnetlib3.open_connection()`` with blocking methods. The asyncio event loop runs in a daemon thread. - :param str host: Remote server hostname or IP address. - :param int port: Remote server port (default 23). - :param float timeout: Default timeout for operations in seconds. - :param str encoding: Character encoding (default 'utf8'). + :param host: Remote server hostname or IP address. + :param port: Remote server port (default 23). + :param timeout: Default timeout for operations in seconds. + :param encoding: Character encoding (default 'utf8'). :param kwargs: Additional arguments passed to ``telnetlib3.open_connection()``. Example:: @@ -124,11 +126,20 @@ def _run_loop(self) -> None: async def _async_connect(self) -> None: """Async connection coroutine.""" + kwargs = dict(self._kwargs) + # Default to TelnetClient (not TelnetTerminalClient) — the blocking API + # is programmatic, not a terminal app, so it should use the cols/rows + # parameters rather than reading the real terminal size. + if "client_factory" not in kwargs: + # local + from .client import TelnetClient # pylint: disable=import-outside-toplevel + + kwargs["client_factory"] = TelnetClient self._reader, self._writer = await _open_connection( self._host, self._port, encoding=self._encoding, - **self._kwargs, + **kwargs, ) self._connected.set() @@ -145,8 +156,8 @@ def read(self, n: int = -1, timeout: Optional[float] = None) -> Union[str, bytes Blocks until data is available or timeout expires. - :param int n: Maximum bytes to read (-1 for any available data). - :param float timeout: Timeout in seconds (uses default if None). + :param n: Maximum bytes to read (-1 for any available data). + :param timeout: Timeout in seconds (uses default if None). :returns: Data read from connection. :raises TimeoutError: If timeout expires before data available. :raises EOFError: If connection closed. @@ -170,7 +181,7 @@ def read_some(self, timeout: Optional[float] = None) -> Union[str, bytes]: Alias for :meth:`read` for compatibility with old telnetlib. - :param float timeout: Timeout in seconds. + :param timeout: Timeout in seconds. :returns: Data read from connection. """ return self.read(-1, timeout=timeout) @@ -181,7 +192,7 @@ def readline(self, timeout: Optional[float] = None) -> Union[str, bytes]: Blocks until a complete line is received or timeout expires. - :param float timeout: Timeout in seconds (uses default if None). + :param timeout: Timeout in seconds (uses default if None). :returns: Line including terminator. :raises TimeoutError: If timeout expires. :raises EOFError: If connection closed before line complete. @@ -208,7 +219,7 @@ def read_until( Like old telnetlib's read_until method. :param match: String or bytes to match. - :param float timeout: Timeout in seconds (uses default if None). + :param timeout: Timeout in seconds (uses default if None). :returns: Data up to and including match. :raises TimeoutError: If timeout expires before match found. :raises EOFError: If connection closed before match found. @@ -239,7 +250,8 @@ def write(self, data: Union[str, bytes]) -> None: """ self._ensure_connected() assert self._loop is not None and self._writer is not None - self._loop.call_soon_threadsafe(self._writer.write, data) + # writer may be TelnetWriter (bytes) or TelnetWriterUnicode (str) + self._loop.call_soon_threadsafe(self._writer.write, data) # type: ignore[arg-type] def flush(self, timeout: Optional[float] = None) -> None: """ @@ -247,7 +259,7 @@ def flush(self, timeout: Optional[float] = None) -> None: Blocks until all buffered data has been sent. - :param float timeout: Timeout in seconds (uses default if None). + :param timeout: Timeout in seconds (uses default if None). :raises TimeoutError: If timeout expires. """ self._ensure_connected() @@ -305,7 +317,7 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: - ``'peername'``: Remote address tuple (host, port) - ``'LANG'``: Language/locale setting - :param str name: Information key. + :param name: Information key. :param default: Default value if key not found. :returns: Information value or default. """ @@ -315,9 +327,9 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: def wait_for( self, - remote: Optional[dict] = None, - local: Optional[dict] = None, - pending: Optional[dict] = None, + remote: Optional[dict[str, bool]] = None, + local: Optional[dict[str, bool]] = None, + pending: Optional[dict[str, bool]] = None, timeout: Optional[float] = None, ) -> None: """ @@ -396,8 +408,8 @@ class BlockingTelnetServer: Wraps async ``telnetlib3.create_server()`` with a blocking interface. Each client connection can be handled in a separate thread. - :param str host: Address to bind to. - :param int port: Port to bind to (default 6023). + :param host: Address to bind to. + :param port: Port to bind to (default 6023). :param handler: Function called for each client connection. Receives a :class:`TelnetConnection`-like object as argument. :param kwargs: Additional arguments passed to ``telnetlib3.create_server()``. @@ -437,7 +449,7 @@ def __init__( self._loop: Optional[asyncio.AbstractEventLoop] = None self._thread: Optional[threading.Thread] = None self._server: Optional[Server] = None - self._client_queue: queue.Queue = queue.Queue() + self._client_queue: queue.Queue[ServerConnection] = queue.Queue() self._started = threading.Event() self._shutdown = threading.Event() @@ -493,7 +505,7 @@ def accept(self, timeout: Optional[float] = None) -> "ServerConnection": Blocks until a client connects. - :param float timeout: Timeout in seconds (None for no timeout). + :param timeout: Timeout in seconds (None for no timeout). :returns: Connection object for the client. :raises TimeoutError: If timeout expires. :raises RuntimeError: If server not started. @@ -617,8 +629,8 @@ def read(self, n: int = -1, timeout: Optional[float] = None) -> Union[str, bytes """ Read up to n bytes/characters from the connection. - :param int n: Maximum bytes to read (-1 for any available data). - :param float timeout: Timeout in seconds. + :param n: Maximum bytes to read (-1 for any available data). + :param timeout: Timeout in seconds. :returns: Data read from connection. :raises RuntimeError: If connection already closed. :raises TimeoutError: If timeout expires. @@ -643,7 +655,7 @@ def read_some(self, timeout: Optional[float] = None) -> Union[str, bytes]: Alias for :meth:`read` for compatibility with old telnetlib. - :param float timeout: Timeout in seconds. + :param timeout: Timeout in seconds. :returns: Data read from connection. """ return self.read(-1, timeout=timeout) @@ -652,7 +664,7 @@ def readline(self, timeout: Optional[float] = None) -> Union[str, bytes]: """ Read one line from the connection. - :param float timeout: Timeout in seconds. + :param timeout: Timeout in seconds. :returns: Line including terminator. :raises RuntimeError: If connection already closed. :raises TimeoutError: If timeout expires. @@ -678,7 +690,7 @@ def read_until( Read until match is found. :param match: String or bytes to match. - :param float timeout: Timeout in seconds. + :param timeout: Timeout in seconds. :returns: Data up to and including match. :raises RuntimeError: If connection already closed. :raises TimeoutError: If timeout expires. @@ -709,13 +721,13 @@ def write(self, data: Union[str, bytes]) -> None: """ if self._closed: raise RuntimeError("Connection closed") - self._loop.call_soon_threadsafe(self._writer.write, data) + self._loop.call_soon_threadsafe(self._writer.write, data) # type: ignore[arg-type] def flush(self, timeout: Optional[float] = None) -> None: """ Flush buffered data to the connection. - :param float timeout: Timeout in seconds. + :param timeout: Timeout in seconds. :raises RuntimeError: If connection already closed. :raises TimeoutError: If timeout expires. """ @@ -747,7 +759,7 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: - ``'rows'``: Terminal height in rows - ``'peername'``: Remote address tuple (host, port) - :param str name: Information key. + :param name: Information key. :param default: Default value if key not found. :returns: Information value or default. """ @@ -755,9 +767,9 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: def wait_for( self, - remote: Optional[dict] = None, - local: Optional[dict] = None, - pending: Optional[dict] = None, + remote: Optional[dict[str, bool]] = None, + local: Optional[dict[str, bool]] = None, + pending: Optional[dict[str, bool]] = None, timeout: Optional[float] = None, ) -> None: """ @@ -831,17 +843,20 @@ def port(self) -> int: @property def terminal_type(self) -> str: """Client terminal type (miniboa-compatible).""" - return self.get_extra_info("TERM", "unknown") + result: str = self.get_extra_info("TERM", "unknown") + return result @property def columns(self) -> int: """Terminal width (miniboa-compatible).""" - return self.get_extra_info("cols", 80) + result: int = self.get_extra_info("cols", 80) + return result @property def rows(self) -> int: """Terminal height (miniboa-compatible).""" - return self.get_extra_info("rows", 24) + result: int = self.get_extra_info("rows", 24) + return result @property def connect_time(self) -> float: diff --git a/telnetlib3/telnetlib.py b/telnetlib3/telnetlib.py index 12eb639a..41f8bfa7 100644 --- a/telnetlib3/telnetlib.py +++ b/telnetlib3/telnetlib.py @@ -36,6 +36,7 @@ # Imported modules import sys import socket +import _thread import selectors from time import monotonic as _time @@ -636,9 +637,6 @@ def interact(self): def mt_interact(self): """Multithreaded version of interact().""" - # std imports - import _thread # pylint: disable=import-outside-toplevel - _thread.start_new_thread(self.listener, ()) while 1: line = sys.stdin.readline() @@ -754,5 +752,5 @@ def main(): tn.interact() -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/telnetlib3/telopt.py b/telnetlib3/telopt.py index 65513a18..b0d32426 100644 --- a/telnetlib3/telopt.py +++ b/telnetlib3/telopt.py @@ -1,5 +1,8 @@ """Telnet option constants exported from the deprecated telnetlib module.""" +# std imports +from typing import Dict + # Exported from the telnetlib module, which is marked for deprecation in version # 3.11 and removal in 3.13 LINEMODE = b'"' @@ -181,7 +184,7 @@ GMCP = bytes([201]) #: List of globals that may match an iac command option bytes -_DEBUG_OPTS = { +_DEBUG_OPTS: Dict[bytes, str] = { value: key for key, value in globals().items() if key @@ -268,25 +271,25 @@ } #: Reverse mapping of option names to option bytes -_NAME_TO_OPT = {name: opt for opt, name in _DEBUG_OPTS.items()} +_NAME_TO_OPT: Dict[str, bytes] = {name: opt for opt, name in _DEBUG_OPTS.items()} -def option_from_name(name): +def option_from_name(name: str) -> bytes: """ Return option bytes for a given option name. - :param str name: Option name (e.g., "NAWS", "TTYPE") + :param name: Option name (e.g., "NAWS", "TTYPE") :returns: Option bytes :raises KeyError: If name is not a known telnet option """ return _NAME_TO_OPT[name.upper()] -def name_command(byte): +def name_command(byte: bytes) -> str: """Return string description for (maybe) telnet command byte.""" return _DEBUG_OPTS.get(byte, repr(byte)) -def name_commands(cmds, sep=" "): +def name_commands(cmds: bytes, sep: str = " ") -> str: """Return string description for array of (maybe) telnet command bytes.""" return sep.join([name_command(bytes([byte])) for byte in cmds]) diff --git a/telnetlib3/tests/accessories.py b/telnetlib3/tests/accessories.py index 18edecb6..262f7956 100644 --- a/telnetlib3/tests/accessories.py +++ b/telnetlib3/tests/accessories.py @@ -75,7 +75,7 @@ async def connection_context(reader, writer): @contextlib.asynccontextmanager async def create_server(*args, **kwargs): """Create a telnetlib3 server with automatic cleanup.""" - # local - avoid circular import + # local import to avoid circular import # local import telnetlib3 @@ -90,10 +90,20 @@ async def create_server(*args, **kwargs): @contextlib.asynccontextmanager async def open_connection(*args, **kwargs): """Open a telnetlib3 connection with automatic cleanup.""" - # local - avoid circular import + # local import to avoid circular import # local import telnetlib3 + # Force deterministic client: TelnetTerminalClient reads the real terminal + # size via TIOCGWINSZ, ignoring cols/rows parameters. Use TelnetClient so + # tests get consistent behavior regardless of whether stdin is a TTY. + if "client_factory" not in kwargs: + # local import to avoid circular import + # local + from telnetlib3.client import TelnetClient + + kwargs["client_factory"] = TelnetClient + reader, writer = await telnetlib3.open_connection(*args, **kwargs) try: yield reader, writer diff --git a/telnetlib3/tests/pty_helper.py b/telnetlib3/tests/pty_helper.py index 9e7685e4..1891e83c 100644 --- a/telnetlib3/tests/pty_helper.py +++ b/telnetlib3/tests/pty_helper.py @@ -20,6 +20,7 @@ # std imports import os import sys +import time def cat_mode(): @@ -43,6 +44,7 @@ def echo_mode(args): def stty_size_mode(): """Print terminal size.""" + # imported locally to avoid error on import with windows systems # std imports import fcntl import struct @@ -75,9 +77,6 @@ def env_mode(args): def sleep_mode(args): """Sleep for specified seconds.""" - # std imports - import time - seconds = float(args[0]) if args else 60 time.sleep(seconds) @@ -103,9 +102,6 @@ def partial_utf8_mode(): """Output incomplete UTF-8 then complete it.""" sys.stdout.buffer.write(b"hello\xc3") sys.stdout.buffer.flush() - # std imports - import time - time.sleep(0.1) sys.stdout.buffer.write(b"\xa9world\n") sys.stdout.buffer.flush() @@ -140,5 +136,5 @@ def main(): modes[mode]() -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/telnetlib3/tests/test_accessories.py b/telnetlib3/tests/test_accessories.py index 04017b2a..4092fabc 100644 --- a/telnetlib3/tests/test_accessories.py +++ b/telnetlib3/tests/test_accessories.py @@ -48,12 +48,21 @@ def test_encoding_from_lang(): "en_US.UTF-8": "UTF-8", "abc.def": "def", ".def@ghi": "def", - "def@": "def", - "UTF-8": "UTF-8", } for given, expected in sorted(given_expected.items()): - # exercise, result = encoding_from_lang(given) + assert result == expected - # verify, + +def test_encoding_from_lang_no_encoding(): + """Test LANG values without encoding suffix return None.""" + given_expected = { + "en_IL": None, + "en_US": None, + "C": None, + "POSIX": None, + "UTF-8": None, + } + for given, expected in sorted(given_expected.items()): + result = encoding_from_lang(given) assert result == expected diff --git a/telnetlib3/tests/test_benchmarks.py b/telnetlib3/tests/test_benchmarks.py index aae71476..b51a7b4f 100644 --- a/telnetlib3/tests/test_benchmarks.py +++ b/telnetlib3/tests/test_benchmarks.py @@ -192,6 +192,7 @@ async def shell(reader, writer): encoding=False, connect_minwait=0.05, connect_maxwait=0.1, + client_factory=telnetlib3.TelnetClient, ) await server_ready.wait() diff --git a/telnetlib3/tests/test_charset.py b/telnetlib3/tests/test_charset.py index daa54bb2..08f6e9b7 100644 --- a/telnetlib3/tests/test_charset.py +++ b/telnetlib3/tests/test_charset.py @@ -22,11 +22,13 @@ from telnetlib3.stream_writer import TelnetWriter from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, + create_server, + asyncio_server, + open_connection, unused_tcp_port, + asyncio_connection, ) -# local imports - # --- Common Mock Classes --- @@ -95,9 +97,6 @@ def send_charset(self, offered): async def test_telnet_server_on_charset(bind_host, unused_tcp_port): """Test Server's callback method on_charset().""" - # local - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_charset = "KOI8-U" @@ -124,9 +123,6 @@ def on_charset(self, charset): async def test_telnet_client_send_charset(bind_host, unused_tcp_port): """Test Client's callback method send_charset() selection for illegals.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - _waiter = asyncio.Future() server_instance = {"protocol": None} @@ -165,9 +161,6 @@ def send_charset(self, offered): async def test_telnet_client_no_charset(bind_host, unused_tcp_port): """Test Client's callback method send_charset() does not select.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - _waiter = asyncio.Future() server_instance = {"protocol": None} @@ -418,9 +411,6 @@ def test_unit_charset_negotiation_sequence(): async def test_charset_send_unknown_encoding(bind_host, unused_tcp_port): """Test client with unknown encoding value.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection - async with asyncio_server(asyncio.Protocol, bind_host, unused_tcp_port): async with open_connection( client_factory=lambda **kwargs: CustomTelnetClient( @@ -435,9 +425,6 @@ async def test_charset_send_unknown_encoding(bind_host, unused_tcp_port): async def test_charset_send_no_viable_offers(bind_host, unused_tcp_port): """Test client with no viable encoding offers.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection - async with asyncio_server(asyncio.Protocol, bind_host, unused_tcp_port): async with open_connection( client_factory=lambda **kwargs: CustomTelnetClient( @@ -453,9 +440,6 @@ async def test_charset_send_no_viable_offers(bind_host, unused_tcp_port): async def test_charset_explicit_non_latin1_encoding(bind_host, unused_tcp_port): """Test client rejecting offered encodings when explicit non-latin1 is set.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection - async with asyncio_server(asyncio.Protocol, bind_host, unused_tcp_port): async with open_connection( client_factory=lambda **kwargs: CustomTelnetClient( diff --git a/telnetlib3/tests/test_client_unit.py b/telnetlib3/tests/test_client_unit.py new file mode 100644 index 00000000..7ec1b8ae --- /dev/null +++ b/telnetlib3/tests/test_client_unit.py @@ -0,0 +1,215 @@ +# std imports +import sys + +# 3rd party +import pytest + +# local +from telnetlib3 import client as cl +from telnetlib3.tests.accessories import ( # noqa: F401 # pylint: disable=unused-import + bind_host, + create_server, +) + +_CLIENT_DEFAULTS = { + "encoding": "utf8", + "encoding_errors": "strict", + "force_binary": False, + "connect_minwait": 0.01, + "connect_maxwait": 0.02, +} + + +def _make_client(**kwargs): + return cl.TelnetClient(**{**_CLIENT_DEFAULTS, **kwargs}) + + +def _make_terminal_client(**kwargs): + return cl.TelnetTerminalClient(**{**_CLIENT_DEFAULTS, **kwargs}) + + +@pytest.mark.parametrize( + "offered,encoding,expected", + [ + pytest.param(["utf-8"], "utf8", "utf-8", id="exact_match"), + pytest.param(["latin-1"], "utf8", "", id="no_match_reject"), + pytest.param(["utf-8", "latin-1"], "latin-1", "latin-1", id="latin1_exact_match"), + pytest.param(["utf-8"], False, "utf-8", id="no_encoding_accepts_viable"), + pytest.param(["utf-8"], "not-a-real-encoding-xyz", "utf-8", id="unknown_encoding"), + pytest.param( + ["iso-8859-1", "utf-8"], + "not-a-real-encoding-xyz", + "iso-8859-1", + id="no_pref_first_viable", + ), + pytest.param( + ["zzz-fake-1", "zzz-fake-2"], + "not-a-real-encoding-xyz", + "", + id="no_viable_encodings", + ), + pytest.param(["utf-8"], "latin-1", "utf-8", id="latin1_weak_default"), + ], +) +@pytest.mark.asyncio +async def test_send_charset(offered, encoding, expected): + c = _make_client(encoding=encoding) + assert c.send_charset(offered) == expected + + +@pytest.mark.asyncio +async def test_send_charset_null_default(): + c = _make_client() + c.default_encoding = None + assert not c.send_charset(["zzz-fake-1"]) + assert c.send_charset(["utf-8"]) == "utf-8" + + +@pytest.mark.asyncio +async def test_send_env(): + c = _make_client(term="xterm", cols=132, rows=43) + env = c.send_env(["TERM", "LANG"]) + assert env["TERM"] == "xterm" + assert "utf8" in env["LANG"] + + c2 = _make_client() + env2 = c2.send_env([]) + assert "TERM" in env2 and "LANG" in env2 + + +@pytest.mark.asyncio +async def test_send_naws(): + assert _make_client(rows=24, cols=80).send_naws() == (24, 80) + + +@pytest.mark.asyncio +async def test_send_ttype(): + assert _make_client(term="vt220").send_ttype() == "vt220" + + +@pytest.mark.asyncio +async def test_send_tspeed(): + assert _make_client(tspeed=(9600, 9600)).send_tspeed() == (9600, 9600) + + +@pytest.mark.asyncio +async def test_send_xdisploc(): + assert _make_client(xdisploc="myhost:0.0").send_xdisploc() == "myhost:0.0" + + +@pytest.mark.skipif(sys.platform == "win32", reason="requires fcntl") +def test_terminal_client_winsize_success(monkeypatch): + # std imports + import fcntl + import struct + + fake_data = struct.pack("hhhh", 42, 120, 0, 0) + monkeypatch.setattr(fcntl, "ioctl", lambda fd, req, buf: fake_data) + assert cl.TelnetTerminalClient._winsize() == (42, 120) + + +@pytest.mark.skipif(sys.platform == "win32", reason="requires fcntl") +def test_terminal_client_winsize_ioerror(monkeypatch): + # std imports + import fcntl + + monkeypatch.setenv("LINES", "30") + monkeypatch.setenv("COLUMNS", "100") + + def _raise(*args, **kwargs): + raise IOError("not a tty") + + monkeypatch.setattr(fcntl, "ioctl", _raise) + assert cl.TelnetTerminalClient._winsize() == (30, 100) + + +@pytest.mark.skipif(sys.platform == "win32", reason="requires fcntl") +@pytest.mark.asyncio +async def test_terminal_client_send_naws(monkeypatch): + # std imports + import fcntl + + monkeypatch.setenv("LINES", "48") + monkeypatch.setenv("COLUMNS", "160") + monkeypatch.setattr(fcntl, "ioctl", lambda *a, **kw: (_ for _ in ()).throw(IOError)) + assert _make_terminal_client().send_naws() == (48, 160) + + +@pytest.mark.skipif(sys.platform == "win32", reason="requires fcntl") +@pytest.mark.asyncio +async def test_terminal_client_send_env(monkeypatch): + # std imports + import fcntl + + def _raise(*args, **kwargs): + raise IOError("not a tty") + + monkeypatch.setenv("LINES", "48") + monkeypatch.setenv("COLUMNS", "160") + monkeypatch.setattr(fcntl, "ioctl", _raise) + env = _make_terminal_client().send_env(["LINES", "COLUMNS"]) + assert env["LINES"] == 48 and env["COLUMNS"] == 160 + + +def test_argument_parser(): + parser = cl._get_argument_parser() + args = parser.parse_args(["example.com", "2323"]) + assert args.host == "example.com" and args.port == 2323 and args.encoding == "utf8" + + defaults = parser.parse_args(["myhost"]) + assert defaults.port == 23 and defaults.force_binary is True and defaults.speed == 38400 + + +def test_transform_args(): + parser = cl._get_argument_parser() + result = cl._transform_args( + parser.parse_args(["myhost", "5555", "--encoding", "latin-1", "--speed", "9600"]) + ) + assert result["host"] == "myhost" and result["port"] == 5555 + assert result["encoding"] == "latin-1" and result["tspeed"] == (9600, 9600) + assert callable(result["shell"]) and "TERM" in result["send_environ"] + + result2 = cl._transform_args(parser.parse_args(["host", "--send-environ", "TERM,LANG"])) + assert result2["send_environ"] == ("TERM", "LANG") + + +@pytest.mark.asyncio +async def test_open_connection_default_factory(bind_host, unused_tcp_port, monkeypatch): + monkeypatch.setattr(sys.stdin, "isatty", lambda: False) + + async with create_server( + host=bind_host, + port=unused_tcp_port, + connect_maxwait=0.05, + ): + reader, writer = await cl.open_connection( + host=bind_host, + port=unused_tcp_port, + connect_minwait=0.05, + connect_maxwait=0.1, + encoding=False, + ) + assert isinstance(writer.protocol, cl.TelnetClient) + assert not isinstance(writer.protocol, cl.TelnetTerminalClient) + writer.close() + + +@pytest.mark.skipif(sys.platform == "win32", reason="TTY factory not used on win32") +@pytest.mark.asyncio +async def test_open_connection_tty_factory(bind_host, unused_tcp_port, monkeypatch): + monkeypatch.setattr(sys.stdin, "isatty", lambda: True) + + async with create_server( + host=bind_host, + port=unused_tcp_port, + connect_maxwait=0.05, + ): + reader, writer = await cl.open_connection( + host=bind_host, + port=unused_tcp_port, + connect_minwait=0.05, + connect_maxwait=0.1, + encoding=False, + ) + assert isinstance(writer.protocol, cl.TelnetTerminalClient) + writer.close() diff --git a/telnetlib3/tests/test_core.py b/telnetlib3/tests/test_core.py index 4014e26f..f60acfb2 100644 --- a/telnetlib3/tests/test_core.py +++ b/telnetlib3/tests/test_core.py @@ -13,28 +13,41 @@ import pexpect # local -# local imports import telnetlib3 +from telnetlib3.telopt import ( + DO, + IS, + SB, + SE, + IAC, + SGA, + ECHO, + NAWS, + WILL, + WONT, + TTYPE, + BINARY, + CHARSET, + NEW_ENVIRON, +) from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, + create_server, + asyncio_server, + open_connection, unused_tcp_port, + asyncio_connection, ) async def test_create_server(bind_host, unused_tcp_port): """Test telnetlib3.create_server basic instantiation.""" - # local - from telnetlib3.tests.accessories import create_server - async with create_server(host=bind_host, port=unused_tcp_port): pass async def test_create_server_conditionals(bind_host, unused_tcp_port): """Test telnetlib3.create_server conditionals.""" - # local - from telnetlib3.tests.accessories import create_server - async with create_server( protocol_factory=lambda: telnetlib3.TelnetServer, host=bind_host, @@ -45,9 +58,6 @@ async def test_create_server_conditionals(bind_host, unused_tcp_port): async def test_create_server_on_connect(bind_host, unused_tcp_port): """Test on_connect() anonymous function callback of create_server.""" - # local - from telnetlib3.tests.accessories import create_server, asyncio_connection - call_tracker = {"called": False, "transport": None} class TrackingProtocol(asyncio.Protocol): @@ -71,10 +81,6 @@ def connection_made(self, transport): async def test_telnet_server_open_close(bind_host, unused_tcp_port): """Test telnetlib3.TelnetServer() instantiation and connection_made().""" - # local - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server(host=bind_host, port=unused_tcp_port) as server: async with asyncio_connection(bind_host, unused_tcp_port) as ( stream_reader, @@ -92,9 +98,6 @@ async def test_telnet_server_open_close(bind_host, unused_tcp_port): async def test_telnet_client_open_close_by_write(bind_host, unused_tcp_port): """Exercise BaseClient.connection_lost() on writer closed.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection - async with asyncio_server(asyncio.Protocol, bind_host, unused_tcp_port): async with open_connection(host=bind_host, port=unused_tcp_port, connect_minwait=0.05) as ( reader, @@ -108,8 +111,6 @@ async def test_telnet_client_open_close_by_write(bind_host, unused_tcp_port): async def test_telnet_client_open_closed_by_peer(bind_host, unused_tcp_port): """Exercise BaseClient.connection_lost().""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection class DisconnecterProtocol(asyncio.Protocol): def connection_made(self, transport): @@ -128,22 +129,6 @@ def connection_made(self, transport): async def test_telnet_server_advanced_negotiation(bind_host, unused_tcp_port): """Test telnetlib3.TelnetServer() advanced negotiation.""" - # local - from telnetlib3.telopt import ( - DO, - SB, - IAC, - SGA, - ECHO, - NAWS, - WILL, - TTYPE, - BINARY, - CHARSET, - NEW_ENVIRON, - ) - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() class ServerTestAdvanced(telnetlib3.TelnetServer): @@ -165,7 +150,7 @@ def begin_advanced_negotiation(self): # server's request for TTYPE value unreplied SB + TTYPE: True, # remaining unreplied values from begin_advanced_negotiation() - DO + NEW_ENVIRON: True, + # DO NEW_ENVIRON is deferred until TTYPE cycle completes DO + CHARSET: True, DO + NAWS: True, WILL + SGA: True, @@ -176,10 +161,6 @@ def begin_advanced_negotiation(self): async def test_telnet_server_closed_by_client(bind_host, unused_tcp_port): """Exercise TelnetServer.connection_lost.""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server(host=bind_host, port=unused_tcp_port) as server: async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): # Read server's negotiation request and send minimal reply @@ -204,10 +185,6 @@ async def test_telnet_server_closed_by_client(bind_host, unused_tcp_port): async def test_telnet_server_eof_by_client(bind_host, unused_tcp_port): """Exercise TelnetServer.eof_received().""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server(host=bind_host, port=unused_tcp_port) as server: async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): # Read server's negotiation request and send minimal reply @@ -231,10 +208,6 @@ async def test_telnet_server_eof_by_client(bind_host, unused_tcp_port): async def test_telnet_server_closed_by_server(bind_host, unused_tcp_port): """Exercise TelnetServer.connection_lost by close().""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -267,10 +240,6 @@ async def test_telnet_server_closed_by_server(bind_host, unused_tcp_port): async def test_telnet_server_idle_duration(bind_host, unused_tcp_port): """Exercise TelnetServer.idle property.""" - # local - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -285,9 +254,6 @@ async def test_telnet_server_idle_duration(bind_host, unused_tcp_port): async def test_telnet_client_idle_duration_minwait(bind_host, unused_tcp_port): """Exercise TelnetClient.idle property and minimum connection time.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection - async with asyncio_server(asyncio.Protocol, bind_host, unused_tcp_port): given_minwait = 0.100 @@ -308,10 +274,6 @@ async def test_telnet_client_idle_duration_minwait(bind_host, unused_tcp_port): async def test_telnet_server_closed_by_error(bind_host, unused_tcp_port): """Exercise TelnetServer.connection_lost by exception.""" - # local - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -332,8 +294,6 @@ class CustomException(Exception): async def test_telnet_client_open_close_by_error(bind_host, unused_tcp_port): """Exercise BaseClient.connection_lost() on error.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection class GivenException(Exception): pass @@ -350,10 +310,6 @@ class GivenException(Exception): async def test_telnet_server_negotiation_fail(bind_host, unused_tcp_port): """Test telnetlib3.TelnetServer() negotiation failure with client.""" - # local - from telnetlib3.telopt import DO, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -376,14 +332,9 @@ async def test_telnet_server_negotiation_fail(bind_host, unused_tcp_port): async def test_telnet_client_negotiation_fail(bind_host, unused_tcp_port): """Test telnetlib3.TelnetCLient() negotiation failure with server.""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection class ClientNegotiationFail(telnetlib3.TelnetClient): def connection_made(self, transport): - # local - from telnetlib3.telopt import WILL, TTYPE - super().connection_made(transport) self.writer.iac(WILL, TTYPE) @@ -421,9 +372,6 @@ async def test_telnet_server_as_module(): @pytest.mark.skipif(sys.platform == "win32", reason="Signal handlers not supported on Windows") async def test_telnet_server_cmdline(bind_host, unused_tcp_port): """Test executing telnetlib3/server.py as server.""" - # local - from telnetlib3.tests.accessories import asyncio_connection - prog = pexpect.which("telnetlib3-server") args = [ prog, @@ -476,9 +424,6 @@ async def test_telnet_client_as_module(): @pytest.mark.skipif(sys.platform == "win32", reason="Client shell not implemented on Windows") async def test_telnet_client_cmdline(bind_host, unused_tcp_port): """Test executing telnetlib3/client.py as client.""" - # local - from telnetlib3.tests.accessories import asyncio_server - prog = pexpect.which("telnetlib3-client") args = [ prog, @@ -497,7 +442,10 @@ def connection_made(self, transport): async with asyncio_server(HelloServer, bind_host, unused_tcp_port): proc = await asyncio.create_subprocess_exec( - *args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) line = await asyncio.wait_for(proc.stdout.readline(), 1.5) @@ -508,6 +456,9 @@ def connection_made(self, transport): out, err = await asyncio.wait_for(proc.communicate(), 1) assert out == b"\x1b[m\nConnection closed by foreign host.\n" + stderr_text = err.decode() + assert "Connected to = len(self._data): + return "" + result = self._data[self._idx] + self._idx += 1 + return result + + +class MockTerm: + normal = "" + clear = "" + civis = "" + cnorm = "" + height = 50 + width = 80 + forestgreen = staticmethod(lambda x: x) + firebrick1 = staticmethod(lambda x: x) + darkorange = staticmethod(lambda x: x) + bold_magenta = "" + + def magenta(self, s): + return s + + def cyan(self, s): + return s + + def clear_eol(self): + return "" + + +@pytest.mark.asyncio +async def test_probe_client_capabilities(): + options = [(fps.BINARY, "BINARY", ""), (fps.SGA, "SGA", "")] + writer = MockWriter(will_options=[fps.BINARY], wont_options=[fps.SGA]) + results = await fps.probe_client_capabilities(writer, options=options, timeout=0.001) + assert results["BINARY"]["status"] == "WILL" + assert results["SGA"]["status"] == "WONT" + + +def test_save_fingerprint_data(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + writer = MockWriter(extra={"peername": ("10.0.0.1", 9999), "TERM": "xterm"}) + writer._protocol = MockProtocol({"TERM": "xterm", "ttype1": "xterm", "ttype2": "xterm-256"}) + probe = { + "BINARY": {"status": "WILL", "opt": fps.BINARY}, + "SGA": {"status": "WONT", "opt": fps.SGA}, + } + filepath = fps._save_fingerprint_data(writer, probe, 0.5) + assert filepath is not None and Path(filepath).exists() + + with open(filepath, encoding="utf-8") as f: + data = json.load(f) + tp = data["telnet-probe"] + assert tp["fingerprint-data"]["probed-protocol"] == "client" + assert "BINARY" in tp["fingerprint-data"]["supported-options"] + assert tp["session_data"]["ttype_cycle"] == ["xterm", "xterm-256"] + assert "peername" not in tp["session_data"]["extra"] + assert data["sessions"][0]["ip"] == "10.0.0.1" + assert Path(filepath).parent.name == fps._UNKNOWN_TERMINAL_HASH + + monkeypatch.setattr(fps, "DATA_DIR", None) + assert fps._save_fingerprint_data(writer, {}, 0.5) is None + + +def test_save_fingerprint_appends_session(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + writer = MockWriter(extra={"peername": ("10.0.0.1", 9999), "TERM": "xterm"}) + writer._protocol = MockProtocol({"TERM": "xterm"}) + + fp1 = fps._save_fingerprint_data(writer, _BINARY_PROBE, 0.5) + fp2 = fps._save_fingerprint_data(writer, _BINARY_PROBE, 0.5) + assert fp1 == fp2 + with open(fp2, encoding="utf-8") as f: + assert len(json.load(f)["sessions"]) == 2 + + +def test_protocol_fingerprint(): + w = MockWriter(extra={"TERM": "xterm", "HOME": "/home/user"}) + w._protocol = MockProtocol({"HOME": "/home/user"}) + assert fps._create_protocol_fingerprint(w, _BINARY_PROBE)["HOME"] == "True" + + probe2 = { + "TTYPE": {"status": "WILL", "opt": fps.TTYPE}, + "BINARY": {"status": "WILL", "opt": fps.BINARY}, + "SGA": {"status": "WONT", "opt": fps.SGA}, + } + fp = fps._create_protocol_fingerprint(MockWriter(), probe2) + assert fp["supported-options"] == ["BINARY", "TTYPE"] + assert fp["refused-options"] == ["SGA"] + + +def test_protocol_hash_consistency(): + w1 = MockWriter(extra={"TERM": "xterm", "HOME": "/home/alice"}) + w1._protocol = MockProtocol({"HOME": "/home/alice"}) + w2 = MockWriter(extra={"TERM": "xterm", "HOME": "/home/bob"}) + w2._protocol = MockProtocol({"HOME": "/home/bob"}) + + h1 = fps._hash_fingerprint(fps._create_protocol_fingerprint(w1, _BINARY_PROBE)) + h2 = fps._hash_fingerprint(fps._create_protocol_fingerprint(w2, _BINARY_PROBE)) + assert h1 == h2 and len(h1) == 16 + + +@pytest.mark.parametrize( + "text,expected", + [ + ("Ghostty", "Ghostty"), + (" Ghostty ", "Ghostty"), + ("", None), + (" ", None), + ("bad\x00name", None), + ("bad\x1bname", None), + ("bad\x7fname", None), + ("good name 123", "good name 123"), + ], +) +def test_validate_suggestion(text, expected): + assert fps._validate_suggestion(text) == expected + + +@requires_unix +def test_prompt_stores_suggestions(tmp_path, monkeypatch, capsys): + filepath = tmp_path / "test.json" + data = { + "telnet-probe": {"fingerprint": "aaa"}, + "terminal-probe": {"fingerprint": "bbbb"}, + "sessions": [], + } + filepath.write_text(json.dumps(data)) + + inputs = iter(["Ghostty", "GNU Telnet"]) + # pylint: disable=possibly-used-before-assignment + monkeypatch.setattr(fpd, "_cooked_input", lambda prompt: next(inputs)) + fpd._prompt_fingerprint_identification(MockTerm(), data, str(filepath), {}) + assert data["suggestions"]["terminal-emulator"] == "Ghostty" + assert data["suggestions"]["telnet-client"] == "GNU Telnet" + assert "Help our database!" in capsys.readouterr().out + + with open(filepath, encoding="utf-8") as f: + assert json.load(f)["suggestions"]["terminal-emulator"] == "Ghostty" + + +@requires_unix +def test_prompt_stores_revision(tmp_path, monkeypatch, capsys): + filepath = tmp_path / "test.json" + data = { + "telnet-probe": {"fingerprint": "aaa"}, + "terminal-probe": {"fingerprint": "bbbb"}, + "sessions": [], + } + filepath.write_text(json.dumps(data)) + + inputs = iter(["Ghostty2", "inetutils-2.5"]) + monkeypatch.setattr(fpd, "_cooked_input", lambda prompt: next(inputs)) + names = {"aaa": "GNU Telnet", "bbbb": "Ghostty"} + fpd._prompt_fingerprint_identification(MockTerm(), data, str(filepath), names) + assert data["suggestions"]["terminal-emulator-revision"] == "Ghostty2" + assert data["suggestions"]["telnet-client-revision"] == "inetutils-2.5" + captured = capsys.readouterr() + assert "Suggest a revision" in captured.out + assert "Your submission is under review." in captured.out + + +@requires_unix +@pytest.mark.asyncio +async def test_server_shell(monkeypatch): + monkeypatch.setattr(fps.asyncio, "sleep", _noop) + monkeypatch.setattr(fps, "DATA_DIR", None) + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.05) + + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "TERM": "xterm"}, will_options=[fps.BINARY] + ) + await fps.fingerprinting_server_shell(MockReader([]), writer) + assert writer._closing + + +@requires_unix +def test_create_terminal_fingerprint(): + terminal_data = { + "software_name": "foot", + "software_version": "1.16.2", + "ambiguous_width": 1, + "terminal_results": { + "number_of_colors": 16777216, + "sixel": True, + "kitty_graphics": False, + "kitty_clipboard_protocol": False, + "iterm2_features": {"supported": False, "features": {}}, + "kitty_keyboard": { + "disambiguate": False, + "report_all_keys": False, + "report_alternates": False, + "report_events": False, + "report_text": False, + }, + "kitty_notifications": False, + "kitty_pointer_shapes": False, + "text_sizing": {"width": False, "scale": False}, + "device_attributes": {"service_class": 62, "extensions": [22, 4]}, + "modes": { + "2027": { + "supported": True, + "changeable": True, + "enabled": True, + "value": 1, + "mode_name": "GRAPHEME_CLUSTERING", + "mode_description": "Grapheme Clustering", + "value_description": "SET", + }, + "5522": { + "supported": False, + "changeable": False, + "enabled": False, + "value": 0, + "mode_name": "UNKNOWN", + "mode_description": "Unknown mode", + "value_description": "NOT_RECOGNIZED", + }, + }, + "xtgettcap": {"supported": True, "capabilities": {"TN": "foot", "Co": "256"}}, + "foreground_color_hex": "#ffffffffffff", + "cell_width": 6, + "cell_height": 16, + "width": 170, + "height": 46, + }, + "test_results": { + "unicode_wide_results": { + "17.0.0": {"n_errors": 0, "n_total": 10, "pct_success": 100.0, "cps": 8.7}, + }, + "emoji_vs16_results": { + "9.0.0": {"n_errors": 2, "n_total": 12, "pct_success": 83.3, "cps": 9.2}, + }, + "language_results": None, + }, + } + + fp = fpd._create_terminal_fingerprint(terminal_data) + assert fp["software_name"] == "foot" and fp["software_version"] == "1.16.2" + assert fp["number_of_colors"] == 16777216 and fp["sixel"] is True + assert fp["kitty_graphics"] is False and fp["kitty_clipboard_protocol"] is False + assert fp["iterm2_features"] == {"supported": False, "features": {}} + assert fp["kitty_keyboard"]["disambiguate"] is False + assert fp["kitty_notifications"] is False and fp["kitty_pointer_shapes"] is False + assert fp["text_sizing"] == {"width": False, "scale": False} + assert fp["da_service_class"] == 62 and fp["da_extensions"] == [4, 22] + assert fp["ambiguous_width"] == 1 + assert fp["modes"]["2027"] == { + "supported": True, + "changeable": True, + "enabled": True, + "value": 1, + } + assert fp["modes"]["5522"]["supported"] is False + assert "mode_name" not in fp["modes"]["2027"] + assert fp["xtgettcap"]["supported"] is True + assert fp["xtgettcap"]["capabilities"]["TN"] == "foot" + assert fp["test_results"]["unicode_wide_results"] == { + "unicode_version": "17.0.0", + "n_errors": 0, + "n_total": 10, + } + assert fp["test_results"]["emoji_vs16_results"] == { + "unicode_version": "9.0.0", + "n_errors": 2, + "n_total": 12, + } + assert "language_results" not in fp["test_results"] + for key in ("foreground_color_hex", "cell_width", "width", "height"): + assert key not in fp + + +@requires_unix +def test_terminal_fingerprint_hash_excludes_session_vars(): + base = { + "software_name": "foot", + "software_version": "1.16.2", + "ambiguous_width": 1, + "terminal_results": { + "number_of_colors": 16777216, + "sixel": True, + "kitty_graphics": False, + "kitty_clipboard_protocol": False, + "device_attributes": {"service_class": 62, "extensions": [4, 22]}, + "modes": {}, + "xtgettcap": {"supported": True, "capabilities": {"TN": "foot"}}, + "text_sizing": {"width": False, "scale": False}, + "foreground_color_hex": "#000000000000", + "width": 80, + "height": 24, + "cell_width": 6, + "cell_height": 16, + }, + "test_results": {}, + } + data1 = copy.deepcopy(base) + data2 = copy.deepcopy(base) + data2["terminal_results"]["foreground_color_hex"] = "#ffffffffffff" + data2["terminal_results"]["width"] = 200 + data2["terminal_results"]["height"] = 50 + + fp1 = fpd._create_terminal_fingerprint(data1) + fp2 = fpd._create_terminal_fingerprint(data2) + assert fps._hash_fingerprint(fp1) == fps._hash_fingerprint(fp2) + + +@pytest.mark.asyncio +async def test_fingerprint_probe_integration(bind_host, unused_tcp_port): + + async with create_server( + host=bind_host, + port=unused_tcp_port, + shell=fps.fingerprinting_server_shell, + connect_maxwait=0.5, + ): + async with open_connection( + host=bind_host, + port=unused_tcp_port, + connect_minwait=0.2, + connect_maxwait=0.5, + ) as (reader, writer): + try: + await asyncio.wait_for(reader.read(100), timeout=1.0) + except asyncio.TimeoutError: + pass + + +@pytest.mark.parametrize( + "ttype1,ttype2,expected", + [ + ("ANSI", "VT100", True), + ("ANSI", "", True), + ("ANSI", None, True), + ("ansi", "vt100", True), + ("xterm", "xterm-256color", False), + ("ANSI", "xterm", False), + ("VT100", "ANSI", False), + ("TINTIN++", "xterm-ghostty", False), + ], +) +def test_is_maybe_ms_telnet(ttype1, ttype2, expected): + extra = {"peername": ("127.0.0.1", 12345)} + if ttype1 is not None: + extra["ttype1"] = ttype1 + if ttype2 is not None: + extra["ttype2"] = ttype2 + assert fps._is_maybe_ms_telnet(MockWriter(extra=extra)) is expected + + +@pytest.mark.asyncio +async def test_run_probe_ms_telnet_reduced(monkeypatch): + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.05) + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "ttype1": "ANSI", "ttype2": "VT100"}, + wont_options=[fps.BINARY, fps.SGA], + ) + results, elapsed = await fps._run_probe(writer, verbose=False) + probed_names = set(results.keys()) + legacy_names = {name for _, name, _ in fps.LEGACY_OPTIONS} + assert not probed_names.intersection(legacy_names) + assert "NEW_ENVIRON" not in probed_names + + +@pytest.mark.asyncio +async def test_run_probe_normal_client_full(monkeypatch): + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.05) + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "ttype1": "xterm", "ttype2": "xterm-256color"}, + wont_options=[fps.BINARY, fps.SGA], + ) + results, elapsed = await fps._run_probe(writer, verbose=False) + probed_names = set(results.keys()) + legacy_names = {name for _, name, _ in fps.LEGACY_OPTIONS} + assert probed_names.issuperset(legacy_names) and "NEW_ENVIRON" in probed_names + + +def _make_ttype_data(ttype_cycle): + return {"telnet-probe": {"session_data": {"ttype_cycle": ttype_cycle}}} + + +@requires_unix +@pytest.mark.parametrize( + "ttype_cycle,expected_term", + [ + (["ANSI", "VT100", "VT52", "VTNT", "VTNT"], "ansi"), + (["ANSI", "ANSI"], "xterm-256color"), + (["xterm-256color", "xterm-256color"], "xterm-256color"), + ([], "xterm-256color"), + ], +) +def test_setup_term_environ_ms_telnet(ttype_cycle, expected_term, monkeypatch): + monkeypatch.setenv("TERM", "xterm-256color") + fpd._setup_term_environ(_make_ttype_data(ttype_cycle)) + assert os.environ["TERM"] == expected_term + + +@requires_unix +def test_setup_term_environ_no_ttype_cycle(monkeypatch): + monkeypatch.setenv("TERM", "vt220") + fpd._setup_term_environ({}) + assert os.environ["TERM"] == "vt220" + + +@requires_unix +@pytest.mark.parametrize( + "probe,expected", + [ + ({"WILL": {"SGA": 3}}, False), + ({"WONT": {"SGA": 3}}, True), + ({"timeout": {"SGA": 3}}, True), + ({}, True), + ], +) +def test_client_requires_ga(probe, expected): + data = {"telnet-probe": {"session_data": {"probe": probe}}} + assert fpd._client_requires_ga(data) is expected + + +@requires_unix +def test_client_requires_ga_missing_keys(): + assert fpd._client_requires_ga({}) is True + assert fpd._client_requires_ga({"telnet-probe": {}}) is True + + +@requires_unix +def test_run_ucs_detect_timeout(monkeypatch, capsys): + + def fake_run(*args, **kwargs): + raise subprocess.TimeoutExpired(cmd="ucs-detect", timeout=20) + + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/ucs-detect") + monkeypatch.setattr("subprocess.run", fake_run) + assert fpd._run_ucs_detect() is None + assert capsys.readouterr().out.endswith("...\r\n") + + +@pytest.mark.asyncio +async def test_probe_default_options(monkeypatch): + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.01) + writer = MockWriter(wont_options=[fps.BINARY]) + results = await fps.probe_client_capabilities(writer, timeout=0.01) + assert "BINARY" in results and len(results) == len(fps.ALL_PROBE_OPTIONS) + + +@pytest.mark.parametrize( + "opt,value,name,expected_status", + [ + pytest.param(fps.SGA, False, "SGA", "WONT", id="already_wont"), + pytest.param(fps.BINARY, True, "BINARY", "WILL", id="already_will"), + ], +) +@pytest.mark.asyncio +async def test_probe_already_negotiated(opt, value, name, expected_status): + writer = MockWriter() + writer.remote_option[opt] = value + results = await fps.probe_client_capabilities( + writer, + options=[(opt, name, "test")], + timeout=0.01, + ) + assert results[name]["status"] == expected_status + assert results[name]["already_negotiated"] is True + + +@pytest.mark.asyncio +async def test_probe_with_progress_callback(): + called = [] + writer = MockWriter(will_options=[fps.BINARY]) + await fps.probe_client_capabilities( + writer, + options=[(fps.BINARY, "BINARY", "test")], + timeout=0.01, + progress_callback=lambda name, idx, total, status: called.append((name, idx)), + ) + assert called == [("BINARY", 1)] + + +def test_get_client_fingerprint(): + writer = MockWriter( + extra={ + "TERM": "xterm-256color", + "peername": ("10.0.0.1", 5555), + "charset": "utf-8", + "LANG": "en_US.UTF-8", + "ttype1": "xterm", + "USER": "testuser", + } + ) + fp = fps.get_client_fingerprint(writer) + assert fp["TERM"] == "xterm-256color" and fp["peername"] == ("10.0.0.1", 5555) + assert fp["charset"] == "utf-8" and fp["ttype1"] == "xterm" and fp["USER"] == "testuser" + + writer2 = MockWriter(extra={"peername": None}) + writer2._extra = {"peername": None} + assert not fps.get_client_fingerprint(writer2) + + +@pytest.mark.asyncio +async def test_run_probe_verbose(monkeypatch): + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.01) + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "ttype1": "xterm"}, + wont_options=[fps.BINARY], + ) + await fps._run_probe(writer, verbose=True) + written = "".join(writer.written) + assert "Probing" in written and "\r\x1b[K" in written + + +@pytest.mark.asyncio +async def test_run_probe_mud_extended(monkeypatch): + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.01) + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "TERM": "mudlet"}, + wont_options=[fps.BINARY, fps.GMCP], + ) + results, _ = await fps._run_probe(writer, verbose=False) + assert "GMCP" in results + + +@pytest.mark.parametrize( + "input_val,expected", + [ + pytest.param(fps.BINARY, "BINARY", id="known_binary"), + pytest.param(fps.SGA, "SGA", id="known_sga"), + pytest.param(b"\xfe", "0xfe", id="unknown_bytes"), + pytest.param(42, "42", id="int"), + pytest.param("", "", id="empty_str"), + ], +) +def test_opt_byte_to_name(input_val, expected): + assert fps._opt_byte_to_name(input_val) == expected + + +def test_collect_rejected_options_with_data(): + writer = MockWriter() + writer.rejected_will = {fps.BINARY, fps.SGA} + writer.rejected_do = {fps.ECHO} + result = fps._collect_rejected_options(writer) + assert len(result["will"]) == 2 and len(result["do"]) == 1 + + +def test_collect_extra_info_tuples_and_bytes(): + writer = MockWriter(extra={"peername": ("1.2.3.4", 99)}) + writer._protocol = MockProtocol( + { + "tspeed": (38400, 38400), + "raw_data": b"\x01\x02\x03", + "name": "test", + } + ) + result = fps._collect_extra_info(writer) + assert result["tspeed"] == [38400, 38400] + assert result["raw_data"] == "010203" and result["name"] == "test" + + +def test_collect_extra_info_removes_duplicate_keys(): + writer = MockWriter(extra={}) + writer._protocol = MockProtocol( + { + "TERM": "xterm", + "term": "xterm", + "COLUMNS": 80, + "cols": 80, + "LINES": 24, + "rows": 24, + "ttype1": "xterm", + } + ) + result = fps._collect_extra_info(writer) + for key in ("term", "cols", "rows", "ttype1"): + assert key not in result + assert result["TERM"] == "xterm" + + +def test_collect_ttype_cycle(): + writer = MockWriter(extra={"ttype1": "xterm", "ttype2": "xterm-256color", "ttype3": "vt100"}) + writer._protocol = MockProtocol( + { + "ttype1": "xterm", + "ttype2": "xterm-256color", + "ttype3": "vt100", + } + ) + assert fps._collect_ttype_cycle(writer) == ["xterm", "xterm-256color", "vt100"] + + writer2 = MockWriter(extra={}) + writer2._protocol = MockProtocol({}) + assert not fps._collect_ttype_cycle(writer2) + + +def test_collect_protocol_timing(): + writer = MockWriter() + writer._protocol.duration = 2.5 + writer._protocol.idle = 0.3 + writer._protocol._connect_time = 1234567890.0 + timing = fps._collect_protocol_timing(writer) + assert timing["duration"] == 2.5 and timing["idle"] == 0.3 + assert timing["connect_time"] == 1234567890.0 + + writer2 = MockWriter() + writer2._protocol = type("P", (), {})() + assert not fps._collect_protocol_timing(writer2) + + +def test_collect_slc_tab_with_data(): + + writer = MockWriter() + writer.remote_option[fps.LINEMODE] = True + tab = dict(slc.generate_slctab(slc.BSD_SLC_TAB)) + tab[slc.SLC_SYNCH] = slc.SLC(mask=slc.SLC_NOSUPPORT, value=slc.theNULL) + tab[slc.SLC_EC] = slc.SLC(mask=slc.SLC_DEFAULT, value=slc.theNULL) + tab[slc.SLC_IP] = slc.SLC(mask=slc.SLC_DEFAULT, value=b"\x04") + writer.slctab = tab + result = fps._collect_slc_tab(writer) + assert "nosupport" in result and "unset" in result and "set" in result + + +def test_collect_slc_tab_empty(): + writer = MockWriter() + writer.slctab = {"something": True} + assert not fps._collect_slc_tab(writer) + assert not fps._collect_slc_tab(MockWriter()) + + +@pytest.mark.parametrize( + "extra,expected_term,expected_encoding", + [ + pytest.param({"LANG": "en_US.UTF-8"}, "None", "UTF-8", id="lang_encoding"), + pytest.param({}, "None", "None", id="no_lang"), + pytest.param({"TERM": "syncterm"}, "Syncterm", "None", id="term_syncterm"), + pytest.param({"TERM": "ansi-color"}, "Yes-ansi", "None", id="term_ansi"), + pytest.param({"TERM": "xterm"}, "Yes", "None", id="term_normal"), + ], +) +def test_create_protocol_fingerprint_term_encoding(extra, expected_term, expected_encoding): + writer = MockWriter(extra=extra) + writer._protocol = MockProtocol({}) + fp = fps._create_protocol_fingerprint(writer, {}) + assert fp["TERM"] == expected_term and fp["encoding"] == expected_encoding + + +def test_create_protocol_fingerprint_with_rejected_options(): + writer = MockWriter(extra={"TERM": "xterm"}) + writer._protocol = MockProtocol({}) + writer.rejected_will = {fps.BINARY} + writer.rejected_do = {fps.ECHO} + fp = fps._create_protocol_fingerprint(writer, {}) + assert "rejected-will" in fp and "rejected-do" in fp + + +def test_create_protocol_fingerprint_with_linemode_slc(): + + writer = MockWriter(extra={"TERM": "xterm"}) + writer._protocol = MockProtocol({}) + writer.remote_option[fps.LINEMODE] = True + tab = dict(slc.generate_slctab(slc.BSD_SLC_TAB)) + tab[slc.SLC_IP] = slc.SLC(mask=slc.SLC_DEFAULT, value=b"\x04") + writer.slctab = tab + probe = {"LINEMODE": {"status": "WILL", "opt": fps.LINEMODE}} + assert "slc" in fps._create_protocol_fingerprint(writer, probe) + + +def test_count_protocol_folder_files(tmp_path): + assert fps._count_protocol_folder_files(str(tmp_path / "nonexistent")) == 0 + (tmp_path / "a.json").write_text("{}") + (tmp_path / "b.json").write_text("{}") + (tmp_path / "c.txt").write_text("nope") + assert fps._count_protocol_folder_files(str(tmp_path)) == 2 + + +def test_count_fingerprint_folders(tmp_path): + assert fps._count_fingerprint_folders(data_dir=str(tmp_path)) == 0 + client_dir = tmp_path / "client" + client_dir.mkdir() + (client_dir / "hash1").mkdir() + (client_dir / "hash2").mkdir() + (client_dir / "not_a_dir.txt").write_text("") + assert fps._count_fingerprint_folders(data_dir=str(tmp_path)) == 2 + assert fps._count_fingerprint_folders(data_dir=None) == 0 + + +def test_create_session_fingerprint(): + writer = MockWriter( + extra={ + "peername": ("10.0.0.1", 5555), + "TERM": "xterm", + "USER": "alice", + "HOME": "/home/alice", + "LANG": "en_US.UTF-8", + "charset": "utf-8", + } + ) + fp = fps._create_session_fingerprint(writer) + assert fp["client-ip"] == "10.0.0.1" and fp["TERM"] == "xterm" + assert fp["USER"] == "alice" and fp["LANG"] == "en_US.UTF-8" + + writer2 = MockWriter(extra={"peername": None}) + writer2._extra = {"peername": None} + assert not fps._create_session_fingerprint(writer2) + + assert fps._create_session_fingerprint(MockWriter(extra={"term": "vt100"}))["TERM"] == "vt100" + + +def test_load_fingerprint_names(tmp_path): + names_file = tmp_path / "fingerprint_names.json" + names_file.write_text(json.dumps({"abc123": "Ghostty", "def456": "iTerm2"})) + assert fps._load_fingerprint_names(data_dir=str(tmp_path)) == { + "abc123": "Ghostty", + "def456": "iTerm2", + } + assert fps._load_fingerprint_names(data_dir=str(tmp_path / "nope")) == {} + assert fps._load_fingerprint_names(data_dir=None) == {} + + +def test_resolve_hash_name(): + names = {"abc": "Ghostty"} + assert fps._resolve_hash_name("abc", names) == "Ghostty" + assert fps._resolve_hash_name("unknown", names) == "unknown" + + +def test_save_fingerprint_data_makedirs(tmp_path, monkeypatch): + new_dir = str(tmp_path / "new_data") + monkeypatch.setattr(fps, "DATA_DIR", new_dir) + filepath = fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) + assert filepath is not None and os.path.exists(filepath) and os.path.isdir(new_dir) + + +def test_save_fingerprint_data_max_fingerprints(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + monkeypatch.setattr(fps, "FINGERPRINT_MAX_FINGERPRINTS", 0) + assert fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) is None + + +def test_save_fingerprint_data_max_files(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + assert fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) is not None + + monkeypatch.setattr(fps, "FINGERPRINT_MAX_FILES", 0) + assert ( + fps._save_fingerprint_data( + _probe_writer(peername=("10.0.0.2", 9999)), + _BINARY_PROBE, + 0.5, + ) + is None + ) + + +def test_save_fingerprint_data_corrupt_existing(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + fp1 = fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) + assert fp1 is not None + with open(fp1, "w", encoding="utf-8") as f: + f.write("not json {{{") + fp2 = fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) + assert fp2 is not None + with open(fp2, encoding="utf-8") as f: + assert len(json.load(f)["sessions"]) == 1 + + +def test_save_fingerprint_data_mkdir_oserror(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + original_makedirs = os.makedirs + + def failing_makedirs(path, **kwargs): + if "client" in path and fps._UNKNOWN_TERMINAL_HASH in path: + raise OSError("permission denied") + return original_makedirs(path, **kwargs) + + monkeypatch.setattr(os, "makedirs", failing_makedirs) + assert fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) is None + + +def test_save_fingerprint_data_write_oserror(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + monkeypatch.setattr( + fps, + "_atomic_json_write", + lambda fp, data: (_ for _ in ()).throw(OSError("disk full")), + ) + assert fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) is None + + +def test_save_fingerprint_data_update_oserror(tmp_path, monkeypatch): + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + assert fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) is not None + + monkeypatch.setattr( + fps, + "_atomic_json_write", + lambda fp, data: (_ for _ in ()).throw(OSError("disk full")), + ) + assert fps._save_fingerprint_data(_probe_writer(), _BINARY_PROBE, 0.5) is None + + +@pytest.mark.parametrize( + "extra,expected", + [ + pytest.param({"TERM": "mudlet"}, True, id="mud_term"), + pytest.param({"TERM": "xterm", "ttype1": "ZMUD"}, True, id="mud_ttype"), + pytest.param({"TERM": "xterm"}, False, id="not_mud"), + ], +) +def test_is_maybe_mud(extra, expected): + assert fps._is_maybe_mud(MockWriter(extra=extra)) is expected + + +def test_build_session_fingerprint_with_slc(): + + w = _probe_writer() + w.remote_option[fps.LINEMODE] = True + tab = dict(slc.generate_slctab(slc.BSD_SLC_TAB)) + tab[slc.SLC_IP] = slc.SLC(mask=slc.SLC_DEFAULT, value=b"\x04") + w.slctab = tab + probe = {"LINEMODE": {"status": "WILL", "opt": fps.LINEMODE}} + assert "slc_tab" in fps._build_session_fingerprint(w, probe, 0.1) + + +def test_build_session_fingerprint_with_rejected(): + w = _probe_writer() + w.rejected_will = {fps.BINARY} + probe = {"BINARY": {"status": "WONT", "opt": fps.BINARY}} + assert "rejected" in fps._build_session_fingerprint(w, probe, 0.1) + + +@requires_unix +@pytest.mark.asyncio +async def test_server_shell_syncterm(monkeypatch): + monkeypatch.setattr(fps.asyncio, "sleep", _noop) + monkeypatch.setattr(fps, "DATA_DIR", None) + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.05) + + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "TERM": "syncterm"}, + will_options=[fps.BINARY], + ) + await fps.fingerprinting_server_shell(MockReader([]), writer) + assert "\x1b[0;40 D" in "".join(writer.written) and writer._closing + + +@requires_unix +@pytest.mark.asyncio +async def test_server_shell_with_post_script(monkeypatch, tmp_path): + monkeypatch.setattr(fps.asyncio, "sleep", _noop) + monkeypatch.setattr(fps, "DATA_DIR", str(tmp_path)) + monkeypatch.setattr(fps, "_PROBE_TIMEOUT", 0.05) + + pty_called = [] + + async def fake_pty_shell(reader, writer, exe, args, raw_mode=False): + pty_called.append((exe, args, raw_mode)) + + monkeypatch.setattr(server_pty_shell, "pty_shell", fake_pty_shell) + + writer = MockWriter( + extra={"peername": ("127.0.0.1", 12345), "TERM": "xterm"}, + will_options=[fps.BINARY], + ) + writer._protocol = MockProtocol({"TERM": "xterm"}) + await fps.fingerprinting_server_shell(MockReader([]), writer) + assert len(pty_called) == 1 and pty_called[0][2] is True + + +@requires_unix +@pytest.mark.parametrize( + "input_fn,expected", + [ + pytest.param( + lambda prompt: (_ for _ in ()).throw(EOFError), + "", + id="eof", + ), + pytest.param(lambda prompt: "hello", "hello", id="normal"), + ], +) +def test_cooked_input(monkeypatch, input_fn, expected): + # std imports + import termios + + fake_attrs = [0, 0, 0, 0, 0, 0, [b"\x00"] * 32] + monkeypatch.setattr(termios, "tcgetattr", lambda fd: list(fake_attrs)) + monkeypatch.setattr(termios, "tcsetattr", lambda fd, when, attrs: None) + monkeypatch.setattr("builtins.input", input_fn) + assert fps._cooked_input("test> ") == expected + + +def test_fingerprinting_main(monkeypatch, tmp_path): + called = [] + monkeypatch.setattr(sys, "argv", ["fingerprinting", str(tmp_path / "test.json")]) + monkeypatch.setattr(fps, "fingerprinting_post_script", called.append) + fps.main() + assert called == [str(tmp_path / "test.json")] + + +def test_fingerprinting_main_usage(monkeypatch, capsys): + monkeypatch.setattr(sys, "argv", ["fingerprinting"]) + with pytest.raises(SystemExit, match="1"): + fps.main() + assert "Usage:" in capsys.readouterr().err + + +@requires_unix +def test_process_client_fingerprint_skips_ucs_detect_for_mud(monkeypatch, tmp_path, capsys): + ucs_called = [] + monkeypatch.setattr(fpd, "_run_ucs_detect", lambda: ucs_called.append(1) or None) + + data = {"telnet-probe": {"session_data": {"probe": {"WONT": {"SGA": 3}}}}} + filepath = str(tmp_path / "test.json") + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f) + + monkeypatch.setattr(fpd, "_setup_term_environ", lambda d: None) + monkeypatch.setattr(fpd, "_make_terminal", lambda: None) + try: + fpd._process_client_fingerprint(filepath, data) + except (ImportError, AttributeError, TypeError): + pass + assert not ucs_called diff --git a/telnetlib3/tests/test_guard_integration.py b/telnetlib3/tests/test_guard_integration.py index c222b7af..38aafbc1 100644 --- a/telnetlib3/tests/test_guard_integration.py +++ b/telnetlib3/tests/test_guard_integration.py @@ -1,11 +1,11 @@ # std imports import asyncio +# local +from telnetlib3.guard_shells import ConnectionCounter, busy_shell, robot_shell -async def test_connection_counter_integration(): - # local - from telnetlib3.guard_shells import ConnectionCounter +async def test_connection_counter_integration(): counter = ConnectionCounter(2) assert counter.try_acquire() @@ -25,9 +25,6 @@ async def test_connection_counter_integration(): async def test_counter_release_on_completion(): - # local - from telnetlib3.guard_shells import ConnectionCounter - counter = ConnectionCounter(1) async def shell_with_finally(): @@ -49,9 +46,6 @@ async def shell_with_finally(): async def test_counter_release_in_guarded_pattern(): - # local - from telnetlib3.guard_shells import ConnectionCounter - counter = ConnectionCounter(2) results = [] @@ -85,9 +79,6 @@ async def guarded_shell(name): async def test_guarded_shell_pattern_busy_shell(): - # local - from telnetlib3.guard_shells import ConnectionCounter, busy_shell - counter = ConnectionCounter(1) shell_calls = [] busy_shell_calls = [] @@ -161,9 +152,6 @@ async def guarded_shell(reader, writer): async def test_guarded_shell_pattern_robot_check(): # pylint: disable=too-complex - # local - from telnetlib3.guard_shells import ConnectionCounter - counter = ConnectionCounter(5) shell_calls = [] robot_shell_calls = [] @@ -243,9 +231,6 @@ async def guarded_shell(reader, writer): async def test_full_guarded_shell_flow(): # pylint: disable=too-complex - # local - from telnetlib3.guard_shells import ConnectionCounter, busy_shell, robot_shell - counter = ConnectionCounter(2) shell_calls = [] busy_calls = [] @@ -259,6 +244,9 @@ def __init__(self): def write(self, data): self.output.append(data) + def echo(self, data): + self.output.append(data) + async def drain(self): pass diff --git a/telnetlib3/tests/test_linemode.py b/telnetlib3/tests/test_linemode.py index 433e5d91..312503a9 100644 --- a/telnetlib3/tests/test_linemode.py +++ b/telnetlib3/tests/test_linemode.py @@ -4,23 +4,21 @@ import asyncio # local -# local imports import telnetlib3 import telnetlib3.stream_writer +from telnetlib3.slc import LMODE_MODE, LMODE_MODE_ACK, LMODE_MODE_LOCAL +from telnetlib3.telopt import DO, SB, SE, IAC, WILL, LINEMODE from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, + create_server, unused_tcp_port, + asyncio_connection, ) async def test_server_demands_remote_linemode_client_agrees( # pylint: disable=too-many-locals bind_host, unused_tcp_port ): - # local - from telnetlib3.slc import LMODE_MODE, LMODE_MODE_ACK - from telnetlib3.telopt import DO, SB, SE, IAC, WILL, LINEMODE - from telnetlib3.tests.accessories import create_server, asyncio_connection - class ServerTestLinemode(telnetlib3.BaseServer): def begin_negotiation(self): super().begin_negotiation() @@ -75,11 +73,6 @@ def begin_negotiation(self): async def test_server_demands_remote_linemode_client_demands_local( # pylint: disable=too-many-locals bind_host, unused_tcp_port ): - # local - from telnetlib3.slc import LMODE_MODE, LMODE_MODE_ACK, LMODE_MODE_LOCAL - from telnetlib3.telopt import DO, SB, SE, IAC, WILL, LINEMODE - from telnetlib3.tests.accessories import create_server, asyncio_connection - class ServerTestLinemode(telnetlib3.BaseServer): def begin_negotiation(self): super().begin_negotiation() diff --git a/telnetlib3/tests/test_naws.py b/telnetlib3/tests/test_naws.py index 8c0b1300..dc45e5d7 100644 --- a/telnetlib3/tests/test_naws.py +++ b/telnetlib3/tests/test_naws.py @@ -15,20 +15,19 @@ import pexpect # local -# local imports import telnetlib3 -from telnetlib3.tests.accessories import ( # pylint: disable=unused-import +from telnetlib3.telopt import SB, SE, IAC, NAWS, WILL +from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, + create_server, + open_connection, unused_tcp_port, + asyncio_connection, ) async def test_telnet_server_on_naws(bind_host, unused_tcp_port): """Test Server's Negotiate about window size (NAWS).""" - # local - from telnetlib3.telopt import SB, SE, IAC, NAWS, WILL - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_cols, given_rows = 40, 20 @@ -54,9 +53,6 @@ def on_naws(self, rows, cols): async def test_telnet_client_send_naws(bind_host, unused_tcp_port): """Test Client's NAWS of callback method send_naws().""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - _waiter = asyncio.Future() given_cols, given_rows = 40, 20 @@ -93,9 +89,6 @@ def on_naws(self, rows, cols): ) async def test_telnet_client_send_tty_naws(bind_host, unused_tcp_port): """Test Client's NAWS of callback method send_naws().""" - # local - from telnetlib3.tests.accessories import create_server - _waiter = asyncio.Future() given_cols, given_rows = 40, 20 prog, args = "telnetlib3-client", [ @@ -129,9 +122,6 @@ def on_naws(self, rows, cols): async def test_telnet_client_send_naws_65534(bind_host, unused_tcp_port): """Test Client's NAWS boundary values.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - _waiter = asyncio.Future() given_cols, given_rows = 9999999, -999999 expect_cols, expect_rows = 65535, 0 @@ -157,3 +147,27 @@ def on_naws(self, rows, cols): recv_cols, recv_rows = await asyncio.wait_for(_waiter, 0.5) assert recv_cols == expect_cols assert recv_rows == expect_rows + + +async def test_naws_without_will(bind_host, unused_tcp_port): + """NAWS subnegotiation received without prior WILL NAWS is tolerated.""" + _waiter = asyncio.Future() + given_cols, given_rows = 132, 43 + + class ServerTestNaws(telnetlib3.TelnetServer): + def on_naws(self, rows, cols): + super().on_naws(rows, cols) + _waiter.set_result(self) + + async with create_server( + protocol_factory=ServerTestNaws, + host=bind_host, + port=unused_tcp_port, + connect_maxwait=0.05, + ): + async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): + writer.write(IAC + SB + NAWS + struct.pack("!HH", given_cols, given_rows) + IAC + SE) + + srv_instance = await asyncio.wait_for(_waiter, 0.5) + assert srv_instance.get_extra_info("cols") == given_cols + assert srv_instance.get_extra_info("rows") == given_rows diff --git a/telnetlib3/tests/test_pty_shell.py b/telnetlib3/tests/test_pty_shell.py index 07be0bdb..a89bcf95 100644 --- a/telnetlib3/tests/test_pty_shell.py +++ b/telnetlib3/tests/test_pty_shell.py @@ -3,15 +3,29 @@ # std imports import os import sys +import time +import struct import asyncio +from unittest.mock import MagicMock, patch # 3rd party import pytest # local import telnetlib3 +from telnetlib3 import server_pty_shell as sps +from telnetlib3.server_pty_shell import ( + _BSU, + _ESU, + PTYSession, + PTYSpawnError, + _platform_check, + _wait_for_terminal_info, +) from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, + create_server, + open_connection, unused_tcp_port, make_preexec_coverage, ) @@ -42,11 +56,6 @@ def require_no_capture(request): @pytest.fixture def mock_session(): """Create a mock PTYSession for unit testing.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import PTYSession def _create(extra_info=None, capture_writes=False): reader = MagicMock() @@ -71,7 +80,6 @@ async def test_pty_shell_integration(bind_host, unused_tcp_port, require_no_capt """Test PTY shell with various helper modes: cat, env, stty_size.""" # local from telnetlib3 import make_pty_shell - from telnetlib3.tests.accessories import create_server, open_connection # Test 1: cat mode - echo input back _waiter = asyncio.Future() @@ -89,7 +97,7 @@ def begin_shell(self, result): shell=make_pty_shell( sys.executable, [PTY_HELPER, "cat"], preexec_fn=make_preexec_coverage() ), - connect_maxwait=0.5, + connect_maxwait=0.15, ): async with open_connection( host=bind_host, @@ -99,7 +107,7 @@ def begin_shell(self, result): connect_minwait=0.05, ) as (reader, writer): await asyncio.wait_for(_waiter, 2.0) - await asyncio.sleep(0.3) + await asyncio.sleep(0.1) writer.write("hello world\n") await writer.drain() @@ -113,7 +121,7 @@ def begin_shell(self, result): async def client_shell(reader, writer): await _waiter - await asyncio.sleep(0.5) + await asyncio.sleep(0.15) output = await asyncio.wait_for(reader.read(100), 2.0) _output.set_result(output) @@ -124,7 +132,7 @@ async def client_shell(reader, writer): shell=make_pty_shell( sys.executable, [PTY_HELPER, "env", "TERM"], preexec_fn=make_preexec_coverage() ), - connect_maxwait=0.5, + connect_maxwait=0.15, ): async with open_connection( host=bind_host, @@ -148,7 +156,7 @@ async def client_shell(reader, writer): shell=make_pty_shell( sys.executable, [PTY_HELPER, "stty_size"], preexec_fn=make_preexec_coverage() ), - connect_maxwait=0.5, + connect_maxwait=0.15, ): async with open_connection( host=bind_host, @@ -158,7 +166,7 @@ async def client_shell(reader, writer): connect_minwait=0.05, ) as (reader, writer): await asyncio.wait_for(_waiter, 2.0) - await asyncio.sleep(0.3) + await asyncio.sleep(0.1) output = await asyncio.wait_for(reader.read(50), 2.0) assert "25 80" in output @@ -169,7 +177,6 @@ async def test_pty_shell_lifecycle(bind_host, unused_tcp_port, require_no_captur """Test PTY shell lifecycle: child exit and client disconnect.""" # local from telnetlib3 import make_pty_shell - from telnetlib3.tests.accessories import create_server, open_connection # Test 1: child exit closes connection gracefully _waiter = asyncio.Future() @@ -187,7 +194,7 @@ def begin_shell(self, result): shell=make_pty_shell( sys.executable, [PTY_HELPER, "exit_code", "0"], preexec_fn=make_preexec_coverage() ), - connect_maxwait=0.5, + connect_maxwait=0.15, ): async with open_connection( host=bind_host, @@ -197,7 +204,7 @@ def begin_shell(self, result): connect_minwait=0.05, ) as (reader, writer): await asyncio.wait_for(_waiter, 2.0) - await asyncio.sleep(0.3) + await asyncio.sleep(0.1) result = await asyncio.wait_for(reader.read(100), 3.0) assert "done" in result @@ -227,7 +234,7 @@ def connection_lost(self, exc): shell=make_pty_shell( sys.executable, [PTY_HELPER, "cat"], preexec_fn=make_preexec_coverage() ), - connect_maxwait=0.5, + connect_maxwait=0.15, ): async with open_connection( host=bind_host, @@ -237,16 +244,13 @@ def connection_lost(self, exc): connect_minwait=0.05, ) as (reader, writer): await asyncio.wait_for(_waiter, 2.0) - await asyncio.sleep(0.3) + await asyncio.sleep(0.1) await asyncio.wait_for(_closed, 3.0) def test_platform_check_not_windows(): """Test that platform check raises on Windows.""" - # local - from telnetlib3.server_pty_shell import _platform_check - original_platform = sys.platform try: sys.platform = "win32" @@ -302,11 +306,9 @@ async def test_pty_session_build_environment(mock_session): assert env["LANG"] == "en_US.ISO-8859-1" -async def test_pty_session_naws_behavior(mock_session): +async def test_pty_session_naws_behavior(mock_session, monkeypatch): """Test NAWS debouncing, latest value usage, and cleanup cancellation.""" - # std imports - import struct - from unittest.mock import MagicMock, patch + monkeypatch.setattr(sps, "_NAWS_DEBOUNCE", 0.05) session, _ = mock_session() session.master_fd = 1 @@ -331,7 +333,7 @@ def mock_ioctl(fd, cmd, data): session._on_naws(50, 150) assert len(signal_calls) == 0 - await asyncio.sleep(0.25) + await asyncio.sleep(0.1) assert len(signal_calls) == 1 assert len(ioctl_calls) == 1 @@ -357,18 +359,17 @@ def mock_killpg_winch(pgid, sig): "os.killpg", side_effect=mock_killpg_winch ), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch("os.close"), patch( "fcntl.ioctl" + ), patch( + "time.sleep" ): session._on_naws(25, 80) session.cleanup() - await asyncio.sleep(0.25) + await asyncio.sleep(0.1) assert len(winch_calls) == 0 async def test_pty_session_write_to_telnet_buffering(mock_session): """Test _write_to_telnet line buffering, BSU/ESU handling, and overflow protection.""" - # local - from telnetlib3.server_pty_shell import _BSU, _ESU - # Line buffering: buffers until newline session, written = mock_session({"charset": "utf-8"}, capture_writes=True) session._write_to_telnet(b"hello") @@ -407,12 +408,6 @@ async def test_pty_session_write_to_telnet_buffering(mock_session): async def test_pty_session_flush_output_behavior(mock_session): """Test flush_output charset handling and incomplete UTF-8 buffering.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import PTYSession - # Charset change recreates decoder reader = MagicMock() writer = MagicMock() @@ -443,9 +438,6 @@ async def test_pty_session_flush_output_behavior(mock_session): async def test_pty_session_write_to_pty_behavior(mock_session): """Test _write_to_pty encoding, error handling, and None fd guard.""" - # std imports - from unittest.mock import patch - # String encoding session, _ = mock_session({"charset": "utf-8"}) session.master_fd = 99 @@ -478,12 +470,6 @@ def mock_write(fd, data): async def test_pty_session_cleanup_flushes_remaining_buffer(): """Test that cleanup flushes remaining buffer with final=True.""" - # std imports - from unittest.mock import MagicMock, patch - - # local - from telnetlib3.server_pty_shell import PTYSession - reader = MagicMock() writer = MagicMock() written = [] @@ -495,7 +481,9 @@ async def test_pty_session_cleanup_flushes_remaining_buffer(): session.master_fd = 99 session.child_pid = 12345 - with patch("os.close"), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)): + with patch("os.close"), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch( + "time.sleep" + ): session.cleanup() assert len(written) == 1 @@ -503,15 +491,8 @@ async def test_pty_session_cleanup_flushes_remaining_buffer(): assert session._output_buffer == b"" -async def test_wait_for_terminal_info_behavior(): +async def test_wait_for_terminal_info_behavior(monkeypatch): """Test _wait_for_terminal_info early return, timeout, and polling behavior.""" - # std imports - import time - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import _wait_for_terminal_info - # Returns early when TERM and rows available writer = MagicMock() writer.get_extra_info = MagicMock(side_effect={"TERM": "xterm", "rows": 25}.get) @@ -521,8 +502,8 @@ async def test_wait_for_terminal_info_behavior(): writer = MagicMock() writer.get_extra_info = MagicMock(return_value=None) start = time.time() - await _wait_for_terminal_info(writer, timeout=0.3) - assert time.time() - start >= 0.25 + await _wait_for_terminal_info(writer, timeout=0.05) + assert time.time() - start >= 0.04 # Polls until rows become available call_count = [0] @@ -545,9 +526,6 @@ def get_info(key): async def test_pty_session_set_window_size_behavior(mock_session): """Test _set_window_size guards and error handling.""" - # std imports - from unittest.mock import patch - # No fd does nothing session, _ = mock_session() session.master_fd = None @@ -581,12 +559,6 @@ async def test_pty_session_cleanup_error_recovery( close_effect, kill_effect, waitpid_effect, check_attr ): """Test cleanup handles various error conditions gracefully.""" - # std imports - from unittest.mock import MagicMock, patch - - # local - from telnetlib3.server_pty_shell import PTYSession - reader = MagicMock() writer = MagicMock() writer.get_extra_info = MagicMock(return_value="utf-8") @@ -601,7 +573,7 @@ async def test_pty_session_cleanup_error_recovery( waitpid_return = None if isinstance(waitpid_effect, Exception) else waitpid_effect waitpid_patch = patch("os.waitpid", side_effect=waitpid_side, return_value=waitpid_return) - with close_patch, kill_patch, waitpid_patch: + with close_patch, kill_patch, waitpid_patch, patch("time.sleep"): session.cleanup() assert getattr(session, check_attr) is None @@ -618,12 +590,6 @@ async def test_pty_session_flush_remaining_scenarios( in_sync_update, expected_writes, expected_buffer ): """Test _flush_remaining behavior based on sync update state.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import PTYSession - reader = MagicMock() writer = MagicMock() written = [] @@ -644,12 +610,6 @@ async def test_pty_session_flush_remaining_scenarios( async def test_pty_session_flush_output_empty_data(): """Test _flush_output does nothing with empty data.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import PTYSession - reader = MagicMock() writer = MagicMock() written = [] @@ -666,12 +626,6 @@ async def test_pty_session_flush_output_empty_data(): async def test_pty_session_write_to_telnet_pre_bsu_content(): """Test content before BSU is flushed.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import _BSU, _ESU, PTYSession - reader = MagicMock() writer = MagicMock() written = [] @@ -688,9 +642,6 @@ async def test_pty_session_write_to_telnet_pre_bsu_content(): async def test_pty_spawn_error(): """Test PTYSpawnError exception class.""" - # local - from telnetlib3.server_pty_shell import PTYSpawnError - err = PTYSpawnError("test error") assert str(err) == "test error" assert isinstance(err, Exception) @@ -705,12 +656,6 @@ async def test_pty_spawn_error(): ) async def test_pty_session_exec_error_parsing(error_data, expected_substrings): """Test _handle_exec_error parses various error formats.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import PTYSession, PTYSpawnError - reader = MagicMock() writer = MagicMock() writer.get_extra_info = MagicMock(return_value=None) @@ -734,12 +679,6 @@ async def test_pty_session_exec_error_parsing(error_data, expected_substrings): ) async def test_pty_session_isalive_scenarios(child_pid, waitpid_behavior, expected): """Test _isalive returns correct values for various child states.""" - # std imports - from unittest.mock import MagicMock, patch - - # local - from telnetlib3.server_pty_shell import PTYSession - reader = MagicMock() writer = MagicMock() writer.get_extra_info = MagicMock(return_value=None) @@ -761,10 +700,6 @@ async def test_pty_session_terminate_scenarios(): """Test _terminate handles various termination scenarios.""" # std imports import signal - from unittest.mock import MagicMock, patch - - # local - from telnetlib3.server_pty_shell import PTYSession reader = MagicMock() writer = MagicMock() @@ -810,3 +745,99 @@ def mock_isalive_2(): result = session._terminate() assert result is True + + +async def test_pty_session_ga_timer_fires_after_idle(mock_session, monkeypatch): + """GA is sent after _flush_remaining when SGA not negotiated.""" + monkeypatch.setattr(sps, "_GA_IDLE", 0.05) + + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + protocol = MagicMock() + protocol.never_send_ga = False + session.writer.protocol = protocol + session.writer.is_closing = MagicMock(return_value=False) + ga_calls = [] + session.writer.send_ga = lambda: ga_calls.append(True) + + session._output_buffer = b"prompt> " + session._flush_remaining() + assert session._ga_timer is not None + assert len(ga_calls) == 0 + + await asyncio.sleep(0.1) + assert len(ga_calls) == 1 + assert session._ga_timer is None + + +async def test_pty_session_ga_timer_cancelled_by_new_output(mock_session, monkeypatch): + """GA timer is cancelled when new PTY output arrives.""" + monkeypatch.setattr(sps, "_GA_IDLE", 0.05) + + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + protocol = MagicMock() + protocol.never_send_ga = False + session.writer.protocol = protocol + session.writer.is_closing = MagicMock(return_value=False) + ga_calls = [] + session.writer.send_ga = lambda: ga_calls.append(True) + + session._output_buffer = b"prompt> " + session._flush_remaining() + assert session._ga_timer is not None + + session._write_to_telnet(b"more output\n") + assert session._ga_timer is None + + await asyncio.sleep(0.1) + assert len(ga_calls) == 0 + + +async def test_pty_session_ga_timer_suppressed_by_never_send_ga(mock_session): + """GA timer is not scheduled when never_send_ga is set.""" + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + protocol = MagicMock() + protocol.never_send_ga = True + session.writer.protocol = protocol + + session._output_buffer = b"prompt> " + session._flush_remaining() + assert session._ga_timer is None + + +async def test_pty_session_ga_timer_suppressed_in_raw_mode(mock_session): + """GA timer is not scheduled in raw_mode (e.g. fingerprinting display).""" + session, _ = mock_session({"charset": "utf-8"}, capture_writes=True) + protocol = MagicMock() + protocol.never_send_ga = False + session.writer.protocol = protocol + session.raw_mode = True + + session._output_buffer = b"prompt> " + session._flush_remaining() + assert session._ga_timer is None + + +async def test_pty_session_ga_timer_cancelled_on_cleanup(mock_session, monkeypatch): + """GA timer is cancelled during cleanup.""" + monkeypatch.setattr(sps, "_GA_IDLE", 0.05) + + session, _ = mock_session({"charset": "utf-8"}) + protocol = MagicMock() + protocol.never_send_ga = False + session.writer.protocol = protocol + session.writer.is_closing = MagicMock(return_value=False) + session.writer.send_ga = MagicMock() + session.master_fd = 99 + session.child_pid = 12345 + + session._schedule_ga() + assert session._ga_timer is not None + + with patch("os.close"), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch( + "time.sleep" + ): + session.cleanup() + + assert session._ga_timer is None + await asyncio.sleep(0.1) + session.writer.send_ga.assert_not_called() diff --git a/telnetlib3/tests/test_reader.py b/telnetlib3/tests/test_reader.py index 9d59cc37..9c5f4347 100644 --- a/telnetlib3/tests/test_reader.py +++ b/telnetlib3/tests/test_reader.py @@ -8,8 +8,10 @@ # local import telnetlib3 -from telnetlib3.tests.accessories import ( # pylint: disable=unused-import +from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, + create_server, + open_connection, unused_tcp_port, ) @@ -78,9 +80,6 @@ def fn_encoding(incoming): async def test_telnet_reader_using_readline_unicode(bind_host, unused_tcp_port): """Ensure strict RFC interpretation of newlines in readline method.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given_expected = { "alpha\r\x00": "alpha\r", "bravo\r\n": "bravo\r\n", @@ -93,9 +92,10 @@ async def test_telnet_reader_using_readline_unicode(bind_host, unused_tcp_port): "xxxxxxxxxxx": "xxxxxxxxxxx", } - def shell(reader, writer): + async def shell(reader, writer): for item in sorted(given_expected): writer.write(item) + await writer.drain() writer.close() async with create_server( @@ -115,9 +115,6 @@ def shell(reader, writer): async def test_telnet_reader_using_readline_bytes(bind_host, unused_tcp_port): """Ensure strict RFC interpretation of newlines in readline method.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given_expected = { b"alpha\r\x00": b"alpha\r", b"bravo\r\n": b"bravo\r\n", @@ -155,9 +152,6 @@ def shell(reader, writer): async def test_telnet_reader_read_exactly_unicode(bind_host, unused_tcp_port): """Ensure TelnetReader.readexactly, especially IncompleteReadError.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given = "☭---------" given_partial = "💉-" @@ -186,9 +180,6 @@ def shell(reader, writer): async def test_telnet_reader_read_exactly_bytes(bind_host, unused_tcp_port): """Ensure TelnetReader.readexactly, especially IncompleteReadError.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given = string.ascii_letters.encode("ascii") given_partial = b"zzz" @@ -235,9 +226,6 @@ def fn_encoding(incoming): async def test_telnet_reader_read_beyond_limit_unicode(bind_host, unused_tcp_port): """Ensure ability to read(-1) beyond segment sizes of reader._limit.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - limit = 10 def shell(reader, writer): @@ -263,9 +251,6 @@ def shell(reader, writer): async def test_telnet_reader_read_beyond_limit_bytes(bind_host, unused_tcp_port): """Ensure ability to read(-1) beyond segment sizes of reader._limit.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - limit = 10 def shell(reader, writer): @@ -296,9 +281,6 @@ def shell(reader, writer): async def test_telnet_reader_readuntil_pattern_success(bind_host, unused_tcp_port): """Test successful pattern matching with readuntil_pattern.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given_shell_banner = b""" Router> enable Router# configure terminal @@ -349,9 +331,6 @@ async def test_telnet_reader_readuntil_pattern_limit_overrun_chunk_too_large( bind_host, unused_tcp_port ): """Test LimitOverrunError when pattern is found but chunk exceeds limit.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given_shell_banner = b""" Router> enable Router# configure terminal which is a very long command line that exceeds our limit @@ -408,9 +387,6 @@ async def test_telnet_reader_readuntil_pattern_limit_overrun_buffer_full( bind_host, unused_tcp_port ): """Test LimitOverrunError when buffer exceeds limit and pattern not found.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - # Create data that will exceed the limit when searching for non-existent pattern long_data = b"x" * 50 # exceeds limit of 30 given_shell_banner = b"Router> " + long_data @@ -453,9 +429,6 @@ async def shell(_, writer): async def test_telnet_reader_readuntil_pattern_incomplete_read_eof(bind_host, unused_tcp_port): """Test IncompleteReadError when EOF occurs before pattern is found.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given_shell_banner = b"Router> some incomplete data\n" pattern = re.compile(rb"\S+[>#]") @@ -514,9 +487,6 @@ async def test_telnet_reader_readuntil_pattern_invalid_arguments(): async def test_telnet_reader_readuntil_pattern_cancelled_error(bind_host, unused_tcp_port): """Test CancelledError handling in readuntil_pattern.""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - given_shell_banner = b"Router> " pattern = re.compile(rb"\S+[>#]") diff --git a/telnetlib3/tests/test_server_api.py b/telnetlib3/tests/test_server_api.py index d819b7b2..b00e61ef 100644 --- a/telnetlib3/tests/test_server_api.py +++ b/telnetlib3/tests/test_server_api.py @@ -3,7 +3,7 @@ import asyncio # local -from telnetlib3.telopt import IAC, WONT, TTYPE +from telnetlib3.telopt import IAC, WILL, WONT, TTYPE, BINARY from telnetlib3.tests.accessories import bind_host # pytest fixture from telnetlib3.tests.accessories import unused_tcp_port # pytest fixture from telnetlib3.tests.accessories import ( @@ -85,9 +85,6 @@ async def test_server_sockets(bind_host, unused_tcp_port): async def test_server_with_wait_for(bind_host, unused_tcp_port): """Test integration of Server.wait_for_client() with writer.wait_for().""" - # local - from telnetlib3.telopt import WILL, BINARY - async with create_server( host=bind_host, port=unused_tcp_port, diff --git a/telnetlib3/tests/test_server_cli.py b/telnetlib3/tests/test_server_cli.py index 1dd72305..9a827f20 100644 --- a/telnetlib3/tests/test_server_cli.py +++ b/telnetlib3/tests/test_server_cli.py @@ -2,17 +2,19 @@ # std imports import sys +import asyncio from unittest import mock # 3rd party import pytest +# local +import telnetlib3 +from telnetlib3 import server + def test_pty_support_detection_with_modules(): """PTY_SUPPORT is True when all required modules are available.""" - # local - from telnetlib3 import server - if sys.platform == "win32": assert server.PTY_SUPPORT is False else: @@ -21,9 +23,6 @@ def test_pty_support_detection_with_modules(): def test_parse_server_args_includes_pty_options_when_supported(): """CLI parser includes --pty-exec when PTY is supported.""" - # local - from telnetlib3 import server - if not server.PTY_SUPPORT: pytest.skip("PTY not supported on this platform") @@ -35,9 +34,6 @@ def test_parse_server_args_includes_pty_options_when_supported(): def test_parse_server_args_excludes_pty_options_when_not_supported(): """CLI parser sets PTY options to defaults when PTY is not supported.""" - # local - from telnetlib3 import server - original_support = server.PTY_SUPPORT try: server.PTY_SUPPORT = False @@ -52,16 +48,10 @@ def test_parse_server_args_excludes_pty_options_when_not_supported(): def test_run_server_raises_on_pty_exec_without_support(): """run_server raises NotImplementedError when pty_exec is used without PTY support.""" - # local - from telnetlib3 import server - original_support = server.PTY_SUPPORT try: server.PTY_SUPPORT = False with pytest.raises(NotImplementedError, match="PTY support is not available"): - # std imports - import asyncio - asyncio.run(server.run_server(pty_exec="/bin/bash")) finally: server.PTY_SUPPORT = original_support @@ -69,21 +59,26 @@ def test_run_server_raises_on_pty_exec_without_support(): def test_telnetlib3_import_exposes_pty_support(): """Telnetlib3 package exposes PTY_SUPPORT flag.""" - # local - import telnetlib3 - assert hasattr(telnetlib3, "PTY_SUPPORT") assert isinstance(telnetlib3.PTY_SUPPORT, bool) def test_telnetlib3_pty_shell_exports_conditional(): """pty_shell exports are only in __all__ when PTY is supported.""" - # local - import telnetlib3 - if telnetlib3.PTY_SUPPORT: assert "make_pty_shell" in telnetlib3.__all__ assert "pty_shell" in telnetlib3.__all__ else: assert "make_pty_shell" not in telnetlib3.__all__ assert "pty_shell" not in telnetlib3.__all__ + + +def test_parse_server_args_never_send_ga(): + """--never-send-ga flag is parsed correctly.""" + with mock.patch.object(sys, "argv", ["server"]): + result = server.parse_server_args() + assert result["never_send_ga"] is False + + with mock.patch.object(sys, "argv", ["server", "--never-send-ga"]): + result = server.parse_server_args() + assert result["never_send_ga"] is True diff --git a/telnetlib3/tests/test_server_shell_unit.py b/telnetlib3/tests/test_server_shell_unit.py index b9d631ae..dc716d82 100644 --- a/telnetlib3/tests/test_server_shell_unit.py +++ b/telnetlib3/tests/test_server_shell_unit.py @@ -17,7 +17,6 @@ class DummyWriter: def __init__(self, slctab=None): self.echos = [] self.slctab = slctab or slc_mod.generate_slctab() - # minimal attributes for do_toggle (unused here) self.local_option = types.SimpleNamespace(enabled=lambda opt: False) self.outbinary = False self.inbinary = False @@ -29,10 +28,8 @@ def echo(self, data): def _run_readline(sequence): - """Drive ss.readline coroutine with given sequence and return list of commands produced.""" w = DummyWriter() gen = ss.readline(None, w) - # prime the coroutine gen.send(None) cmds = [] for ch in sequence: @@ -42,158 +39,6 @@ def _run_readline(sequence): return cmds, w.echos -def test_readline_basic_and_crlf_and_backspace(): - # simple command, CR terminator - cmds, echos = _run_readline("foo\r") - assert cmds == ["foo"] - assert "".join(echos).endswith("foo") # echoed chars - - # CRLF pair: the LF after CR should be consumed and not yield an empty command - cmds, echos = _run_readline("bar\r\n") - assert cmds == ["bar"] - - # LF as terminator alone - cmds, _ = _run_readline("baz\n") - assert cmds == ["baz"] - - # CR NUL should be treated like CRLF (LF/NUL consumed after CR) - cmds, _ = _run_readline("zip\r\x00zap\r\n") - assert cmds == ["zip", "zap"] - - # backspace handling (^H and DEL): 'help' after correction - cmds, echos = _run_readline("\bhel\blp\r") - assert cmds == ["help"] - # ensure backspace echoing placed sequence '\b \b' - assert "\b \b" in "".join(echos) - - -def test_character_dump_yields_patterns_and_summary(): - it = ss.character_dump(1) # enter loop - first = next(it) - second = next(it) - assert first.startswith("/" * 80) - assert second.startswith("\\" * 80) - - # when kb_limit is 0, no loop, only the summary line is yielded - summary = list(ss.character_dump(0))[-1] - assert summary.endswith("wrote 0 bytes") - - -def test_get_slcdata_contains_expected_sections(): - writer = DummyWriter(slctab=slc_mod.generate_slctab()) - out = ss.get_slcdata(writer) - assert "Special Line Characters:" in out - # a known supported mapping should appear (like SLC_EC) - assert "SLC_EC" in out - # and known unset entries should be listed - assert "Unset by client:" in out and "SLC_BRK" in out - # and some not-supported entries section is present - assert "Not supported by server:" in out - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.platform == "win32", reason="Terminal class not available on Windows") -async def test_terminal_determine_mode_no_echo_returns_same(monkeypatch): - # Build a dummy telnet_writer with will_echo False - class TW: - will_echo = False - log = types.SimpleNamespace(debug=lambda *a, **k: None) - - # pytest captures stdin; provide a fake with fileno() for Terminal.__init__ - monkeypatch.setattr(sys, "stdin", types.SimpleNamespace(fileno=lambda: 0)) - - term = cs.Terminal(TW()) - ModeDef = cs.Terminal.ModeDef - - # construct a plausible mode tuple (values aren't important here) - base_mode = ModeDef( - iflag=0xFFFF, - oflag=0xFFFF, - cflag=0xFFFF, - lflag=0xFFFF, - ispeed=38400, - ospeed=38400, - cc=[0] * 32, - ) - - result = term.determine_mode(base_mode) - # must be the exact same object when will_echo is False - assert result is base_mode - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.platform == "win32", reason="Terminal class not available on Windows") -async def test_terminal_determine_mode_will_echo_adjusts_flags(monkeypatch): - # Build a dummy telnet_writer with will_echo True - class TW: - will_echo = True - log = types.SimpleNamespace(debug=lambda *a, **k: None) - - # pytest captures stdin; provide a fake with fileno() for Terminal.__init__ - monkeypatch.setattr(sys, "stdin", types.SimpleNamespace(fileno=lambda: 0)) - - term = cs.Terminal(TW()) - ModeDef = cs.Terminal.ModeDef - t = cs.termios - - # Start with flags that should be cleared by determine_mode - iflag = 0 - for flag in (t.BRKINT, t.ICRNL, t.INPCK, t.ISTRIP, t.IXON): - iflag |= flag - - # oflag clears OPOST and ONLCR - oflag = t.OPOST | getattr(t, "ONLCR", 0) - - # cflag: set PARENB and a size other than CS8 to ensure it flips - cflag = t.PARENB | getattr(t, "CS7", 0) | getattr(t, "CREAD", 0) - - # lflag: will clear ICANON | IEXTEN | ISIG | ECHO - lflag = t.ICANON | t.IEXTEN | t.ISIG | t.ECHO - - # cc array with different VMIN/VTIME values that should be overridden - cc = [0] * 32 - cc[t.VMIN] = 0 - cc[t.VTIME] = 1 - - base_mode = ModeDef( - iflag=iflag, - oflag=oflag, - cflag=cflag, - lflag=lflag, - ispeed=38400, - ospeed=38400, - cc=list(cc), - ) - - new_mode = term.determine_mode(base_mode) - - # verify input flags cleared - for flag in (t.BRKINT, t.ICRNL, t.INPCK, t.ISTRIP, t.IXON): - assert not new_mode.iflag & flag - - # verify output flags cleared - assert not new_mode.oflag & t.OPOST - if hasattr(t, "ONLCR"): - assert not new_mode.oflag & t.ONLCR - - # verify cflag: PARENB cleared, CS8 set, CSIZE cleared except CS8 - assert not new_mode.cflag & t.PARENB - assert new_mode.cflag & t.CS8 - # CSIZE mask bits should be exactly CS8 now - assert (new_mode.cflag & t.CSIZE) == t.CS8 - # CREAD (if present) should remain unchanged - if hasattr(t, "CREAD") and (cflag & t.CREAD): - assert new_mode.cflag & t.CREAD - - # verify lflag cleared for ICANON, IEXTEN, ISIG, ECHO - for flag in (t.ICANON, t.IEXTEN, t.ISIG, t.ECHO): - assert not new_mode.lflag & flag - - # cc changes - assert new_mode.cc[t.VMIN] == 1 - assert new_mode.cc[t.VTIME] == 0 - - class MockReader: def __init__(self, data): self._data = list(data) @@ -213,11 +58,18 @@ async def read(self, n): return "" +class _MockProtocol: + def __init__(self, never_send_ga=False): + self.never_send_ga = never_send_ga + + class MockWriter: - def __init__(self): + def __init__(self, protocol=None): self.written = [] self._closing = False self._extra = {"peername": ("127.0.0.1", 12345)} + self.protocol = protocol or _MockProtocol() + self.ga_calls = [] def write(self, data): self.written.append(data) @@ -234,6 +86,85 @@ def is_closing(self): def echo(self, data): self.written.append(data) + def close(self): + self._closing = True + + def send_ga(self): + self.ga_calls.append(True) + return True + + +def test_readline_basic_and_crlf_and_backspace(): + cmds, echos = _run_readline("foo\r") + assert cmds == ["foo"] + assert "".join(echos).endswith("foo") + + cmds, _ = _run_readline("bar\r\n") + assert cmds == ["bar"] + + cmds, _ = _run_readline("baz\n") + assert cmds == ["baz"] + + cmds, _ = _run_readline("zip\r\x00zap\r\n") + assert cmds == ["zip", "zap"] + + cmds, echos = _run_readline("\bhel\blp\r") + assert cmds == ["help"] + assert "\b \b" in "".join(echos) + + +def test_character_dump_yields_patterns_and_summary(): + it = ss.character_dump(1) + assert next(it).startswith("/" * 80) + assert next(it).startswith("\\" * 80) + assert list(ss.character_dump(0))[-1].endswith("wrote 0 bytes") + + +def test_get_slcdata_contains_expected_sections(): + out = ss.get_slcdata(DummyWriter(slctab=slc_mod.generate_slctab())) + assert "Special Line Characters:" in out + assert "SLC_EC" in out + assert "Unset by client:" in out and "SLC_BRK" in out + assert "Not supported by server:" in out + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.platform == "win32", reason="requires termios") +async def test_terminal_determine_mode(monkeypatch): + monkeypatch.setattr(sys, "stdin", types.SimpleNamespace(fileno=lambda: 0)) + tw = types.SimpleNamespace( + will_echo=False, + log=types.SimpleNamespace(debug=lambda *a, **k: None), + ) + term = cs.Terminal(tw) + mode = cs.Terminal.ModeDef(0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 38400, 38400, [0] * 32) + assert term.determine_mode(mode) is mode + + tw.will_echo = True + t = cs.termios + cc = [0] * 32 + cc[t.VMIN] = 0 + cc[t.VTIME] = 1 + mode = cs.Terminal.ModeDef( + t.BRKINT | t.ICRNL | t.INPCK | t.ISTRIP | t.IXON, + t.OPOST | getattr(t, "ONLCR", 0), + t.PARENB | getattr(t, "CS7", 0), + t.ICANON | t.IEXTEN | t.ISIG | t.ECHO, + 38400, + 38400, + list(cc), + ) + new = term.determine_mode(mode) + for flag in (t.BRKINT, t.ICRNL, t.INPCK, t.ISTRIP, t.IXON): + assert not new.iflag & flag + assert not new.oflag & t.OPOST + assert not new.cflag & t.PARENB + assert new.cflag & t.CS8 + for flag in (t.ICANON, t.IEXTEN, t.ISIG, t.ECHO): + assert not new.lflag & flag + assert new.cc[t.VMIN] == 1 + assert new.cc[t.VTIME] == 0 + @pytest.mark.parametrize( "limit,acquires,expected_count,expected_results", @@ -275,29 +206,19 @@ def test_connection_counter_release(): ) @pytest.mark.asyncio async def test_read_line_inner(input_data, max_len, expected): - reader = MockReader(list(input_data)) - result = await gs._read_line_inner(reader, max_len) - assert result == expected + assert await gs._read_line_inner(MockReader(list(input_data)), max_len) == expected @pytest.mark.asyncio -async def test_read_line_with_timeout_success(): - reader = MockReader(list("hello\r")) - result = await gs._read_line(reader, timeout=5.0) - assert result == "hello" - - -@pytest.mark.asyncio -async def test_read_line_with_timeout_expires(): - result = await gs._read_line(SlowReader(), timeout=0.01) - assert result is None +async def test_read_line_with_timeout(): + assert await gs._read_line(MockReader(list("hello\r")), timeout=5.0) == "hello" + assert await gs._read_line(SlowReader(), timeout=0.01) is None @pytest.mark.asyncio async def test_robot_shell_full_conversation(): - reader = MockReader(["y", "\r", "n", "o", "\r"]) writer = MockWriter() - await gs.robot_shell(reader, writer) + await gs.robot_shell(MockReader(["y", "\r", "n", "o", "\r"]), writer) written = "".join(writer.written) assert "Do robots dream of electric sheep?" in written assert "windowmakers" in written @@ -305,9 +226,8 @@ async def test_robot_shell_full_conversation(): @pytest.mark.asyncio async def test_busy_shell_full_conversation(): - reader = MockReader(["h", "i", "\r", "x", "\r"]) writer = MockWriter() - await gs.busy_shell(reader, writer) + await gs.busy_shell(MockReader(["h", "i", "\r", "x", "\r"]), writer) written = "".join(writer.written) assert "Machine is busy" in written assert "distant explosion" in written @@ -326,10 +246,7 @@ async def test_busy_shell_full_conversation(): ) @pytest.mark.asyncio async def test_filter_ansi(input_chars, expected): - reader = MockReader(input_chars) - writer = MockWriter() - result = await ss.filter_ansi(reader, writer) - assert result == expected + assert await ss.filter_ansi(MockReader(input_chars), MockWriter()) == expected @pytest.mark.parametrize( @@ -339,14 +256,13 @@ async def test_filter_ansi(input_chars, expected): pytest.param(["h", "x", "\x7f", "i", "\r"], "hi", id="with_backspace"), pytest.param(["\n", "\x00", "a", "\r"], "a", id="ignores_initial_lf_nul"), pytest.param(["a", ""], None, id="returns_none_on_eof"), + pytest.param(["\x7f", "a", "\r"], "a", id="backspace_on_empty"), + pytest.param(["\b", "\b", "x", "\r"], "x", id="multiple_backspace_on_empty"), ], ) @pytest.mark.asyncio async def test_readline2(input_chars, expected): - reader = MockReader(input_chars) - writer = MockWriter() - result = await ss.readline2(reader, writer) - assert result == expected + assert await ss.readline2(MockReader(input_chars), MockWriter()) == expected @pytest.mark.parametrize( @@ -355,252 +271,202 @@ async def test_readline2(input_chars, expected): pytest.param(["a"], False, "a", id="normal"), pytest.param(["\x1b", "A", "x"], False, "x", id="skips_escape"), pytest.param([], True, None, id="returns_none_when_closing"), + pytest.param(["\x1b", "1", "A", "x"], False, "x", id="escape_non_letter"), ], ) @pytest.mark.asyncio async def test_get_next_ascii(input_chars, closing, expected): - reader = MockReader(input_chars) writer = MockWriter() writer._closing = closing - result = await ss.get_next_ascii(reader, writer) - assert result == expected - - -class CPRReader: - def __init__(self, data): - self._data = list(data) - self._idx = 0 - - async def read(self, n): - if self._idx >= len(self._data): - return b"" - result = self._data[self._idx] - self._idx += 1 - return result + assert await ss.get_next_ascii(MockReader(input_chars), writer) == expected @pytest.mark.parametrize( "input_data,expected", [ - pytest.param([b"\x1b", b"[", b"1", b"0", b";", b"2", b"0", b"R"], (10, 20), id="valid_cpr"), + pytest.param([b"\x1b", b"[", b"1", b"0", b";", b"2", b"0", b"R"], (10, 20), id="valid"), pytest.param([b"\x1b", b"[", b"1", b";", b"1", b"R"], (1, 1), id="single_digit"), - pytest.param( - [b"\x1b", b"[", b"2", b"5", b";", b"8", b"0", b"R"], (25, 80), id="typical_size" - ), + pytest.param([b"\x1b", b"[", b"2", b"5", b";", b"8", b"0", b"R"], (25, 80), id="typical"), pytest.param([b""], None, id="eof"), pytest.param( [b"g", b"a", b"r", b"\x1b", b"[", b"5", b";", b"3", b"R"], (5, 3), id="garbage_prefix" ), + pytest.param([b"R", b"\x1b", b"[", b"3", b";", b"7", b"R"], (3, 7), id="R_without_match"), + pytest.param(list("\x1b[5;10R"), (5, 10), id="string_input"), ], ) @pytest.mark.asyncio async def test_read_cpr_response(input_data, expected): - reader = CPRReader(input_data) - result = await gs._read_cpr_response(reader) - assert result == expected + assert await gs._read_cpr_response(MockReader(input_data)) == expected @pytest.mark.asyncio -async def test_read_cpr_response_string_input(): - class StringReader: +async def test_read_cpr_response_unicode_decode_error(): + class BadReader: def __init__(self): - self._data = list("\x1b[5;10R") - self._idx = 0 + self._call = 0 async def read(self, n): - if self._idx >= len(self._data): - return "" - result = self._data[self._idx] - self._idx += 1 - return result - - reader = StringReader() - result = await gs._read_cpr_response(reader) - assert result == (5, 10) - - -class CPRMockWriter: - def __init__(self): - self.written = [] + self._call += 1 + if self._call == 1: + raise UnicodeDecodeError("utf-8", b"\xff", 0, 1, "invalid") + return b"" - def write(self, data): - self.written.append(data) - - async def drain(self): - pass + assert await gs._read_cpr_response(BadReader()) is None @pytest.mark.asyncio async def test_get_cursor_position_success(): - reader = CPRReader([b"\x1b", b"[", b"1", b"0", b";", b"2", b"0", b"R"]) - writer = CPRMockWriter() - result = await gs._get_cursor_position(reader, writer, timeout=1.0) - assert result == (10, 20) + reader = MockReader([b"\x1b", b"[", b"1", b"0", b";", b"2", b"0", b"R"]) + writer = MockWriter() + assert await gs._get_cursor_position(reader, writer, timeout=1.0) == (10, 20) assert "\x1b[6n" in writer.written @pytest.mark.asyncio -async def test_get_cursor_position_timeout(): - result = await gs._get_cursor_position(SlowReader(), CPRMockWriter(), timeout=0.01) - assert result == (None, None) - - -@pytest.mark.asyncio -async def test_get_cursor_position_eof(): - reader = CPRReader([b""]) - writer = CPRMockWriter() - result = await gs._get_cursor_position(reader, writer, timeout=1.0) - assert result == (None, None) +async def test_get_cursor_position_failure(): + assert await gs._get_cursor_position(SlowReader(), MockWriter(), timeout=0.01) == (None, None) + assert await gs._get_cursor_position(MockReader([b""]), MockWriter(), timeout=1.0) == ( + None, + None, + ) +@pytest.mark.parametrize( + "responses,expected", + [ + pytest.param([(1, 5), (1, 7)], 2, id="success"), + pytest.param([(None, None)], None, id="first_cpr_fails"), + pytest.param([(1, 5), (None, None)], None, id="second_cpr_fails"), + ], +) @pytest.mark.asyncio -async def test_measure_width_success(monkeypatch): - positions = iter([(1, 5), (1, 7)]) +async def test_measure_width(monkeypatch, responses, expected): + responses_iter = iter(responses) - async def mock_get_cursor_position(reader, writer, timeout): - return next(positions) + async def mock_gcp(reader, writer, timeout): + return next(responses_iter) - monkeypatch.setattr(gs, "_get_cursor_position", mock_get_cursor_position) - writer = CPRMockWriter() - result = await gs._measure_width(None, writer, "ab", timeout=1.0) - assert result == 2 - assert any("\x1b[5G" in w for w in writer.written) + monkeypatch.setattr(gs, "_get_cursor_position", mock_gcp) + assert await gs._measure_width(None, MockWriter(), "ab", timeout=1.0) == expected +@pytest.mark.parametrize( + "width,expected", + [ + pytest.param(2, True, id="width_2"), + pytest.param(1, False, id="width_1"), + pytest.param(None, False, id="width_none"), + ], +) @pytest.mark.asyncio -async def test_measure_width_first_cpr_fails(monkeypatch): - async def mock_get_cursor_position(reader, writer, timeout): - return (None, None) +async def test_robot_check(monkeypatch, width, expected): + async def mock_measure(r, w, text, timeout): + return width - monkeypatch.setattr(gs, "_get_cursor_position", mock_get_cursor_position) - result = await gs._measure_width(None, CPRMockWriter(), "x", timeout=1.0) - assert result is None + monkeypatch.setattr(gs, "_measure_width", mock_measure) + assert await gs.robot_check(None, None, timeout=1.0) is expected @pytest.mark.asyncio -async def test_measure_width_second_cpr_fails(monkeypatch): - call_count = [0] - - async def mock_get_cursor_position(reader, writer, timeout): - call_count[0] += 1 - if call_count[0] == 1: - return (1, 5) - return (None, None) - - monkeypatch.setattr(gs, "_get_cursor_position", mock_get_cursor_position) - result = await gs._measure_width(None, CPRMockWriter(), "x", timeout=1.0) - assert result is None - - -@pytest.mark.asyncio -async def test_robot_check_returns_true_when_width_is_2(monkeypatch): - async def mock_measure_width(reader, writer, text, timeout): - return 2 - - monkeypatch.setattr(gs, "_measure_width", mock_measure_width) - result = await gs.robot_check(None, None, timeout=1.0) - assert result is True +async def test_readline_with_echo_timeout(): + assert await gs._readline_with_echo(SlowReader(), MockWriter(), timeout=0.01) is None +@pytest.mark.parametrize( + "timeout_at", + [pytest.param(1, id="first_question"), pytest.param(2, id="second_question")], +) @pytest.mark.asyncio -async def test_robot_check_returns_false_when_width_is_not_2(monkeypatch): - async def mock_measure_width(reader, writer, text, timeout): - return 1 - - monkeypatch.setattr(gs, "_measure_width", mock_measure_width) - result = await gs.robot_check(None, None, timeout=1.0) - assert result is False - +async def test_robot_shell_timeout(monkeypatch, timeout_at): + call_count = [0] -@pytest.mark.asyncio -async def test_robot_check_returns_false_when_width_is_none(monkeypatch): - async def mock_measure_width(reader, writer, text, timeout): - return None + async def mock_readline_with_echo(reader, writer, timeout): + call_count[0] += 1 + if call_count[0] == timeout_at: + return None + return "y" - monkeypatch.setattr(gs, "_measure_width", mock_measure_width) - result = await gs.robot_check(None, None, timeout=1.0) - assert result is False + monkeypatch.setattr(gs, "_readline_with_echo", mock_readline_with_echo) + await gs.robot_shell(MockReader([]), MockWriter()) + assert call_count[0] == timeout_at @pytest.mark.asyncio -async def test_robot_shell_timeout_on_first_question(monkeypatch): +async def test_busy_shell_timeout(monkeypatch): call_count = [0] - original_read_line = gs._read_line async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): call_count[0] += 1 - if call_count[0] == 1: - return None - return await original_read_line(reader, timeout, max_len) + return None if call_count[0] == 1 else "hi" monkeypatch.setattr(gs, "_read_line", mock_read_line) - - writer = MockWriter() - await gs.robot_shell(MockReader([]), writer) - written = "".join(writer.written) - assert "Do robots dream of electric sheep?" in written - assert call_count[0] == 1 + await gs.busy_shell(MockReader([]), MockWriter()) + assert call_count[0] == 2 @pytest.mark.asyncio -async def test_robot_shell_timeout_on_second_question(monkeypatch): +async def test_ask_question_blank_then_answer(monkeypatch): call_count = [0] - original_read_line = gs._read_line - async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): + async def mock_readline_with_echo(reader, writer, timeout): call_count[0] += 1 - if call_count[0] == 1: - return "y" - if call_count[0] == 2: - return None - return await original_read_line(reader, timeout, max_len) - - monkeypatch.setattr(gs, "_read_line", mock_read_line) + return " " if call_count[0] == 1 else "answer" - writer = MockWriter() - await gs.robot_shell(MockReader([]), writer) - written = "".join(writer.written) - assert "Do robots dream of electric sheep?" in written - assert "windowmakers" in written + monkeypatch.setattr(gs, "_readline_with_echo", mock_readline_with_echo) + assert await gs._ask_question(None, MockWriter(), "q? ", timeout=5.0) == "answer" assert call_count[0] == 2 +@pytest.mark.parametrize( + "never_send_ga,expect_ga", + [ + pytest.param(False, True, id="ga_sent"), + pytest.param(True, False, id="ga_suppressed"), + ], +) @pytest.mark.asyncio -async def test_busy_shell_timeout_on_first_input(monkeypatch): - call_count = [0] - original_read_line = gs._read_line - - async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): - call_count[0] += 1 - if call_count[0] == 1: - return None - return await original_read_line(reader, timeout, max_len) +async def test_telnet_server_shell_ga(never_send_ga, expect_ga): + reader = MockReader(list("quit\r")) + writer = MockWriter(protocol=_MockProtocol(never_send_ga=never_send_ga)) + await ss.telnet_server_shell(reader, writer) + assert (len(writer.ga_calls) >= 1) == expect_ga - monkeypatch.setattr(gs, "_read_line", mock_read_line) +@pytest.mark.asyncio +async def test_telnet_server_shell_dump_with_delay(monkeypatch): + slept = [] + _real_sleep = asyncio.sleep + monkeypatch.setattr(asyncio, "sleep", lambda d: slept.append(d) or _real_sleep(0)) + reader = MockReader(list("dump 0 1000\r") + list("quit\r")) writer = MockWriter() - await gs.busy_shell(MockReader([]), writer) + await ss.telnet_server_shell(reader, writer) written = "".join(writer.written) - assert "Machine is busy" in written + assert "kb_limit=0" in written + assert "delay=1" in written @pytest.mark.asyncio -async def test_busy_shell_timeout_on_second_input(monkeypatch): - call_count = [0] - original_read_line = gs._read_line - - async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): - call_count[0] += 1 - if call_count[0] == 1: - return "hi" - if call_count[0] == 2: - return None - return await original_read_line(reader, timeout, max_len) - - monkeypatch.setattr(gs, "_read_line", mock_read_line) - +async def test_telnet_server_shell_dump_with_explicit_kb(): writer = MockWriter() - await gs.busy_shell(MockReader([]), writer) + await ss.telnet_server_shell(MockReader(list("dump 0\r") + list("quit\r")), writer) written = "".join(writer.written) - assert "Machine is busy" in written - assert "distant explosion" in written + assert "kb_limit=0" in written + assert "wrote 0 bytes" in written + + +@pytest.mark.asyncio +async def test_telnet_server_shell_dump_closing(): + class _ClosingWriter(MockWriter): + def write(self, data): + super().write(data) + if "kb_limit=" in data: + self._closing = True + + w1 = _ClosingWriter() + await ss.telnet_server_shell(MockReader(list("dump\r") + list("quit\r")), w1) + assert "kb_limit=1000" in "".join(w1.written) + + w2 = _ClosingWriter() + await ss.telnet_server_shell(MockReader(list("dump 1\r") + list("quit\r")), w2) + assert "1 OK" not in "".join(w2.written) diff --git a/telnetlib3/tests/test_shell.py b/telnetlib3/tests/test_shell.py index a248727e..e1417527 100644 --- a/telnetlib3/tests/test_shell.py +++ b/telnetlib3/tests/test_shell.py @@ -5,19 +5,20 @@ import logging # local -# local imports +from telnetlib3 import accessories, telnet_server_shell +from telnetlib3.telopt import DO, IAC, SGA, ECHO, WILL, WONT, TTYPE, BINARY from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, + create_server, + asyncio_server, + open_connection, unused_tcp_port, + asyncio_connection, ) async def test_telnet_server_shell_as_coroutine(bind_host, unused_tcp_port): """Test callback shell(reader, writer) as coroutine of create_server().""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() send_input = "Alpha" expect_output = "Beta" @@ -58,9 +59,6 @@ async def shell(reader, writer): async def test_telnet_client_shell_as_coroutine(bind_host, unused_tcp_port): """Test callback shell(reader, writer) as coroutine of create_server().""" - # local - from telnetlib3.tests.accessories import asyncio_server, open_connection - _waiter = asyncio.Future() async def shell(reader, writer): @@ -80,10 +78,6 @@ async def shell(reader, writer): async def test_telnet_server_shell_make_coro_by_function(bind_host, unused_tcp_port): """Test callback shell(reader, writer) as function, for create_server().""" - # local - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() def shell(reader, writer): @@ -99,10 +93,6 @@ def shell(reader, writer): async def test_telnet_server_no_shell(bind_host, unused_tcp_port): """Test telnetlib3.TelnetServer() instantiation and connection_made().""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - client_expected = IAC + DO + TTYPE + b"beta" async with create_server(host=bind_host, port=unused_tcp_port) as server: @@ -124,11 +114,6 @@ async def test_telnet_server_given_shell( bind_host, unused_tcp_port ): # pylint: disable=too-many-locals """Iterate all state-reading commands of default telnet_server_shell.""" - # local - from telnetlib3 import telnet_server_shell - from telnetlib3.telopt import DO, IAC, SGA, ECHO, WILL, WONT, TTYPE, BINARY - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -136,6 +121,7 @@ async def test_telnet_server_given_shell( connect_maxwait=0.05, timeout=1.25, limit=13377, + never_send_ga=True, ) as server: async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): expected = IAC + DO + TTYPE @@ -297,11 +283,6 @@ async def test_telnet_server_given_shell( async def test_telnet_server_shell_eof(bind_host, unused_tcp_port): """Test EOF in telnet_server_shell().""" - # local - from telnetlib3 import telnet_server_shell - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -318,11 +299,6 @@ async def test_telnet_server_shell_eof(bind_host, unused_tcp_port): async def test_telnet_server_shell_version_command(bind_host, unused_tcp_port): """Test version command in telnet_server_shell.""" - # local - from telnetlib3 import accessories, telnet_server_shell - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -361,11 +337,6 @@ async def test_telnet_server_shell_version_command(bind_host, unused_tcp_port): async def test_telnet_server_shell_dump_with_kb_limit(bind_host, unused_tcp_port): """Test dump command with explicit kb_limit.""" - # local - from telnetlib3 import telnet_server_shell - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -401,11 +372,6 @@ async def test_telnet_server_shell_dump_with_kb_limit(bind_host, unused_tcp_port async def test_telnet_server_shell_dump_with_all_options(bind_host, unused_tcp_port): """Test dump command with all options including close.""" - # local - from telnetlib3 import telnet_server_shell - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -440,11 +406,6 @@ async def test_telnet_server_shell_dump_with_all_options(bind_host, unused_tcp_p async def test_telnet_server_shell_dump_nodrain(bind_host, unused_tcp_port): """Test dump command with nodrain option.""" - # local - from telnetlib3 import telnet_server_shell - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, @@ -480,11 +441,6 @@ async def test_telnet_server_shell_dump_nodrain(bind_host, unused_tcp_port): async def test_telnet_server_shell_dump_large_output(bind_host, unused_tcp_port): """Test dump command with larger output.""" - # local - from telnetlib3 import telnet_server_shell - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server( host=bind_host, port=unused_tcp_port, diff --git a/telnetlib3/tests/test_status_logger.py b/telnetlib3/tests/test_status_logger.py index ab3eeb7f..fd5f73ba 100644 --- a/telnetlib3/tests/test_status_logger.py +++ b/telnetlib3/tests/test_status_logger.py @@ -116,13 +116,14 @@ def clients(self): status_one = { "count": 1, - "clients": [{"ip": "127.0.0.1", "port": 12345, "rx": 100, "tx": 200}], + "clients": [{"ip": "127.0.0.1", "port": 12345, "rx": 100, "tx": 200, "idle": 5}], } formatted = status_logger._format_status(status_one) assert "1 client(s)" in formatted assert "127.0.0.1:12345" in formatted assert "rx=100" in formatted assert "tx=200" in formatted + assert "idle=5" in formatted async def test_status_logger_start_stop(bind_host, unused_tcp_port): diff --git a/telnetlib3/tests/test_stream_reader_extra.py b/telnetlib3/tests/test_stream_reader_extra.py index 2ebdcd9c..5db88818 100644 --- a/telnetlib3/tests/test_stream_reader_extra.py +++ b/telnetlib3/tests/test_stream_reader_extra.py @@ -113,8 +113,7 @@ async def test_pause_and_resume_transport_based_on_buffer_limit(): async def test_anext_iterates_lines_and_stops_on_eof(): r = TelnetReader() r.feed_data(b"Line1\nLine2\n") - # first line (using __anext__ for Python 3.8/3.9 compat; anext() is 3.10+) - one = await r.__anext__() + one = await r.__anext__() # anext() is 3.10+ assert one == b"Line1\n" # second line two = await r.__anext__() diff --git a/telnetlib3/tests/test_stream_writer_extra.py b/telnetlib3/tests/test_stream_writer_extra.py index b0ccf742..3ef3e01f 100644 --- a/telnetlib3/tests/test_stream_writer_extra.py +++ b/telnetlib3/tests/test_stream_writer_extra.py @@ -6,7 +6,7 @@ import pytest # local -from telnetlib3 import slc +from telnetlib3 import slc, client_base from telnetlib3.telopt import ( DO, IS, @@ -35,7 +35,9 @@ LFLOW_RESTART_ANY, LFLOW_RESTART_XON, ) -from telnetlib3.stream_writer import TelnetWriter +from telnetlib3.client_base import BaseClient +from telnetlib3.server_base import BaseServer +from telnetlib3.stream_writer import TelnetWriter, _encode_env_buf class MockTransport: @@ -53,6 +55,12 @@ def is_closing(self): def get_extra_info(self, name, default=None): return self.extra.get(name, default) + def pause_reading(self): + pass + + def resume_reading(self): + pass + def close(self): self._closing = True @@ -284,9 +292,6 @@ def test_handle_sb_ttype_is_and_send(): def _encode_env(env): """Helper to encode env dict like _encode_env_buf would, for tests.""" - # local - from telnetlib3.stream_writer import _encode_env_buf - return _encode_env_buf(env) @@ -461,15 +466,21 @@ def test_handle_sb_status_send_and_is(): ws2._handle_sb_status(buf2) -def test_handle_sb_forwardmask_assertions_and_do_raises_notimplemented(): - # client end receiving DO FORWARDMASK must have WILL LINEMODE True +def test_handle_sb_forwardmask_do_raises_notimplemented(): wc, _, _ = new_writer(server=False, client=True) wc.local_option[LINEMODE] = True - # DO with some bytes must call _handle_do_forwardmask -> NotImplementedError - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): wc._handle_sb_forwardmask(DO, collections.deque([b"x", b"y"])) +def test_handle_sb_linemode_mode_empty_buffer(): + ws, _, _ = new_writer(server=True) + ws.local_option[LINEMODE] = True + ws.remote_option[LINEMODE] = True + with pytest.raises(ValueError, match="missing mode byte"): + ws._handle_sb_linemode_mode(collections.deque()) + + def test_handle_sb_linemode_switches(): ws, ts, _ = new_writer(server=True) @@ -508,3 +519,56 @@ def test_handle_subnegotiation_dispatch_and_unhandled(): ws._handle_sb_naws(buf) # unhandled command + with pytest.raises(ValueError, match="SB unhandled"): + ws.handle_subnegotiation(collections.deque([b"\x99", b"\x00"])) + + +async def test_server_data_received_split_sb_linemode(): + class NoNegServer(BaseServer): + def begin_negotiation(self): + pass + + def _check_negotiation_timer(self): + pass + + transport = MockTransport() + server = NoNegServer(encoding=False) + server.connection_made(transport) + + server.writer.remote_option[LINEMODE] = True + server.writer.local_option[LINEMODE] = True + + transport.writes.clear() + + chunk1 = IAC + SB + LINEMODE + slc.LMODE_MODE + server.data_received(chunk1) + assert server.writer.is_oob + + mask_byte = b"\x10" + chunk2 = mask_byte + IAC + SE + server.data_received(chunk2) + + response = b"".join(transport.writes) + assert IAC + SB + LINEMODE + slc.LMODE_MODE in response + + +async def test_client_process_chunk_split_sb_linemode(): + transport = MockTransport() + client = BaseClient(encoding=False) + client.connection_made(transport) + + client.writer.remote_option[LINEMODE] = True + client.writer.local_option[LINEMODE] = True + + transport.writes.clear() + + chunk1 = IAC + SB + LINEMODE + slc.LMODE_MODE + client._process_chunk(chunk1) + assert client.writer.is_oob + + mask_byte = b"\x10" + chunk2 = mask_byte + IAC + SE + client._process_chunk(chunk2) + + response = b"".join(transport.writes) + assert IAC + SB + LINEMODE + slc.LMODE_MODE in response diff --git a/telnetlib3/tests/test_stream_writer_full.py b/telnetlib3/tests/test_stream_writer_full.py index a3cef99c..40d5282e 100644 --- a/telnetlib3/tests/test_stream_writer_full.py +++ b/telnetlib3/tests/test_stream_writer_full.py @@ -14,9 +14,11 @@ SB, SE, TM, + ESC, IAC, NOP, SGA, + VAR, DONT, ECHO, GMCP, @@ -34,6 +36,7 @@ TSPEED, CHARSET, REQUEST, + USERVAR, ACCEPTED, LINEMODE, REJECTED, @@ -267,14 +270,24 @@ def test_handle_will_invalid_cases_and_else_unhandled(): w3.set_ext_callback(LOGOUT, lambda cmd: seen.setdefault("v", cmd)) w3.handle_will(LOGOUT) assert seen["v"] == WILL - # ELSE branch (unhandled) -> DONT sent, options set -1, pending cleared + # ELSE branch (unhandled) -> DONT sent, pending cleared, rejected tracked w4, t4, _ = new_writer(server=True) w4.pending_option[DO + GMCP] = True w4.handle_will(GMCP) assert t4.writes[-1] == IAC + DONT + GMCP - assert w4.remote_option[GMCP] == -1 - assert w4.local_option[GMCP] == -1 assert not w4.pending_option.get(DO + GMCP, False) + assert GMCP in w4.rejected_will + + +def test_handle_will_then_do_unsupported_sends_both_dont_and_wont(): + """WILL then DO for unsupported option must send DONT and WONT.""" + w, t, _ = new_writer(server=True) + w.handle_will(COM_PORT_OPTION) + assert t.writes[-1] == IAC + DONT + COM_PORT_OPTION + assert COM_PORT_OPTION in w.rejected_will + w.handle_do(COM_PORT_OPTION) + assert t.writes[-1] == IAC + WONT + COM_PORT_OPTION + assert COM_PORT_OPTION in w.rejected_do def test_handle_wont_tm_and_logout_paths(): @@ -497,9 +510,6 @@ def test_option_enabled_and_setitem_debug_path(): def test_escape_unescape_and_env_encode_decode_roundtrip(): # escaping VAR/USERVAR - # local - from telnetlib3.telopt import ESC, VAR, USERVAR - buf = b"A" + VAR + b"B" + USERVAR + b"C" esc = _escape_environ(buf) assert VAR in esc and USERVAR in esc and esc.count(ESC) == 2 @@ -752,13 +762,44 @@ def test_handle_sb_forwardmask_server_will_and_client_do(): opt = SB + LINEMODE + slc.LMODE_FORWARDMASK assert ws.remote_option[opt] is True - # client DO path currently asserts that bytes must follow DO (pre-check) + # client DO path -> _handle_do_forwardmask -> NotImplementedError wc, tc, pc = new_writer(server=False, client=True) wc.local_option[LINEMODE] = True - with pytest.raises(AssertionError): + with pytest.raises(NotImplementedError): wc._handle_sb_forwardmask(DO, collections.deque([b"x"])) +def test_handle_sb_forwardmask_server_without_linemode(): + ws, ts, ps = new_writer(server=True) + ws._handle_sb_forwardmask(WILL, collections.deque()) + opt = SB + LINEMODE + slc.LMODE_FORWARDMASK + assert ws.remote_option[opt] is True + + +def test_handle_sb_forwardmask_server_rejects_do_dont(): + ws, ts, ps = new_writer(server=True) + ws.remote_option[LINEMODE] = True + ws._handle_sb_forwardmask(DO, collections.deque()) + opt = SB + LINEMODE + slc.LMODE_FORWARDMASK + assert opt not in ws.remote_option + + +def test_handle_sb_forwardmask_client_without_linemode(): + wc, tc, pc = new_writer(server=False, client=True) + wc._handle_sb_forwardmask(DONT, collections.deque()) + opt = SB + LINEMODE + slc.LMODE_FORWARDMASK + assert wc.local_option[opt] is False + + +def test_handle_sb_linemode_passes_opt_to_forwardmask(): + ws, ts, ps = new_writer(server=True) + ws.remote_option[LINEMODE] = True + buf = collections.deque([LINEMODE, WONT, slc.LMODE_FORWARDMASK]) + ws._handle_sb_linemode(buf) + opt = SB + LINEMODE + slc.LMODE_FORWARDMASK + assert ws.remote_option[opt] is False + + def test_slc_add_buffer_full_raises(): w, t, p = new_writer(server=True) # fill buffer to maximum diff --git a/telnetlib3/tests/test_sync.py b/telnetlib3/tests/test_sync.py index 5c90e704..c12799b2 100644 --- a/telnetlib3/tests/test_sync.py +++ b/telnetlib3/tests/test_sync.py @@ -49,7 +49,7 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: conn.write("hello") @@ -69,7 +69,7 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: line = conn.readline(timeout=5) @@ -88,7 +88,7 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: data = conn.read_until(">>> ", timeout=5) @@ -107,7 +107,7 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: assert "test" in conn.read_some(timeout=5) @@ -150,7 +150,7 @@ def test_server_accept(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5): time.sleep(0.5) @@ -175,7 +175,7 @@ def handler(conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.2) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: conn.write("test") @@ -215,7 +215,7 @@ def test_server_connection_read_write(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: conn.write("hello") conn.flush() @@ -264,7 +264,7 @@ def test_server_connection_miniboa_properties(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5): time.sleep(0.5) @@ -293,7 +293,7 @@ def test_server_connection_miniboa_methods(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: conn.write("test\r\n") conn.flush() @@ -327,7 +327,7 @@ def test_server_connection_send_converts_newlines(bind_host, unused_tcp_port): received = [] def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: received.append(conn.read(20, timeout=5)) @@ -367,7 +367,7 @@ def test_server_connection_writer_property(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5): time.sleep(0.5) @@ -437,7 +437,7 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: with pytest.raises(TimeoutError, match="Read timed out"): @@ -457,11 +457,11 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: with pytest.raises(TimeoutError, match="Readline timed out"): - conn.readline(timeout=0.2) + conn.readline(timeout=0.1) server.shutdown() @@ -479,7 +479,7 @@ def test_server_connection_timeout(bind_host, unused_tcp_port, method, args, err server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5): time.sleep(2) @@ -499,7 +499,7 @@ def test_server_connection_read_until_timeout(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: conn.write("no match here") conn.flush() @@ -510,7 +510,7 @@ def client_thread(): conn = server.accept(timeout=5) with pytest.raises(TimeoutError, match="Read until timed out"): - conn.read_until(">>> ", timeout=0.2) + conn.read_until(">>> ", timeout=0.1) conn.close() server.shutdown() @@ -521,7 +521,7 @@ def test_server_connection_wait_for_timeout(bind_host, unused_tcp_port): server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5): time.sleep(1.0) @@ -550,7 +550,7 @@ def test_server_connection_methods_closed_error(bind_host, unused_tcp_port, meth server.start() def client_thread(): - time.sleep(0.1) + time.sleep(0.05) with TelnetConnection(bind_host, unused_tcp_port, timeout=5): time.sleep(0.2) @@ -584,7 +584,7 @@ def handler(server_conn): server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() - time.sleep(0.1) + server._started.wait(timeout=5) with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: with pytest.raises(EOFError, match="Connection closed before match found"): diff --git a/telnetlib3/tests/test_telnetlib.py b/telnetlib3/tests/test_telnetlib.py index ba439d63..922ba23b 100644 --- a/telnetlib3/tests/test_telnetlib.py +++ b/telnetlib3/tests/test_telnetlib.py @@ -1,6 +1,5 @@ # jdq(2025): This file was modified from cpython 3.12 test_telnetlib.py, to make it compatible # with more versions of python, and, to use pytest instead of unittest. -# pylint: skip-file # std imports import io import re @@ -19,6 +18,7 @@ except OSError: pytest.skip("Working socket required", allow_module_level=True) +# pylint:disable=consider-using-from-import # local import telnetlib3.telnetlib as telnetlib # noqa: E402 @@ -67,7 +67,7 @@ def captured_stdout(): yield buf -class SocketStub(object): +class SocketStub: """A socket proxy that re-defines sendall()""" def __init__(self, reads=()): @@ -93,6 +93,7 @@ def fileno(self): """Provide a real OS-level file descriptor so selectors and any code that calls fileno() can work, even though the network I/O is mocked.""" s = getattr(self, "_fileno_sock", None) + # pylint: disable=attribute-defined-outside-init if s is None: try: s1, s2 = socket.socketpair() @@ -108,6 +109,7 @@ def fileno(self): def close(self): # Close the internal fileno() provider sockets, but leave the mocked self.sock alone + # pylint: disable=attribute-defined-outside-init try: if getattr(self, "_fileno_sock", None) is not None: try: @@ -130,7 +132,6 @@ def msg(self, msg, *args): with captured_stdout() as out: telnetlib.Telnet.msg(self, msg, *args) self._messages += out.getvalue() - return class MockSelector(selectors.BaseSelector): @@ -157,8 +158,7 @@ def select(self, timeout=None): break if block: return [] - else: - return [(key, key.events) for key in self.keys.values()] + return [(key, key.events) for key in self.keys.values()] def get_map(self): return self.keys @@ -169,21 +169,21 @@ def mocktest_socket(reads): def new_conn(*ignored): return SocketStub(reads) + old_conn = socket.create_connection try: - old_conn = socket.create_connection socket.create_connection = new_conn yield None finally: socket.create_connection = old_conn - return def make_telnet(reads=(), cls=TelnetAlike): """Return a telnetlib.Telnet object that uses a SocketStub with reads queued up to be read.""" for x in reads: - assert type(x) is bytes, x + assert isinstance(x, bytes) with mocktest_socket(reads): telnet = cls("dummy", 0) + # pylint: disable=attribute-defined-outside-init telnet._messages = "" # debuglevel output return telnet @@ -332,7 +332,7 @@ def test_read_lazy(self): assert data == want -class nego_collector(object): +class nego_collector: def __init__(self, sb_getter=None): self.seen = b"" self.sb_getter = sb_getter @@ -390,7 +390,7 @@ def test_IAC_commands(self): self._test_command([b"x" * 100, tl.IAC, cmd, b"y" * 100]) self._test_command([b"x" * 10, tl.IAC, cmd, b"y" * 10]) # all at once - self._test_command([tl.IAC + cmd for (cmd) in self.cmds]) + self._test_command([tl.IAC + cmd for cmd in self.cmds]) def test_SB_commands(self): # RFC 855, subnegotiations portion @@ -440,6 +440,7 @@ def test_debug_accepts_str_port(self): # Issue 10695 with mocktest_socket([]): telnet = TelnetAlike("dummy", "0") + # pylint: disable=attribute-defined-outside-init telnet._messages = "" telnet.set_debuglevel(1) telnet.msg("test") diff --git a/telnetlib3/tests/test_timeout.py b/telnetlib3/tests/test_timeout.py index 27246145..4beea2c4 100644 --- a/telnetlib3/tests/test_timeout.py +++ b/telnetlib3/tests/test_timeout.py @@ -5,19 +5,17 @@ import asyncio # local -# local imports -from telnetlib3.tests.accessories import ( # pylint: disable=unused-import +from telnetlib3.telopt import DO, IAC, WONT, TTYPE +from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, + create_server, unused_tcp_port, + asyncio_connection, ) async def test_telnet_server_default_timeout(bind_host, unused_tcp_port): """Test callback on_timeout() as coroutine of create_server().""" - # local - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - given_timeout = 19.29 async with create_server( @@ -37,10 +35,6 @@ async def test_telnet_server_default_timeout(bind_host, unused_tcp_port): async def test_telnet_server_set_timeout(bind_host, unused_tcp_port): """Test callback on_timeout() as coroutine of create_server().""" - # local - from telnetlib3.telopt import IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - async with create_server(host=bind_host, port=unused_tcp_port) as server: async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): writer.write(IAC + WONT + TTYPE) @@ -56,10 +50,6 @@ async def test_telnet_server_set_timeout(bind_host, unused_tcp_port): async def test_telnet_server_waitfor_timeout(bind_host, unused_tcp_port): """Test callback on_timeout() as coroutine of create_server().""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - expected_output = IAC + DO + TTYPE + b"\r\nTimeout.\r\n" async with create_server(host=bind_host, port=unused_tcp_port, timeout=0.050): @@ -75,10 +65,6 @@ async def test_telnet_server_waitfor_timeout(bind_host, unused_tcp_port): async def test_telnet_server_binary_mode(bind_host, unused_tcp_port): """Test callback on_timeout() in BINARY mode when encoding=False is used.""" - # local - from telnetlib3.telopt import DO, IAC, WONT, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - expected_output = IAC + DO + TTYPE + b"\r\nTimeout.\r\n" async with create_server(host=bind_host, port=unused_tcp_port, timeout=0.150, encoding=False): diff --git a/telnetlib3/tests/test_tspeed.py b/telnetlib3/tests/test_tspeed.py index 85d992d9..ad9bb96a 100644 --- a/telnetlib3/tests/test_tspeed.py +++ b/telnetlib3/tests/test_tspeed.py @@ -4,21 +4,20 @@ import asyncio # local -# local imports import telnetlib3 import telnetlib3.stream_writer -from telnetlib3.tests.accessories import ( # pylint: disable=unused-import +from telnetlib3.telopt import DO, IS, SB, SE, IAC, WILL, TSPEED +from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, + create_server, + open_connection, unused_tcp_port, + asyncio_connection, ) async def test_telnet_server_on_tspeed(bind_host, unused_tcp_port): """Test Server's callback method on_tspeed().""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TSPEED - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() class ServerTestTspeed(telnetlib3.TelnetServer): @@ -39,9 +38,6 @@ def on_tspeed(self, rx, tx): async def test_telnet_client_send_tspeed(bind_host, unused_tcp_port): """Test Client's callback method send_tspeed().""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - _waiter = asyncio.Future() given_rx, given_tx = 1337, 1919 diff --git a/telnetlib3/tests/test_ttype.py b/telnetlib3/tests/test_ttype.py index c2a793c2..6429e644 100644 --- a/telnetlib3/tests/test_ttype.py +++ b/telnetlib3/tests/test_ttype.py @@ -4,21 +4,19 @@ import asyncio # local -# local imports import telnetlib3 import telnetlib3.stream_writer -from telnetlib3.tests.accessories import ( # pylint: disable=unused-import +from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE +from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, + create_server, unused_tcp_port, + asyncio_connection, ) async def test_telnet_server_on_ttype(bind_host, unused_tcp_port): """Test Server's callback method on_ttype().""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() class ServerTestTtype(telnetlib3.TelnetServer): @@ -42,10 +40,6 @@ def on_ttype(self, ttype): async def test_telnet_server_on_ttype_beyond_max(bind_host, unused_tcp_port): """Test Server's callback method on_ttype() with long list.""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_ttypes = ( "ALPHA", @@ -90,10 +84,6 @@ def on_ttype(self, ttype): async def test_telnet_server_on_ttype_empty(bind_host, unused_tcp_port): """Test Server's callback method on_ttype(): empty value is ignored.""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_ttypes = ("ALPHA", "", "BETA") @@ -119,10 +109,6 @@ def on_ttype(self, ttype): async def test_telnet_server_on_ttype_looped(bind_host, unused_tcp_port): """Test Server's callback method on_ttype() when value looped.""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_ttypes = ("ALPHA", "BETA", "GAMMA", "ALPHA") @@ -153,10 +139,6 @@ def on_ttype(self, ttype): async def test_telnet_server_on_ttype_repeated(bind_host, unused_tcp_port): """Test Server's callback method on_ttype() when value repeats.""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_ttypes = ("ALPHA", "BETA", "GAMMA", "GAMMA") @@ -187,10 +169,6 @@ def on_ttype(self, ttype): async def test_telnet_server_on_ttype_mud(bind_host, unused_tcp_port): """Test Server's callback method on_ttype() for MUD clients (MTTS).""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, TTYPE - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_ttypes = ("ALPHA", "BETA", "MTTS 137") diff --git a/telnetlib3/tests/test_uvloop_integration.py b/telnetlib3/tests/test_uvloop_integration.py index a12ce45b..ef323e67 100644 --- a/telnetlib3/tests/test_uvloop_integration.py +++ b/telnetlib3/tests/test_uvloop_integration.py @@ -50,7 +50,11 @@ async def test_uvloop_telnet_integration(bind_host, unused_tcp_port): ) # Connect client - reader, writer = await telnetlib3.open_connection(host=bind_host, port=unused_tcp_port) + reader, writer = await telnetlib3.open_connection( + host=bind_host, + port=unused_tcp_port, + client_factory=telnetlib3.TelnetClient, + ) # Read response and verify connection works data = await reader.read(1024) diff --git a/telnetlib3/tests/test_writer.py b/telnetlib3/tests/test_writer.py index 4d96064d..8bb964f5 100644 --- a/telnetlib3/tests/test_writer.py +++ b/telnetlib3/tests/test_writer.py @@ -1,14 +1,38 @@ # std imports import asyncio +import threading +from unittest.mock import MagicMock # 3rd party import pytest # local import telnetlib3 +from telnetlib3.telopt import ( + DO, + GA, + SB, + SE, + TM, + EOR, + IAC, + NOP, + SGA, + DONT, + ECHO, + NAWS, + WILL, + WONT, + TTYPE, + CMD_EOR, + option_from_name, +) from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, + create_server, + open_connection, unused_tcp_port, + asyncio_connection, ) @@ -71,9 +95,6 @@ def test_repr(): def test_illegal_2byte_iac(): """Given an illegal 2byte IAC command, raise ValueError.""" writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) - # local - from telnetlib3.telopt import IAC, SGA - writer.feed_byte(IAC) with pytest.raises(ValueError): # IAC SGA(b'\x03'): not a legal 2-byte cmd @@ -82,12 +103,6 @@ def test_illegal_2byte_iac(): def test_legal_2byte_iac(): """Nothing special about a 2-byte IAC, test wiring a callback.""" - # std imports - import threading - - # local - from telnetlib3.telopt import IAC, NOP - called = threading.Event() def callback(cmd): @@ -115,9 +130,6 @@ def test_sb_interrupted(): # instead of awaiting the unlikely SE, and throwing all intermediary bytes # out, we just clear what we have received so far within this so called # 'SB', and exit the sb buffering state. - # local - from telnetlib3.telopt import SB, SE, TM, IAC - writer = telnetlib3.TelnetWriter( transport=None, protocol=None, @@ -151,9 +163,6 @@ def test_sb_interrupted(): async def test_iac_do_twice_replies_once(bind_host, unused_tcp_port): """WILL/WONT replied only once for repeated DO.""" - # local - from telnetlib3.telopt import DO, IAC, ECHO, WILL - from telnetlib3.tests.accessories import create_server, asyncio_connection async def shell(reader, writer): writer.close() @@ -180,9 +189,6 @@ async def shell(reader, writer): async def test_iac_dont_dont(bind_host, unused_tcp_port): """WILL/WONT replied only once for repeated DO.""" - # local - from telnetlib3.telopt import IAC, DONT, ECHO - from telnetlib3.tests.accessories import create_server, asyncio_connection async def shell(reader, writer): writer.close() @@ -209,10 +215,6 @@ async def shell(reader, writer): async def test_send_iac_dont_dont(bind_host, unused_tcp_port): """Try a DONT and ensure it cannot be sent twice.""" - # local - from telnetlib3.telopt import DONT, ECHO - from telnetlib3.tests.accessories import create_server, open_connection - async with create_server( protocol_factory=telnetlib3.BaseServer, host=bind_host, @@ -243,9 +245,7 @@ async def test_send_iac_dont_dont(bind_host, unused_tcp_port): async def test_slc_simul(bind_host, unused_tcp_port): """Test SLC control characters are simulated in kludge mode.""" # For example, ^C is simulated as IP (Interrupt Process) callback. - # local - from telnetlib3.telopt import DO, IAC, SGA, ECHO, WILL - + # # First, change server state into kludge mode -- Then, send all control # characters. We ensure all of our various callbacks that are simulated # by control characters were 'fired', as well as the raw bytes received @@ -298,10 +298,6 @@ async def shell(reader, writer): async def test_unhandled_do_sends_wont(bind_host, unused_tcp_port): """An unhandled DO is denied by WONT.""" - # local - from telnetlib3.telopt import DO, IAC, NOP, WONT - from telnetlib3.tests.accessories import create_server, asyncio_connection - given_input_outband = IAC + DO + NOP expected_output = IAC + WONT + NOP @@ -323,9 +319,6 @@ async def test_unhandled_do_sends_wont(bind_host, unused_tcp_port): async def test_writelines_bytes(bind_host, unused_tcp_port): """Exercise bytes-only interface of writer.writelines() function.""" - # local - from telnetlib3.tests.accessories import create_server, asyncio_connection - given = (b"a", b"b", b"c", b"d") expected = b"abcd" @@ -352,9 +345,6 @@ async def shell(reader, writer): async def test_writelines_unicode(bind_host, unused_tcp_port): """Exercise unicode interface of writer.writelines() function.""" - # local - from telnetlib3.tests.accessories import create_server, asyncio_connection - given = ("a", "b", "c", "d") expected = b"abcd" @@ -381,9 +371,6 @@ async def shell(reader, writer): def test_bad_iac(): """Test using writer.iac for something outside of DO/DONT/WILL/WONT.""" - # local - from telnetlib3.telopt import NOP - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) with pytest.raises(ValueError): writer.iac(NOP) @@ -391,10 +378,6 @@ def test_bad_iac(): async def test_send_ga(bind_host, unused_tcp_port): """Writer sends IAC + GA when SGA is not negotiated.""" - # local - from telnetlib3.telopt import GA, IAC - from telnetlib3.tests.accessories import create_server, asyncio_connection - expected = IAC + GA async def shell(reader, writer): @@ -420,10 +403,6 @@ async def shell(reader, writer): async def test_not_send_ga(bind_host, unused_tcp_port): """Writer does not send IAC + GA when SGA is negotiated.""" - # local - from telnetlib3.telopt import DO, IAC, SGA, WILL - from telnetlib3.tests.accessories import create_server, asyncio_connection - # we require IAC + DO + SGA, and expect a confirming reply. We also # call writer.send_ga() from the shell, whose result should be False # (not sent). The reader never receives an IAC + GA. @@ -453,9 +432,6 @@ async def shell(reader, writer): async def test_not_send_eor(bind_host, unused_tcp_port): """Writer does not send IAC + EOR when un-negotiated.""" - # local - from telnetlib3.tests.accessories import create_server, asyncio_connection - expected = b"" async def shell(reader, writer): @@ -481,10 +457,6 @@ async def shell(reader, writer): async def test_send_eor(bind_host, unused_tcp_port): """Writer sends IAC + EOR if client requests by DO.""" - # local - from telnetlib3.telopt import DO, EOR, IAC, WILL, CMD_EOR - from telnetlib3.tests.accessories import create_server, asyncio_connection - given = IAC + DO + EOR expected = IAC + WILL + EOR + b"<" + IAC + CMD_EOR + b">" @@ -574,9 +546,6 @@ async def _drain_helper(self): def test_option_from_name(): """Test option_from_name returns correct option bytes.""" - # local - from telnetlib3.telopt import ECHO, NAWS, TTYPE, option_from_name - assert option_from_name("NAWS") == NAWS assert option_from_name("naws") == NAWS assert option_from_name("TTYPE") == TTYPE @@ -588,9 +557,6 @@ def test_option_from_name(): async def test_wait_for_immediate_return(): """Test wait_for returns immediately when conditions already met.""" - # local - from telnetlib3.telopt import ECHO - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) writer.remote_option[ECHO] = True @@ -600,9 +566,6 @@ async def test_wait_for_immediate_return(): async def test_wait_for_remote_option(): """Test wait_for waits for remote option to become true.""" - # local - from telnetlib3.telopt import ECHO - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) async def set_option_later(): @@ -617,9 +580,6 @@ async def set_option_later(): async def test_wait_for_local_option(): """Test wait_for waits for local option to become true.""" - # local - from telnetlib3.telopt import ECHO - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) async def set_option_later(): @@ -634,9 +594,6 @@ async def set_option_later(): async def test_wait_for_pending_false(): """Test wait_for waits for pending option to become false.""" - # local - from telnetlib3.telopt import DO, TTYPE - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) writer.pending_option[DO + TTYPE] = True @@ -652,9 +609,6 @@ async def clear_pending_later(): async def test_wait_for_combined_conditions(): """Test wait_for with multiple conditions.""" - # local - from telnetlib3.telopt import ECHO, NAWS - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) async def set_options_later(): @@ -703,9 +657,6 @@ async def test_wait_for_condition_immediate(): async def test_wait_for_condition_waits(): """Test wait_for_condition waits for condition to become true.""" - # local - from telnetlib3.telopt import ECHO - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) async def set_option_later(): @@ -722,9 +673,6 @@ async def set_option_later(): async def test_wait_for_cleanup_on_success(): """Test that waiters are cleaned up after successful completion.""" - # local - from telnetlib3.telopt import ECHO - writer = telnetlib3.TelnetWriter(transport=None, protocol=None, server=True) async def set_option_later(): diff --git a/telnetlib3/tests/test_xdisploc.py b/telnetlib3/tests/test_xdisploc.py index d9df8086..e927f362 100644 --- a/telnetlib3/tests/test_xdisploc.py +++ b/telnetlib3/tests/test_xdisploc.py @@ -4,21 +4,20 @@ import asyncio # local -# local imports import telnetlib3 import telnetlib3.stream_writer -from telnetlib3.tests.accessories import ( # pylint: disable=unused-import +from telnetlib3.telopt import DO, IS, SB, SE, IAC, WILL, XDISPLOC +from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, + create_server, + open_connection, unused_tcp_port, + asyncio_connection, ) async def test_telnet_server_on_xdisploc(bind_host, unused_tcp_port): """Test Server's callback method on_xdisploc().""" - # local - from telnetlib3.telopt import IS, SB, SE, IAC, WILL, XDISPLOC - from telnetlib3.tests.accessories import create_server, asyncio_connection - _waiter = asyncio.Future() given_xdisploc = "alpha:0" @@ -40,9 +39,6 @@ def on_xdisploc(self, xdisploc): async def test_telnet_client_send_xdisploc(bind_host, unused_tcp_port): """Test Client's callback method send_xdisploc().""" - # local - from telnetlib3.tests.accessories import create_server, open_connection - _waiter = asyncio.Future() given_xdisploc = "alpha" @@ -52,9 +48,6 @@ def on_xdisploc(self, xdisploc): _waiter.set_result(xdisploc) def begin_advanced_negotiation(self): - # local - from telnetlib3.telopt import DO, XDISPLOC - super().begin_advanced_negotiation() self.writer.iac(DO, XDISPLOC) diff --git a/tox.ini b/tox.ini index acd9c978..f43e3c7b 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ ignore_basepython_conflict = True skip_missing_interpreters = True envlist = - py{38,39,310,311,312,313,314} + py{39,310,311,312,313,314} black docformatter isort @@ -102,7 +102,7 @@ commands = deps = mypy commands = - mypy --config-file=tox.ini telnetlib3 + mypy telnetlib3 [testenv:pydocstyle] deps = @@ -144,7 +144,7 @@ commands = deps = codespell commands = - codespell --skip="*.pyc,htmlcov,_build,build,*.egg-info,.tox,.git" \ + codespell --skip="*.pyc,htmlcov*,_build,build,*.egg-info,.tox,.git" \ --ignore-words-list="wont,nams,flushin,thirdparty,lient" \ --uri-ignore-words-list "*" \ --summary --count @@ -189,18 +189,16 @@ commands = ####### Tool Configs ####### -[mypy] -warn_unused_configs = true -warn_redundant_casts = true -warn_unused_ignores = true -ignore_missing_imports = true - [coverage:run] branch = True parallel = True source = telnetlib3 omit = telnetlib3/tests/* telnetlib3/telnetlib.py + telnetlib3/_types.py + # Too complex for good coverage, though some tests exists, it + # is tested by running a public server: telnet 1984.ws 555 + telnetlib3/fingerprinting_display.py relative_files = True [coverage:report] @@ -209,6 +207,8 @@ exclude_lines = pragma: no cover omit = telnetlib3/tests/* telnetlib3/telnetlib.py + telnetlib3/_types.py + telnetlib3/fingerprinting_display.py [coverage:paths] source = telnetlib3/ From 8f9cc3d7370e905feaec0fa18e2e5fcf7e3791a4 Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 15:13:31 -0500 Subject: [PATCH 2/8] add missing new docs requirement, sphinx_autodoc_typehints --- requirements-docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-docs.txt b/requirements-docs.txt index 4d928304..c3c36819 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,4 @@ Sphinx sphinx-rtd-theme sphinx-paramlinks +sphinx-autodoc-typehints From a054d1b6f80674b36c110e8583f64ea2520b73d0 Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 15:16:29 -0500 Subject: [PATCH 3/8] don't require PyYAML in [extras], and document it --- README.rst | 10 +++++++++- pyproject.toml | 1 - 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 82923b21..6f546aba 100644 --- a/README.rst +++ b/README.rst @@ -77,7 +77,14 @@ program. Fingerprinting Server --------------------- -A built-in fingerprinting server shell is provided to uniquely identify telnet clients:: +A built-in fingerprinting server shell is provided to uniquely identify telnet clients. + +Install with optional dependencies for full fingerprinting support (prettytable_ +and ucs-detect_):: + + pip install telnetlib3[extras] + +Usage:: export TELNETLIB3_DATA_DIR=./data telnetlib3-server --shell telnetlib3.fingerprinting_server_shell @@ -99,6 +106,7 @@ runs it to probe terminal capabilities (colors, sixel, kitty graphics, etc.) and adds the results to the fingerprint data as ``terminal-fingerprint-data``. .. _ucs-detect: https://github.com/jquast/ucs-detect +.. _prettytable: https://pypi.org/project/prettytable/ Legacy telnetlib ---------------- diff --git a/pyproject.toml b/pyproject.toml index 36d517ad..ee0696af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ docs = [ extras = [ "ucs-detect>=2", "prettytable", - "pyyaml", ] [project.scripts] From c79b3f6ab4179a1b9c17a16bfad24faac32e3f5f Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 15:18:31 -0500 Subject: [PATCH 4/8] lint 10/10! --- telnetlib3/server_pty_shell.py | 6 ++++++ telnetlib3/slc.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/telnetlib3/server_pty_shell.py b/telnetlib3/server_pty_shell.py index c480a615..861b13a6 100644 --- a/telnetlib3/server_pty_shell.py +++ b/telnetlib3/server_pty_shell.py @@ -104,6 +104,7 @@ def start(self) -> None: :raises PTYSpawnError: If the child process fails to exec. """ + # pylint: disable=import-outside-toplevel # std imports import pty import fcntl @@ -223,6 +224,7 @@ def _setup_child( ) -> None: """Child process setup before exec.""" # Note: pty.fork() already calls setsid() for the child, so we don't need to + # pylint: disable=import-outside-toplevel # std imports import fcntl import termios @@ -263,6 +265,7 @@ def _setup_child( def _setup_parent(self) -> None: """Parent process setup after fork.""" + # pylint: disable=import-outside-toplevel # std imports import fcntl @@ -294,6 +297,7 @@ def _fire_naws_update(self) -> None: def _set_window_size(self, rows: int, cols: int) -> None: """Set PTY window size and send SIGWINCH to child.""" + # pylint: disable=import-outside-toplevel # std imports import fcntl import signal @@ -310,6 +314,7 @@ def _set_window_size(self, rows: int, cols: int) -> None: async def run(self) -> None: """Bridge loop between telnet and PTY.""" + # pylint: disable=import-outside-toplevel # std imports import errno @@ -534,6 +539,7 @@ def _terminate(self, force: bool = False) -> bool: :param force: If True, use SIGKILL as last resort. :returns: True if child was terminated, False otherwise. """ + # pylint: disable=import-outside-toplevel # std imports import signal diff --git a/telnetlib3/slc.py b/telnetlib3/slc.py index 453bccea..28440705 100644 --- a/telnetlib3/slc.py +++ b/telnetlib3/slc.py @@ -389,13 +389,13 @@ def description_table(self) -> List[str]: mrk_cont = "(...)" def continuing() -> bool: - return bool(len(result) and result[-1] == mrk_cont) + return bool(result and result[-1] == mrk_cont) def is_last(mask: int) -> bool: return mask == len(self.value) - 1 def same_as_last(row: str) -> bool: - return bool(len(result) and result[-1].endswith(row.split()[-1])) + return bool(result and result[-1].endswith(row.split()[-1])) for mask, byte in enumerate(self.value): if byte == 0: From c42d35137e6f3359782a8d133bbacddbc243911b Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 15:20:24 -0500 Subject: [PATCH 5/8] version 2.2.0 --- docs/conf.py | 2 +- telnetlib3/accessories.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7073fa86..1c73938f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,7 +70,7 @@ version = "2.2" # The full version, including alpha/beta/rc tags. -release = "2.2.0" # keep in sync with setup.py and telnetlib3/accessories.py !! +release = "2.2.0" # keep in sync with pyproject.toml and telnetlib3/accessories.py !! # The language for content auto-generated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/telnetlib3/accessories.py b/telnetlib3/accessories.py index 9dc12872..7b42f05b 100644 --- a/telnetlib3/accessories.py +++ b/telnetlib3/accessories.py @@ -26,7 +26,7 @@ def get_version() -> str: """Return the current version of telnetlib3.""" - return "2.2.0" # keep in sync with setup.py and docs/conf.py !! + return "2.2.0" # keep in sync with pyproject.toml and docs/conf.py !! def encoding_from_lang(lang: str) -> Optional[str]: From 45c117f9e33fca44a3f563445be847f5fa6f880d Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 16:30:24 -0500 Subject: [PATCH 6/8] new simpler tagline, remove redundant heading --- README.rst | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 6f546aba..230e89ff 100644 --- a/README.rst +++ b/README.rst @@ -29,14 +29,11 @@ Introduction ============ -``telnetlib3`` is a full-featured Telnet Client and Server library for python3.9 and newer. +``telnetlib3`` is a feature-rich Telnet Server and Client and Protocol library for Python 3.9 and newer. -Modern asyncio_ and legacy blocking API's are provided. +This library supports both modern asyncio_ *and* legacy `Blocking API`_. -The python telnetlib.py_ module removed by Python 3.13 is also re-distributed as a backport. - -Overview -======== +The python telnetlib.py_ module removed by Python 3.13 is also re-distributed as-is, as a backport. telnetlib3 provides multiple interfaces for working with the Telnet protocol: From c3bcaa44f5eecbb8e08105a71bbaff1a9db10b85 Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 16:36:40 -0500 Subject: [PATCH 7/8] add rfc930 TTYPE is missing --- README.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.rst b/README.rst index 230e89ff..2733686c 100644 --- a/README.rst +++ b/README.rst @@ -202,6 +202,7 @@ The following RFC specifications are implemented: * `rfc-859`_, "Telnet Status Option", May 1983. * `rfc-860`_, "Telnet Timing mark Option", May 1983. * `rfc-885`_, "Telnet End of Record Option", Dec 1983. +* `rfc-930`_, "Telnet Terminal Type Option", Jan 1984. * `rfc-1073`_, "Telnet Window Size Option", Oct 1988. * `rfc-1079`_, "Telnet Terminal Speed Option", Dec 1988. * `rfc-1091`_, "Telnet Terminal-Type Option", Feb 1989. @@ -224,6 +225,7 @@ The following RFC specifications are implemented: .. _rfc-859: https://www.rfc-editor.org/rfc/rfc859.txt .. _rfc-860: https://www.rfc-editor.org/rfc/rfc860.txt .. _rfc-885: https://www.rfc-editor.org/rfc/rfc885.txt +.. _rfc-930: https://www.rfc-editor.org/rfc/rfc930.txt .. _rfc-1073: https://www.rfc-editor.org/rfc/rfc1073.txt .. _rfc-1079: https://www.rfc-editor.org/rfc/rfc1079.txt .. _rfc-1091: https://www.rfc-editor.org/rfc/rfc1091.txt From 4565981ea4f3bc040f6da6836e49b4db27140857 Mon Sep 17 00:00:00 2001 From: Jeff Quast Date: Fri, 6 Feb 2026 16:59:44 -0500 Subject: [PATCH 8/8] telnetlib3-client --connect-timeout (#113) * --connect-timeout, closes #30 * threw in typing fixes, too! --- README.rst | 3 +- docs/history.rst | 3 ++ pyproject.toml | 5 +- telnetlib3/client.py | 38 +++++++++++--- telnetlib3/client_shell.py | 4 +- telnetlib3/fingerprinting_display.py | 2 +- telnetlib3/server_pty_shell.py | 4 +- telnetlib3/stream_writer.py | 4 +- telnetlib3/sync.py | 2 + telnetlib3/tests/test_core.py | 3 -- telnetlib3/tests/test_environ.py | 2 +- telnetlib3/tests/test_fingerprinting.py | 5 +- telnetlib3/tests/test_linemode.py | 4 +- telnetlib3/tests/test_pty_shell.py | 55 +++++++++++++------- telnetlib3/tests/test_shell.py | 2 +- telnetlib3/tests/test_stream_writer_extra.py | 2 +- telnetlib3/tests/test_sync.py | 18 +++++++ telnetlib3/tests/test_timeout.py | 47 +++++++++++++++++ telnetlib3/tests/test_tspeed.py | 3 -- telnetlib3/tests/test_writer.py | 1 - 20 files changed, 151 insertions(+), 56 deletions(-) diff --git a/README.rst b/README.rst index 2733686c..d8b326cd 100644 --- a/README.rst +++ b/README.rst @@ -29,7 +29,8 @@ Introduction ============ -``telnetlib3`` is a feature-rich Telnet Server and Client and Protocol library for Python 3.9 and newer. +``telnetlib3`` is a feature-rich Telnet Server and Client Protocol library +for Python 3.9 and newer. This library supports both modern asyncio_ *and* legacy `Blocking API`_. diff --git a/docs/history.rst b/docs/history.rst index bc861e8e..15793de5 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -1,5 +1,8 @@ History ======= +*unreleased* + * new: ``connect_timeout`` arguments for client and ``--connect-timeout`` Client CLI argument. + 2.2.0 * bugfix: workaround for Microsoft Telnet client crash on ``SB NEW_ENVIRON SEND``, :ghissue:`24`. Server now defers ``DO diff --git a/pyproject.toml b/pyproject.toml index ee0696af..80604732 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,10 +139,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = ["telnetlib3.tests.*"] -disallow_untyped_defs = false -disallow_incomplete_defs = false -disallow_untyped_calls = false -warn_return_any = false +ignore_errors = true [[tool.mypy.overrides]] module = ["telnetlib3.telnetlib"] diff --git a/telnetlib3/client.py b/telnetlib3/client.py index 61719ae7..83b1895e 100755 --- a/telnetlib3/client.py +++ b/telnetlib3/client.py @@ -390,6 +390,7 @@ async def open_connection( # pylint: disable=too-many-locals shell: Optional[ShellCallback] = None, connect_minwait: float = 2.0, connect_maxwait: float = 3.0, + connect_timeout: Optional[float] = None, waiter_closed: Optional[asyncio.Future[None]] = None, _waiter_connected: Optional[asyncio.Future[None]] = None, limit: Optional[int] = None, @@ -445,6 +446,11 @@ async def open_connection( # pylint: disable=too-many-locals otherwise confused by our demands, the shell continues anyway after the greater of this value has elapsed. A client that is not answering option negotiation will delay the start of the shell by this amount. + :param connect_timeout: Timeout in seconds for the TCP connection to be + established. When ``None`` (default), no timeout is applied and the + connection attempt may block indefinitely. When specified, a + :exc:`ConnectionError` is raised if the connection is not established + within the given time. :param force_binary: When ``True``, the encoding is used regardless of BINARY mode negotiation. @@ -480,14 +486,22 @@ def connection_factory() -> client_base.BaseClient: send_environ=send_environ, ) - _, protocol = await asyncio.get_event_loop().create_connection( - connection_factory, - host or "localhost", - port, - family=family, - flags=flags, - local_addr=local_addr, - ) + try: + _, protocol = await asyncio.wait_for( + asyncio.get_event_loop().create_connection( + connection_factory, + host or "localhost", + port, + family=family, + flags=flags, + local_addr=local_addr, + ), + timeout=connect_timeout, + ) + except asyncio.TimeoutError as exc: + raise ConnectionError( + f"TCP connection to {host or 'localhost'}:{port}" f" timed out after {connect_timeout}s" + ) from exc await protocol._waiter_connected # pylint: disable=protected-access @@ -518,6 +532,7 @@ async def run_client() -> None: "force_binary": args["force_binary"], "encoding_errors": args["encoding_errors"], "connect_minwait": args["connect_minwait"], + "connect_timeout": args["connect_timeout"], "send_environ": args["send_environ"], } @@ -564,6 +579,12 @@ def _get_argument_parser() -> argparse.ArgumentParser: type=float, help="timeout for pending negotiation", ) + parser.add_argument( + "--connect-timeout", + default=None, + type=float, + help="timeout for TCP connection (seconds, default: no timeout)", + ) parser.add_argument( "--send-environ", default="TERM,LANG,COLUMNS,LINES,COLORTERM", @@ -586,6 +607,7 @@ def _transform_args(args: argparse.Namespace) -> Dict[str, Any]: "force_binary": args.force_binary, "encoding_errors": args.encoding_errors, "connect_minwait": args.connect_minwait, + "connect_timeout": args.connect_timeout, "send_environ": tuple(v.strip() for v in args.send_environ.split(",") if v.strip()), } diff --git a/telnetlib3/client_shell.py b/telnetlib3/client_shell.py index 3d8f8e02..663cd17d 100644 --- a/telnetlib3/client_shell.py +++ b/telnetlib3/client_shell.py @@ -43,9 +43,7 @@ class Terminal: "ModeDef", ["iflag", "oflag", "cflag", "lflag", "ispeed", "ospeed", "cc"] ) - def __init__( - self, telnet_writer: Union[TelnetWriter, TelnetWriterUnicode] - ) -> None: + def __init__(self, telnet_writer: Union[TelnetWriter, TelnetWriterUnicode]) -> None: self.telnet_writer = telnet_writer self._fileno = sys.stdin.fileno() self._istty = os.path.sameopenfile(0, 1) diff --git a/telnetlib3/fingerprinting_display.py b/telnetlib3/fingerprinting_display.py index fca110da..5f273cd1 100644 --- a/telnetlib3/fingerprinting_display.py +++ b/telnetlib3/fingerprinting_display.py @@ -919,7 +919,7 @@ def _normalize_color_hex(hex_color: str) -> str: ) r, g, b = hex_to_rgb(hex_color) - return rgb_to_hex(r, g, b) + return str(rgb_to_hex(r, g, b)) def _filter_terminal_detail( # pylint: disable=too-complex,too-many-branches diff --git a/telnetlib3/server_pty_shell.py b/telnetlib3/server_pty_shell.py index 861b13a6..31594c72 100644 --- a/telnetlib3/server_pty_shell.py +++ b/telnetlib3/server_pty_shell.py @@ -379,9 +379,7 @@ async def _bridge_loop( telnet_task: asyncio.Task[Union[bytes, str]] = asyncio.create_task( self.reader.read(4096) ) - pty_task: asyncio.Task[bool] = asyncio.create_task( - pty_read_event.wait() - ) + pty_task: asyncio.Task[bool] = asyncio.create_task(pty_read_event.wait()) done, pending = await asyncio.wait( {telnet_task, pty_task}, diff --git a/telnetlib3/stream_writer.py b/telnetlib3/stream_writer.py index 704f9dba..cce80f50 100644 --- a/telnetlib3/stream_writer.py +++ b/telnetlib3/stream_writer.py @@ -1456,9 +1456,7 @@ def handle_environ(self, env: dict[str, str]) -> None: """Receive environment variables as dict, :rfc:`1572`.""" self.log.debug("Environment values are %r", env) - def handle_send_client_environ( - self, _keys: Any - ) -> dict[str, str]: + def handle_send_client_environ(self, _keys: Any) -> dict[str, str]: """ Send environment variables as dict, :rfc:`1572`. diff --git a/telnetlib3/sync.py b/telnetlib3/sync.py index 201d8d37..6a68a56e 100644 --- a/telnetlib3/sync.py +++ b/telnetlib3/sync.py @@ -59,6 +59,8 @@ class TelnetConnection: :param port: Remote server port (default 23). :param timeout: Default timeout for operations in seconds. :param encoding: Character encoding (default 'utf8'). + :param connect_timeout: Timeout in seconds for the TCP connection to be + established. Passed to ``telnetlib3.open_connection()``. :param kwargs: Additional arguments passed to ``telnetlib3.open_connection()``. Example:: diff --git a/telnetlib3/tests/test_core.py b/telnetlib3/tests/test_core.py index f60acfb2..5e51918f 100644 --- a/telnetlib3/tests/test_core.py +++ b/telnetlib3/tests/test_core.py @@ -16,9 +16,7 @@ import telnetlib3 from telnetlib3.telopt import ( DO, - IS, SB, - SE, IAC, SGA, ECHO, @@ -28,7 +26,6 @@ TTYPE, BINARY, CHARSET, - NEW_ENVIRON, ) from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, diff --git a/telnetlib3/tests/test_environ.py b/telnetlib3/tests/test_environ.py index 2a54f26e..70627771 100644 --- a/telnetlib3/tests/test_environ.py +++ b/telnetlib3/tests/test_environ.py @@ -9,7 +9,7 @@ # local import telnetlib3 import telnetlib3.stream_writer -from telnetlib3.telopt import DO, IS, SB, SE, IAC, VAR, WILL, WONT, TTYPE, USERVAR, NEW_ENVIRON +from telnetlib3.telopt import DO, IS, SB, SE, IAC, VAR, WILL, TTYPE, USERVAR, NEW_ENVIRON from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, create_server, diff --git a/telnetlib3/tests/test_fingerprinting.py b/telnetlib3/tests/test_fingerprinting.py index 5de9126d..2585a189 100644 --- a/telnetlib3/tests/test_fingerprinting.py +++ b/telnetlib3/tests/test_fingerprinting.py @@ -15,8 +15,11 @@ from telnetlib3 import fingerprinting as fps if sys.platform != "win32": - from telnetlib3 import fingerprinting_display as fpd + # local from telnetlib3 import server_pty_shell + from telnetlib3 import fingerprinting_display as fpd +else: + server_pty_shell = None # type: ignore[assignment] # local from telnetlib3.tests.accessories import ( # noqa: F401 # pylint: disable=unused-import diff --git a/telnetlib3/tests/test_linemode.py b/telnetlib3/tests/test_linemode.py index 312503a9..885c2a4f 100644 --- a/telnetlib3/tests/test_linemode.py +++ b/telnetlib3/tests/test_linemode.py @@ -16,7 +16,7 @@ ) -async def test_server_demands_remote_linemode_client_agrees( # pylint: disable=too-many-locals +async def test_server_demands_remote_linemode_client_agrees( bind_host, unused_tcp_port ): class ServerTestLinemode(telnetlib3.BaseServer): @@ -70,7 +70,7 @@ def begin_negotiation(self): assert srv_instance.writer.linemode.lit_echo is True -async def test_server_demands_remote_linemode_client_demands_local( # pylint: disable=too-many-locals +async def test_server_demands_remote_linemode_client_demands_local( bind_host, unused_tcp_port ): class ServerTestLinemode(telnetlib3.BaseServer): diff --git a/telnetlib3/tests/test_pty_shell.py b/telnetlib3/tests/test_pty_shell.py index a89bcf95..ce8b26c8 100644 --- a/telnetlib3/tests/test_pty_shell.py +++ b/telnetlib3/tests/test_pty_shell.py @@ -324,9 +324,11 @@ def mock_killpg(pgid, sig): def mock_ioctl(fd, cmd, data): ioctl_calls.append((fd, cmd, data)) - with patch("os.getpgid", return_value=12345), patch( - "os.killpg", side_effect=mock_killpg - ), patch("fcntl.ioctl", side_effect=mock_ioctl): + with ( + patch("os.getpgid", return_value=12345), + patch("os.killpg", side_effect=mock_killpg), + patch("fcntl.ioctl", side_effect=mock_ioctl), + ): # Rapid updates should be debounced - only one signal after delay session._on_naws(25, 80) session._on_naws(30, 90) @@ -355,12 +357,14 @@ def mock_killpg_winch(pgid, sig): if sig == signal_mod.SIGWINCH: winch_calls.append((pgid, sig)) - with patch("os.getpgid", return_value=12345), patch( - "os.killpg", side_effect=mock_killpg_winch - ), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch("os.close"), patch( - "fcntl.ioctl" - ), patch( - "time.sleep" + with ( + patch("os.getpgid", return_value=12345), + patch("os.killpg", side_effect=mock_killpg_winch), + patch("os.kill"), + patch("os.waitpid", return_value=(0, 0)), + patch("os.close"), + patch("fcntl.ioctl"), + patch("time.sleep"), ): session._on_naws(25, 80) session.cleanup() @@ -481,8 +485,11 @@ async def test_pty_session_cleanup_flushes_remaining_buffer(): session.master_fd = 99 session.child_pid = 12345 - with patch("os.close"), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch( - "time.sleep" + with ( + patch("os.close"), + patch("os.kill"), + patch("os.waitpid", return_value=(0, 0)), + patch("time.sleep"), ): session.cleanup() @@ -541,8 +548,10 @@ async def test_pty_session_set_window_size_behavior(mock_session): session, _ = mock_session() session.master_fd = 99 session.child_pid = 12345 - with patch("fcntl.ioctl"), patch("os.getpgid", return_value=12345), patch( - "os.killpg", side_effect=ProcessLookupError("process gone") + with ( + patch("fcntl.ioctl"), + patch("os.getpgid", return_value=12345), + patch("os.killpg", side_effect=ProcessLookupError("process gone")), ): session._set_window_size(25, 80) @@ -722,9 +731,11 @@ def mock_kill(pid, sig): def mock_isalive(): return isalive_calls.pop(0) if isalive_calls else False - with patch.object(os, "kill", side_effect=mock_kill), patch.object( - session, "_isalive", side_effect=mock_isalive - ), patch("time.sleep"): + with ( + patch.object(os, "kill", side_effect=mock_kill), + patch.object(session, "_isalive", side_effect=mock_isalive), + patch("time.sleep"), + ): result = session._terminate() assert result is True @@ -739,8 +750,9 @@ def mock_isalive(): def mock_isalive_2(): return isalive_returns.pop(0) if isalive_returns else False - with patch.object(os, "kill", side_effect=ProcessLookupError), patch.object( - session, "_isalive", side_effect=mock_isalive_2 + with ( + patch.object(os, "kill", side_effect=ProcessLookupError), + patch.object(session, "_isalive", side_effect=mock_isalive_2), ): result = session._terminate() @@ -833,8 +845,11 @@ async def test_pty_session_ga_timer_cancelled_on_cleanup(mock_session, monkeypat session._schedule_ga() assert session._ga_timer is not None - with patch("os.close"), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch( - "time.sleep" + with ( + patch("os.close"), + patch("os.kill"), + patch("os.waitpid", return_value=(0, 0)), + patch("time.sleep"), ): session.cleanup() diff --git a/telnetlib3/tests/test_shell.py b/telnetlib3/tests/test_shell.py index e1417527..f00bf0bf 100644 --- a/telnetlib3/tests/test_shell.py +++ b/telnetlib3/tests/test_shell.py @@ -112,7 +112,7 @@ async def test_telnet_server_no_shell(bind_host, unused_tcp_port): async def test_telnet_server_given_shell( bind_host, unused_tcp_port -): # pylint: disable=too-many-locals +): """Iterate all state-reading commands of default telnet_server_shell.""" async with create_server( host=bind_host, diff --git a/telnetlib3/tests/test_stream_writer_extra.py b/telnetlib3/tests/test_stream_writer_extra.py index 3ef3e01f..fcd52857 100644 --- a/telnetlib3/tests/test_stream_writer_extra.py +++ b/telnetlib3/tests/test_stream_writer_extra.py @@ -6,7 +6,7 @@ import pytest # local -from telnetlib3 import slc, client_base +from telnetlib3 import slc from telnetlib3.telopt import ( DO, IS, diff --git a/telnetlib3/tests/test_sync.py b/telnetlib3/tests/test_sync.py index c12799b2..f739d8b8 100644 --- a/telnetlib3/tests/test_sync.py +++ b/telnetlib3/tests/test_sync.py @@ -591,3 +591,21 @@ def handler(server_conn): conn.read_until(">>> ", timeout=2) server.shutdown() + + +def test_client_connect_timeout(bind_host, unused_tcp_port): + """TelnetConnection connect_timeout raises ConnectionError on unreachable port.""" + conn = TelnetConnection(bind_host, unused_tcp_port, timeout=5, connect_timeout=0.1) + with pytest.raises(ConnectionError): + conn.connect() + + +def test_client_connect_timeout_success(bind_host, unused_tcp_port): + """TelnetConnection connect_timeout does not interfere with successful connection.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + + with TelnetConnection(bind_host, unused_tcp_port, timeout=5, connect_timeout=5.0) as conn: + assert conn._connected.is_set() + + server.shutdown() diff --git a/telnetlib3/tests/test_timeout.py b/telnetlib3/tests/test_timeout.py index 4beea2c4..200af650 100644 --- a/telnetlib3/tests/test_timeout.py +++ b/telnetlib3/tests/test_timeout.py @@ -4,11 +4,16 @@ import time import asyncio +# 3rd party +import pytest + # local +from telnetlib3.client import _transform_args, _get_argument_parser from telnetlib3.telopt import DO, IAC, WONT, TTYPE from telnetlib3.tests.accessories import ( # pylint: disable=unused-import; pylint: disable=unused-import, bind_host, create_server, + open_connection, unused_tcp_port, asyncio_connection, ) @@ -76,3 +81,45 @@ async def test_telnet_server_binary_mode(bind_host, unused_tcp_port): elapsed = time.time() - stime assert 0.050 <= round(elapsed, 3) <= 0.200 assert output == expected_output + + +async def test_open_connection_connect_timeout(bind_host, unused_tcp_port): + """Test connect_timeout raises ConnectionError on unreachable port.""" + with pytest.raises(ConnectionError): + async with open_connection( + bind_host, + unused_tcp_port, + connect_timeout=0.1, + encoding=False, + ): + pass + + +async def test_open_connection_connect_timeout_success(bind_host, unused_tcp_port): + """Test connect_timeout does not interfere with successful connection.""" + async with create_server(host=bind_host, port=unused_tcp_port): + async with open_connection( + bind_host, + unused_tcp_port, + connect_timeout=5.0, + encoding=False, + connect_minwait=0.05, + connect_maxwait=0.5, + ): + pass + + +def test_cli_connect_timeout_arg(): + """Test --connect-timeout CLI argument is parsed.""" + parser = _get_argument_parser() + args = parser.parse_args(["example.com", "--connect-timeout", "2.5"]) + result = _transform_args(args) + assert result["connect_timeout"] == 2.5 + + +def test_cli_connect_timeout_default(): + """Test --connect-timeout defaults to None.""" + parser = _get_argument_parser() + args = parser.parse_args(["example.com"]) + result = _transform_args(args) + assert result["connect_timeout"] is None diff --git a/telnetlib3/tests/test_tspeed.py b/telnetlib3/tests/test_tspeed.py index ad9bb96a..21880c62 100644 --- a/telnetlib3/tests/test_tspeed.py +++ b/telnetlib3/tests/test_tspeed.py @@ -47,9 +47,6 @@ def on_tspeed(self, rx, tx): _waiter.set_result((rx, tx)) def begin_advanced_negotiation(self): - # local - from telnetlib3.telopt import DO, TSPEED - super().begin_advanced_negotiation() self.writer.iac(DO, TSPEED) diff --git a/telnetlib3/tests/test_writer.py b/telnetlib3/tests/test_writer.py index 8bb964f5..878f6c33 100644 --- a/telnetlib3/tests/test_writer.py +++ b/telnetlib3/tests/test_writer.py @@ -1,7 +1,6 @@ # std imports import asyncio import threading -from unittest.mock import MagicMock # 3rd party import pytest