Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions agents/matlab/matlab_agent/docs/INTERACTIVE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Interactive Frame Format

Frames exchanged during an interactive simulation are YAML documents with the following structure:

```yaml
simulation:
inputs:
<key>: <value>
```

The `inputs` section contains key/value pairs that are provided to the MATLAB simulation. Each frame sent from the message broker to the agent should conform to this format.

66 changes: 52 additions & 14 deletions agents/matlab/matlab_agent/src/core/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,33 @@

import psutil
import yaml
import shlex

from ..comm.interfaces import IMessageBroker
from ..utils.create_response import create_response
from ..utils.logger import get_logger
from ..utils.performance_monitor import PerformanceMonitor
from ..utils.constants import (
ACCEPT_TIMEOUT,
BUFFER_SIZE,
DEFAULT_INPUT_PORT,
DEFAULT_OUTPUT_PORT,
BYTES_IN_MB,
)

logger = get_logger()


def _parse_frame(body: bytes) -> Dict[str, Any]:
"""Decode a YAML frame received from RabbitMQ."""
"""Decode a YAML frame received from RabbitMQ.

The frame is expected to be a YAML encoded dictionary describing the
simulation inputs. Invalid YAML results in an empty dictionary being
returned and a log entry emitted.
"""
try:
return yaml.safe_load(body)
except Exception as exc: # pragma: no cover - logging only
except yaml.YAMLError as exc: # pragma: no cover - logging only
logger.error("[INTERACTIVE] Bad frame: %s", exc)
return {}

Expand All @@ -48,7 +61,7 @@ def start(self) -> None:
def accept(self) -> None:
if not self._srv:
raise RuntimeError("Server not started")
ready = select([self._srv], [], [], 60)
ready = select([self._srv], [], [], ACCEPT_TIMEOUT)
if ready[0]:
self._conn, _ = self._srv.accept()
self._conn.setblocking(False)
Expand All @@ -66,7 +79,7 @@ def recv_all(self) -> list[Dict[str, Any]]:
if not self._conn or not select([self._conn], [], [], 0)[0]:
return []

self._buffer += self._conn.recv(4096)
self._buffer += self._conn.recv(BUFFER_SIZE)
lines = self._buffer.split(b"\n")
self._buffer = lines[-1]
messages: list[Dict[str, Any]] = []
Expand All @@ -76,8 +89,9 @@ def recv_all(self) -> list[Dict[str, Any]]:
continue
try:
messages.append(json.loads(line.decode()))
except json.JSONDecodeError as exc: # pragma: no cover - logging only
except json.JSONDecodeError as exc: # pragma: no cover - logs error and skips invalid message
logger.error("[INTERACTIVE] Invalid JSON: %s", exc)
messages.append({"error": f"Invalid JSON: {str(exc)}"})
return messages

def close(self) -> None:
Expand Down Expand Up @@ -107,6 +121,8 @@ def __init__(
agent_id: str = "agent",
) -> None:
self.sim_path = Path(path).resolve()
if len(file) > 100:
raise ValueError("Simulation file name too long")
self.sim_file = file
if not (self.sim_path / self.sim_file).exists():
raise FileNotFoundError(self.sim_file)
Expand All @@ -120,25 +136,24 @@ def __init__(

self.out_srv = _TcpServer(
tcp_cfg.get("host", "localhost"),
tcp_cfg.get("output_port", 5678),
tcp_cfg.get("output_port", DEFAULT_OUTPUT_PORT),
)
self.in_srv = _TcpServer(
tcp_cfg.get("host", "localhost"),
tcp_cfg.get("input_port", 5679),
tcp_cfg.get("input_port", DEFAULT_INPUT_PORT),
)

self.start_time: Optional[float] = None
self.sequence = 0

# ------------------------------------------------------------------
def _start_matlab(self) -> None:
safe_path = shlex.quote(str(self.sim_path))
safe_file = shlex.quote(self.sim_file)
cmd = [
"matlab",
"-batch",
f"addpath('{
self.sim_path}');cd('{
self.sim_path}');run('{
self.sim_file}');",
f"addpath('{safe_path}');cd('{safe_path}');run('{safe_file}');",
]
self.out_srv.matlab_proc = subprocess.Popen(
cmd,
Expand Down Expand Up @@ -187,7 +202,7 @@ def _relay(self, payload: Dict[str, Any]) -> None:

@staticmethod
def _only_inputs(frame: Dict[str, Any]) -> Dict[str, Any]:
"""Extract only the inputs from the frame."""
"""Extract only the ``inputs`` section from a simulation frame."""
if isinstance(frame, dict):
sim = frame.get("simulation")
if isinstance(sim, dict) and "inputs" in sim:
Expand Down Expand Up @@ -216,6 +231,9 @@ def run(self, pm: PerformanceMonitor, msg_dict: Dict[str, Any]) -> None:

try:
while True:
if self.out_srv.matlab_proc and self.out_srv.matlab_proc.poll() is not None:
logger.debug("[INTERACTIVE] MATLAB process ended, stopping loop")
break
method, properties, body = ch.basic_get(
queue=qname, auto_ack=True)
while method:
Expand All @@ -228,8 +246,14 @@ def run(self, pm: PerformanceMonitor, msg_dict: Dict[str, Any]) -> None:

# Receive Responses from MATLAB
for resp in self.out_srv.recv_all():
if resp.get("status") == "completed":
self._relay(resp)
logger.debug("[INTERACTIVE] Received completion signal")
return
# Send the response to the broker
self._relay(resp)
except KeyboardInterrupt: # pragma: no cover - manual interruption
logger.info("[INTERACTIVE] Interrupted by user")
finally:
pm.record_simulation_complete()

Expand All @@ -243,7 +267,7 @@ def metadata(self) -> Dict[str, Any]:
if self.start_time:
meta["execution_time"] = time.time() - self.start_time
meta["memory_usage"] = psutil.Process(
).memory_info().rss // (1024 * 1024)
).memory_info().rss // BYTES_IN_MB
return meta


Expand Down Expand Up @@ -275,7 +299,7 @@ def handle_interactive_simulation(
try:
controller.start(pm)
controller.run(pm, msg_dict)
except Exception as exc: # pragma: no cover - runtime errors
except (KeyError, ValueError, RuntimeError) as exc: # pragma: no cover - handled errors
logger.error("[INTERACTIVE] Fatal: %s", exc)
rabbitmq_manager.send_result(
source,
Expand All @@ -289,6 +313,20 @@ def handle_interactive_simulation(
error={"message": str(exc), "type": "execution_error"},
),
)
except Exception as exc: # pragma: no cover - unexpected errors
logger.exception("[INTERACTIVE] Unexpected error: %s", exc)
rabbitmq_manager.send_result(
source,
create_response(
"error",
sim.get("file", ""),
"interactive",
response_templates,
bridge_meta=sim.get("bridge_meta", "unknown"),
request_id=sim.get("request_id", "unknown"),
error={"message": str(exc), "type": "execution_error"},
),
)
finally:
pm.complete_operation()
controller.close()
15 changes: 15 additions & 0 deletions agents/matlab/matlab_agent/src/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Common constants used across the MATLAB agent."""

# Timeouts
ACCEPT_TIMEOUT = 60 # seconds to wait for TCP connection

# Network buffer size
BUFFER_SIZE = 4096

# Default TCP ports
DEFAULT_OUTPUT_PORT = 5678
DEFAULT_INPUT_PORT = 5679

# Memory usage divisor for converting bytes to MB
BYTES_IN_MB = 1024 * 1024

Loading