Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions gui/control_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
SnapCommentDialog,
StaffScreenDialog,
UserScreenDialog,
CalculatorWindow
CalculatorWindow,
MultiColDialog,
)
from gui.widgets.log_widget import get_summary_widget
from gui.raster import RasterCell, RasterGroup
Expand Down Expand Up @@ -794,14 +795,10 @@ def createSampleTab(self):
) # something for criteria to decide on which hotspots to collect on for multi-xtal
self.hBoxMultiColParamsLayout1 = QtWidgets.QHBoxLayout()
self.hBoxMultiColParamsLayout1.setAlignment(QtCore.Qt.AlignLeft)
multiColCutoffLabel = QtWidgets.QLabel("Diffraction Cutoff")
multiColCutoffLabel.setFixedWidth(110)
self.multiColCutoffEdit = QtWidgets.QLineEdit(
"320"
) # may need to store this in DB at some point, it's a silly number for now
self.multiColCutoffEdit.setFixedWidth(60)
self.hBoxMultiColParamsLayout1.addWidget(multiColCutoffLabel)
self.hBoxMultiColParamsLayout1.addWidget(self.multiColCutoffEdit)
self.multi_col_button = QtWidgets.QPushButton("Select Multicol Centers")
self.multi_col_button.clicked.connect(self.add_multicol)
self.hBoxMultiColParamsLayout1.addWidget(self.multi_col_button)

self.multiColParamsFrame.setLayout(self.hBoxMultiColParamsLayout1)
self.characterizeParamsFrame = QFrame()
vBoxCharacterizeParams1 = QtWidgets.QVBoxLayout()
Expand Down Expand Up @@ -1264,6 +1261,7 @@ def createSampleTab(self):
focusMinusButton.clicked.connect(functools.partial(self.focusTweakCB, -5))
annealButton = QtWidgets.QPushButton("Anneal")
annealButton.clicked.connect(self.annealButtonCB)
annealButton.hide()
annealTimeLabel = QtWidgets.QLabel("Time")
self.annealTime_ledit = QtWidgets.QLineEdit()
self.annealTime_ledit.setFixedWidth(40)
Expand Down Expand Up @@ -2478,6 +2476,7 @@ def updateVectorLengthAndSpeed(self):
)
self.vecLenLabelOutput.setText(str(int(vector_length)))
self.vecSpeedLabelOutput.setText(str(int(vector_speed)))
self.calcLifetimeCB()
return x_vec, y_vec, z_vec, vector_length

def totalExpChanged(self, text):
Expand Down Expand Up @@ -2885,6 +2884,19 @@ def moveEnergyMaxDeltaCB(self, max_delta=10.0):
else:
self.popupServerMessage("You don't have control")

def add_multicol(self):
if self.selectedSampleRequest and self.selectedSampleRequest.get("request_type") == "raster":
raster_results = db_lib.getResultsforRequest(self.selectedSampleRequest["uid"])

for result in raster_results:
if result["result_type"] == 'rasterResult':
raster_result = result
break
else:
return
multicol_dialog = MultiColDialog(parent=self, raster_req=self.selectedSampleRequest, raster_result=raster_result)
multicol_dialog.show()

def moveEnergyCB(self):
if self.controlEnabled():
set_energy = SetEnergyDialog(parent=self)
Expand Down
1 change: 1 addition & 0 deletions gui/dialog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .screen_defaults import ScreenDefaultsDialog
from .set_energy import SetEnergyDialog
from .resolution_dialog import CalculatorWindow
from .multicol import MultiColDialog
114 changes: 114 additions & 0 deletions gui/dialog/multicol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from datetime import date
import logging
import typing

from qtpy import QtCore, QtGui, QtWidgets
from qtpy.QtCore import Qt

if typing.TYPE_CHECKING:
from lsdcGui import ControlMain

from gui.widgets.heatmap_widget import HeatmapWidget
from utils.raster import determine_raster_shape, create_snake_array, peakfind_maxburn, calculate_flattened_index, get_score_vals
import db_lib
import daq_utils

logger = logging.getLogger()

class MultiColDialog(QtWidgets.QDialog):

def __init__(self, parent: "ControlMain", raster_req: dict, raster_result: dict):
# Pass in the raster request and result for the widget to run a cell selection algorithm
super().__init__(parent)
self._parent = parent
self.raster_req = raster_req
self.raster_result = raster_result
raster_def = raster_req["request_obj"]["rasterDef"]
self.cell_results = raster_result["result_obj"]["rasterCellResults"]['resultObj']
self.raster_map = raster_result["result_obj"]["rasterCellMap"]
score_vals = get_score_vals(self.cell_results, "spot_count_no_ice")
self.direction, self.M, self.N = determine_raster_shape(raster_def)
self.raster_array = create_snake_array(score_vals, self.direction, self.M, self.N)
self.initUI(self.raster_array, self.cell_results)

def initUI(self, data, cell_results):
layout = QtWidgets.QGridLayout()
self.heatmap_widget = HeatmapWidget(self._parent, data=data, cell_results=cell_results)
threshold_label = QtWidgets.QLabel("Number of centers:")
validator = QtGui.QIntValidator()
self.threshold_edit = QtWidgets.QLineEdit("10")
self.threshold_edit.setValidator(validator)
self.calculate_centers_button = QtWidgets.QPushButton("Calculate centers")
self.calculate_centers_button.clicked.connect(self.calculate_centers)
self.clear_centers_button = QtWidgets.QPushButton("Clear Centers")
self.clear_centers_button.clicked.connect(self.heatmap_widget.clear_highlights)

wedge_label = QtWidgets.QLabel("Wedge:")
self.wedge_edit = QtWidgets.QLineEdit()
self.wedge_edit.setValidator(QtGui.QDoubleValidator())

self.submit_centers_button = QtWidgets.QPushButton("Submit Centers")
self.submit_centers_button.clicked.connect(self.submit_centers)
self.cancel_button = QtWidgets.QPushButton("Cancel")
self.cancel_button.clicked.connect(self.close)
layout.addWidget(self.heatmap_widget, 0, 0, 1, 4)
layout.addWidget(threshold_label, 1, 0)
layout.addWidget(self.threshold_edit, 1, 1)
layout.addWidget(self.calculate_centers_button, 1, 2)
layout.addWidget(self.clear_centers_button, 1, 3, 1, 1)

layout.addWidget(wedge_label, 2, 0, 1, 1)

layout.addWidget(self.submit_centers_button, 3, 0, 1, 1)
layout.addWidget(self.cancel_button, 3, 3, 1, 1)
self.setLayout(layout)

def calculate_centers(self):
self.heatmap_widget.clear_highlights()
indices, array = peakfind_maxburn(self.raster_array, int(self.threshold_edit.text()))
self.heatmap_widget.highlight_cells(indices)


def submit_centers(self):
indices = self.heatmap_widget.highlighted_patches.keys()
for (x, y) in indices:
flattened_index = calculate_flattened_index(x, y, self.M, self.N, self.direction)
hitFile = self.cell_results[flattened_index]["cellMapKey"]
hitCoords = self.raster_map[hitFile]
self.addMultiRequestLocation(self.raster_result["request"], hitCoords, flattened_index, float(self._parent.osc_end_ledit.text()))
self._parent.treeChanged_pv.put(1)
self.accept()


def addMultiRequestLocation(self, parentReqID, hitCoords, locIndex, wedge=10.0):
parentRequest = db_lib.getRequestByID(parentReqID)
sampleID = parentRequest["sample"]

logger.info(str(sampleID))
logger.info(hitCoords)
dataDirectory = parentRequest["request_obj"]['directory']+"multi_"+str(locIndex)
runNum = parentRequest["request_obj"]['runNum']
tempnewStratRequest = daq_utils.createDefaultRequest(sampleID)
ss = parentRequest["request_obj"]["rasterDef"]["omega"]
if "wedge" in parentRequest["request_obj"]:
wedge = float(parentRequest["request_obj"]["wedge"])

newReqObj = tempnewStratRequest["request_obj"]
newReqObj["sweep_start"] = ss - wedge/2
newReqObj["sweep_end"] = ss + wedge/2
newReqObj["img_width"] = float(self._parent.osc_range_ledit.text())
newReqObj["exposure_time"] = float(self._parent.exp_time_ledit.text())
newReqObj["detDist"] = float(self._parent.detDistMotorEntry.getEntry().text())
newReqObj["directory"] = dataDirectory
newReqObj["pos_x"] = hitCoords['x']
newReqObj["pos_y"] = hitCoords['y']
newReqObj["pos_z"] = hitCoords['z']
newReqObj["fastDP"] = True
newReqObj["fastEP"] = False
newReqObj["dimple"] = False
newReqObj["xia2"] = False
newReqObj["runNum"] = runNum
newReqObj["parentReqID"] = parentReqID
newReqObj["energy"] = self._parent.energy_pv.get()
newReqObj["wavelength"] = daq_utils.energy2wave(newReqObj["energy"])
db_lib.addRequesttoSample(sampleID,newReqObj["protocol"],daq_utils.owner,newReqObj,priority=6000,proposalID=daq_utils.getProposalID())
178 changes: 178 additions & 0 deletions gui/widgets/heatmap_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from qtpy.QtWidgets import (
QSizePolicy,
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import make_axes_locatable
from utils.raster import calculate_flattened_index, determine_raster_shape, get_score_vals
import numpy as np
from typing import Tuple, TYPE_CHECKING
if TYPE_CHECKING:
from lsdcGui import ControlMain

class MplCanvas(FigureCanvas):
def __init__(self, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
self.axes = fig.add_subplot(111)
super(MplCanvas, self).__init__(fig)
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.updateGeometry()
self.colorbar = None


class HeatmapWidget(MplCanvas):
def __init__(self, parent: "ControlMain", width=5, height=4, dpi=100, data=None, cell_results=None):
super().__init__(parent, width, height, dpi)
self.mpl_connect("button_press_event", self.on_click)
self.mpl_connect("motion_notify_event", self.on_hover)
self._parent = parent
self.data = data
self.cell_results = cell_results
self.highlighted_patches = {}
self.render_heatmap()

def render_heatmap(self):
if self.data is None:
return

# Create the heatmap
cax = self.axes.imshow(self.data, cmap="inferno", origin="upper")

y_ticks = np.arange(self.data.shape[0])
x_ticks = np.arange(self.data.shape[1])

# Label offset: adjust ticks to be between points
self.axes.set_xticks(x_ticks)
self.axes.set_yticks(y_ticks)

self.axes.set_xlim([x_ticks[0] - 0.5, x_ticks[-1] + 0.5])
self.axes.set_ylim([y_ticks[-1] + 0.5, y_ticks[0] - 0.5])

divider = make_axes_locatable(self.axes)
cax_cb = divider.append_axes(
"right", size="5%", pad=0.05
) # Adjust size and padding as needed
self.colorbar = self.figure.colorbar(cax, cax=cax_cb)

# Create a text annotation for the tooltip, initially hidden
self.tooltip = self.axes.text(
-2.5,
0.95,
"",
color="white",
backgroundcolor="black",
ha="center",
va="center",
fontsize=10,
bbox=dict(facecolor="black", alpha=0.8),
)
self.tooltip.set_visible(False)

# Create a rectangle for highlighting cells, initially hidden
self.highlight = Rectangle(
(0, 0), 1, 1, linewidth=2, edgecolor="white", facecolor="none"
)
self.axes.add_patch(self.highlight)
self.highlight.set_visible(False)

# Redraw the canvas
self.draw()
self.figure.tight_layout()

def on_click(self, event):
if self.data is None or event.inaxes is None:
return
x, y = int(round(event.xdata)), int(round(event.ydata))
col, row = int(np.floor(x)), int(np.floor(y))
if 0 <= row < self.data.shape[0] and 0 <= col < self.data.shape[1]:
value = self.data[row, col]
if event.dblclick:
self.highlight_patch(row, col)
else:
self.show_diffraction_image(row, col)

def show_diffraction_image(self, row, col):
if not self.cell_results:
return
flattened_index = calculate_flattened_index(row, col, self.data.shape[0], self.data.shape[1])
cell_filename = self.cell_results[flattened_index].get("image")
if not cell_filename:
return
self._parent.albulaInterface.open_file(cell_filename)


def highlight_patch(self, row, col):
if (row, col) in self.highlighted_patches:
patch = self.highlighted_patches.pop((row, col))
patch.remove()
else:
rect = Rectangle(
(col - 0.5, row - 0.5),
1,
1,
linewidth=2,
edgecolor="green",
facecolor="none"
)
self.axes.add_patch(rect)
self.highlighted_patches[(row, col)] = rect
self.draw_idle() # Redraw the canvas


def on_hover(self, event):
"""Display a tooltip with the intensity value at the mouse position."""
if self.data is None:
return

matrix = self.data
# Check if the mouse is over the axes
if event.inaxes == self.axes:
# Get the row and column indices
#x, y = event.xdata, event.ydata
x, y = int(round(event.xdata)), int(round(event.ydata))
col, row = int(np.floor(x)), int(np.floor(y))

# Check if the indices are within the bounds of the matrix
if 0 <= row < matrix.shape[0] and 0 <= col < matrix.shape[1]:
# Get the intensity of the current cell
intensity = matrix[row, col]

# Update the position and text of the tooltip
# self.tooltip.set_position((col+5, row+5))
self.tooltip.set_text(f"({row}, {col})\nSpot Count: {intensity:.2f}")
self.tooltip.set_visible(True)

# Update the position of the highlight rectangle
self.highlight.set_bounds(col - 0.5, row - 0.5, 1, 1)
self.highlight.set_visible(True)
else:
# Hide the tooltip and highlight if outside the matrix
self.tooltip.set_visible(False)
self.highlight.set_visible(False)
else:
# Hide the tooltip and highlight if the mouse is outside the axes
self.tooltip.set_visible(False)
self.highlight.set_visible(False)

self.draw_idle() # Redraw the canvas

def highlight_cells(self, indices: "list[Tuple[float, float]]"):
for (i, j) in indices:
rect = Rectangle(
(j - 0.5, i - 0.5),
1,
1,
linewidth=2,
edgecolor="green",
facecolor="none"
)
self.axes.add_patch(rect)
self.highlighted_patches[(i, j)] = rect
self.draw_idle()

def clear_highlights(self):
for key in list(self.highlighted_patches):
patch = self.highlighted_patches.pop(key)
patch.remove()
self.draw_idle()
Loading