Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile_processor
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM python:3.11
WORKDIR /app

COPY . /app
RUN pip install ./arroyo
# RUN pip install ./arroyo
RUN pip install --no-cache-dir --upgrade .

CMD ["python", "-m", "tr_ap_xps.apps.processor_cli"]
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
]

dependencies = [
# "arroyo",
"arroyopy",
"astropy",
"dynaconf",
"python-dotenv",
Expand All @@ -22,7 +22,8 @@ dependencies = [
"numpy",
"Pillow",
"pyzmq",
"tiled[client] @ git+https://github.com/bluesky/tiled.git",
"scipy==1.14.1",
"tiled[client]",
"tqdm",
"typer",
"websockets",
Expand All @@ -34,6 +35,7 @@ dependencies = [
# the documentation) but not necessarily required for _using_ it.
dev = [
"flake8",
"fakeredis",
"pre-commit",
"pytest-asyncio",
"pytest-mock",
Expand Down
2 changes: 1 addition & 1 deletion settings.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
xps:
xps_operator:
log_level: "INFO"
tiled_uri: "http://localhost:8000"
lv_zmq_listener:
Expand Down
3 changes: 1 addition & 2 deletions src/_tests/test_calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def test_peak_fit(test_array):

def test_fft_items(test_array):
"""Test the FFT calculation functionality."""
vfft, sum, ifft = calculate_fft_items(test_array)
vfft, ifft = calculate_fft_items(test_array)

assert vfft.shape == test_array.shape, "vfft shape mismatch"
assert len(sum.shape) == 1, "sum should be a 1D array"
assert ifft.shape == test_array.shape, "ifft shape mismatch"
16 changes: 11 additions & 5 deletions src/_tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ def start_processor_cli():
subprocess.run(["python", "processor_cli.py"])


def start_zmq_publisher():
def start_zmq_publisher(port):
context = zmq.Context()
socket = context.socket(zmq.PUB)
socket.bind("tcp://*:5555")
socket.bind(f"tcp://*:{port}")
while True:
socket.send_string("test message")
time.sleep(1)
Expand All @@ -28,8 +28,14 @@ def test_integration():
processor_cli_process.start()
time.sleep(2) # Give it time to start

# Start zmq publisher in a background process
zmq_publisher_process = Process(target=start_zmq_publisher)
# Dynamically assign a port for ZMQ publisher
context = zmq.Context()
temp_socket = context.socket(zmq.PUB)
port = temp_socket.bind_to_random_port("tcp://*")
temp_socket.close()

# Start zmq publisher in a background process with the random port
zmq_publisher_process = Process(target=start_zmq_publisher, args=(port,))
zmq_publisher_process.start()
time.sleep(2) # Give it time to start

Expand All @@ -41,7 +47,7 @@ def test_integration():
# Set up zmq subscriber to receive messages
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.connect("tcp://localhost:5555")
socket.connect(f"tcp://localhost:{port}")
socket.setsockopt_string(zmq.SUBSCRIBE, "")
#
# Check if messages are received and processed
Expand Down
5 changes: 3 additions & 2 deletions src/_tests/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ async def run_simulator(num_frames: int = 1):


@pytest.fixture
async def mock_operator():
def mock_operator():
return AsyncMock()


@pytest.mark.asyncio
async def test_listen_zmq_interface(mock_operator):
async def test_listen_zmq_interface(mock_operator, monkeypatch):
# monkeypatch.setattr("tr_ap_xps.labview.app_settings.lv_zmq_listener.zmq_pub_port", "6000")
zmq_socket = setup_zmq() # Ensure setup_zmq supports async if needed

async with run_simulator(num_frames=1):
Expand Down
12 changes: 4 additions & 8 deletions src/tr_ap_xps/apps/processor_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
logger = logging.getLogger("tr_ap_xps")
setup_logger(logger)

app_settings = settings.xps

app_settings = settings.xps_operator

def tiled_runs_container() -> Container:
try:
Expand Down Expand Up @@ -48,14 +47,11 @@ async def listen() -> None:

# setup websocket server
operator = XPSOperator()
ws_publisher = XPSWSResultPublisher(
host=app_settings.websockets_publisher.host,
port=app_settings.websockets_publisher.port,
)
tiled_pub = TiledPublisher(tiled_runs_container())
ws_publisher = XPSWSResultPublisher(app_settings.websocket_url)
# tiled_pub = TiledPublisher(tiled_runs_container())

operator.add_publisher(ws_publisher)
operator.add_publisher(tiled_pub)
# operator.add_publisher(tiled_pub)
# connect to labview zmq

lv_zmq_socket = setup_zmq()
Expand Down
4 changes: 2 additions & 2 deletions src/tr_ap_xps/labview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import zmq.asyncio

from arroyo.zmq import ZMQListener
from arroyopy.zmq import ZMQListener

from .config import settings
from .schemas import NumpyArrayModel, XPSImageInfo, XPSRawEvent, XPSStart, XPSStop
Expand All @@ -26,7 +26,7 @@
"Double Float": np.dtype(np.double).newbyteorder(">"),
}

app_settings = settings.xps
app_settings = settings.xps_operator

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions src/tr_ap_xps/pipeline/xps_operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging

from arroyo.operator import Operator
from arroyo.schemas import Message
from arroyopy.operator import Operator
from arroyopy.schemas import Message

from ..schemas import DataFrameModel, XPSRawEvent, XPSResultStop, XPSStart, XPSStop
from ..timing import timer
Expand Down
5 changes: 1 addition & 4 deletions src/tr_ap_xps/pipeline/xps_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,20 @@ def process_frame(self, message: XPSRawEvent) -> None:
# Things to do with every shot (a "shot" is a complete cycle of frames)
if (
message.image_info.frame_number != 0
and message.image_info.frame_number % self.frames_per_cycle == 0
and (message.image_info.frame_number + 1) % self.frames_per_cycle == 0
):
self.shot_num += 1

self.shot_recent = self.shot_cache

self._compute_rolling_values(self.shot_cache)


logger.info(f"Processing frame {message.image_info.frame_number}")
# Peak detection on new_integrated_frame
detected_peaks_df = peak_fit(new_integrated_frame)
# TODO: allow user to select repeat factor and width on UI
vfft_np, ifft_np = calculate_fft_items(
self.integrated_frames, repeat_factor=20, width=0
)

result = XPSResult(
frame_number=message.image_info.frame_number,
integrated_frames=NumpyArrayModel(array=self.integrated_frames),
Expand Down
2 changes: 1 addition & 1 deletion src/tr_ap_xps/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel, Field

from arroyo.schemas import DataFrameModel, Event, Message, NumpyArrayModel, Start, Stop
from arroyopy.schemas import DataFrameModel, Event, Message, NumpyArrayModel, Start, Stop

"""
This module defines schemas for XPS (X-ray Photoelectron Spectroscopy) messages and events using
Expand Down
4 changes: 2 additions & 2 deletions src/tr_ap_xps/tiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from tiled.structures.data_source import DataSource
from tiled.structures.table import TableStructure

from arroyo.publisher import Publisher
from arroyopy.publisher import Publisher

from .config import settings
from .schemas import XPSResult, XPSResultStop, XPSStart

app_settings = settings.xps
app_settings = settings.xps_operator

logger = logging.getLogger(__name__)

Expand Down
34 changes: 19 additions & 15 deletions src/tr_ap_xps/websockets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from urllib.parse import urlparse
import json
import logging
from typing import Union
Expand All @@ -8,7 +9,7 @@
import pandas as pd
import websockets

from arroyo.publisher import Publisher
from arroyopy.publisher import Publisher

from .schemas import XPSResult, XPSResultStop, XPSStart

Expand All @@ -25,21 +26,21 @@ class XPSWSResultPublisher(Publisher):
connected_clients = set()
current_start_message = None

def __init__(self, host: str = "localhost", port: int = 8001):
def __init__(self, ws_url: str = "ws://localhost:8765"):
super().__init__()
self.host = host
self.port = port
self.ws_url = ws_url

async def start(
self,
):
# Use partial to bind `self` while matching the expected handler signature
parsed_url = urlparse(self.ws_url)
server = await websockets.serve(
self.websocket_handler,
self.host,
self.port,
parsed_url.hostname,
parsed_url.port,
)
logger.info(f"Websocket server started at ws://{self.host}:{self.port}")
logger.info(f"Websocket server started at ws://{parsed_url.hostname}:{parsed_url.port}")
await server.wait_closed()

async def publish(self, message: XPSResult) -> None:
Expand Down Expand Up @@ -79,7 +80,7 @@ async def publish_ws(

async def websocket_handler(self, websocket):
logger.info(f"New connection from {websocket.remote_address}")
if websocket.request.path != "/simImages":
if websocket.request.path != "/xps_operator":
logger.info(
f"Invalid path: {websocket.request.path}, we only support /simImages"
)
Expand All @@ -98,14 +99,17 @@ def convert_to_uint8(image: np.ndarray) -> bytes:
"""
Convert an image to uint8, scaling image
"""

# scaled = (image - image.min()) / (image.max() - image.min()) * 255
# return scaled.astype(np.uint8).tobytes()


if np.allclose(image, 0):
return image.astype(np.uint8).tobytes()

image_normalized = (image - image.min()) / (image.max() - image.min())

# Apply logarithmic stretch
log_stretched = np.log1p(image_normalized) # log(1 + x) to handle near-zero values

# Normalize the log-stretched image to [0, 1] again
log_stretched_normalized = (log_stretched - log_stretched.min()) / (
log_stretched.max() - log_stretched.min()
Expand Down Expand Up @@ -135,11 +139,11 @@ def pack_images(message: XPSResult) -> bytes:
"""
return msgpack.packb(
{
"raw": convert_to_uint8(message.integrated_frames.array),
"vfft": convert_to_uint8(message.vfft.array),
"ifft": convert_to_uint8(message.ifft.array),
"width": message.integrated_frames.array.shape[0],
"height": message.integrated_frames.array.shape[1],
# "raw": convert_to_uint8(message.integrated_frames.array),
# "vfft": convert_to_uint8(message.vfft.array),
# "ifft": convert_to_uint8(message.ifft.array),
"width": message.shot_mean.array.shape[0],
"height": message.shot_mean.array.shape[1],
"fitted": json.dumps(peaks_output(message.detected_peaks.df)),
"shot_num": message.shot_num,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One note about shot_num: I noticed that it starts from 1 instead of 0 during beamtime. It works well — just a note for awareness.

Shall we add tiled_url for shot_mean in the websocket message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeas, I like the tiled_url for shot_mean, but we'll need to reconstitute the TiledPublisher here to do that.

"shot_recent": convert_to_uint8(message.shot_recent.array),
Expand Down