diff --git a/gui/control_main.py b/gui/control_main.py index b5c1c408..46ade7c7 100644 --- a/gui/control_main.py +++ b/gui/control_main.py @@ -62,7 +62,8 @@ SnapCommentDialog, StaffScreenDialog, UserScreenDialog, - CalculatorWindow + CalculatorWindow, + MultiColDialog, ) from gui.widgets.log_widget import get_summary_widget from gui.raster import RasterCell, RasterGroup @@ -794,14 +795,10 @@ def createSampleTab(self): ) # something for criteria to decide on which hotspots to collect on for multi-xtal self.hBoxMultiColParamsLayout1 = QtWidgets.QHBoxLayout() self.hBoxMultiColParamsLayout1.setAlignment(QtCore.Qt.AlignLeft) - multiColCutoffLabel = QtWidgets.QLabel("Diffraction Cutoff") - multiColCutoffLabel.setFixedWidth(110) - self.multiColCutoffEdit = QtWidgets.QLineEdit( - "320" - ) # may need to store this in DB at some point, it's a silly number for now - self.multiColCutoffEdit.setFixedWidth(60) - self.hBoxMultiColParamsLayout1.addWidget(multiColCutoffLabel) - self.hBoxMultiColParamsLayout1.addWidget(self.multiColCutoffEdit) + self.multi_col_button = QtWidgets.QPushButton("Select Multicol Centers") + self.multi_col_button.clicked.connect(self.add_multicol) + self.hBoxMultiColParamsLayout1.addWidget(self.multi_col_button) + self.multiColParamsFrame.setLayout(self.hBoxMultiColParamsLayout1) self.characterizeParamsFrame = QFrame() vBoxCharacterizeParams1 = QtWidgets.QVBoxLayout() @@ -1264,6 +1261,7 @@ def createSampleTab(self): focusMinusButton.clicked.connect(functools.partial(self.focusTweakCB, -5)) annealButton = QtWidgets.QPushButton("Anneal") annealButton.clicked.connect(self.annealButtonCB) + annealButton.hide() annealTimeLabel = QtWidgets.QLabel("Time") self.annealTime_ledit = QtWidgets.QLineEdit() self.annealTime_ledit.setFixedWidth(40) @@ -2478,6 +2476,7 @@ def updateVectorLengthAndSpeed(self): ) self.vecLenLabelOutput.setText(str(int(vector_length))) self.vecSpeedLabelOutput.setText(str(int(vector_speed))) + self.calcLifetimeCB() return x_vec, y_vec, z_vec, vector_length def totalExpChanged(self, text): @@ -2885,6 +2884,19 @@ def moveEnergyMaxDeltaCB(self, max_delta=10.0): else: self.popupServerMessage("You don't have control") + def add_multicol(self): + if self.selectedSampleRequest and self.selectedSampleRequest.get("request_type") == "raster": + raster_results = db_lib.getResultsforRequest(self.selectedSampleRequest["uid"]) + + for result in raster_results: + if result["result_type"] == 'rasterResult': + raster_result = result + break + else: + return + multicol_dialog = MultiColDialog(parent=self, raster_req=self.selectedSampleRequest, raster_result=raster_result) + multicol_dialog.show() + def moveEnergyCB(self): if self.controlEnabled(): set_energy = SetEnergyDialog(parent=self) diff --git a/gui/dialog/__init__.py b/gui/dialog/__init__.py index 08dc984c..7088df54 100644 --- a/gui/dialog/__init__.py +++ b/gui/dialog/__init__.py @@ -12,3 +12,4 @@ from .screen_defaults import ScreenDefaultsDialog from .set_energy import SetEnergyDialog from .resolution_dialog import CalculatorWindow +from .multicol import MultiColDialog diff --git a/gui/dialog/multicol.py b/gui/dialog/multicol.py new file mode 100644 index 00000000..2b81c357 --- /dev/null +++ b/gui/dialog/multicol.py @@ -0,0 +1,114 @@ +from datetime import date +import logging +import typing + +from qtpy import QtCore, QtGui, QtWidgets +from qtpy.QtCore import Qt + +if typing.TYPE_CHECKING: + from lsdcGui import ControlMain + +from gui.widgets.heatmap_widget import HeatmapWidget +from utils.raster import determine_raster_shape, create_snake_array, peakfind_maxburn, calculate_flattened_index, get_score_vals +import db_lib +import daq_utils + +logger = logging.getLogger() + +class MultiColDialog(QtWidgets.QDialog): + + def __init__(self, parent: "ControlMain", raster_req: dict, raster_result: dict): + # Pass in the raster request and result for the widget to run a cell selection algorithm + super().__init__(parent) + self._parent = parent + self.raster_req = raster_req + self.raster_result = raster_result + raster_def = raster_req["request_obj"]["rasterDef"] + self.cell_results = raster_result["result_obj"]["rasterCellResults"]['resultObj'] + self.raster_map = raster_result["result_obj"]["rasterCellMap"] + score_vals = get_score_vals(self.cell_results, "spot_count_no_ice") + self.direction, self.M, self.N = determine_raster_shape(raster_def) + self.raster_array = create_snake_array(score_vals, self.direction, self.M, self.N) + self.initUI(self.raster_array, self.cell_results) + + def initUI(self, data, cell_results): + layout = QtWidgets.QGridLayout() + self.heatmap_widget = HeatmapWidget(self._parent, data=data, cell_results=cell_results) + threshold_label = QtWidgets.QLabel("Number of centers:") + validator = QtGui.QIntValidator() + self.threshold_edit = QtWidgets.QLineEdit("10") + self.threshold_edit.setValidator(validator) + self.calculate_centers_button = QtWidgets.QPushButton("Calculate centers") + self.calculate_centers_button.clicked.connect(self.calculate_centers) + self.clear_centers_button = QtWidgets.QPushButton("Clear Centers") + self.clear_centers_button.clicked.connect(self.heatmap_widget.clear_highlights) + + wedge_label = QtWidgets.QLabel("Wedge:") + self.wedge_edit = QtWidgets.QLineEdit() + self.wedge_edit.setValidator(QtGui.QDoubleValidator()) + + self.submit_centers_button = QtWidgets.QPushButton("Submit Centers") + self.submit_centers_button.clicked.connect(self.submit_centers) + self.cancel_button = QtWidgets.QPushButton("Cancel") + self.cancel_button.clicked.connect(self.close) + layout.addWidget(self.heatmap_widget, 0, 0, 1, 4) + layout.addWidget(threshold_label, 1, 0) + layout.addWidget(self.threshold_edit, 1, 1) + layout.addWidget(self.calculate_centers_button, 1, 2) + layout.addWidget(self.clear_centers_button, 1, 3, 1, 1) + + layout.addWidget(wedge_label, 2, 0, 1, 1) + + layout.addWidget(self.submit_centers_button, 3, 0, 1, 1) + layout.addWidget(self.cancel_button, 3, 3, 1, 1) + self.setLayout(layout) + + def calculate_centers(self): + self.heatmap_widget.clear_highlights() + indices, array = peakfind_maxburn(self.raster_array, int(self.threshold_edit.text())) + self.heatmap_widget.highlight_cells(indices) + + + def submit_centers(self): + indices = self.heatmap_widget.highlighted_patches.keys() + for (x, y) in indices: + flattened_index = calculate_flattened_index(x, y, self.M, self.N, self.direction) + hitFile = self.cell_results[flattened_index]["cellMapKey"] + hitCoords = self.raster_map[hitFile] + self.addMultiRequestLocation(self.raster_result["request"], hitCoords, flattened_index, float(self._parent.osc_end_ledit.text())) + self._parent.treeChanged_pv.put(1) + self.accept() + + + def addMultiRequestLocation(self, parentReqID, hitCoords, locIndex, wedge=10.0): + parentRequest = db_lib.getRequestByID(parentReqID) + sampleID = parentRequest["sample"] + + logger.info(str(sampleID)) + logger.info(hitCoords) + dataDirectory = parentRequest["request_obj"]['directory']+"multi_"+str(locIndex) + runNum = parentRequest["request_obj"]['runNum'] + tempnewStratRequest = daq_utils.createDefaultRequest(sampleID) + ss = parentRequest["request_obj"]["rasterDef"]["omega"] + if "wedge" in parentRequest["request_obj"]: + wedge = float(parentRequest["request_obj"]["wedge"]) + + newReqObj = tempnewStratRequest["request_obj"] + newReqObj["sweep_start"] = ss - wedge/2 + newReqObj["sweep_end"] = ss + wedge/2 + newReqObj["img_width"] = float(self._parent.osc_range_ledit.text()) + newReqObj["exposure_time"] = float(self._parent.exp_time_ledit.text()) + newReqObj["detDist"] = float(self._parent.detDistMotorEntry.getEntry().text()) + newReqObj["directory"] = dataDirectory + newReqObj["pos_x"] = hitCoords['x'] + newReqObj["pos_y"] = hitCoords['y'] + newReqObj["pos_z"] = hitCoords['z'] + newReqObj["fastDP"] = True + newReqObj["fastEP"] = False + newReqObj["dimple"] = False + newReqObj["xia2"] = False + newReqObj["runNum"] = runNum + newReqObj["parentReqID"] = parentReqID + newReqObj["energy"] = self._parent.energy_pv.get() + newReqObj["wavelength"] = daq_utils.energy2wave(newReqObj["energy"]) + db_lib.addRequesttoSample(sampleID,newReqObj["protocol"],daq_utils.owner,newReqObj,priority=6000,proposalID=daq_utils.getProposalID()) diff --git a/gui/widgets/heatmap_widget.py b/gui/widgets/heatmap_widget.py new file mode 100644 index 00000000..21bc6268 --- /dev/null +++ b/gui/widgets/heatmap_widget.py @@ -0,0 +1,178 @@ +from qtpy.QtWidgets import ( + QSizePolicy, +) +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from matplotlib.patches import Rectangle +from mpl_toolkits.axes_grid1 import make_axes_locatable +from utils.raster import calculate_flattened_index, determine_raster_shape, get_score_vals +import numpy as np +from typing import Tuple, TYPE_CHECKING +if TYPE_CHECKING: + from lsdcGui import ControlMain + +class MplCanvas(FigureCanvas): + def __init__(self, parent=None, width=5, height=4, dpi=100): + fig = Figure(figsize=(width, height), dpi=dpi) + self.axes = fig.add_subplot(111) + super(MplCanvas, self).__init__(fig) + self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.updateGeometry() + self.colorbar = None + + +class HeatmapWidget(MplCanvas): + def __init__(self, parent: "ControlMain", width=5, height=4, dpi=100, data=None, cell_results=None): + super().__init__(parent, width, height, dpi) + self.mpl_connect("button_press_event", self.on_click) + self.mpl_connect("motion_notify_event", self.on_hover) + self._parent = parent + self.data = data + self.cell_results = cell_results + self.highlighted_patches = {} + self.render_heatmap() + + def render_heatmap(self): + if self.data is None: + return + + # Create the heatmap + cax = self.axes.imshow(self.data, cmap="inferno", origin="upper") + + y_ticks = np.arange(self.data.shape[0]) + x_ticks = np.arange(self.data.shape[1]) + + # Label offset: adjust ticks to be between points + self.axes.set_xticks(x_ticks) + self.axes.set_yticks(y_ticks) + + self.axes.set_xlim([x_ticks[0] - 0.5, x_ticks[-1] + 0.5]) + self.axes.set_ylim([y_ticks[-1] + 0.5, y_ticks[0] - 0.5]) + + divider = make_axes_locatable(self.axes) + cax_cb = divider.append_axes( + "right", size="5%", pad=0.05 + ) # Adjust size and padding as needed + self.colorbar = self.figure.colorbar(cax, cax=cax_cb) + + # Create a text annotation for the tooltip, initially hidden + self.tooltip = self.axes.text( + -2.5, + 0.95, + "", + color="white", + backgroundcolor="black", + ha="center", + va="center", + fontsize=10, + bbox=dict(facecolor="black", alpha=0.8), + ) + self.tooltip.set_visible(False) + + # Create a rectangle for highlighting cells, initially hidden + self.highlight = Rectangle( + (0, 0), 1, 1, linewidth=2, edgecolor="white", facecolor="none" + ) + self.axes.add_patch(self.highlight) + self.highlight.set_visible(False) + + # Redraw the canvas + self.draw() + self.figure.tight_layout() + + def on_click(self, event): + if self.data is None or event.inaxes is None: + return + x, y = int(round(event.xdata)), int(round(event.ydata)) + col, row = int(np.floor(x)), int(np.floor(y)) + if 0 <= row < self.data.shape[0] and 0 <= col < self.data.shape[1]: + value = self.data[row, col] + if event.dblclick: + self.highlight_patch(row, col) + else: + self.show_diffraction_image(row, col) + + def show_diffraction_image(self, row, col): + if not self.cell_results: + return + flattened_index = calculate_flattened_index(row, col, self.data.shape[0], self.data.shape[1]) + cell_filename = self.cell_results[flattened_index].get("image") + if not cell_filename: + return + self._parent.albulaInterface.open_file(cell_filename) + + + def highlight_patch(self, row, col): + if (row, col) in self.highlighted_patches: + patch = self.highlighted_patches.pop((row, col)) + patch.remove() + else: + rect = Rectangle( + (col - 0.5, row - 0.5), + 1, + 1, + linewidth=2, + edgecolor="green", + facecolor="none" + ) + self.axes.add_patch(rect) + self.highlighted_patches[(row, col)] = rect + self.draw_idle() # Redraw the canvas + + + def on_hover(self, event): + """Display a tooltip with the intensity value at the mouse position.""" + if self.data is None: + return + + matrix = self.data + # Check if the mouse is over the axes + if event.inaxes == self.axes: + # Get the row and column indices + #x, y = event.xdata, event.ydata + x, y = int(round(event.xdata)), int(round(event.ydata)) + col, row = int(np.floor(x)), int(np.floor(y)) + + # Check if the indices are within the bounds of the matrix + if 0 <= row < matrix.shape[0] and 0 <= col < matrix.shape[1]: + # Get the intensity of the current cell + intensity = matrix[row, col] + + # Update the position and text of the tooltip + # self.tooltip.set_position((col+5, row+5)) + self.tooltip.set_text(f"({row}, {col})\nSpot Count: {intensity:.2f}") + self.tooltip.set_visible(True) + + # Update the position of the highlight rectangle + self.highlight.set_bounds(col - 0.5, row - 0.5, 1, 1) + self.highlight.set_visible(True) + else: + # Hide the tooltip and highlight if outside the matrix + self.tooltip.set_visible(False) + self.highlight.set_visible(False) + else: + # Hide the tooltip and highlight if the mouse is outside the axes + self.tooltip.set_visible(False) + self.highlight.set_visible(False) + + self.draw_idle() # Redraw the canvas + + def highlight_cells(self, indices: "list[Tuple[float, float]]"): + for (i, j) in indices: + rect = Rectangle( + (j - 0.5, i - 0.5), + 1, + 1, + linewidth=2, + edgecolor="green", + facecolor="none" + ) + self.axes.add_patch(rect) + self.highlighted_patches[(i, j)] = rect + self.draw_idle() + + def clear_highlights(self): + for key in list(self.highlighted_patches): + patch = self.highlighted_patches.pop(key) + patch.remove() + self.draw_idle() diff --git a/utils/raster.py b/utils/raster.py index 97a6fd5f..835f2627 100644 --- a/utils/raster.py +++ b/utils/raster.py @@ -1,4 +1,20 @@ import numpy as np +import logging +import db_lib, daq_utils + +logger = logging.getLogger() + +def get_score_vals(cellResults, scoreOption): + """ + Returns a numpy 1d-array that stores the selected scores as a flattened array + """ + score_vals = np.zeros(len(cellResults)) + for i, res in enumerate(cellResults): + try: + score_vals[i] = float(res[scoreOption]) + except TypeError: + logger.debug(f"Option {scoreOption} not found for {res}") + return score_vals def calculate_matrix_index(k, M, N, pattern="horizontal"): @@ -108,3 +124,25 @@ def get_flattened_indices_of_max_col(raster_def, max_col): ) return indices + +def peakfind_maxburn(array, num_iter): + ''' + Collection center finding for multiCol protocol + + Input 2D array, find max element, store max index in list, + then set max element and its 8 neighbors to zero (a.k.a. burn this spot) + + Repeat until all elements set to zero, or maximum number of iterations reached + + Returns center list, and "burnt" array + ''' + arr_work = array.copy() + indices = [] + iterate = 1 + while arr_work.max() != 0 and iterate <= num_iter: + iterate = iterate + 1 + i, j = np.unravel_index(np.argmax(arr_work), arr_work.shape) + indices.append((i, j)) + arr_work[max(i-1, 0):i+2, max(j-1, 0):j+2] = 0 + return indices, arr_work +