From 81d617a4f8e2cc24d99d02671c8edf1c03608096 Mon Sep 17 00:00:00 2001 From: Marco Melloni <98281551+marcomelloni@users.noreply.github.com> Date: Wed, 23 Jul 2025 10:23:02 +0200 Subject: [PATCH] refactor interactive module --- .../matlab/matlab_agent/docs/INTERACTIVE.md | 12 ++++ .../matlab_agent/src/core/interactive.py | 66 +++++++++++++++---- .../matlab_agent/src/utils/constants.py | 15 +++++ 3 files changed, 79 insertions(+), 14 deletions(-) create mode 100644 agents/matlab/matlab_agent/docs/INTERACTIVE.md create mode 100644 agents/matlab/matlab_agent/src/utils/constants.py diff --git a/agents/matlab/matlab_agent/docs/INTERACTIVE.md b/agents/matlab/matlab_agent/docs/INTERACTIVE.md new file mode 100644 index 00000000..278a0a4f --- /dev/null +++ b/agents/matlab/matlab_agent/docs/INTERACTIVE.md @@ -0,0 +1,12 @@ +# Interactive Frame Format + +Frames exchanged during an interactive simulation are YAML documents with the following structure: + +```yaml +simulation: + inputs: + : +``` + +The `inputs` section contains key/value pairs that are provided to the MATLAB simulation. Each frame sent from the message broker to the agent should conform to this format. + diff --git a/agents/matlab/matlab_agent/src/core/interactive.py b/agents/matlab/matlab_agent/src/core/interactive.py index 2613b8d7..7b199738 100644 --- a/agents/matlab/matlab_agent/src/core/interactive.py +++ b/agents/matlab/matlab_agent/src/core/interactive.py @@ -12,20 +12,33 @@ import psutil import yaml +import shlex from ..comm.interfaces import IMessageBroker from ..utils.create_response import create_response from ..utils.logger import get_logger from ..utils.performance_monitor import PerformanceMonitor +from ..utils.constants import ( + ACCEPT_TIMEOUT, + BUFFER_SIZE, + DEFAULT_INPUT_PORT, + DEFAULT_OUTPUT_PORT, + BYTES_IN_MB, +) logger = get_logger() def _parse_frame(body: bytes) -> Dict[str, Any]: - """Decode a YAML frame received from RabbitMQ.""" + """Decode a YAML frame received from RabbitMQ. + + The frame is expected to be a YAML encoded dictionary describing the + simulation inputs. Invalid YAML results in an empty dictionary being + returned and a log entry emitted. + """ try: return yaml.safe_load(body) - except Exception as exc: # pragma: no cover - logging only + except yaml.YAMLError as exc: # pragma: no cover - logging only logger.error("[INTERACTIVE] Bad frame: %s", exc) return {} @@ -48,7 +61,7 @@ def start(self) -> None: def accept(self) -> None: if not self._srv: raise RuntimeError("Server not started") - ready = select([self._srv], [], [], 60) + ready = select([self._srv], [], [], ACCEPT_TIMEOUT) if ready[0]: self._conn, _ = self._srv.accept() self._conn.setblocking(False) @@ -66,7 +79,7 @@ def recv_all(self) -> list[Dict[str, Any]]: if not self._conn or not select([self._conn], [], [], 0)[0]: return [] - self._buffer += self._conn.recv(4096) + self._buffer += self._conn.recv(BUFFER_SIZE) lines = self._buffer.split(b"\n") self._buffer = lines[-1] messages: list[Dict[str, Any]] = [] @@ -76,8 +89,9 @@ def recv_all(self) -> list[Dict[str, Any]]: continue try: messages.append(json.loads(line.decode())) - except json.JSONDecodeError as exc: # pragma: no cover - logging only + except json.JSONDecodeError as exc: # pragma: no cover - logs error and skips invalid message logger.error("[INTERACTIVE] Invalid JSON: %s", exc) + messages.append({"error": f"Invalid JSON: {str(exc)}"}) return messages def close(self) -> None: @@ -107,6 +121,8 @@ def __init__( agent_id: str = "agent", ) -> None: self.sim_path = Path(path).resolve() + if len(file) > 100: + raise ValueError("Simulation file name too long") self.sim_file = file if not (self.sim_path / self.sim_file).exists(): raise FileNotFoundError(self.sim_file) @@ -120,11 +136,11 @@ def __init__( self.out_srv = _TcpServer( tcp_cfg.get("host", "localhost"), - tcp_cfg.get("output_port", 5678), + tcp_cfg.get("output_port", DEFAULT_OUTPUT_PORT), ) self.in_srv = _TcpServer( tcp_cfg.get("host", "localhost"), - tcp_cfg.get("input_port", 5679), + tcp_cfg.get("input_port", DEFAULT_INPUT_PORT), ) self.start_time: Optional[float] = None @@ -132,13 +148,12 @@ def __init__( # ------------------------------------------------------------------ def _start_matlab(self) -> None: + safe_path = shlex.quote(str(self.sim_path)) + safe_file = shlex.quote(self.sim_file) cmd = [ "matlab", "-batch", - f"addpath('{ - self.sim_path}');cd('{ - self.sim_path}');run('{ - self.sim_file}');", + f"addpath('{safe_path}');cd('{safe_path}');run('{safe_file}');", ] self.out_srv.matlab_proc = subprocess.Popen( cmd, @@ -187,7 +202,7 @@ def _relay(self, payload: Dict[str, Any]) -> None: @staticmethod def _only_inputs(frame: Dict[str, Any]) -> Dict[str, Any]: - """Extract only the inputs from the frame.""" + """Extract only the ``inputs`` section from a simulation frame.""" if isinstance(frame, dict): sim = frame.get("simulation") if isinstance(sim, dict) and "inputs" in sim: @@ -216,6 +231,9 @@ def run(self, pm: PerformanceMonitor, msg_dict: Dict[str, Any]) -> None: try: while True: + if self.out_srv.matlab_proc and self.out_srv.matlab_proc.poll() is not None: + logger.debug("[INTERACTIVE] MATLAB process ended, stopping loop") + break method, properties, body = ch.basic_get( queue=qname, auto_ack=True) while method: @@ -228,8 +246,14 @@ def run(self, pm: PerformanceMonitor, msg_dict: Dict[str, Any]) -> None: # Receive Responses from MATLAB for resp in self.out_srv.recv_all(): + if resp.get("status") == "completed": + self._relay(resp) + logger.debug("[INTERACTIVE] Received completion signal") + return # Send the response to the broker self._relay(resp) + except KeyboardInterrupt: # pragma: no cover - manual interruption + logger.info("[INTERACTIVE] Interrupted by user") finally: pm.record_simulation_complete() @@ -243,7 +267,7 @@ def metadata(self) -> Dict[str, Any]: if self.start_time: meta["execution_time"] = time.time() - self.start_time meta["memory_usage"] = psutil.Process( - ).memory_info().rss // (1024 * 1024) + ).memory_info().rss // BYTES_IN_MB return meta @@ -275,7 +299,7 @@ def handle_interactive_simulation( try: controller.start(pm) controller.run(pm, msg_dict) - except Exception as exc: # pragma: no cover - runtime errors + except (KeyError, ValueError, RuntimeError) as exc: # pragma: no cover - handled errors logger.error("[INTERACTIVE] Fatal: %s", exc) rabbitmq_manager.send_result( source, @@ -289,6 +313,20 @@ def handle_interactive_simulation( error={"message": str(exc), "type": "execution_error"}, ), ) + except Exception as exc: # pragma: no cover - unexpected errors + logger.exception("[INTERACTIVE] Unexpected error: %s", exc) + rabbitmq_manager.send_result( + source, + create_response( + "error", + sim.get("file", ""), + "interactive", + response_templates, + bridge_meta=sim.get("bridge_meta", "unknown"), + request_id=sim.get("request_id", "unknown"), + error={"message": str(exc), "type": "execution_error"}, + ), + ) finally: pm.complete_operation() controller.close() diff --git a/agents/matlab/matlab_agent/src/utils/constants.py b/agents/matlab/matlab_agent/src/utils/constants.py new file mode 100644 index 00000000..49b974d2 --- /dev/null +++ b/agents/matlab/matlab_agent/src/utils/constants.py @@ -0,0 +1,15 @@ +"""Common constants used across the MATLAB agent.""" + +# Timeouts +ACCEPT_TIMEOUT = 60 # seconds to wait for TCP connection + +# Network buffer size +BUFFER_SIZE = 4096 + +# Default TCP ports +DEFAULT_OUTPUT_PORT = 5678 +DEFAULT_INPUT_PORT = 5679 + +# Memory usage divisor for converting bytes to MB +BYTES_IN_MB = 1024 * 1024 +