diff --git a/Dockerfile_processor b/Dockerfile_processor index 1323c71..6960ba6 100644 --- a/Dockerfile_processor +++ b/Dockerfile_processor @@ -3,7 +3,7 @@ FROM python:3.11 WORKDIR /app COPY . /app -RUN pip install ./arroyo +# RUN pip install ./arroyo RUN pip install --no-cache-dir --upgrade . CMD ["python", "-m", "tr_ap_xps.apps.processor_cli"] diff --git a/pyproject.toml b/pyproject.toml index cf68b2d..ff66dbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ ] dependencies = [ - # "arroyo", + "arroyopy", "astropy", "dynaconf", "python-dotenv", @@ -22,7 +22,8 @@ dependencies = [ "numpy", "Pillow", "pyzmq", - "tiled[client] @ git+https://github.com/bluesky/tiled.git", + "scipy==1.14.1", + "tiled[client]", "tqdm", "typer", "websockets", @@ -34,6 +35,7 @@ dependencies = [ # the documentation) but not necessarily required for _using_ it. dev = [ "flake8", + "fakeredis", "pre-commit", "pytest-asyncio", "pytest-mock", diff --git a/settings.yaml b/settings.yaml index aef0431..f4bd73b 100644 --- a/settings.yaml +++ b/settings.yaml @@ -1,4 +1,4 @@ -xps: +xps_operator: log_level: "INFO" tiled_uri: "http://localhost:8000" lv_zmq_listener: diff --git a/src/_tests/test_calculate.py b/src/_tests/test_calculate.py index 65e445e..4c8806d 100644 --- a/src/_tests/test_calculate.py +++ b/src/_tests/test_calculate.py @@ -60,8 +60,7 @@ def test_peak_fit(test_array): def test_fft_items(test_array): """Test the FFT calculation functionality.""" - vfft, sum, ifft = calculate_fft_items(test_array) + vfft, ifft = calculate_fft_items(test_array) assert vfft.shape == test_array.shape, "vfft shape mismatch" - assert len(sum.shape) == 1, "sum should be a 1D array" assert ifft.shape == test_array.shape, "ifft shape mismatch" diff --git a/src/_tests/test_integration.py b/src/_tests/test_integration.py index f7cb505..d903e34 100644 --- a/src/_tests/test_integration.py +++ b/src/_tests/test_integration.py @@ -9,10 +9,10 @@ def start_processor_cli(): subprocess.run(["python", "processor_cli.py"]) -def start_zmq_publisher(): +def start_zmq_publisher(port): context = zmq.Context() socket = context.socket(zmq.PUB) - socket.bind("tcp://*:5555") + socket.bind(f"tcp://*:{port}") while True: socket.send_string("test message") time.sleep(1) @@ -28,8 +28,14 @@ def test_integration(): processor_cli_process.start() time.sleep(2) # Give it time to start - # Start zmq publisher in a background process - zmq_publisher_process = Process(target=start_zmq_publisher) + # Dynamically assign a port for ZMQ publisher + context = zmq.Context() + temp_socket = context.socket(zmq.PUB) + port = temp_socket.bind_to_random_port("tcp://*") + temp_socket.close() + + # Start zmq publisher in a background process with the random port + zmq_publisher_process = Process(target=start_zmq_publisher, args=(port,)) zmq_publisher_process.start() time.sleep(2) # Give it time to start @@ -41,7 +47,7 @@ def test_integration(): # Set up zmq subscriber to receive messages context = zmq.Context() socket = context.socket(zmq.SUB) - socket.connect("tcp://localhost:5555") + socket.connect(f"tcp://localhost:{port}") socket.setsockopt_string(zmq.SUBSCRIBE, "") # # Check if messages are received and processed diff --git a/src/_tests/test_listener.py b/src/_tests/test_listener.py index 9bc35b1..61f5765 100644 --- a/src/_tests/test_listener.py +++ b/src/_tests/test_listener.py @@ -24,12 +24,13 @@ async def run_simulator(num_frames: int = 1): @pytest.fixture -async def mock_operator(): +def mock_operator(): return AsyncMock() @pytest.mark.asyncio -async def test_listen_zmq_interface(mock_operator): +async def test_listen_zmq_interface(mock_operator, monkeypatch): + # monkeypatch.setattr("tr_ap_xps.labview.app_settings.lv_zmq_listener.zmq_pub_port", "6000") zmq_socket = setup_zmq() # Ensure setup_zmq supports async if needed async with run_simulator(num_frames=1): diff --git a/src/tr_ap_xps/apps/processor_cli.py b/src/tr_ap_xps/apps/processor_cli.py index e15901e..9d52d11 100644 --- a/src/tr_ap_xps/apps/processor_cli.py +++ b/src/tr_ap_xps/apps/processor_cli.py @@ -17,8 +17,7 @@ logger = logging.getLogger("tr_ap_xps") setup_logger(logger) -app_settings = settings.xps - +app_settings = settings.xps_operator def tiled_runs_container() -> Container: try: @@ -48,14 +47,11 @@ async def listen() -> None: # setup websocket server operator = XPSOperator() - ws_publisher = XPSWSResultPublisher( - host=app_settings.websockets_publisher.host, - port=app_settings.websockets_publisher.port, - ) - tiled_pub = TiledPublisher(tiled_runs_container()) + ws_publisher = XPSWSResultPublisher(app_settings.websocket_url) + # tiled_pub = TiledPublisher(tiled_runs_container()) operator.add_publisher(ws_publisher) - operator.add_publisher(tiled_pub) + # operator.add_publisher(tiled_pub) # connect to labview zmq lv_zmq_socket = setup_zmq() diff --git a/src/tr_ap_xps/labview.py b/src/tr_ap_xps/labview.py index b82424e..1050c20 100644 --- a/src/tr_ap_xps/labview.py +++ b/src/tr_ap_xps/labview.py @@ -5,7 +5,7 @@ import numpy as np import zmq.asyncio -from arroyo.zmq import ZMQListener +from arroyopy.zmq import ZMQListener from .config import settings from .schemas import NumpyArrayModel, XPSImageInfo, XPSRawEvent, XPSStart, XPSStop @@ -26,7 +26,7 @@ "Double Float": np.dtype(np.double).newbyteorder(">"), } -app_settings = settings.xps +app_settings = settings.xps_operator logger = logging.getLogger(__name__) diff --git a/src/tr_ap_xps/pipeline/xps_operator.py b/src/tr_ap_xps/pipeline/xps_operator.py index 230f7f4..b2c6a09 100644 --- a/src/tr_ap_xps/pipeline/xps_operator.py +++ b/src/tr_ap_xps/pipeline/xps_operator.py @@ -1,8 +1,8 @@ import asyncio import logging -from arroyo.operator import Operator -from arroyo.schemas import Message +from arroyopy.operator import Operator +from arroyopy.schemas import Message from ..schemas import DataFrameModel, XPSRawEvent, XPSResultStop, XPSStart, XPSStop from ..timing import timer diff --git a/src/tr_ap_xps/pipeline/xps_processor.py b/src/tr_ap_xps/pipeline/xps_processor.py index 78a5177..f601569 100644 --- a/src/tr_ap_xps/pipeline/xps_processor.py +++ b/src/tr_ap_xps/pipeline/xps_processor.py @@ -71,15 +71,13 @@ def process_frame(self, message: XPSRawEvent) -> None: # Things to do with every shot (a "shot" is a complete cycle of frames) if ( message.image_info.frame_number != 0 - and message.image_info.frame_number % self.frames_per_cycle == 0 + and (message.image_info.frame_number + 1) % self.frames_per_cycle == 0 ): self.shot_num += 1 - self.shot_recent = self.shot_cache self._compute_rolling_values(self.shot_cache) - logger.info(f"Processing frame {message.image_info.frame_number}") # Peak detection on new_integrated_frame detected_peaks_df = peak_fit(new_integrated_frame) @@ -87,7 +85,6 @@ def process_frame(self, message: XPSRawEvent) -> None: vfft_np, ifft_np = calculate_fft_items( self.integrated_frames, repeat_factor=20, width=0 ) - result = XPSResult( frame_number=message.image_info.frame_number, integrated_frames=NumpyArrayModel(array=self.integrated_frames), diff --git a/src/tr_ap_xps/schemas.py b/src/tr_ap_xps/schemas.py index fb84315..c3a195b 100644 --- a/src/tr_ap_xps/schemas.py +++ b/src/tr_ap_xps/schemas.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from arroyo.schemas import DataFrameModel, Event, Message, NumpyArrayModel, Start, Stop +from arroyopy.schemas import DataFrameModel, Event, Message, NumpyArrayModel, Start, Stop """ This module defines schemas for XPS (X-ray Photoelectron Spectroscopy) messages and events using diff --git a/src/tr_ap_xps/tiled.py b/src/tr_ap_xps/tiled.py index f39b6dc..2b6e9f3 100644 --- a/src/tr_ap_xps/tiled.py +++ b/src/tr_ap_xps/tiled.py @@ -11,12 +11,12 @@ from tiled.structures.data_source import DataSource from tiled.structures.table import TableStructure -from arroyo.publisher import Publisher +from arroyopy.publisher import Publisher from .config import settings from .schemas import XPSResult, XPSResultStop, XPSStart -app_settings = settings.xps +app_settings = settings.xps_operator logger = logging.getLogger(__name__) diff --git a/src/tr_ap_xps/websockets.py b/src/tr_ap_xps/websockets.py index cb6bce4..19b0eee 100644 --- a/src/tr_ap_xps/websockets.py +++ b/src/tr_ap_xps/websockets.py @@ -1,4 +1,5 @@ import asyncio +from urllib.parse import urlparse import json import logging from typing import Union @@ -8,7 +9,7 @@ import pandas as pd import websockets -from arroyo.publisher import Publisher +from arroyopy.publisher import Publisher from .schemas import XPSResult, XPSResultStop, XPSStart @@ -25,21 +26,21 @@ class XPSWSResultPublisher(Publisher): connected_clients = set() current_start_message = None - def __init__(self, host: str = "localhost", port: int = 8001): + def __init__(self, ws_url: str = "ws://localhost:8765"): super().__init__() - self.host = host - self.port = port + self.ws_url = ws_url async def start( self, ): # Use partial to bind `self` while matching the expected handler signature + parsed_url = urlparse(self.ws_url) server = await websockets.serve( self.websocket_handler, - self.host, - self.port, + parsed_url.hostname, + parsed_url.port, ) - logger.info(f"Websocket server started at ws://{self.host}:{self.port}") + logger.info(f"Websocket server started at ws://{parsed_url.hostname}:{parsed_url.port}") await server.wait_closed() async def publish(self, message: XPSResult) -> None: @@ -79,7 +80,7 @@ async def publish_ws( async def websocket_handler(self, websocket): logger.info(f"New connection from {websocket.remote_address}") - if websocket.request.path != "/simImages": + if websocket.request.path != "/xps_operator": logger.info( f"Invalid path: {websocket.request.path}, we only support /simImages" ) @@ -98,14 +99,17 @@ def convert_to_uint8(image: np.ndarray) -> bytes: """ Convert an image to uint8, scaling image """ + # scaled = (image - image.min()) / (image.max() - image.min()) * 255 # return scaled.astype(np.uint8).tobytes() - + + if np.allclose(image, 0): + return image.astype(np.uint8).tobytes() + image_normalized = (image - image.min()) / (image.max() - image.min()) # Apply logarithmic stretch log_stretched = np.log1p(image_normalized) # log(1 + x) to handle near-zero values - # Normalize the log-stretched image to [0, 1] again log_stretched_normalized = (log_stretched - log_stretched.min()) / ( log_stretched.max() - log_stretched.min() @@ -135,11 +139,11 @@ def pack_images(message: XPSResult) -> bytes: """ return msgpack.packb( { - "raw": convert_to_uint8(message.integrated_frames.array), - "vfft": convert_to_uint8(message.vfft.array), - "ifft": convert_to_uint8(message.ifft.array), - "width": message.integrated_frames.array.shape[0], - "height": message.integrated_frames.array.shape[1], + # "raw": convert_to_uint8(message.integrated_frames.array), + # "vfft": convert_to_uint8(message.vfft.array), + # "ifft": convert_to_uint8(message.ifft.array), + "width": message.shot_mean.array.shape[0], + "height": message.shot_mean.array.shape[1], "fitted": json.dumps(peaks_output(message.detected_peaks.df)), "shot_num": message.shot_num, "shot_recent": convert_to_uint8(message.shot_recent.array),