diff --git a/pdfstream/analyzers/base.py b/pdfstream/analyzers/base.py index 91950ca2..72a660cd 100644 --- a/pdfstream/analyzers/base.py +++ b/pdfstream/analyzers/base.py @@ -1,7 +1,7 @@ from configparser import ConfigParser from bluesky.callbacks.core import CallbackBase -from databroker.core import BlueskyRun +from databroker.client import BlueskyRun class AnalyzerConfig(ConfigParser): diff --git a/pdfstream/analyzers/xpd_analyzer.py b/pdfstream/analyzers/xpd_analyzer.py index 7381cc6c..7f39856b 100644 --- a/pdfstream/analyzers/xpd_analyzer.py +++ b/pdfstream/analyzers/xpd_analyzer.py @@ -1,7 +1,7 @@ import typing as tp from databroker import catalog -from databroker.core import BlueskyRun +from databroker.client import BlueskyRun from pdfstream.analyzers.base import AnalyzerConfig, Analyzer from pdfstream.servers.xpd_server import XPDRouter, XPDConfig diff --git a/pdfstream/callbacks/analysis.py b/pdfstream/callbacks/analysis.py index 229776e6..0bd8b192 100644 --- a/pdfstream/callbacks/analysis.py +++ b/pdfstream/callbacks/analysis.py @@ -5,15 +5,22 @@ from configparser import ConfigParser from pathlib import Path +import databroker.mongo_normalized import event_model import matplotlib.pyplot as plt import numpy as np from bluesky.callbacks.stream import LiveDispatcher from databroker.v1 import Broker from event_model import RunRouter -from pyFAI.azimuthalIntegrator import AzimuthalIntegrator +# from pyFAI.azimuthalIntegrator import AzimuthalIntegrator +from pyFAI.integrator.azimuthal import AzimuthalIntegrator + from suitcase.csv import Serializer as CSVSerializer from suitcase.json_metadata import Serializer as JsonSerializer +import pandas +from tiled.client import from_uri +from tiled.client.array import ArrayClient +from tiled.client.dataframe import DataFrameClient import pdfstream import pdfstream.callbacks.from_descriptor as from_desc @@ -124,6 +131,8 @@ def trans_setting(self): "rmin": self.getfloat("ANALYSIS", "rmin", fallback=0.), "rmax": self.getfloat("ANALYSIS", "rmax", fallback=30.), "rstep": self.getfloat("ANALYSIS", "rstep", fallback=0.01), + "backgroundfiles": self.get("ANALYSIS", "bkg_file", fallback=""), + "bgscales":self.getfloat("ANALYSIS", "bgscale", fallback=1), "dataformat": "QA" } @@ -175,9 +184,15 @@ def __init__(self, config: AnalysisConfig): self.dirc = None self.file_prefix = None + self._tiled_client = from_uri("https://tiled.nsls2.bnl.gov/api/v1/metadata/xpd/sandbox") + def start(self, doc, _md=None): io.server_message("Receive the start of '{}'.".format(doc["uid"])) self.clear_cache() + + # Get detectors name + self._detectors = doc["detectors"] + # get indeps self.indeps = from_start.get_indeps(doc, exclude={"time"}) # copy the default config and read the user config @@ -217,7 +232,7 @@ def start(self, doc, _md=None): # create directoy d = self.config.directory fp = self.config.file_prefix - self.dirc = Path(d).expanduser().joinpath(new_start["sample_name"]) + self.dirc = Path(str(Path(d).expanduser().joinpath(new_start["sample_name"])).format(start=doc)) if self.config.save_file: self.dirc.mkdir(parents=True, exist_ok=True) # create file prefix @@ -242,6 +257,19 @@ def descriptor(self, doc): except ValueNotFoundError as error: self.dark_image = None io.server_message("Failed to find the dark: " + str(error)) + + stream_desc = {"stream": {"fields": []}} + for obj_name in doc["hints"]: + stream_desc["stream"]["fields"].extend(doc["hints"][obj_name]["fields"]) + + fields_to_add = ["chi_max", "chi_argmax", "gr_max", "gr_argmax"] + + fields_to_add.extend([val for val in doc["object_keys"][self._detectors[0]] if not val.endswith("_image")]) + + stream_desc["stream"]["fields"].extend(fields_to_add) + + doc["hints"].update(stream_desc) + return super(AnalysisStream, self).descriptor(doc) def event(self, doc, _md=None): @@ -254,6 +282,34 @@ def stop(self, doc, _md=None): io.server_message("Receive the stop of '{}'.".format(doc["run_start"])) return super(AnalysisStream, self).stop(doc) + def _get_uid_from_uri(self, uri) -> str: + return uri.split("/")[-1] + + def _write_dataframe_to_tiled(self, data_dict, columns, group_key, default_md): + df_data = [] + for key in columns: + df_data.append(data_dict[key]) + df_data = np.array(df_data).T + + df = pandas.DataFrame(df_data, columns=columns) + + metadata = { + "field": group_key, + **default_md + } + for m in ["argmax", "max"]: + k = f"{group_key}_{m}" + if k in data_dict: + metadata[k] = data_dict[k] + + entry = self._tiled_client.write_dataframe(df, metadata=metadata, access_tags=["xpd_sandbox"]) + entry_uri = entry.uri + entry_uid = self._get_uid_from_uri(entry_uri) + + # print(f"{group_key = } {entry_uri = } {entry_uid = }") + + return {"uri": entry_uri, "uid": entry_uid} + def process_data(self, doc) -> dict: """Process the data in the event doc. Return a dictionary of processed data.""" # the raw image in the data @@ -267,6 +323,8 @@ def process_data(self, doc) -> dict: if not self.config.save_file: filename, directory = None, None # process the data output a dictionary + import time as ttime + start_time = ttime.monotonic() an_data = process( raw_img=raw_img, ai=self.ai, @@ -286,8 +344,80 @@ def process_data(self, doc) -> dict: # filter the data if self.valid_keys: an_data = self.filter(an_data) + duration = ttime.monotonic() - start_time + print(f"process took {duration:.6f} sec.") + # from pprint import pformat + # print(f"{self.__class__.__name__}:\ndoc={pformat(doc)}\nraw_data={pformat(raw_data)}\nan_data={pformat(an_data)}") + + import time as ttime + start_time = ttime.monotonic() + + # Enter the information to Tiled: + tiled_dict = {} + default_md = {"run_start": self.start_doc["uid"]} + + # DataFrames: + + # DataFrame: chi + group_key = "chi" + tiled_key = f"tiled_{group_key}" + columns=["chi_2theta", "chi_Q", "chi_I"] + tiled_dict[tiled_key] = self._write_dataframe_to_tiled( + an_data, + columns=columns, + group_key=group_key, + default_md=default_md, + ) + for key in columns + ["iq_I", "iq_Q"]: # special case + an_data[key] = tiled_dict[tiled_key]["uid"] + + # DataFrame: fq/sq + group_key = "fqsq" + tiled_key = f"tiled_{group_key}" + columns = ["fq_F", "fq_Q", "sq_Q", "sq_S"] + tiled_dict[tiled_key] = self._write_dataframe_to_tiled( + an_data, + columns=columns, + group_key=group_key, + default_md=default_md, + ) + for key in columns: + an_data[key] = tiled_dict[tiled_key]["uid"] + + # DataFrame: gr + group_key = "gr" + tiled_key = f"tiled_{group_key}" + columns = ["gr_G", "gr_r"] + tiled_dict[tiled_key] = self._write_dataframe_to_tiled( + an_data, + columns=columns, + group_key=group_key, + default_md=default_md, + ) + for key in columns: + an_data[key] = tiled_dict[tiled_key]["uid"] + + # Arrays: + for key in ["dk_sub_image", "mask"]: + tiled_key = f"tiled_{key}" + entry = self._tiled_client.write_array( + an_data[key], + metadata={"field": key, **default_md}, + access_tags=["xpd_sandbox"]) + entry_uri = entry.uri + entry_uid = entry_uri.split("/")[-1] + tiled_dict[tiled_key] = {"uri": entry_uri, "uid": entry_uid} + an_data[key] = tiled_dict[tiled_key]["uid"] + + duration = ttime.monotonic() - start_time + print(f"Uploading to tiled took {duration:.6f} sec.") + + from pprint import pformat + print(f"tiled_dict:\n{pformat(tiled_dict)}") + # print(f"an_data:\n{pformat(an_data)}") + # the final output data is a combination of the independent variables and processed data - return dict(**raw_data, **an_data) + return dict(**raw_data, **an_data, **tiled_dict) def filter(self, data: dict): return {k: v for k, v in data.items() if k in self.valid_keys} @@ -337,8 +467,8 @@ def process( "chi_2theta": np.array([0.]), "chi_Q": np.array([0.]), "chi_I": np.array([0.]), - "chi_max": np.float(0.), - "chi_argmax": np.float(0.), + "chi_max": np.float64(0.), + "chi_argmax": np.float64(0.), "iq_Q": np.array([0.]), "iq_I": np.array([0.]), "sq_Q": np.array([0.]), @@ -347,8 +477,8 @@ def process( "fq_F": np.array([0.]), "gr_r": np.array([0.]), "gr_G": np.array([0.]), - "gr_max": np.float(0.), - "gr_argmax": np.float(0.) + "gr_max": np.float64(0.), + "gr_argmax": np.float64(0.) } # dark subtraction if dk_img is not None: @@ -379,7 +509,7 @@ def process( pdfconfig = PDFConfig(**pdfgetx_setting) pdfgetter = PDFGetter(pdfconfig) pdfgetter(x, y) - iq, sq, fq, gr = pdfgetter.iq, pdfgetter.sq, pdfgetter.fq, pdfgetter.gr + iq, sq, fq, gr = [x, y], pdfgetter.sq, pdfgetter.fq, pdfgetter.gr gr_max_ind = np.argmax(gr[1]) data.update( { @@ -440,19 +570,93 @@ def tiff_setting(self): } +class TiledClientTypeException(Exception): + ... + +def fill_data_from_tiled(data, tiled_client): + """The helper function to fill the event's data field with the data from sandbox Tiled. + + Parameters + ---------- + data (dict): a subset of the event document via the 'data' key). + tiled_client (tiled.client...): a Tiled client instance. + + Examples + -------- + + Queueries can look like that: + + In [69]: queries + Out[69]: + {'32cf499e-0b28-4806-a83f-111f993812e6': , + 'fef412dd-88a3-48eb-8c12-93883a7983ea': , + '9c1470af-dbd3-40f6-bea8-f5058e6bb146': , + '41ad8fdf-12d0-4c1a-b057-87032ba6f34b': , + '6e2eedec-2a03-421a-96e4-59f77a27213c': } + """ + all_values = list(data.values()) + all_uids = [x for x in all_values if type(x) is str] + queries = {} + + for uid in set(all_uids): + queries[uid] = tiled_client[uid] + + for key in data: + if key.startswith("tiled_"): # the values are dictionaries (can't be used as keys) + continue + client = queries.get(data[key], None) # check if the uid (=data[key]) is in the 'queries' keys, otherwise skip the filling. + if client is None: + continue + if isinstance(client, ArrayClient): # image data + data[key] = client.read() + elif isinstance(client, DataFrameClient): # Pandas DataFrames + new_key = key + if key in ["iq_I", "iq_Q"]: # special case + new_key = key.replace("iq", "chi") + data[key] = np.array(client.read()[new_key]) + else: + raise TiledClientTypeException(f"Unknown tiled client type: {type(client)}") + + return data + + class Exporter(RunRouter): """Export the processed data to file systems. Add readable_time to start doc.""" def __init__(self, config: ExportConfig): + self._config = config factory = ExporterFactory(config) - super().__init__([factory]) - io.server_message("Data will be exported in '{}'.".format(str(config.tiff_base))) + super().__init__([factory], handler_registry=databroker.mongo_normalized.discover_handlers()) + io.server_message("Data will be exported in '{}' in a proposal directory.".format(str(config.tiff_base))) + + self._tiled_client = from_uri("https://tiled.nsls2.bnl.gov/api/v1/metadata/xpd/sandbox") def start(self, start_doc): + save_dir = self._config.tiff_base.joinpath(self._config.directory_template.format(start=start_doc)) + io.server_message("Data will be exported in '{}'.".format(save_dir)) io.server_message("Receive the start of '{}'.".format(start_doc["uid"])) return super(Exporter, self).start(start_doc) + def descriptor(self, doc): + doc["data_keys"]["dk_sub_image"].update({"dtype": "array", "shape": [-1, -1]}) + super().descriptor(doc) + def event(self, doc): + # from pprint import pformat + # print(f"{self.__class__.__name__} (before filling from Tiled): {pformat(doc)}") + + import time as ttime + start_time = ttime.monotonic() + + data = doc["data"] + # Get information for all fillable entries in 'an_data' dict from Tiled: + data = fill_data_from_tiled(data=data, tiled_client=self._tiled_client) + + duration = ttime.monotonic() - start_time + print(f"Downloading from tiled took {duration:.6f} sec.") + + # print(f"{self.__class__.__name__} (after filling from Tiled): {pformat(doc)}") + io.server_message("Export data in the event {}.".format(doc["seq_num"])) return super(Exporter, self).event(doc) @@ -606,7 +810,16 @@ class Visualizer(RunRouter): def __init__(self, config: VisConfig): self._factory = VisFactory(config) - super(Visualizer, self).__init__([self._factory]) + super().__init__([self._factory]) + + self._tiled_client = from_uri("https://tiled.nsls2.bnl.gov/api/v1/metadata/xpd/sandbox") + print(f"Tiled client: {self._tiled_client}") + + def event(self, doc): + data = doc["data"] + # Get information for all fillable entries in 'an_data' dict from Tiled: + data = fill_data_from_tiled(data=data, tiled_client=self._tiled_client) + return super().event(doc) def show_figs(self): """Show all the figures in the callbacks in the factory.""" diff --git a/pdfstream/callbacks/basic.py b/pdfstream/callbacks/basic.py index a32208d9..d2ad391b 100644 --- a/pdfstream/callbacks/basic.py +++ b/pdfstream/callbacks/basic.py @@ -242,6 +242,14 @@ def __init__(self, *, xlabel: str, ylabel: str, ax: Axes, **kwargs): ) self.x_offset_slider.on_changed(self.update_x_offset) + def clear(self): + self.key_list.clear() + self.x_array_list.clear() + self.y_array_list.clear() + for line in self.ax.get_lines(): + line.remove() + self.canvas.draw_idle() + class LiveWaterfall(CallbackBase): """A live water plot for the one dimensional data.""" @@ -377,6 +385,8 @@ def event(self, doc): returned = super(MyTiffSerializer, self).event(doc) # go back to original data key self._file_prefix = _file_prefix + # TODO: submit the fix below to the 'suitcase-tiff' repo. + self.close() return returned diff --git a/pdfstream/integration/tools.py b/pdfstream/integration/tools.py index 7de97993..7a193164 100644 --- a/pdfstream/integration/tools.py +++ b/pdfstream/integration/tools.py @@ -6,7 +6,8 @@ import numpy as np from matplotlib.axes import Axes from numpy import ndarray -from pyFAI.azimuthalIntegrator import AzimuthalIntegrator +# from pyFAI.azimuthalIntegrator import AzimuthalIntegrator +from pyFAI.integrator.azimuthal import AzimuthalIntegrator from pdfstream.vend.masking import generate_binner, mask_img diff --git a/pdfstream/io.py b/pdfstream/io.py index fe79d5c3..c401387b 100644 --- a/pdfstream/io.py +++ b/pdfstream/io.py @@ -6,6 +6,7 @@ import fabio import numpy as np import pyFAI +from pyFAI.integrator.azimuthal import AzimuthalIntegrator import yaml from numpy import ndarray from tifffile import TiffWriter @@ -14,15 +15,15 @@ from pdfstream.vend.loaddata import load_data -def load_ai_from_poni_file(poni_file: str) -> pyFAI.AzimuthalIntegrator: +def load_ai_from_poni_file(poni_file: str) -> AzimuthalIntegrator: """Initiate the AzimuthalIntegrator using poni file.""" ai = pyFAI.load(poni_file) return ai -def load_ai_from_calib_result(calib_result: dict) -> pyFAI.AzimuthalIntegrator: +def load_ai_from_calib_result(calib_result: dict) -> AzimuthalIntegrator: """Initiate the AzimuthalIntegrator using calibration information.""" - ai = pyFAI.azimuthalIntegrator.AzimuthalIntegrator() + ai = AzimuthalIntegrator() # different from poni file, set_config only accepts dictionary of lowercase keys _calib_result = _lower_key(calib_result) # the pyFAI only accept strings so the None should be parsed to a string diff --git a/pdfstream/servers/base.py b/pdfstream/servers/base.py index 7bbde18a..cd9347e6 100644 --- a/pdfstream/servers/base.py +++ b/pdfstream/servers/base.py @@ -1,8 +1,13 @@ import typing +import uuid from configparser import ConfigParser +from enum import Enum from bluesky.callbacks import CallbackBase -from bluesky.callbacks.zmq import RemoteDispatcher +from bluesky.callbacks.zmq import RemoteDispatcher as RemoteDispatcherZMQ +from bluesky_kafka import RemoteDispatcher as RemoteDispatcherKafka, Publisher as PublisherKafka + +from nslsii.kafka_utils import _read_bluesky_kafka_config_file from pdfstream.io import server_message from pdfstream.vend.qt_kicker import install_qt_kicker @@ -38,7 +43,7 @@ def read(self, filenames, encoding=None) -> typing.List[str]: return returned -class BaseServer(RemoteDispatcher): +class BaseServer(RemoteDispatcherZMQ): """The basic server class.""" def __init__(self, config: ServerConfig): @@ -59,6 +64,102 @@ def install_qt_kicker(self): install_qt_kicker(self.loop) +def _get_kafka_config(topic): + kafka_dict = { + "topics": [f"{topic}.bluesky.runengine.documents"], + "group_id": f"echo-{topic}-{str(uuid.uuid4())[:8]}", + "kafka_config": _read_bluesky_kafka_config_file(config_file_path="/etc/bluesky/kafka.yml"), + } + kafka_dict["bootstrap_servers"] = ",".join(kafka_dict["kafka_config"]["bootstrap_servers"]) + return kafka_dict + + +def _get_kafka_producer_config(topic): + kafka_dict = _get_kafka_config(topic=topic) + key = kafka_dict.pop("group_id") + topics = kafka_dict.pop("topics") + kafka_config = kafka_dict.pop("kafka_config") + return {"producer_config": kafka_config["runengine_producer_config"], "key": key, "topic": topics[0], **kafka_dict} + + +def _get_kafka_consumer_config(topic): + kafka_dict = _get_kafka_config(topic=topic) + kafka_config = kafka_dict.pop("kafka_config") + return {"consumer_config": kafka_config["runengine_producer_config"], **kafka_dict} + + +class KafkaTopics(Enum): + raw = "xpd" + analysis = "xpd-analysis" + + +class BaseServerKafkaRaw(RemoteDispatcherKafka): + """The basic server class using Kafka message bus for consuming the raw data.""" + topic = KafkaTopics.raw.value + + def __init__(self, config: ServerConfig): + + kafka_dict = _get_kafka_consumer_config(topic=self.topic) + super().__init__(**kafka_dict) + self._config = config + self._kafka_dict = kafka_dict + + def start(self, *args, **kwargs): + try: + server_message( + "Server is started. " + + "Listen to {}, topics {}.".format(self._kafka_dict["bootstrap_servers"], self._kafka_dict["topics"]) + ) + super().start(*args, **kwargs) + except KeyboardInterrupt: + server_message("Server is terminated.") + + def install_qt_kicker(self): + pass + + +class BaseServerKafkaAnalysis(BaseServerKafkaRaw): + """The basic server class using Kafka message bus for consuming analysis data.""" + topic = KafkaTopics.analysis.value + + +class BaseServerKafkaViz(BaseServerKafkaAnalysis): + ... + + +class PublisherKafkaAnalysis(PublisherKafka): + def __call__(self, name, doc): + doc["topic"] = KafkaTopics.analysis.value + return super().__call__(name, doc) + + +from bluesky_widgets.qt.kafka_dispatcher import QtRemoteDispatcher + +class BaseServerKafkaVizQt(QtRemoteDispatcher): + """NOT WORKING YET!!! The basic server class using Kafka message bus for consuming analysis data for plotting.""" + topic = KafkaTopics.analysis.value + + def __init__(self, config: ServerConfig): + + kafka_dict = _get_kafka_consumer_config(topic=self.topic) + super().__init__(**kafka_dict) + self._config = config + self._kafka_dict = kafka_dict + + def start(self): + try: + server_message( + "Server is started. " + + "Listen to {}, topics {}.".format(self._kafka_dict["bootstrap_servers"], self._kafka_dict["topics"]) + ) + super().start() + except KeyboardInterrupt: + server_message("Server is terminated.") + + def install_qt_kicker(self): + pass + + class StartStopCallback(CallbackBase): """Print the time for analysis""" diff --git a/pdfstream/servers/xpd_server.py b/pdfstream/servers/xpd_server.py index 0b40aebb..0bcf4cde 100644 --- a/pdfstream/servers/xpd_server.py +++ b/pdfstream/servers/xpd_server.py @@ -1,16 +1,19 @@ """The analysis server. Process raw image to PDF.""" import typing as tp +import uuid -import databroker.core -from bluesky.callbacks.zmq import Publisher +import databroker.mongo_normalized +from bluesky.callbacks.zmq import Publisher as PublisherZMQ from databroker.v1 import Broker from event_model import RunRouter +from nslsii.kafka_utils import _read_bluesky_kafka_config_file + import pdfstream.io as io from pdfstream.callbacks.analysis import AnalysisConfig, VisConfig, ExportConfig, AnalysisStream, Exporter, \ Visualizer from pdfstream.callbacks.calibration import CalibrationConfig, Calibration -from pdfstream.servers.base import ServerConfig, BaseServer +from pdfstream.servers.base import ServerConfig, BaseServer as BaseServerZMQ, BaseServerKafkaRaw, BaseServerKafkaAnalysis, _get_kafka_producer_config, KafkaTopics, PublisherKafkaAnalysis class XPDConfig(CalibrationConfig, AnalysisConfig, VisConfig, ExportConfig): @@ -50,8 +53,11 @@ class XPDServerConfig(ServerConfig, XPDConfig): """The configuration for xpd server.""" pass +#BaseServerClass = BaseServerZMQ +BaseServerClass = BaseServerKafkaRaw + -class XPDServer(BaseServer): +class XPDServer(BaseServerClass): """The server of XPD data analysis. It is a live dispatcher with XPDRouter subscribed.""" def __init__(self, config: XPDServerConfig): super(XPDServer, self).__init__(config) @@ -97,9 +103,9 @@ class XPDRouter(RunRouter): def __init__(self, config: XPDConfig): factory = XPDFactory(config) - super(XPDRouter, self).__init__( + super().__init__( [factory], - handler_registry=databroker.core.discover_handlers() + handler_registry=databroker.mongo_normalized.discover_handlers() ) @@ -125,9 +131,17 @@ def __init__(self, config: XPDConfig): pub_config["address"][0], pub_config["address"][1], pub_config["prefix"] ) ) - self.analysis[0].subscribe(Publisher(**pub_config)) + + # self.analysis[0].subscribe(PublisherZMQ(**pub_config)) + # if self.calibration: + # self.calibration[0].subscribe(PublisherZMQ(**pub_config)) + + # Kafka configuration for Producer: + kafka_pub_config = _get_kafka_producer_config(KafkaTopics.analysis.value) + + self.analysis[0].subscribe(PublisherKafkaAnalysis(**kafka_pub_config)) if self.calibration: - self.calibration[0].subscribe(Publisher(**pub_config)) + self.calibration[0].subscribe(PublisherKafkaAnalysis(**kafka_pub_config)) def __call__(self, name: str, doc: dict) -> tp.Tuple[list, list]: if name == "start": diff --git a/pdfstream/servers/xpdsave_server.py b/pdfstream/servers/xpdsave_server.py index 13853fac..42deea4f 100644 --- a/pdfstream/servers/xpdsave_server.py +++ b/pdfstream/servers/xpdsave_server.py @@ -1,5 +1,5 @@ from pdfstream.callbacks.analysis import ExportConfig, Exporter -from pdfstream.servers.base import BaseServer, ServerConfig +from pdfstream.servers.base import BaseServer as BaseServerZMQ, ServerConfig, BaseServerKafkaAnalysis class XPDSaveServerConfig(ServerConfig, ExportConfig): @@ -7,7 +7,10 @@ class XPDSaveServerConfig(ServerConfig, ExportConfig): pass -class XPDSaveServer(BaseServer): +#BaseServerClass = BaseServerZMQ +BaseServerClass = BaseServerKafkaAnalysis + +class XPDSaveServer(BaseServerClass): """A server that saves the analyzed data from the xpd server.""" def __init__(self, config: XPDSaveServerConfig): diff --git a/pdfstream/servers/xpdvis_server.py b/pdfstream/servers/xpdvis_server.py index 206dcca3..b0a3da90 100644 --- a/pdfstream/servers/xpdvis_server.py +++ b/pdfstream/servers/xpdvis_server.py @@ -1,7 +1,8 @@ from bluesky.callbacks.best_effort import BestEffortCallback from pdfstream.callbacks.analysis import Visualizer, VisConfig -from pdfstream.servers.base import BaseServer, ServerConfig +from pdfstream.servers.base import BaseServer as BaseServerZMQ, BaseServerKafkaViz, BaseServerKafkaAnalysis, ServerConfig +import matplotlib.pyplot as plt class XPDVisServerConfig(ServerConfig, VisConfig): @@ -9,7 +10,12 @@ class XPDVisServerConfig(ServerConfig, VisConfig): pass -class XPDVisServer(BaseServer): +#BaseServerClass = BaseServerZMQ +# BaseServerClass = BaseServerKafkaViz +BaseServerClass = BaseServerKafkaAnalysis + + +class XPDVisServer(BaseServerClass): """A server that visualizes the analyzed data from the xpd server.""" def __init__(self, config: XPDVisServerConfig): @@ -49,5 +55,16 @@ def make_and_run( config.read(cfg_file) server = XPDVisServer(config) if not test_mode: + # ZMQ: server.install_qt_kicker() - server.start() + kwargs = {} + + # Kafka: + def polling_func(): + plt.gcf().canvas.draw_idle() + plt.gcf().canvas.start_event_loop(0.05) + + kwargs = {"work_during_wait": polling_func} + + # Applies to both servers: + server.start(**kwargs) diff --git a/pdfstream/vend/masking.py b/pdfstream/vend/masking.py index 059dd473..02a28d44 100644 --- a/pdfstream/vend/masking.py +++ b/pdfstream/vend/masking.py @@ -122,7 +122,7 @@ def mask_img( mask_method=auto_type, pool=pool, ) - working_mask = working_mask.astype(np.bool) + working_mask = working_mask.astype(np.bool_) return working_mask diff --git a/pdfstream/vend/qt_kicker.py b/pdfstream/vend/qt_kicker.py index 0e2dfc6e..f217ab95 100644 --- a/pdfstream/vend/qt_kicker.py +++ b/pdfstream/vend/qt_kicker.py @@ -25,12 +25,10 @@ def install_qt_kicker(loop=None, update_rate=0.03): return if not any(p in sys.modules for p in ['PyQt4', 'pyside', 'PyQt5']): return - import matplotlib.backends.backend_qt5 from matplotlib.backends.backend_qt5 import _create_qApp from matplotlib._pylab_helpers import Gcf - _create_qApp() - qApp = matplotlib.backends.backend_qt5.qApp + qApp = _create_qApp() try: _draw_all = Gcf.draw_all # mpl version >= 1.5