diff --git a/README.rst b/README.rst index 2b992a05..4ae33047 100644 --- a/README.rst +++ b/README.rst @@ -36,77 +36,49 @@ Modern asyncio_ and legacy blocking API's are provided. The python telnetlib.py_ module removed by Python 3.13 is also re-distributed as a backport. Overview --------- +======== telnetlib3 provides multiple interfaces for working with the Telnet protocol: -**Legacy telnetlib** - An unadulterated copy of Python 3.12's telnetlib.py_ See `Legacy telnetlib`_ below. - -**Asyncio Protocol** - Modern async/await interface for both client and server, supporting concurrent - connections. See the `Guidebook`_ for examples and the `API documentation`_. - -**Command-line Utilities** - Two CLI tools are included: ``telnetlib3-client`` for connecting to servers - and ``telnetlib3-server`` for hosting. See `Command-line`_ below. - -**Blocking API** - A synchronous interface modeled after telnetlib (client) and miniboa_ (server), - with enhancements. - - See `sync API documentation`_. - - Enhancements over Python 3.11 telnetlib (client): - - - Full RFC 854 protocol negotiation (NAWS, TTYPE, BINARY, ECHO, SGA) - - `wait_for()`_ method to block until specific option states are negotiated - - `get_extra_info()`_ for terminal type, size, and other metadata - - Context manager support (``with TelnetConnection(...) as conn:``) - - Thread-safe operation with asyncio_ running in background +Asyncio Protocol +---------------- - Enhancements over miniboa (server): +Modern async/await interface for both client and server, supporting concurrent +connections. See the `Guidebook`_ for examples and the `API documentation`_. - - Thread-per-connection model with blocking I/O (vs poll-based) - - `readline()`_ and `read_until()`_ blocking methods - - Full telnet option negotiation and inspection - - miniboa-compatible properties: `active`_, `address`_, `terminal_type`_, - `columns`_, `rows`_, `idle()`_, `duration()`_, `deactivate()`_ +Blocking API +------------ -Quick Example -------------- +A traditional synchronous interface modeled after telnetlib.py_ (client) and miniboa_ (server), +with various enhancements in protocol negotiation is provided. Blocking API calls for complex +arrangements of clients and servers typically require threads. -A simple telnet server: +See `sync API documentation`_ for more. -.. code-block:: python +Command-line Utilities +---------------------- - import asyncio - import telnetlib3 +Two CLI tools are included: ``telnetlib3-client`` for connecting to servers +and ``telnetlib3-server`` for hosting a server. - async def shell(reader, writer): - writer.write('\r\nWould you like to play a game? ') - inp = await reader.read(1) - if inp: - writer.echo(inp) - writer.write('\r\nThey say the only way to win ' - 'is to not play at all.\r\n') - await writer.drain() - writer.close() - - async def main(): - server = await telnetlib3.create_server(port=6023, shell=shell) - await server.wait_closed() +Both tools argument ``--shell=my_module.fn_shell`` describing a python +module path to a function of signature ``async def shell(reader, writer)``. +The server also provides ``--pty-exec`` argument to host a stand-alone +program. - asyncio.run(main()) +:: -More examples are available in the `Guidebook`_ and the ``bin/`` directory. + telnetlib3-client nethack.alt.org + telnetlib3-client xibalba.l33t.codes 44510 + telnetlib3-client --shell bin.client_wargame.shell 1984.ws 666 + telnetlib3-server 0.0.0.0 1984 --shell=bin.server_wargame.shell + telnetlib3-server --pty-exec /bin/bash -- --login Legacy telnetlib ---------------- -This library *also* contains a copy of telnetlib.py_ from the standard library of -Python 3.12 before it was removed in Python 3.13. asyncio_ is not required to use -it. +This library contains an unadulterated copy of Python 3.12's telnetlib.py_, +from the standard library before it was removed in Python 3.13. To migrate code, change import statements: @@ -118,39 +90,56 @@ To migrate code, change import statements: # NEW imports: import telnetlib3 -Command-line ------------- - -Two command-line scripts are distributed with this package, -``telnetlib3-client`` and ``telnetlib3-server``. - -Both accept argument ``--shell=my_module.fn_shell`` describing a python -module path to a function of signature ``async def shell(reader, writer)``. - -:: +``telnetlib3`` did not provide server support, while this library also provides +both client and server support through a similar Blocking API interface. - telnetlib3-client nethack.alt.org - telnetlib3-client xibalba.l33t.codes 44510 - telnetlib3-client --shell bin.client_wargame.shell 1984.ws 666 - telnetlib3-server --pty-exec /bin/bash -- --login - telnetlib3-server 0.0.0.0 6023 --shell='bin.server_wargame.shell +See `sync API documentation`_ for details. Encoding -------- -Use ``--encoding`` and ``--force-binary`` for non-ASCII terminals:: +Often required, ``--encoding`` and ``--force-binary``:: - telnetlib3-client --encoding=cp437 --force-binary blackflag.acid.org + telnetlib3-client --encoding=cp437 --force-binary 20forbeers.com 1337 -The default encoding is UTF-8, but all text is limited to ASCII until BINARY -mode is agreed by compliance of their respective RFCs. +The default encoding is the system locale, usually UTF-8, but all Telnet +protocol text *should* be limited to ASCII until BINARY mode is agreed by +compliance of their respective RFCs. However, many clients and servers that are capable of non-ascii encodings like -utf-8 or cp437 may not be capable of negotiating about BINARY, NEW_ENVIRON, -or CHARSET to demand about it. +UTF-8 or CP437 may not be capable of negotiating about BINARY, NEW_ENVIRON, +or CHARSET to negotiate about it. + +In this case, use ``--force-binary`` and ``--encoding`` when the encoding of +the remote end is known. + +Quick Example +============= + +A simple telnet server: + +.. code-block:: python + + import asyncio + import telnetlib3 + + async def shell(reader, writer): + writer.write('\r\nWould you like to play a game? ') + inp = await reader.read(1) + if inp: + writer.echo(inp) + writer.write('\r\nThey say the only way to win ' + 'is to not play at all.\r\n') + await writer.drain() + writer.close() + + async def main(): + server = await telnetlib3.create_server(port=6023, shell=shell) + await server.wait_closed() + + asyncio.run(main()) -In this case, use ``--force-binary`` argument for clients and servers to -enforce that the specified ``--encoding`` is always used, no matter what. +More examples are available in the `Guidebook`_ and the `bin/`_ directory of the repository. Features -------- @@ -200,6 +189,7 @@ The following RFC specifications are implemented: .. _rfc-1571: https://www.rfc-editor.org/rfc/rfc1571.txt .. _rfc-1572: https://www.rfc-editor.org/rfc/rfc1572.txt .. _rfc-2066: https://www.rfc-editor.org/rfc/rfc2066.txt +.. _`bin/`: https://github.com/jquast/telnetlib3/tree/master/bin .. _telnetlib.py: https://docs.python.org/3.12/library/telnetlib.html .. _Guidebook: https://telnetlib3.readthedocs.io/en/latest/guidebook.html .. _API documentation: https://telnetlib3.readthedocs.io/en/latest/api.html diff --git a/telnetlib3/server_pty_shell.py b/telnetlib3/server_pty_shell.py index 44ee9b61..1a0821d1 100644 --- a/telnetlib3/server_pty_shell.py +++ b/telnetlib3/server_pty_shell.py @@ -9,6 +9,7 @@ import os import pty import sys +import time import errno import fcntl import codecs @@ -21,7 +22,15 @@ # local from .telopt import NAWS -__all__ = ("make_pty_shell", "pty_shell") +__all__ = ("make_pty_shell", "pty_shell", "PTYSpawnError") + +# Delay between termination signals (seconds) +_TERMINATE_DELAY = 0.1 + + +class PTYSpawnError(Exception): + """Raised when PTY child process fails to exec.""" + logger = logging.getLogger("telnetlib3.server_pty_shell") @@ -40,7 +49,7 @@ def _platform_check(): class PTYSession: """Manages a PTY session lifecycle.""" - def __init__(self, reader, writer, program, args): + def __init__(self, reader, writer, program, args, *, preexec_fn=None): """ Initialize PTY session. @@ -48,11 +57,15 @@ def __init__(self, reader, writer, program, args): :param writer: TelnetWriter instance. :param program: Path to program to execute. :param args: List of arguments for the program. + :param preexec_fn: Optional callable to run in child before exec. Called with no arguments + after fork but before _setup_child. Useful for test coverage tracking in the forked + child process. """ self.reader = reader self.writer = writer self.program = program self.args = args or [] + self.preexec_fn = preexec_fn self.master_fd = None self.child_pid = None self._closing = False @@ -64,17 +77,46 @@ def __init__(self, reader, writer, program, args): self._naws_timer = None def start(self): - """Fork PTY, configure environment, and exec program.""" + """ + Fork PTY, configure environment, and exec program. + + :raises PTYSpawnError: If the child process fails to exec. + """ _platform_check() env = self._build_environment() rows, cols = self._get_window_size() + # Create pipe for exec error detection (ptyprocess pattern). + # Child sets close-on-exec; successful exec closes pipe automatically. + # If exec fails, child writes error through pipe before exiting. + exec_err_pipe_read, exec_err_pipe_write = os.pipe() + self.child_pid, self.master_fd = pty.fork() if self.child_pid == 0: - self._setup_child(env, rows, cols) + # Child process + os.close(exec_err_pipe_read) + fcntl.fcntl(exec_err_pipe_write, fcntl.F_SETFD, fcntl.FD_CLOEXEC) + + # Coverage object from preexec_fn, saved before exec + child_cov = None + if self.preexec_fn is not None: + try: + child_cov = self.preexec_fn() + except Exception as e: # pylint: disable=broad-exception-caught + self._write_exec_error(exec_err_pipe_write, e) + os._exit(1) + self._setup_child(env, rows, cols, exec_err_pipe_write, child_cov=child_cov) else: + # Parent process + os.close(exec_err_pipe_write) + exec_err_data = os.read(exec_err_pipe_read, 4096) + os.close(exec_err_pipe_read) + + if exec_err_data: + self._handle_exec_error(exec_err_data) + logger.debug( "forked PTY: program=%s pid=%d fd=%d", self.program, @@ -86,6 +128,26 @@ def start(self): if pid: logger.warning("child already exited: status=%d", status) + def _write_exec_error(self, pipe_fd, exc): + """Write exception info to pipe for parent to read.""" + ename = type(exc).__name__ + msg = f"{ename}:{getattr(exc, 'errno', 0)}:{exc}" + os.write(pipe_fd, msg.encode("utf-8", errors="replace")) + os.close(pipe_fd) + + def _handle_exec_error(self, data): + """Parse exec error from child and raise appropriate exception.""" + try: + parts = data.decode("utf-8", errors="replace").split(":", 2) + if len(parts) == 3: + errclass, _errno_s, errmsg = parts + raise PTYSpawnError(f"{errclass}: {errmsg}") + raise PTYSpawnError(f"Exec failed: {data!r}") + except PTYSpawnError: + raise + except Exception as exc: + raise PTYSpawnError(f"Exec failed: {data!r}") from exc + def _build_environment(self): """Build environment dict from negotiated values.""" env = os.environ.copy() @@ -124,7 +186,7 @@ def _get_window_size(self): cols = self.writer.get_extra_info("cols", 80) return rows, cols - def _setup_child(self, env, rows, cols): + def _setup_child(self, env, rows, cols, exec_err_pipe, *, child_cov=None): """Child process setup before exec.""" # Note: pty.fork() already calls setsid() for the child, so we don't need to @@ -141,8 +203,17 @@ def _setup_child(self, env, rows, cols): # Keep c_oflag intact - OPOST and ONLCR translate \n to \r\n termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, attrs) + # Save coverage data before exec replaces the process + if child_cov is not None: + child_cov.stop() + child_cov.save() + argv = [self.program] + self.args - os.execvpe(self.program, argv, env) + try: + os.execvpe(self.program, argv, env) + except OSError as err: + self._write_exec_error(exec_err_pipe, err) + os._exit(os.EX_OSERR) def _setup_parent(self): """Parent process setup after fork.""" @@ -347,6 +418,43 @@ def _flush_remaining(self): self._flush_output(self._output_buffer) self._output_buffer = b"" + def _isalive(self): + """Check if child process is still running.""" + if self.child_pid is None: + return False + try: + pid, _status = os.waitpid(self.child_pid, os.WNOHANG) + return pid == 0 + except ChildProcessError: + return False + + def _terminate(self, force=False): + """ + Terminate child with signal escalation (ptyprocess pattern). + + Tries SIGHUP, SIGCONT, SIGINT in sequence. If force=True, also tries SIGKILL. + + :param force: If True, use SIGKILL as last resort. + :returns: True if child was terminated, False otherwise. + """ + if not self._isalive(): + return True + + signals = [signal.SIGHUP, signal.SIGCONT, signal.SIGINT] + if force: + signals.append(signal.SIGKILL) + + for sig in signals: + try: + os.kill(self.child_pid, sig) + except ProcessLookupError: + return True + time.sleep(_TERMINATE_DELAY) + if not self._isalive(): + return True + + return not self._isalive() + def cleanup(self): """Kill child process and close PTY fd.""" # Cancel any pending NAWS timer @@ -368,11 +476,7 @@ def cleanup(self): self.master_fd = None if self.child_pid is not None: - try: - os.kill(self.child_pid, signal.SIGTERM) - except ProcessLookupError: - pass - + self._terminate(force=True) try: os.waitpid(self.child_pid, os.WNOHANG) except ChildProcessError: @@ -398,7 +502,7 @@ async def _wait_for_terminal_info(writer, timeout=2.0): await asyncio.sleep(0.1) -async def pty_shell(reader, writer, program, args=None): +async def pty_shell(reader, writer, program, args=None, preexec_fn=None): """ PTY shell callback for telnet server. @@ -406,12 +510,13 @@ async def pty_shell(reader, writer, program, args=None): :param TelnetWriter writer: TelnetWriter instance. :param str program: Path to program to execute. :param list args: List of arguments for the program. + :param preexec_fn: Optional callable to run in child before exec. """ _platform_check() await _wait_for_terminal_info(writer, timeout=2.0) - session = PTYSession(reader, writer, program, args) + session = PTYSession(reader, writer, program, args, preexec_fn=preexec_fn) try: session.start() await session.run() @@ -421,12 +526,14 @@ async def pty_shell(reader, writer, program, args=None): writer.close() -def make_pty_shell(program, args=None): +def make_pty_shell(program, args=None, preexec_fn=None): """ Factory returning a shell callback for PTY execution. :param str program: Path to program to execute. :param list args: List of arguments for the program. + :param preexec_fn: Optional callable to run in child before exec. + Useful for test coverage tracking in the forked child process. :returns: Async shell callback suitable for use with create_server(). Example usage:: @@ -441,6 +548,6 @@ def make_pty_shell(program, args=None): """ async def shell(reader, writer): - await pty_shell(reader, writer, program, args) + await pty_shell(reader, writer, program, args, preexec_fn=preexec_fn) return shell diff --git a/telnetlib3/sync.py b/telnetlib3/sync.py index 599877e8..e77f6f88 100644 --- a/telnetlib3/sync.py +++ b/telnetlib3/sync.py @@ -32,6 +32,7 @@ def handler(conn): import queue import asyncio import threading +import concurrent.futures from typing import Any, Union, Callable, Optional # local @@ -108,7 +109,7 @@ def connect(self) -> None: future = asyncio.run_coroutine_threadsafe(self._async_connect(), self._loop) try: future.result(timeout=self._timeout) - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: self._cleanup() raise TimeoutError("Connection timed out") from exc except Exception: @@ -159,7 +160,7 @@ def read(self, n: int = -1, timeout: Optional[float] = None) -> Union[str, bytes if not result: raise EOFError("Connection closed") return result - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Read timed out") from exc @@ -194,7 +195,7 @@ def readline(self, timeout: Optional[float] = None) -> Union[str, bytes]: if not result: raise EOFError("Connection closed") return result - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Readline timed out") from exc @@ -221,7 +222,7 @@ def read_until( future = asyncio.run_coroutine_threadsafe(self._reader.readuntil(match), self._loop) try: return future.result(timeout=timeout) - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Read until timed out") from exc except asyncio.IncompleteReadError as exc: @@ -255,7 +256,7 @@ def flush(self, timeout: Optional[float] = None) -> None: future = asyncio.run_coroutine_threadsafe(self._writer.drain(), self._loop) try: future.result(timeout=timeout) - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Flush timed out") from exc @@ -358,7 +359,7 @@ def wait_for( ) try: future.result(timeout=timeout) - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Wait for negotiation timed out") from exc @@ -632,7 +633,7 @@ def read(self, n: int = -1, timeout: Optional[float] = None) -> Union[str, bytes raise EOFError("Connection closed") self._last_input_time = time.time() return result - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Read timed out") from exc @@ -666,7 +667,7 @@ def readline(self, timeout: Optional[float] = None) -> Union[str, bytes]: raise EOFError("Connection closed") self._last_input_time = time.time() return result - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Readline timed out") from exc @@ -693,7 +694,7 @@ def read_until( result = future.result(timeout=timeout) self._last_input_time = time.time() return result - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Read until timed out") from exc except asyncio.IncompleteReadError as exc: @@ -723,7 +724,7 @@ def flush(self, timeout: Optional[float] = None) -> None: future = asyncio.run_coroutine_threadsafe(self._writer.drain(), self._loop) try: future.result(timeout=timeout) - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Flush timed out") from exc @@ -786,7 +787,7 @@ def wait_for( ) try: future.result(timeout=timeout) - except asyncio.TimeoutError as exc: + except concurrent.futures.TimeoutError as exc: future.cancel() raise TimeoutError("Wait for negotiation timed out") from exc diff --git a/telnetlib3/tests/accessories.py b/telnetlib3/tests/accessories.py index e88571a1..18edecb6 100644 --- a/telnetlib3/tests/accessories.py +++ b/telnetlib3/tests/accessories.py @@ -1,6 +1,7 @@ """Test accessories for telnetlib3 project.""" # std imports +import os import asyncio import contextlib @@ -9,6 +10,42 @@ from pytest_asyncio.plugin import unused_tcp_port +def init_subproc_coverage(run_note=None): + """ + Initialize coverage tracking in a forked subprocess. + + Derived from blessed library's test accessories. + + :param run_note: Optional note (unused, for compatibility). + :returns: Coverage instance or None. + """ + try: + # 3rd party + import coverage + except ImportError: + return None + + coveragerc = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "tox.ini") + cov = coverage.Coverage(config_file=coveragerc) + cov.start() + return cov + + +def make_preexec_coverage(): + """ + Create a preexec_fn for PTY coverage tracking. + + Derived from blessed library's test accessories. + + :returns: Callable that starts and returns coverage in forked child. + """ + + def preexec(): + return init_subproc_coverage() + + return preexec + + @pytest.fixture(scope="module", params=["127.0.0.1"]) def bind_host(request): """Localhost bind address.""" @@ -121,6 +158,8 @@ class TrackingProtocol(_TrackingProtocol, protocol_factory): "bind_host", "connection_context", "create_server", + "init_subproc_coverage", + "make_preexec_coverage", "open_connection", "server_context", "unused_tcp_port", diff --git a/telnetlib3/tests/pty_helper.py b/telnetlib3/tests/pty_helper.py new file mode 100644 index 00000000..9e7685e4 --- /dev/null +++ b/telnetlib3/tests/pty_helper.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +""" +Simple PTY test programs for telnetlib3 tests. + +These programs are designed to be run via PTY for testing purposes. +Usage: python -m telnetlib3.tests.pty_helper [args...] + +Modes: + cat - Echo stdin to stdout (like /bin/cat) + echo - Print arguments and exit + stty_size - Print terminal size as "rows cols" + exit_code - Exit with given code (default 0) + env - Print specified environment variable + sleep - Sleep for N seconds (default 60) + env_all - Print all environment variables + sync_output - Output with BSU/ESU synchronized update sequences + partial_utf8 - Output incomplete UTF-8 then complete it +""" + +# std imports +import os +import sys + + +def cat_mode(): + """Echo stdin to stdout until EOF.""" + try: + while True: + data = sys.stdin.read(1) + if not data: + break + sys.stdout.write(data) + sys.stdout.flush() + except (EOFError, KeyboardInterrupt): + pass + + +def echo_mode(args): + """Print arguments to stdout.""" + print(" ".join(args)) + sys.stdout.flush() + + +def stty_size_mode(): + """Print terminal size.""" + # std imports + import fcntl + import struct + import termios + + try: + winsize = fcntl.ioctl(sys.stdin.fileno(), termios.TIOCGWINSZ, b"\x00" * 8) + rows, cols = struct.unpack("HHHH", winsize)[:2] + print(f"{rows} {cols}") + except OSError: + print("unknown") + sys.stdout.flush() + + +def exit_code_mode(args): + """Exit with specified code.""" + code = int(args[0]) if args else 0 + print("done") + sys.stdout.flush() + sys.exit(code) + + +def env_mode(args): + """Print environment variable.""" + var_name = args[0] if args else "TERM" + value = os.environ.get(var_name, "") + print(value) + sys.stdout.flush() + + +def sleep_mode(args): + """Sleep for specified seconds.""" + # std imports + import time + + seconds = float(args[0]) if args else 60 + time.sleep(seconds) + + +def env_all_mode(): + """Print all environment variables.""" + for key in sorted(os.environ.keys()): + print(f"{key}={os.environ[key]}") + sys.stdout.flush() + + +def sync_output_mode(): + """Output with BSU/ESU synchronized update sequences.""" + bsu = b"\x1b[?2026h" + esu = b"\x1b[?2026l" + sys.stdout.buffer.write(b"before\n") + sys.stdout.buffer.write(bsu + b"synchronized content" + esu) + sys.stdout.buffer.write(b"\nafter\n") + sys.stdout.buffer.flush() + + +def partial_utf8_mode(): + """Output incomplete UTF-8 then complete it.""" + sys.stdout.buffer.write(b"hello\xc3") + sys.stdout.buffer.flush() + # std imports + import time + + time.sleep(0.1) + sys.stdout.buffer.write(b"\xa9world\n") + sys.stdout.buffer.flush() + + +def main(): + """Entry point for PTY test helper.""" + if len(sys.argv) < 2: + print("Usage: python -m telnetlib3.tests.pty_helper [args...]", file=sys.stderr) + sys.exit(1) + + mode = sys.argv[1] + args = sys.argv[2:] + + modes = { + "cat": cat_mode, + "echo": lambda: echo_mode(args), + "stty_size": stty_size_mode, + "exit_code": lambda: exit_code_mode(args), + "env": lambda: env_mode(args), + "sleep": lambda: sleep_mode(args), + "env_all": env_all_mode, + "sync_output": sync_output_mode, + "partial_utf8": partial_utf8_mode, + } + + if mode not in modes: + print(f"Unknown mode: {mode}", file=sys.stderr) + print(f"Available modes: {', '.join(modes.keys())}", file=sys.stderr) + sys.exit(1) + + modes[mode]() + + +if __name__ == "__main__": + main() diff --git a/telnetlib3/tests/test_guard_integration.py b/telnetlib3/tests/test_guard_integration.py new file mode 100644 index 00000000..c222b7af --- /dev/null +++ b/telnetlib3/tests/test_guard_integration.py @@ -0,0 +1,332 @@ +# std imports +import asyncio + + +async def test_connection_counter_integration(): + # local + from telnetlib3.guard_shells import ConnectionCounter + + counter = ConnectionCounter(2) + + assert counter.try_acquire() + assert counter.count == 1 + + assert counter.try_acquire() + assert counter.count == 2 + + assert not counter.try_acquire() + assert counter.count == 2 + + counter.release() + assert counter.count == 1 + + assert counter.try_acquire() + assert counter.count == 2 + + +async def test_counter_release_on_completion(): + # local + from telnetlib3.guard_shells import ConnectionCounter + + counter = ConnectionCounter(1) + + async def shell_with_finally(): + if not counter.try_acquire(): + raise RuntimeError("Counter should have allowed acquire") + try: + raise ValueError("Simulated error") + finally: + counter.release() + + assert counter.count == 0 + + try: + await shell_with_finally() + except ValueError: + pass + + assert counter.count == 0 + + +async def test_counter_release_in_guarded_pattern(): + # local + from telnetlib3.guard_shells import ConnectionCounter + + counter = ConnectionCounter(2) + + results = [] + + async def guarded_shell(name): + if not counter.try_acquire(): + results.append(f"{name}: rejected") + return + + try: + results.append(f"{name}: acquired (count={counter.count})") + await asyncio.sleep(0.05) + finally: + counter.release() + results.append(f"{name}: released (count={counter.count})") + + await asyncio.gather( + guarded_shell("client1"), + guarded_shell("client2"), + guarded_shell("client3"), + ) + + acquired_count = sum(1 for r in results if "acquired" in r) + released_count = sum(1 for r in results if "released" in r) + rejected_count = sum(1 for r in results if "rejected" in r) + + assert acquired_count == 2 + assert released_count == 2 + assert rejected_count == 1 + assert counter.count == 0 + + +async def test_guarded_shell_pattern_busy_shell(): + # local + from telnetlib3.guard_shells import ConnectionCounter, busy_shell + + counter = ConnectionCounter(1) + shell_calls = [] + busy_shell_calls = [] + shell_done = asyncio.Event() + + class MockWriter: + def __init__(self): + self._closing = False + + def write(self, data): + pass + + async def drain(self): + pass + + def is_closing(self): + return self._closing + + def close(self): + self._closing = True + + def get_extra_info(self, key, default=None): + return ("127.0.0.1", 12345) if key == "peername" else default + + class MockReader: + def __init__(self): + self._data = list("response\r") + self._idx = 0 + + async def read(self, n): + if self._idx >= len(self._data): + return "" + result = self._data[self._idx] + self._idx += 1 + return result + + async def inner_shell(reader, writer): + shell_calls.append(True) + await shell_done.wait() + + async def guarded_shell(reader, writer): + if not counter.try_acquire(): + busy_shell_calls.append(True) + await busy_shell(reader, writer) + if not writer.is_closing(): + writer.close() + return + + try: + await inner_shell(reader, writer) + finally: + counter.release() + + writer1 = MockWriter() + writer2 = MockWriter() + reader1 = MockReader() + reader2 = MockReader() + + task1 = asyncio.create_task(guarded_shell(reader1, writer1)) + await asyncio.sleep(0.01) + task2 = asyncio.create_task(guarded_shell(reader2, writer2)) + + await asyncio.sleep(0.01) + shell_done.set() + + await asyncio.gather(task1, task2) + + assert len(shell_calls) == 1 + assert len(busy_shell_calls) == 1 + assert counter.count == 0 + + +async def test_guarded_shell_pattern_robot_check(): # pylint: disable=too-complex + # local + from telnetlib3.guard_shells import ConnectionCounter + + counter = ConnectionCounter(5) + shell_calls = [] + robot_shell_calls = [] + + class MockWriter: + def __init__(self): + self._closing = False + + def write(self, data): + pass + + async def drain(self): + pass + + def is_closing(self): + return self._closing + + def close(self): + self._closing = True + + def get_extra_info(self, key, default=None): + return ("127.0.0.1", 12345) if key == "peername" else default + + class MockReader: + def __init__(self): + self._data = list("response\r") + self._idx = 0 + + async def read(self, n): + if self._idx >= len(self._data): + return "" + result = self._data[self._idx] + self._idx += 1 + return result + + robot_check_results = [True, False, True] + robot_check_idx = [0] + + async def mock_robot_check(reader, writer): + idx = robot_check_idx[0] + robot_check_idx[0] += 1 + return robot_check_results[idx % len(robot_check_results)] + + async def mock_robot_shell(reader, writer): + robot_shell_calls.append(True) + + async def inner_shell(reader, writer): + shell_calls.append(True) + + async def guarded_shell(reader, writer): + if not counter.try_acquire(): + return + + try: + passed = await mock_robot_check(reader, writer) + if not passed: + await mock_robot_shell(reader, writer) + if not writer.is_closing(): + writer.close() + return + + await inner_shell(reader, writer) + finally: + counter.release() + + tasks = [] + for i in range(3): + reader = MockReader() + writer = MockWriter() + tasks.append(asyncio.create_task(guarded_shell(reader, writer))) + + await asyncio.gather(*tasks) + + assert len(shell_calls) == 2 + assert len(robot_shell_calls) == 1 + assert counter.count == 0 + + +async def test_full_guarded_shell_flow(): # pylint: disable=too-complex + # local + from telnetlib3.guard_shells import ConnectionCounter, busy_shell, robot_shell + + counter = ConnectionCounter(2) + shell_calls = [] + busy_calls = [] + robot_calls = [] + + class MockWriter: + def __init__(self): + self._closing = False + self.output = [] + + def write(self, data): + self.output.append(data) + + async def drain(self): + pass + + def is_closing(self): + return self._closing + + def close(self): + self._closing = True + + def get_extra_info(self, key, default=None): + return ("127.0.0.1", 12345) if key == "peername" else default + + class MockReader: + def __init__(self, responses=None): + self._data = responses or list("response\r") + self._idx = 0 + + async def read(self, n): + if self._idx >= len(self._data): + return "" + result = self._data[self._idx] + self._idx += 1 + return result + + robot_check_results = [True, False, True, True] + robot_check_idx = [0] + + async def mock_robot_check(reader, writer): + idx = robot_check_idx[0] + robot_check_idx[0] += 1 + return robot_check_results[idx % len(robot_check_results)] + + async def inner_shell(reader, writer): + shell_calls.append(True) + writer.write("Shell active") + + async def guarded_shell(reader, writer, do_robot_check=True): + if not counter.try_acquire(): + busy_calls.append(True) + await busy_shell(reader, writer) + if not writer.is_closing(): + writer.close() + return + + try: + if do_robot_check: + passed = await mock_robot_check(reader, writer) + if not passed: + robot_calls.append(True) + await robot_shell(reader, writer) + if not writer.is_closing(): + writer.close() + return + + await inner_shell(reader, writer) + finally: + counter.release() + + writers = [MockWriter() for _ in range(4)] + readers = [MockReader(list("y\ryes\r")) for _ in range(4)] + + await asyncio.gather( + guarded_shell(readers[0], writers[0]), + guarded_shell(readers[1], writers[1]), + guarded_shell(readers[2], writers[2]), + guarded_shell(readers[3], writers[3]), + ) + + assert len(shell_calls) >= 1 + assert len(robot_calls) >= 1 + assert counter.count == 0 diff --git a/telnetlib3/tests/test_pty_shell.py b/telnetlib3/tests/test_pty_shell.py index 85b95f1a..07be0bdb 100644 --- a/telnetlib3/tests/test_pty_shell.py +++ b/telnetlib3/tests/test_pty_shell.py @@ -1,6 +1,7 @@ """Tests for PTY shell functionality.""" # std imports +import os import sys import asyncio @@ -12,12 +13,23 @@ from telnetlib3.tests.accessories import ( # pylint: disable=unused-import bind_host, unused_tcp_port, + make_preexec_coverage, ) pytestmark = [ pytest.mark.skipif(sys.platform == "win32", reason="PTY not supported on Windows"), ] +PTY_HELPER = os.path.join(os.path.dirname(__file__), "pty_helper.py") + +# Python 3.15+ emits DeprecationWarning when forkpty() is called in a multi-threaded +# process. The warning is valid (forking in threaded processes can deadlock), but +# pytest itself uses threads, so we can't avoid it. The PTY code still works fine - +# we just suppress the warning in tests rather than skipping them entirely. +_ignore_forkpty_deprecation = pytest.mark.filterwarnings( + "ignore:This process.*is multi-threaded, use of forkpty:DeprecationWarning" +) + @pytest.fixture def require_no_capture(request): @@ -27,12 +39,41 @@ def require_no_capture(request): pytest.skip("PTY tests require --capture=no or -s flag") -async def test_pty_shell_basic_cat(bind_host, unused_tcp_port, require_no_capture): - """Test basic echo with /bin/cat.""" +@pytest.fixture +def mock_session(): + """Create a mock PTYSession for unit testing.""" + # std imports + from unittest.mock import MagicMock + + # local + from telnetlib3.server_pty_shell import PTYSession + + def _create(extra_info=None, capture_writes=False): + reader = MagicMock() + writer = MagicMock() + written = [] if capture_writes else None + if capture_writes: + writer.write = written.append + if extra_info is None: + writer.get_extra_info = MagicMock(return_value=None) + elif callable(extra_info): + writer.get_extra_info = MagicMock(side_effect=extra_info) + else: + writer.get_extra_info = MagicMock(side_effect=lambda k, d=None: extra_info.get(k, d)) + session = PTYSession(reader, writer, "/nonexistent.program", []) + return session, written + + return _create + + +@_ignore_forkpty_deprecation +async def test_pty_shell_integration(bind_host, unused_tcp_port, require_no_capture): + """Test PTY shell with various helper modes: cat, env, stty_size.""" # local from telnetlib3 import make_pty_shell from telnetlib3.tests.accessories import create_server, open_connection + # Test 1: cat mode - echo input back _waiter = asyncio.Future() class ServerWithWaiter(telnetlib3.TelnetServer): @@ -45,7 +86,9 @@ def begin_shell(self, result): protocol_factory=ServerWithWaiter, host=bind_host, port=unused_tcp_port, - shell=make_pty_shell("/bin/cat"), + shell=make_pty_shell( + sys.executable, [PTY_HELPER, "cat"], preexec_fn=make_preexec_coverage() + ), connect_maxwait=0.5, ): async with open_connection( @@ -64,22 +107,10 @@ def begin_shell(self, result): result = await asyncio.wait_for(reader.read(50), 2.0) assert "hello world" in result - -async def test_pty_shell_term_propagation(bind_host, unused_tcp_port, require_no_capture): - """Test TERM environment propagation.""" - # local - from telnetlib3 import make_pty_shell - from telnetlib3.tests.accessories import create_server, open_connection - + # Test 2: env mode - verify TERM propagation _waiter = asyncio.Future() _output = asyncio.Future() - class ServerWithWaiter(telnetlib3.TelnetServer): - def begin_shell(self, result): - super().begin_shell(result) - if not _waiter.done(): - _waiter.set_result(self) - async def client_shell(reader, writer): await _waiter await asyncio.sleep(0.5) @@ -90,7 +121,9 @@ async def client_shell(reader, writer): protocol_factory=ServerWithWaiter, host=bind_host, port=unused_tcp_port, - shell=make_pty_shell("/bin/sh", ["-c", "echo $TERM"]), + shell=make_pty_shell( + sys.executable, [PTY_HELPER, "env", "TERM"], preexec_fn=make_preexec_coverage() + ), connect_maxwait=0.5, ): async with open_connection( @@ -105,28 +138,16 @@ async def client_shell(reader, writer): output = await asyncio.wait_for(_output, 5.0) assert "vt220" in output or "xterm" in output - -async def test_pty_shell_child_exit_closes_connection( - bind_host, unused_tcp_port, require_no_capture -): - """Test that child exit closes connection gracefully.""" - # local - from telnetlib3 import make_pty_shell - from telnetlib3.tests.accessories import create_server, open_connection - + # Test 3: stty_size mode - verify NAWS propagation _waiter = asyncio.Future() - class ServerWithWaiter(telnetlib3.TelnetServer): - def begin_shell(self, result): - super().begin_shell(result) - if not _waiter.done(): - _waiter.set_result(self) - async with create_server( protocol_factory=ServerWithWaiter, host=bind_host, port=unused_tcp_port, - shell=make_pty_shell("/bin/sh", ["-c", "echo done; exit 0"]), + shell=make_pty_shell( + sys.executable, [PTY_HELPER, "stty_size"], preexec_fn=make_preexec_coverage() + ), connect_maxwait=0.5, ): async with open_connection( @@ -139,23 +160,19 @@ def begin_shell(self, result): await asyncio.wait_for(_waiter, 2.0) await asyncio.sleep(0.3) - result = await asyncio.wait_for(reader.read(100), 3.0) - assert "done" in result - - remaining = await asyncio.wait_for(reader.read(), 3.0) - assert not remaining + output = await asyncio.wait_for(reader.read(50), 2.0) + assert "25 80" in output -async def test_pty_shell_client_disconnect_kills_child( - bind_host, unused_tcp_port, require_no_capture -): - """Test that client disconnect kills child process.""" +@_ignore_forkpty_deprecation +async def test_pty_shell_lifecycle(bind_host, unused_tcp_port, require_no_capture): + """Test PTY shell lifecycle: child exit and client disconnect.""" # local from telnetlib3 import make_pty_shell from telnetlib3.tests.accessories import create_server, open_connection + # Test 1: child exit closes connection gracefully _waiter = asyncio.Future() - _closed = asyncio.Future() class ServerWithWaiter(telnetlib3.TelnetServer): def begin_shell(self, result): @@ -163,16 +180,13 @@ def begin_shell(self, result): if not _waiter.done(): _waiter.set_result(self) - def connection_lost(self, exc): - super().connection_lost(exc) - if not _closed.done(): - _closed.set_result(True) - async with create_server( protocol_factory=ServerWithWaiter, host=bind_host, port=unused_tcp_port, - shell=make_pty_shell("/bin/cat"), + shell=make_pty_shell( + sys.executable, [PTY_HELPER, "exit_code", "0"], preexec_fn=make_preexec_coverage() + ), connect_maxwait=0.5, ): async with open_connection( @@ -185,28 +199,34 @@ def connection_lost(self, exc): await asyncio.wait_for(_waiter, 2.0) await asyncio.sleep(0.3) - await asyncio.wait_for(_closed, 3.0) - + result = await asyncio.wait_for(reader.read(100), 3.0) + assert "done" in result -async def test_pty_shell_naws_resize(bind_host, unused_tcp_port, require_no_capture): - """Test NAWS resize forwarding.""" - # local - from telnetlib3 import make_pty_shell - from telnetlib3.tests.accessories import create_server, open_connection + remaining = await asyncio.wait_for(reader.read(), 3.0) + assert not remaining + # Test 2: client disconnect kills child process _waiter = asyncio.Future() + _closed = asyncio.Future() - class ServerWithWaiter(telnetlib3.TelnetServer): + class ServerWithCloseWaiter(telnetlib3.TelnetServer): def begin_shell(self, result): super().begin_shell(result) if not _waiter.done(): _waiter.set_result(self) + def connection_lost(self, exc): + super().connection_lost(exc) + if not _closed.done(): + _closed.set_result(True) + async with create_server( - protocol_factory=ServerWithWaiter, + protocol_factory=ServerWithCloseWaiter, host=bind_host, port=unused_tcp_port, - shell=make_pty_shell("/bin/sh", ["-c", "stty size"]), + shell=make_pty_shell( + sys.executable, [PTY_HELPER, "cat"], preexec_fn=make_preexec_coverage() + ), connect_maxwait=0.5, ): async with open_connection( @@ -219,8 +239,7 @@ def begin_shell(self, result): await asyncio.wait_for(_waiter, 2.0) await asyncio.sleep(0.3) - output = await asyncio.wait_for(reader.read(50), 2.0) - assert "25 80" in output + await asyncio.wait_for(_closed, 3.0) def test_platform_check_not_windows(): @@ -242,36 +261,26 @@ def test_make_pty_shell_returns_callable(): # local from telnetlib3 import make_pty_shell - shell = make_pty_shell("/bin/sh") + shell = make_pty_shell(sys.executable) assert callable(shell) - shell_with_args = make_pty_shell("/bin/sh", ["-c", "echo hello"]) + shell_with_args = make_pty_shell(sys.executable, [PTY_HELPER, "echo", "hello"]) assert callable(shell_with_args) -async def test_pty_session_build_environment(): - """Test PTYSession environment building.""" - # std imports - from unittest.mock import MagicMock - - # local - from telnetlib3.server_pty_shell import PTYSession - - reader = MagicMock() - writer = MagicMock() - writer.get_extra_info = MagicMock( - side_effect=lambda k, d=None: { +async def test_pty_session_build_environment(mock_session): + """Test PTYSession environment building with various configurations.""" + # Test with full environment info + session, _ = mock_session( + { "TERM": "xterm-256color", "rows": 30, "cols": 100, "LANG": "en_US.UTF-8", "DISPLAY": ":0", - }.get(k, d) + } ) - - session = PTYSession(reader, writer, "/bin/sh", []) env = session._build_environment() - assert env["TERM"] == "xterm-256color" assert env["LINES"] == "30" assert env["COLUMNS"] == "100" @@ -279,35 +288,299 @@ async def test_pty_session_build_environment(): assert env["LC_ALL"] == "en_US.UTF-8" assert env["DISPLAY"] == ":0" + # Test charset fallback when no LANG + session, _ = mock_session( + { + "TERM": "vt100", + "rows": 24, + "cols": 80, + "charset": "ISO-8859-1", + } + ) + env = session._build_environment() + assert env["TERM"] == "vt100" + assert env["LANG"] == "en_US.ISO-8859-1" + + +async def test_pty_session_naws_behavior(mock_session): + """Test NAWS debouncing, latest value usage, and cleanup cancellation.""" + # std imports + import struct + from unittest.mock import MagicMock, patch + + session, _ = mock_session() + session.master_fd = 1 + session.child_pid = 12345 + session.writer.protocol = MagicMock() + + signal_calls = [] + ioctl_calls = [] + + def mock_killpg(pgid, sig): + signal_calls.append((pgid, sig)) + + def mock_ioctl(fd, cmd, data): + ioctl_calls.append((fd, cmd, data)) -async def test_pty_session_build_environment_charset_fallback(): - """Test PTYSession environment building with charset fallback.""" + with patch("os.getpgid", return_value=12345), patch( + "os.killpg", side_effect=mock_killpg + ), patch("fcntl.ioctl", side_effect=mock_ioctl): + # Rapid updates should be debounced - only one signal after delay + session._on_naws(25, 80) + session._on_naws(30, 90) + session._on_naws(50, 150) + assert len(signal_calls) == 0 + + await asyncio.sleep(0.25) + assert len(signal_calls) == 1 + assert len(ioctl_calls) == 1 + + # Should use latest values (50, 150) + expected_winsize = struct.pack("HHHH", 50, 150, 0, 0) + assert ioctl_calls[0][2] == expected_winsize + + # Test cleanup cancels pending NAWS timer + session, _ = mock_session() + session.master_fd = 1 + session.child_pid = 12345 + session.writer.protocol = MagicMock() + winch_calls = [] + + def mock_killpg_winch(pgid, sig): + # std imports + import signal as signal_mod + + if sig == signal_mod.SIGWINCH: + winch_calls.append((pgid, sig)) + + with patch("os.getpgid", return_value=12345), patch( + "os.killpg", side_effect=mock_killpg_winch + ), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)), patch("os.close"), patch( + "fcntl.ioctl" + ): + session._on_naws(25, 80) + session.cleanup() + await asyncio.sleep(0.25) + assert len(winch_calls) == 0 + + +async def test_pty_session_write_to_telnet_buffering(mock_session): + """Test _write_to_telnet line buffering, BSU/ESU handling, and overflow protection.""" + # local + from telnetlib3.server_pty_shell import _BSU, _ESU + + # Line buffering: buffers until newline + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + session._write_to_telnet(b"hello") + assert len(written) == 0 + assert session._output_buffer == b"hello" + + session._write_to_telnet(b" world\nmore") + assert len(written) == 1 + assert "hello world\n" in written[0] + assert session._output_buffer == b"more" + + # BSU/ESU: complete sequence flushes immediately + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + session._write_to_telnet(_BSU + b"content" + _ESU) + assert len(written) == 1 + assert session._in_sync_update is False + + # BSU waits for ESU + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + session._write_to_telnet(_BSU + b"partial") + assert len(written) == 0 + assert session._in_sync_update is True + session._write_to_telnet(b" content" + _ESU) + assert len(written) == 1 + assert session._in_sync_update is False + + # Buffer overflow protection (256KB) + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + session._in_sync_update = True + session._output_buffer = b"x" * 300000 + + session._write_to_telnet(b"") + assert len(written) == 1 + assert session._output_buffer == b"" + + +async def test_pty_session_flush_output_behavior(mock_session): + """Test flush_output charset handling and incomplete UTF-8 buffering.""" # std imports from unittest.mock import MagicMock # local from telnetlib3.server_pty_shell import PTYSession + # Charset change recreates decoder reader = MagicMock() writer = MagicMock() + written = [] + writer.write = written.append + charset_values = ["utf-8"] writer.get_extra_info = MagicMock( - side_effect=lambda k, d=None: { - "TERM": "vt100", - "rows": 24, - "cols": 80, - "charset": "ISO-8859-1", - }.get(k, d) + side_effect=lambda k, d=None: charset_values[0] if k == "charset" else d ) + session = PTYSession(reader, writer, "/nonexistent.program", []) + session._flush_output(b"hello") + original_decoder = session._decoder + assert session._decoder_charset == "utf-8" + charset_values[0] = "latin-1" + session._flush_output(b"world") + assert session._decoder is not original_decoder + assert session._decoder_charset == "latin-1" + + # Incomplete UTF-8 sequences are buffered + session, written = mock_session({"charset": "utf-8"}, capture_writes=True) + session._flush_output(b"hello\xc3") + assert len(written) == 1 + assert written[0] == "hello" + session._flush_output(b"\xa9", final=True) + assert len(written) == 2 + assert written[1] == "\xe9" + + +async def test_pty_session_write_to_pty_behavior(mock_session): + """Test _write_to_pty encoding, error handling, and None fd guard.""" + # std imports + from unittest.mock import patch + + # String encoding + session, _ = mock_session({"charset": "utf-8"}) + session.master_fd = 99 + written_data = [] + + def mock_write(fd, data): + written_data.append((fd, data)) + return len(data) + + with patch("os.write", side_effect=mock_write): + session._write_to_pty("hello") + assert written_data == [(99, b"hello")] + + # OSError sets _closing flag + session, _ = mock_session({"charset": "utf-8"}) + session.master_fd = 99 + session._closing = False + with patch("os.write", side_effect=OSError("broken pipe")): + session._write_to_pty(b"data") + assert session._closing is True + + # None fd does nothing + session, _ = mock_session({"charset": "utf-8"}) + session.master_fd = None + write_calls = [] + with patch("os.write", side_effect=lambda fd, data: write_calls.append((fd, data))): + session._write_to_pty(b"data") + assert len(write_calls) == 0 + + +async def test_pty_session_cleanup_flushes_remaining_buffer(): + """Test that cleanup flushes remaining buffer with final=True.""" + # std imports + from unittest.mock import MagicMock, patch - session = PTYSession(reader, writer, "/bin/sh", []) - env = session._build_environment() + # local + from telnetlib3.server_pty_shell import PTYSession - assert env["TERM"] == "vt100" - assert env["LANG"] == "en_US.ISO-8859-1" + reader = MagicMock() + writer = MagicMock() + written = [] + writer.write = written.append + writer.get_extra_info = MagicMock(return_value="utf-8") + + session = PTYSession(reader, writer, "/nonexistent.program", []) + session._output_buffer = b"remaining data" + session.master_fd = 99 + session.child_pid = 12345 + + with patch("os.close"), patch("os.kill"), patch("os.waitpid", return_value=(0, 0)): + session.cleanup() + + assert len(written) == 1 + assert written[0] == "remaining data" + assert session._output_buffer == b"" + + +async def test_wait_for_terminal_info_behavior(): + """Test _wait_for_terminal_info early return, timeout, and polling behavior.""" + # std imports + import time + from unittest.mock import MagicMock + + # local + from telnetlib3.server_pty_shell import _wait_for_terminal_info + + # Returns early when TERM and rows available + writer = MagicMock() + writer.get_extra_info = MagicMock(side_effect={"TERM": "xterm", "rows": 25}.get) + await _wait_for_terminal_info(writer, timeout=2.0) + + # Times out when info not available + writer = MagicMock() + writer.get_extra_info = MagicMock(return_value=None) + start = time.time() + await _wait_for_terminal_info(writer, timeout=0.3) + assert time.time() - start >= 0.25 + + # Polls until rows become available + call_count = [0] + + def get_info(key): + call_count[0] += 1 + if key == "TERM": + return "xterm" + if key == "rows" and call_count[0] > 4: + return 25 + return None + + writer = MagicMock() + writer.get_extra_info = MagicMock(side_effect=get_info) + start = time.time() + await _wait_for_terminal_info(writer, timeout=2.0) + assert time.time() - start < 1.0 + assert call_count[0] > 2 + + +async def test_pty_session_set_window_size_behavior(mock_session): + """Test _set_window_size guards and error handling.""" + # std imports + from unittest.mock import patch + # No fd does nothing + session, _ = mock_session() + session.master_fd = None + session.child_pid = None + ioctl_calls = [] + with patch( + "fcntl.ioctl", side_effect=lambda fd, cmd, data: ioctl_calls.append((fd, cmd, data)) + ): + session._set_window_size(25, 80) + assert len(ioctl_calls) == 0 -async def test_pty_session_naws_debouncing(): - """Test that rapid NAWS updates are debounced.""" + # Handles ProcessLookupError gracefully + session, _ = mock_session() + session.master_fd = 99 + session.child_pid = 12345 + with patch("fcntl.ioctl"), patch("os.getpgid", return_value=12345), patch( + "os.killpg", side_effect=ProcessLookupError("process gone") + ): + session._set_window_size(25, 80) + + +@pytest.mark.parametrize( + "close_effect,kill_effect,waitpid_effect,check_attr", + [ + (None, None, ChildProcessError("already reaped"), "child_pid"), + (OSError("bad fd"), None, (0, 0), "master_fd"), + (None, ProcessLookupError("already dead"), (0, 0), "child_pid"), + ], +) +async def test_pty_session_cleanup_error_recovery( + close_effect, kill_effect, waitpid_effect, check_attr +): + """Test cleanup handles various error conditions gracefully.""" # std imports from unittest.mock import MagicMock, patch @@ -316,87 +589,178 @@ async def test_pty_session_naws_debouncing(): reader = MagicMock() writer = MagicMock() - protocol = MagicMock() - writer.protocol = protocol - writer.get_extra_info = MagicMock(return_value=None) + writer.get_extra_info = MagicMock(return_value="utf-8") - session = PTYSession(reader, writer, "/bin/sh", []) - session.master_fd = 1 + session = PTYSession(reader, writer, "/nonexistent.program", []) + session.master_fd = 99 session.child_pid = 12345 - signal_calls = [] + close_patch = patch("os.close", side_effect=close_effect) if close_effect else patch("os.close") + kill_patch = patch("os.kill", side_effect=kill_effect) if kill_effect else patch("os.kill") + waitpid_side = waitpid_effect if isinstance(waitpid_effect, Exception) else None + waitpid_return = None if isinstance(waitpid_effect, Exception) else waitpid_effect + waitpid_patch = patch("os.waitpid", side_effect=waitpid_side, return_value=waitpid_return) - def mock_killpg(pgid, sig): - signal_calls.append((pgid, sig)) + with close_patch, kill_patch, waitpid_patch: + session.cleanup() - with patch("os.getpgid", return_value=12345), \ - patch("os.killpg", side_effect=mock_killpg), \ - patch("fcntl.ioctl"): - session._on_naws(25, 80) - session._on_naws(30, 90) - session._on_naws(35, 100) + assert getattr(session, check_attr) is None - assert len(signal_calls) == 0 - await asyncio.sleep(0.25) +@pytest.mark.parametrize( + "in_sync_update,expected_writes,expected_buffer", + [ + (False, 1, b""), + (True, 0, b"partial line"), + ], +) +async def test_pty_session_flush_remaining_scenarios( + in_sync_update, expected_writes, expected_buffer +): + """Test _flush_remaining behavior based on sync update state.""" + # std imports + from unittest.mock import MagicMock - assert len(signal_calls) == 1 + # local + from telnetlib3.server_pty_shell import PTYSession - signal_calls.clear() - session._on_naws(40, 120) - session._on_naws(45, 130) + reader = MagicMock() + writer = MagicMock() + written = [] + writer.write = written.append + writer.get_extra_info = MagicMock(return_value="utf-8") - assert len(signal_calls) == 0 + session = PTYSession(reader, writer, "/nonexistent.program", []) + session._output_buffer = b"partial line" + session._in_sync_update = in_sync_update - await asyncio.sleep(0.25) + session._flush_remaining() - assert len(signal_calls) == 1 + assert len(written) == expected_writes + if expected_writes > 0: + assert written[0] == "partial line" + assert session._output_buffer == expected_buffer -async def test_pty_session_naws_debounce_uses_latest_values(): - """Test that debounced NAWS uses the latest values.""" +async def test_pty_session_flush_output_empty_data(): + """Test _flush_output does nothing with empty data.""" # std imports - from unittest.mock import MagicMock, call, patch + from unittest.mock import MagicMock # local from telnetlib3.server_pty_shell import PTYSession reader = MagicMock() writer = MagicMock() - protocol = MagicMock() - writer.protocol = protocol + written = [] + writer.write = written.append + writer.get_extra_info = MagicMock(return_value="utf-8") + + session = PTYSession(reader, writer, "/nonexistent.program", []) + + session._flush_output(b"") + session._flush_output(b"", final=True) + + assert len(written) == 0 + + +async def test_pty_session_write_to_telnet_pre_bsu_content(): + """Test content before BSU is flushed.""" + # std imports + from unittest.mock import MagicMock + + # local + from telnetlib3.server_pty_shell import _BSU, _ESU, PTYSession + + reader = MagicMock() + writer = MagicMock() + written = [] + writer.write = written.append + writer.get_extra_info = MagicMock(return_value="utf-8") + + session = PTYSession(reader, writer, "/nonexistent.program", []) + + session._write_to_telnet(b"before\n" + _BSU + b"during" + _ESU) + assert len(written) == 2 + assert "before\n" in written[0] + assert session._in_sync_update is False + + +async def test_pty_spawn_error(): + """Test PTYSpawnError exception class.""" + # local + from telnetlib3.server_pty_shell import PTYSpawnError + + err = PTYSpawnError("test error") + assert str(err) == "test error" + assert isinstance(err, Exception) + + +@pytest.mark.parametrize( + "error_data,expected_substrings", + [ + (b"FileNotFoundError:2:No such file", ["FileNotFoundError", "No such file"]), + (b"just some error text", ["Exec failed"]), + ], +) +async def test_pty_session_exec_error_parsing(error_data, expected_substrings): + """Test _handle_exec_error parses various error formats.""" + # std imports + from unittest.mock import MagicMock + + # local + from telnetlib3.server_pty_shell import PTYSession, PTYSpawnError + + reader = MagicMock() + writer = MagicMock() writer.get_extra_info = MagicMock(return_value=None) - session = PTYSession(reader, writer, "/bin/sh", []) - session.master_fd = 1 - session.child_pid = 12345 + session = PTYSession(reader, writer, "/nonexistent.program", []) - ioctl_calls = [] + with pytest.raises(PTYSpawnError) as exc_info: + session._handle_exec_error(error_data) - def mock_ioctl(fd, cmd, data): - ioctl_calls.append((fd, cmd, data)) + for substring in expected_substrings: + assert substring in str(exc_info.value) - with patch("os.getpgid", return_value=12345), \ - patch("os.killpg"), \ - patch("fcntl.ioctl", side_effect=mock_ioctl): - session._on_naws(25, 80) - session._on_naws(30, 90) - session._on_naws(50, 150) - await asyncio.sleep(0.25) +@pytest.mark.parametrize( + "child_pid,waitpid_behavior,expected", + [ + (None, None, False), + (99999, ChildProcessError, False), + (12345, (0, 0), True), + ], +) +async def test_pty_session_isalive_scenarios(child_pid, waitpid_behavior, expected): + """Test _isalive returns correct values for various child states.""" + # std imports + from unittest.mock import MagicMock, patch - assert len(ioctl_calls) == 1 - # std imports - import struct - import termios + # local + from telnetlib3.server_pty_shell import PTYSession - expected_winsize = struct.pack("HHHH", 50, 150, 0, 0) - assert ioctl_calls[0][2] == expected_winsize + reader = MagicMock() + writer = MagicMock() + writer.get_extra_info = MagicMock(return_value=None) + + session = PTYSession(reader, writer, "/nonexistent.program", []) + session.child_pid = child_pid + if waitpid_behavior is None: + assert session._isalive() is expected + elif isinstance(waitpid_behavior, type) and issubclass(waitpid_behavior, Exception): + with patch.object(os, "waitpid", side_effect=waitpid_behavior): + assert session._isalive() is expected + else: + with patch.object(os, "waitpid", return_value=waitpid_behavior): + assert session._isalive() is expected -async def test_pty_session_naws_cleanup_cancels_pending(): - """Test that cleanup cancels pending NAWS timer.""" + +async def test_pty_session_terminate_scenarios(): + """Test _terminate handles various termination scenarios.""" # std imports + import signal from unittest.mock import MagicMock, patch # local @@ -404,33 +768,45 @@ async def test_pty_session_naws_cleanup_cancels_pending(): reader = MagicMock() writer = MagicMock() - protocol = MagicMock() - writer.protocol = protocol writer.get_extra_info = MagicMock(return_value=None) - session = PTYSession(reader, writer, "/bin/sh", []) - session.master_fd = 1 + # Scenario 1: No child pid - returns True immediately + session = PTYSession(reader, writer, "/nonexistent.program", []) + session.child_pid = None + assert session._terminate() is True + + # Scenario 2: Child alive, sends signals, child dies + session = PTYSession(reader, writer, "/nonexistent.program", []) session.child_pid = 12345 + kill_calls = [] + isalive_calls = [True, True, False] - signal_calls = [] + def mock_kill(pid, sig): + kill_calls.append((pid, sig)) - def mock_killpg(pgid, sig): - # std imports - import signal as signal_mod + def mock_isalive(): + return isalive_calls.pop(0) if isalive_calls else False - if sig == signal_mod.SIGWINCH: - signal_calls.append((pgid, sig)) - - with patch("os.getpgid", return_value=12345), \ - patch("os.killpg", side_effect=mock_killpg), \ - patch("os.kill"), \ - patch("os.waitpid", return_value=(0, 0)), \ - patch("os.close"), \ - patch("fcntl.ioctl"): - session._on_naws(25, 80) + with patch.object(os, "kill", side_effect=mock_kill), patch.object( + session, "_isalive", side_effect=mock_isalive + ), patch("time.sleep"): + result = session._terminate() - session.cleanup() + assert result is True + assert len(kill_calls) >= 1 + assert kill_calls[0][1] == signal.SIGHUP - await asyncio.sleep(0.25) + # Scenario 3: ProcessLookupError - child already gone + session = PTYSession(reader, writer, "/nonexistent.program", []) + session.child_pid = 12345 + isalive_returns = [True] - assert len(signal_calls) == 0 + def mock_isalive_2(): + return isalive_returns.pop(0) if isalive_returns else False + + with patch.object(os, "kill", side_effect=ProcessLookupError), patch.object( + session, "_isalive", side_effect=mock_isalive_2 + ): + result = session._terminate() + + assert result is True diff --git a/telnetlib3/tests/test_server_shell_unit.py b/telnetlib3/tests/test_server_shell_unit.py index 1b5c5b10..b9d631ae 100644 --- a/telnetlib3/tests/test_server_shell_unit.py +++ b/telnetlib3/tests/test_server_shell_unit.py @@ -1,6 +1,7 @@ # std imports import sys import types +import asyncio # 3rd party import pytest @@ -8,6 +9,7 @@ # local from telnetlib3 import slc as slc_mod from telnetlib3 import client_shell as cs +from telnetlib3 import guard_shells as gs from telnetlib3 import server_shell as ss @@ -190,3 +192,415 @@ class TW: # cc changes assert new_mode.cc[t.VMIN] == 1 assert new_mode.cc[t.VTIME] == 0 + + +class MockReader: + def __init__(self, data): + self._data = list(data) + self._idx = 0 + + async def read(self, n): + if self._idx >= len(self._data): + return "" + result = self._data[self._idx] + self._idx += 1 + return result + + +class SlowReader: + async def read(self, n): + await asyncio.sleep(1.0) + return "" + + +class MockWriter: + def __init__(self): + self.written = [] + self._closing = False + self._extra = {"peername": ("127.0.0.1", 12345)} + + def write(self, data): + self.written.append(data) + + async def drain(self): + pass + + def get_extra_info(self, key, default=None): + return self._extra.get(key, default) + + def is_closing(self): + return self._closing + + def echo(self, data): + self.written.append(data) + + +@pytest.mark.parametrize( + "limit,acquires,expected_count,expected_results", + [ + pytest.param(1, 1, 1, [True], id="single_acquire"), + pytest.param(1, 2, 1, [True, False], id="over_limit"), + pytest.param(3, 3, 3, [True, True, True], id="at_limit"), + pytest.param(2, 3, 2, [True, True, False], id="over_limit_by_one"), + ], +) +def test_connection_counter_acquire(limit, acquires, expected_count, expected_results): + counter = gs.ConnectionCounter(limit) + results = [counter.try_acquire() for _ in range(acquires)] + assert results == expected_results + assert counter.count == expected_count + + +def test_connection_counter_release(): + counter = gs.ConnectionCounter(2) + assert counter.try_acquire() + assert counter.try_acquire() + assert counter.count == 2 + counter.release() + assert counter.count == 1 + counter.release() + counter.release() + assert counter.count == 0 + + +@pytest.mark.parametrize( + "input_data,max_len,expected", + [ + pytest.param("hi\r", 100, "hi", id="cr_terminator"), + pytest.param("hi\n", 100, "hi", id="lf_terminator"), + pytest.param("ab", 100, "ab", id="eof_no_terminator"), + pytest.param("", 100, "", id="empty_input"), + pytest.param("abcdefgh", 5, "abcde", id="truncated_at_max_len"), + ], +) +@pytest.mark.asyncio +async def test_read_line_inner(input_data, max_len, expected): + reader = MockReader(list(input_data)) + result = await gs._read_line_inner(reader, max_len) + assert result == expected + + +@pytest.mark.asyncio +async def test_read_line_with_timeout_success(): + reader = MockReader(list("hello\r")) + result = await gs._read_line(reader, timeout=5.0) + assert result == "hello" + + +@pytest.mark.asyncio +async def test_read_line_with_timeout_expires(): + result = await gs._read_line(SlowReader(), timeout=0.01) + assert result is None + + +@pytest.mark.asyncio +async def test_robot_shell_full_conversation(): + reader = MockReader(["y", "\r", "n", "o", "\r"]) + writer = MockWriter() + await gs.robot_shell(reader, writer) + written = "".join(writer.written) + assert "Do robots dream of electric sheep?" in written + assert "windowmakers" in written + + +@pytest.mark.asyncio +async def test_busy_shell_full_conversation(): + reader = MockReader(["h", "i", "\r", "x", "\r"]) + writer = MockWriter() + await gs.busy_shell(reader, writer) + written = "".join(writer.written) + assert "Machine is busy" in written + assert "distant explosion" in written + + +@pytest.mark.parametrize( + "input_chars,expected", + [ + pytest.param(["\x1b", "[", "A", "x"], "x", id="csi_sequence"), + pytest.param(["\x1b", "X"], "X", id="esc_non_bracket"), + pytest.param(["a"], "a", id="normal_char"), + pytest.param([""], "", id="eof"), + pytest.param(["\x1b", "[", "1", ";", "2", "H", "z"], "z", id="csi_with_params"), + pytest.param(["\x1b", "[", ""], "", id="csi_no_final_byte"), + ], +) +@pytest.mark.asyncio +async def test_filter_ansi(input_chars, expected): + reader = MockReader(input_chars) + writer = MockWriter() + result = await ss.filter_ansi(reader, writer) + assert result == expected + + +@pytest.mark.parametrize( + "input_chars,expected", + [ + pytest.param(["h", "e", "l", "l", "o", "\r"], "hello", id="basic"), + pytest.param(["h", "x", "\x7f", "i", "\r"], "hi", id="with_backspace"), + pytest.param(["\n", "\x00", "a", "\r"], "a", id="ignores_initial_lf_nul"), + pytest.param(["a", ""], None, id="returns_none_on_eof"), + ], +) +@pytest.mark.asyncio +async def test_readline2(input_chars, expected): + reader = MockReader(input_chars) + writer = MockWriter() + result = await ss.readline2(reader, writer) + assert result == expected + + +@pytest.mark.parametrize( + "input_chars,closing,expected", + [ + pytest.param(["a"], False, "a", id="normal"), + pytest.param(["\x1b", "A", "x"], False, "x", id="skips_escape"), + pytest.param([], True, None, id="returns_none_when_closing"), + ], +) +@pytest.mark.asyncio +async def test_get_next_ascii(input_chars, closing, expected): + reader = MockReader(input_chars) + writer = MockWriter() + writer._closing = closing + result = await ss.get_next_ascii(reader, writer) + assert result == expected + + +class CPRReader: + def __init__(self, data): + self._data = list(data) + self._idx = 0 + + async def read(self, n): + if self._idx >= len(self._data): + return b"" + result = self._data[self._idx] + self._idx += 1 + return result + + +@pytest.mark.parametrize( + "input_data,expected", + [ + pytest.param([b"\x1b", b"[", b"1", b"0", b";", b"2", b"0", b"R"], (10, 20), id="valid_cpr"), + pytest.param([b"\x1b", b"[", b"1", b";", b"1", b"R"], (1, 1), id="single_digit"), + pytest.param( + [b"\x1b", b"[", b"2", b"5", b";", b"8", b"0", b"R"], (25, 80), id="typical_size" + ), + pytest.param([b""], None, id="eof"), + pytest.param( + [b"g", b"a", b"r", b"\x1b", b"[", b"5", b";", b"3", b"R"], (5, 3), id="garbage_prefix" + ), + ], +) +@pytest.mark.asyncio +async def test_read_cpr_response(input_data, expected): + reader = CPRReader(input_data) + result = await gs._read_cpr_response(reader) + assert result == expected + + +@pytest.mark.asyncio +async def test_read_cpr_response_string_input(): + class StringReader: + def __init__(self): + self._data = list("\x1b[5;10R") + self._idx = 0 + + async def read(self, n): + if self._idx >= len(self._data): + return "" + result = self._data[self._idx] + self._idx += 1 + return result + + reader = StringReader() + result = await gs._read_cpr_response(reader) + assert result == (5, 10) + + +class CPRMockWriter: + def __init__(self): + self.written = [] + + def write(self, data): + self.written.append(data) + + async def drain(self): + pass + + +@pytest.mark.asyncio +async def test_get_cursor_position_success(): + reader = CPRReader([b"\x1b", b"[", b"1", b"0", b";", b"2", b"0", b"R"]) + writer = CPRMockWriter() + result = await gs._get_cursor_position(reader, writer, timeout=1.0) + assert result == (10, 20) + assert "\x1b[6n" in writer.written + + +@pytest.mark.asyncio +async def test_get_cursor_position_timeout(): + result = await gs._get_cursor_position(SlowReader(), CPRMockWriter(), timeout=0.01) + assert result == (None, None) + + +@pytest.mark.asyncio +async def test_get_cursor_position_eof(): + reader = CPRReader([b""]) + writer = CPRMockWriter() + result = await gs._get_cursor_position(reader, writer, timeout=1.0) + assert result == (None, None) + + +@pytest.mark.asyncio +async def test_measure_width_success(monkeypatch): + positions = iter([(1, 5), (1, 7)]) + + async def mock_get_cursor_position(reader, writer, timeout): + return next(positions) + + monkeypatch.setattr(gs, "_get_cursor_position", mock_get_cursor_position) + writer = CPRMockWriter() + result = await gs._measure_width(None, writer, "ab", timeout=1.0) + assert result == 2 + assert any("\x1b[5G" in w for w in writer.written) + + +@pytest.mark.asyncio +async def test_measure_width_first_cpr_fails(monkeypatch): + async def mock_get_cursor_position(reader, writer, timeout): + return (None, None) + + monkeypatch.setattr(gs, "_get_cursor_position", mock_get_cursor_position) + result = await gs._measure_width(None, CPRMockWriter(), "x", timeout=1.0) + assert result is None + + +@pytest.mark.asyncio +async def test_measure_width_second_cpr_fails(monkeypatch): + call_count = [0] + + async def mock_get_cursor_position(reader, writer, timeout): + call_count[0] += 1 + if call_count[0] == 1: + return (1, 5) + return (None, None) + + monkeypatch.setattr(gs, "_get_cursor_position", mock_get_cursor_position) + result = await gs._measure_width(None, CPRMockWriter(), "x", timeout=1.0) + assert result is None + + +@pytest.mark.asyncio +async def test_robot_check_returns_true_when_width_is_2(monkeypatch): + async def mock_measure_width(reader, writer, text, timeout): + return 2 + + monkeypatch.setattr(gs, "_measure_width", mock_measure_width) + result = await gs.robot_check(None, None, timeout=1.0) + assert result is True + + +@pytest.mark.asyncio +async def test_robot_check_returns_false_when_width_is_not_2(monkeypatch): + async def mock_measure_width(reader, writer, text, timeout): + return 1 + + monkeypatch.setattr(gs, "_measure_width", mock_measure_width) + result = await gs.robot_check(None, None, timeout=1.0) + assert result is False + + +@pytest.mark.asyncio +async def test_robot_check_returns_false_when_width_is_none(monkeypatch): + async def mock_measure_width(reader, writer, text, timeout): + return None + + monkeypatch.setattr(gs, "_measure_width", mock_measure_width) + result = await gs.robot_check(None, None, timeout=1.0) + assert result is False + + +@pytest.mark.asyncio +async def test_robot_shell_timeout_on_first_question(monkeypatch): + call_count = [0] + original_read_line = gs._read_line + + async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): + call_count[0] += 1 + if call_count[0] == 1: + return None + return await original_read_line(reader, timeout, max_len) + + monkeypatch.setattr(gs, "_read_line", mock_read_line) + + writer = MockWriter() + await gs.robot_shell(MockReader([]), writer) + written = "".join(writer.written) + assert "Do robots dream of electric sheep?" in written + assert call_count[0] == 1 + + +@pytest.mark.asyncio +async def test_robot_shell_timeout_on_second_question(monkeypatch): + call_count = [0] + original_read_line = gs._read_line + + async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): + call_count[0] += 1 + if call_count[0] == 1: + return "y" + if call_count[0] == 2: + return None + return await original_read_line(reader, timeout, max_len) + + monkeypatch.setattr(gs, "_read_line", mock_read_line) + + writer = MockWriter() + await gs.robot_shell(MockReader([]), writer) + written = "".join(writer.written) + assert "Do robots dream of electric sheep?" in written + assert "windowmakers" in written + assert call_count[0] == 2 + + +@pytest.mark.asyncio +async def test_busy_shell_timeout_on_first_input(monkeypatch): + call_count = [0] + original_read_line = gs._read_line + + async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): + call_count[0] += 1 + if call_count[0] == 1: + return None + return await original_read_line(reader, timeout, max_len) + + monkeypatch.setattr(gs, "_read_line", mock_read_line) + + writer = MockWriter() + await gs.busy_shell(MockReader([]), writer) + written = "".join(writer.written) + assert "Machine is busy" in written + + +@pytest.mark.asyncio +async def test_busy_shell_timeout_on_second_input(monkeypatch): + call_count = [0] + original_read_line = gs._read_line + + async def mock_read_line(reader, timeout, max_len=gs._MAX_INPUT): + call_count[0] += 1 + if call_count[0] == 1: + return "hi" + if call_count[0] == 2: + return None + return await original_read_line(reader, timeout, max_len) + + monkeypatch.setattr(gs, "_read_line", mock_read_line) + + writer = MockWriter() + await gs.busy_shell(MockReader([]), writer) + written = "".join(writer.written) + assert "Machine is busy" in written + assert "distant explosion" in written diff --git a/telnetlib3/tests/test_shell.py b/telnetlib3/tests/test_shell.py index 266c2ebf..a248727e 100644 --- a/telnetlib3/tests/test_shell.py +++ b/telnetlib3/tests/test_shell.py @@ -314,3 +314,205 @@ async def test_telnet_server_shell_eof(bind_host, unused_tcp_port): # Wait for server to process client disconnect await asyncio.sleep(0.05) assert srv_instance._closing + + +async def test_telnet_server_shell_version_command(bind_host, unused_tcp_port): + """Test version command in telnet_server_shell.""" + # local + from telnetlib3 import accessories, telnet_server_shell + from telnetlib3.telopt import DO, IAC, WONT, TTYPE + from telnetlib3.tests.accessories import create_server, asyncio_connection + + async with create_server( + host=bind_host, + port=unused_tcp_port, + shell=telnet_server_shell, + connect_maxwait=0.05, + ): + async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): + expected = IAC + DO + TTYPE + result = await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + assert result == expected + + writer.write(IAC + WONT + TTYPE) + + expected = b"Ready.\r\ntel:sh> " + result = await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + assert result == expected + + writer.write(b"version\r") + await asyncio.sleep(0.05) + + result = b"" + while True: + try: + chunk = await asyncio.wait_for(reader.read(100), 0.2) + if not chunk: + break + result += chunk + if b"tel:sh>" in result: + break + except asyncio.TimeoutError: + break + + expected_version = accessories.get_version() + assert expected_version.encode("ascii") in result + + +async def test_telnet_server_shell_dump_with_kb_limit(bind_host, unused_tcp_port): + """Test dump command with explicit kb_limit.""" + # local + from telnetlib3 import telnet_server_shell + from telnetlib3.telopt import DO, IAC, WONT, TTYPE + from telnetlib3.tests.accessories import create_server, asyncio_connection + + async with create_server( + host=bind_host, + port=unused_tcp_port, + shell=telnet_server_shell, + connect_maxwait=0.05, + ): + async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): + expected = IAC + DO + TTYPE + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + writer.write(IAC + WONT + TTYPE) + + expected = b"Ready.\r\ntel:sh> " + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + + writer.write(b"dump 0\r") + await asyncio.sleep(0.05) + + result = b"" + while True: + try: + chunk = await asyncio.wait_for(reader.read(200), 0.2) + if not chunk: + break + result += chunk + if b"wrote 0 bytes" in result: + break + except asyncio.TimeoutError: + break + + assert b"kb_limit=0" in result + assert b"wrote 0 bytes" in result + + +async def test_telnet_server_shell_dump_with_all_options(bind_host, unused_tcp_port): + """Test dump command with all options including close.""" + # local + from telnetlib3 import telnet_server_shell + from telnetlib3.telopt import DO, IAC, WONT, TTYPE + from telnetlib3.tests.accessories import create_server, asyncio_connection + + async with create_server( + host=bind_host, + port=unused_tcp_port, + shell=telnet_server_shell, + connect_maxwait=0.05, + ): + async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): + expected = IAC + DO + TTYPE + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + writer.write(IAC + WONT + TTYPE) + + expected = b"Ready.\r\ntel:sh> " + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + + writer.write(b"dump 0 0 nodrain close\r") + await asyncio.sleep(0.05) + + result = b"" + while True: + try: + chunk = await asyncio.wait_for(reader.read(300), 0.2) + if not chunk: + break + result += chunk + except asyncio.TimeoutError: + break + + assert b"kb_limit=0" in result + assert b"do_close=True" in result + assert b"drain=True" in result + + +async def test_telnet_server_shell_dump_nodrain(bind_host, unused_tcp_port): + """Test dump command with nodrain option.""" + # local + from telnetlib3 import telnet_server_shell + from telnetlib3.telopt import DO, IAC, WONT, TTYPE + from telnetlib3.tests.accessories import create_server, asyncio_connection + + async with create_server( + host=bind_host, + port=unused_tcp_port, + shell=telnet_server_shell, + connect_maxwait=0.05, + ): + async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): + expected = IAC + DO + TTYPE + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + writer.write(IAC + WONT + TTYPE) + + expected = b"Ready.\r\ntel:sh> " + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + + writer.write(b"dump 0 0 drain\r") + await asyncio.sleep(0.05) + + result = b"" + while True: + try: + chunk = await asyncio.wait_for(reader.read(200), 0.2) + if not chunk: + break + result += chunk + if b"drain=False" in result: + break + except asyncio.TimeoutError: + break + + assert b"kb_limit=0" in result + assert b"drain=False" in result + + +async def test_telnet_server_shell_dump_large_output(bind_host, unused_tcp_port): + """Test dump command with larger output.""" + # local + from telnetlib3 import telnet_server_shell + from telnetlib3.telopt import DO, IAC, WONT, TTYPE + from telnetlib3.tests.accessories import create_server, asyncio_connection + + async with create_server( + host=bind_host, + port=unused_tcp_port, + shell=telnet_server_shell, + connect_maxwait=0.05, + ): + async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): + expected = IAC + DO + TTYPE + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + writer.write(IAC + WONT + TTYPE) + + expected = b"Ready.\r\ntel:sh> " + await asyncio.wait_for(reader.readexactly(len(expected)), 0.5) + + writer.write(b"dump 1\r") + await asyncio.sleep(0.05) + + result = b"" + while True: + try: + chunk = await asyncio.wait_for(reader.read(4096), 0.5) + if not chunk: + break + result += chunk + if b"wrote" in result and b"bytes" in result: + break + except asyncio.TimeoutError: + break + + assert b"kb_limit=1" in result + assert b"/" in result or b"\\" in result diff --git a/telnetlib3/tests/test_status_logger.py b/telnetlib3/tests/test_status_logger.py index f2f50eb2..ab3eeb7f 100644 --- a/telnetlib3/tests/test_status_logger.py +++ b/telnetlib3/tests/test_status_logger.py @@ -62,7 +62,7 @@ async def test_status_logger_get_status(bind_host, unused_tcp_port): status_logger = StatusLogger(server, 60) status = status_logger._get_status() assert status["count"] == 0 - assert status["clients"] == [] + assert not status["clients"] async with asyncio_connection(bind_host, unused_tcp_port) as (reader, writer): writer.write(IAC + WONT + TTYPE) diff --git a/telnetlib3/tests/test_sync.py b/telnetlib3/tests/test_sync.py index dbc691ff..5c90e704 100644 --- a/telnetlib3/tests/test_sync.py +++ b/telnetlib3/tests/test_sync.py @@ -401,3 +401,193 @@ def test_client_get_extra_info(bind_host, unused_tcp_port): assert conn.get_extra_info("nonexistent", "default") == "default" server.shutdown() + + +def test_client_operations_after_close_raise(bind_host, unused_tcp_port): + """Operations fail after connection is closed.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + + conn = TelnetConnection(bind_host, unused_tcp_port, timeout=5) + conn.connect() + conn.close() + + with pytest.raises(RuntimeError, match="Connection closed"): + conn.read() + with pytest.raises(RuntimeError, match="Connection closed"): + conn.readline() + with pytest.raises(RuntimeError, match="Connection closed"): + conn.write("test") + with pytest.raises(RuntimeError, match="Connection closed"): + conn.flush() + with pytest.raises(RuntimeError, match="Connection closed"): + conn.get_extra_info("peername") + with pytest.raises(RuntimeError, match="Connection closed"): + conn.wait_for(remote={"NAWS": True}) + + server.shutdown() + + +def test_client_read_timeout(bind_host, unused_tcp_port): + """TelnetConnection.read times out when no data available.""" + + def handler(server_conn): + time.sleep(5) + + server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + time.sleep(0.1) + + with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: + with pytest.raises(TimeoutError, match="Read timed out"): + conn.read(1, timeout=0.1) + + server.shutdown() + + +def test_client_readline_timeout(bind_host, unused_tcp_port): + """TelnetConnection.readline times out when no line available.""" + + def handler(server_conn): + server_conn.write("no newline here") + server_conn.flush(timeout=5) + time.sleep(5) + + server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + time.sleep(0.1) + + with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: + with pytest.raises(TimeoutError, match="Readline timed out"): + conn.readline(timeout=0.2) + + server.shutdown() + + +@pytest.mark.parametrize( + "method,args,error_match", + [ + pytest.param("read", (1,), "Read timed out", id="read"), + pytest.param("readline", (), "Readline timed out", id="readline"), + ], +) +def test_server_connection_timeout(bind_host, unused_tcp_port, method, args, error_match): + """ServerConnection methods time out when no data available.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + + def client_thread(): + time.sleep(0.1) + with TelnetConnection(bind_host, unused_tcp_port, timeout=5): + time.sleep(2) + + thread = threading.Thread(target=client_thread, daemon=True) + thread.start() + + conn = server.accept(timeout=5) + with pytest.raises(TimeoutError, match=error_match): + getattr(conn, method)(*args, timeout=0.1) + conn.close() + server.shutdown() + + +def test_server_connection_read_until_timeout(bind_host, unused_tcp_port): + """ServerConnection.read_until times out when match not found.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + + def client_thread(): + time.sleep(0.1) + with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: + conn.write("no match here") + conn.flush() + time.sleep(2) + + thread = threading.Thread(target=client_thread, daemon=True) + thread.start() + + conn = server.accept(timeout=5) + with pytest.raises(TimeoutError, match="Read until timed out"): + conn.read_until(">>> ", timeout=0.2) + conn.close() + server.shutdown() + + +def test_server_connection_wait_for_timeout(bind_host, unused_tcp_port): + """ServerConnection.wait_for times out when conditions not met.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + + def client_thread(): + time.sleep(0.1) + with TelnetConnection(bind_host, unused_tcp_port, timeout=5): + time.sleep(1.0) + + thread = threading.Thread(target=client_thread, daemon=True) + thread.start() + + conn = server.accept(timeout=5) + with pytest.raises(TimeoutError, match="Wait for negotiation timed out"): + conn.wait_for(remote={"LINEMODE": True}, timeout=0.1) + conn.close() + server.shutdown() + + +@pytest.mark.parametrize( + "method,args,kwargs", + [ + pytest.param("wait_for", (), {"remote": {"NAWS": True}}, id="wait_for"), + pytest.param("read_until", (">>> ",), {}, id="read_until"), + pytest.param("flush", (), {}, id="flush"), + pytest.param("readline", (), {}, id="readline"), + ], +) +def test_server_connection_methods_closed_error(bind_host, unused_tcp_port, method, args, kwargs): + """ServerConnection methods raise RuntimeError when called after close.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + + def client_thread(): + time.sleep(0.1) + with TelnetConnection(bind_host, unused_tcp_port, timeout=5): + time.sleep(0.2) + + thread = threading.Thread(target=client_thread, daemon=True) + thread.start() + + conn = server.accept(timeout=5) + conn.close() + with pytest.raises(RuntimeError, match="Connection closed"): + getattr(conn, method)(*args, **kwargs) + server.shutdown() + + +def test_server_already_started_error(bind_host, unused_tcp_port): + """Server start raises if already started.""" + server = BlockingTelnetServer(bind_host, unused_tcp_port) + server.start() + with pytest.raises(RuntimeError, match="Server already started"): + server.start() + server.shutdown() + + +def test_client_read_until_eof(bind_host, unused_tcp_port): + """TelnetConnection.read_until raises EOFError when connection closes before match.""" + + def handler(server_conn): + server_conn.write("partial data") + server_conn.flush(timeout=5) + server_conn.close() + + server = BlockingTelnetServer(bind_host, unused_tcp_port, handler=handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + time.sleep(0.1) + + with TelnetConnection(bind_host, unused_tcp_port, timeout=5) as conn: + with pytest.raises(EOFError, match="Connection closed before match found"): + conn.read_until(">>> ", timeout=2) + + server.shutdown() diff --git a/telnetlib3/tests/test_telnetlib.py b/telnetlib3/tests/test_telnetlib.py index 12903452..ba439d63 100644 --- a/telnetlib3/tests/test_telnetlib.py +++ b/telnetlib3/tests/test_telnetlib.py @@ -1,6 +1,6 @@ # jdq(2025): This file was modified from cpython 3.12 test_telnetlib.py, to make it compatible # with more versions of python, and, to use pytest instead of unittest. -# pylint: disable-all +# pylint: skip-file # std imports import io import re diff --git a/telnetlib3/tests/test_tspeed.py b/telnetlib3/tests/test_tspeed.py index 721a5664..85d992d9 100644 --- a/telnetlib3/tests/test_tspeed.py +++ b/telnetlib3/tests/test_tspeed.py @@ -33,7 +33,7 @@ def on_tspeed(self, rx, tx): writer.write(IAC + WILL + TSPEED) writer.write(IAC + SB + TSPEED + IS + b"123,456" + IAC + SE) - srv_instance = await asyncio.wait_for(_waiter, 0.5) + srv_instance = await asyncio.wait_for(_waiter, 3.0) assert srv_instance.get_extra_info("tspeed") == "123,456" @@ -66,6 +66,6 @@ def begin_advanced_negotiation(self): tspeed=(given_rx, given_tx), connect_minwait=0.05, ) as (reader, writer): - recv_rx, recv_tx = await asyncio.wait_for(_waiter, 0.5) + recv_rx, recv_tx = await asyncio.wait_for(_waiter, 3.0) assert recv_rx == given_rx assert recv_tx == given_tx diff --git a/telnetlib3/tests/test_writer.py b/telnetlib3/tests/test_writer.py index 5fe1d045..4d96064d 100644 --- a/telnetlib3/tests/test_writer.py +++ b/telnetlib3/tests/test_writer.py @@ -230,11 +230,11 @@ async def test_send_iac_dont_dont(bind_host, unused_tcp_port): result = client_writer.iac(DONT, ECHO) assert result is False - srv_instance = await asyncio.wait_for(server.wait_for_client(), 0.5) + srv_instance = await asyncio.wait_for(server.wait_for_client(), 3.0) server_writer = srv_instance.writer # Wait for server to process client disconnect - await asyncio.sleep(0.05) + await asyncio.sleep(0.1) assert client_writer.remote_option[ECHO] is False, client_writer.remote_option assert server_writer.local_option[ECHO] is False, server_writer.local_option diff --git a/tox.ini b/tox.ini index 275f2f19..acd9c978 100644 --- a/tox.ini +++ b/tox.ini @@ -200,6 +200,7 @@ branch = True parallel = True source = telnetlib3 omit = telnetlib3/tests/* + telnetlib3/telnetlib.py relative_files = True [coverage:report] @@ -207,6 +208,7 @@ precision = 1 exclude_lines = pragma: no cover omit = telnetlib3/tests/* + telnetlib3/telnetlib.py [coverage:paths] source = telnetlib3/ @@ -247,9 +249,11 @@ norecursedirs = .git .tox build asyncio_mode = auto log_level = debug log_format = %(levelname)8s %(filename)s:%(lineno)s %(message)s +# PTY tests require --capture=no to function (fork/pty breaks with capture) addopts = --strict-markers --verbose + --capture=no --color=yes --cov --cov-append