diff --git a/src/qcodes/instrument_drivers/AlazarTech/utils.py b/src/qcodes/instrument_drivers/AlazarTech/utils.py index 1bbf8b8b606c..76aa8ffbeea2 100644 --- a/src/qcodes/instrument_drivers/AlazarTech/utils.py +++ b/src/qcodes/instrument_drivers/AlazarTech/utils.py @@ -5,15 +5,16 @@ :mod:`.AlazarTech.helpers` module). """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from qcodes.parameters import Parameter, ParamRawDataType if TYPE_CHECKING: - from .ATS import AlazarTechATS + # ruff does not detect that this is used as a generic parameter below + from .ATS import AlazarTechATS # noqa: F401 -class TraceParameter(Parameter): +class TraceParameter(Parameter[Any, "AlazarTechATS"]): """ A parameter that keeps track of if its value has been synced to the ``Instrument``. To achieve that, this parameter sets @@ -38,6 +39,5 @@ def synced_to_card(self) -> bool: return self._synced_to_card def set_raw(self, value: ParamRawDataType) -> None: - instrument = cast("AlazarTechATS", self.instrument) - instrument._parameters_synced = False + self.instrument._parameters_synced = False self._synced_to_card = False diff --git a/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py b/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py index d401b2b9f46c..5e6457dbe666 100644 --- a/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py +++ b/src/qcodes/instrument_drivers/Keysight/Keysight_N9030B.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar import numpy as np +import numpy.typing as npt from qcodes.instrument import ( InstrumentBaseKWArgs, @@ -23,56 +24,62 @@ from typing_extensions import Unpack +_T = TypeVar( + "_T", bound="KeysightN9030BSpectrumAnalyzerMode | KeysightN9030BPhaseNoiseMode" +) +_S = TypeVar("_S") + -class FrequencyAxis(Parameter): +class FrequencyAxis( + Parameter[ + npt.NDArray[np.float64], + _T, + ], + Generic[_T], +): def __init__( self, - start: Parameter, - stop: Parameter, - npts: Parameter, + start: Parameter[float, _T], + stop: Parameter[float, _T], + npts: Parameter[int, _T], *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._start: Parameter = start - self._stop: Parameter = stop - self._npts: Parameter = npts + self._start = start + self._stop = stop + self._npts = npts def get_raw(self) -> ParamRawDataType: start_val = self._start() stop_val = self._stop() npts_val = self._npts() - assert start_val is not None - assert stop_val is not None - assert npts_val is not None + if start_val is None or stop_val is None or npts_val is None: + raise RuntimeError("Start, Stop and Npts parameters must be set.") return np.linspace(start_val, stop_val, npts_val) -class Trace(ParameterWithSetpoints): +class Trace(ParameterWithSetpoints[_S, _T], Generic[_S, _T]): def __init__( self, number: int, *args: Any, - get_data: Callable[[int], ParamRawDataType], + get_data: Callable[[int], _S], **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - # the parameter classes should ideally be generic in instrument - # and root instrument classes so we can specialize here. - # for now we have to ignore a type error from pyright - self.instrument: ( - KeysightN9030BSpectrumAnalyzerMode | KeysightN9030BPhaseNoiseMode - ) + # while the parameter classes should ideally be generic in instrument + # type it is not generic in the root instrument type self.root_instrument: KeysightN9030B self.number = number self.get_data = get_data - def get_raw(self) -> ParamRawDataType: + def get_raw(self) -> _S: return self.get_data(self.number) -class KeysightN9030BSpectrumAnalyzerMode(InstrumentChannel): +class KeysightN9030BSpectrumAnalyzerMode(InstrumentChannel["KeysightN9030B"]): """ Spectrum Analyzer Mode for Keysight N9030B instrument. """ @@ -86,6 +93,7 @@ def __init__( **kwargs: Unpack[InstrumentBaseKWArgs], ): super().__init__(parent, name, *arg, **kwargs) + self.root_instrument: KeysightN9030B self._additional_wait = additional_wait self._min_freq = -8e7 @@ -104,68 +112,82 @@ def __init__( self._max_freq = self._valid_max_freq[opt] # Frequency Parameters - self.start: Parameter = self.add_parameter( - name="start", - unit="Hz", - get_cmd=":SENSe:FREQuency:STARt?", - set_cmd=self._set_start, - get_parser=float, - vals=Numbers(self._min_freq, self._max_freq - 10), - docstring="Start Frequency", + self.start: Parameter[float, KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="start", + unit="Hz", + get_cmd=":SENSe:FREQuency:STARt?", + set_cmd=self._set_start, + get_parser=float, + vals=Numbers(self._min_freq, self._max_freq - 10), + docstring="Start Frequency", + ) ) """Start Frequency""" - self.stop: Parameter = self.add_parameter( - name="stop", - unit="Hz", - get_cmd=":SENSe:FREQuency:STOP?", - set_cmd=self._set_stop, - get_parser=float, - vals=Numbers(self._min_freq + 10, self._max_freq), - docstring="Stop Frequency", + self.stop: Parameter[float, KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="stop", + unit="Hz", + get_cmd=":SENSe:FREQuency:STOP?", + set_cmd=self._set_stop, + get_parser=float, + vals=Numbers(self._min_freq + 10, self._max_freq), + docstring="Stop Frequency", + ) ) """Stop Frequency""" - self.center: Parameter = self.add_parameter( - name="center", - unit="Hz", - get_cmd=":SENSe:FREQuency:CENTer?", - set_cmd=self._set_center, - get_parser=float, - vals=Numbers(self._min_freq + 5, self._max_freq - 5), - docstring="Sets and gets center frequency", + self.center: Parameter[float, KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="center", + unit="Hz", + get_cmd=":SENSe:FREQuency:CENTer?", + set_cmd=self._set_center, + get_parser=float, + vals=Numbers(self._min_freq + 5, self._max_freq - 5), + docstring="Sets and gets center frequency", + ) ) """Sets and gets center frequency""" - self.span: Parameter = self.add_parameter( - name="span", - unit="Hz", - get_cmd=":SENSe:FREQuency:SPAN?", - set_cmd=self._set_span, - get_parser=float, - vals=Numbers(10, self._max_freq - self._min_freq), - docstring="Changes span of frequency", + self.span: Parameter[float, KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="span", + unit="Hz", + get_cmd=":SENSe:FREQuency:SPAN?", + set_cmd=self._set_span, + get_parser=float, + vals=Numbers(10, self._max_freq - self._min_freq), + docstring="Changes span of frequency", + ) ) """Changes span of frequency""" - self.npts: Parameter = self.add_parameter( - name="npts", - get_cmd=":SENSe:SWEep:POINts?", - set_cmd=":SENSe:SWEep:POINts {}", - get_parser=int, - vals=Ints(1, 20001), - docstring="Number of points for the sweep", + self.npts: Parameter[int, KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="npts", + get_cmd=":SENSe:SWEep:POINts?", + set_cmd=":SENSe:SWEep:POINts {}", + get_parser=int, + vals=Ints(1, 20001), + docstring="Number of points for the sweep", + ) ) """Number of points for the sweep""" # Amplitude/Input Parameters - self.mech_attenuation: Parameter = self.add_parameter( - name="mech_attenuation", - unit="dB", - get_cmd=":SENS:POW:ATT?", - set_cmd=":SENS:POW:ATT {}", - get_parser=int, - vals=Ints(0, 70), - docstring="Internal mechanical attenuation", + self.mech_attenuation: Parameter[int, KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="mech_attenuation", + unit="dB", + get_cmd=":SENS:POW:ATT?", + set_cmd=":SENS:POW:ATT {}", + get_parser=int, + vals=Ints(0, 70), + docstring="Internal mechanical attenuation", + ) ) """Internal mechanical attenuation""" - self.preamp: Parameter = self.add_parameter( + self.preamp: Parameter[ + Literal["LOW", "FULL"], KeysightN9030BSpectrumAnalyzerMode + ] = self.add_parameter( name="preamp", get_cmd=":SENS:POW:GAIN:BAND?", set_cmd=":SENS:POW:GAIN:BAND {}", @@ -354,20 +376,25 @@ def __init__( """Sets up sweep type. Possible options are 'fft' and 'sweep'.""" # Array (Data) Parameters - self.freq_axis: FrequencyAxis = self.add_parameter( - name="freq_axis", - label="Frequency", - unit="Hz", - start=self.start, - stop=self.stop, - npts=self.npts, - vals=Arrays(shape=(self.npts.get_latest,)), - parameter_class=FrequencyAxis, - docstring="Creates frequency axis for the sweep from start, " - "stop and npts values.", + self.freq_axis: FrequencyAxis[KeysightN9030BSpectrumAnalyzerMode] = ( + self.add_parameter( + name="freq_axis", + label="Frequency", + unit="Hz", + start=self.start, + stop=self.stop, + npts=self.npts, + vals=Arrays(shape=(self.npts.get_latest,)), + parameter_class=FrequencyAxis, + docstring="Creates frequency axis for the sweep from start, " + "stop and npts values.", + ) ) + """Creates frequency axis for the sweep from start, stop and npts values.""" - self.trace: Trace = self.add_parameter( + self.trace: Trace[ + npt.NDArray[np.float64], KeysightN9030BSpectrumAnalyzerMode + ] = self.add_parameter( name="trace", label="Trace", unit="dB", @@ -418,7 +445,7 @@ def _set_span(self, val: float) -> None: self.write(f":SENSe:FREQuency:SPAN {val}") self.update_trace() - def _get_data(self, trace_num: int) -> ParamRawDataType: + def _get_data(self, trace_num: int) -> npt.NDArray[np.float64]: """ Gets data from the measurement. """ @@ -443,8 +470,8 @@ def _get_data(self, trace_num: int) -> ParamRawDataType: is_big_endian=False, ) - data = np.array(data).reshape((-1, 2)) - return data[:, 1] + data_array = np.array(data).reshape((-1, 2)) + return data_array[:, 1] def update_trace(self) -> None: """ @@ -481,7 +508,7 @@ def autotune(self) -> None: self.center() -class KeysightN9030BPhaseNoiseMode(InstrumentChannel): +class KeysightN9030BPhaseNoiseMode(InstrumentChannel["KeysightN9030B"]): """ Phase Noise Mode for Keysight N9030B instrument. """ diff --git a/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1517A.py b/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1517A.py index d1ccb9912282..2876d2ee2b39 100644 --- a/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1517A.py +++ b/src/qcodes/instrument_drivers/Keysight/keysightb1500/KeysightB1517A.py @@ -1,6 +1,6 @@ import re import textwrap -from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast, overload +from typing import TYPE_CHECKING, Any, Literal, NotRequired, overload import numpy as np import numpy.typing as npt @@ -701,7 +701,7 @@ def _get_sweep_steps_parser(response: str) -> SweepSteps: """ -class _ParameterWithStatus(Parameter): +class _ParameterWithStatus(Parameter[Any, "KeysightB1517A"]): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -728,7 +728,7 @@ def snapshot_base( class _SpotMeasurementVoltageParameter(_ParameterWithStatus): def set_raw(self, value: ParamRawDataType) -> None: - smu = cast("KeysightB1517A", self.instrument) + smu = self.instrument if smu._source_config["output_range"] is None: smu._source_config["output_range"] = constants.VOutputRange.AUTO @@ -752,7 +752,7 @@ def set_raw(self, value: ParamRawDataType) -> None: ) def get_raw(self) -> ParamRawDataType: - smu = cast("KeysightB1517A", self.instrument) + smu = self.instrument msg = MessageBuilder().tv( chnum=smu.channels[0], @@ -769,7 +769,7 @@ def get_raw(self) -> ParamRawDataType: class _SpotMeasurementCurrentParameter(_ParameterWithStatus): def set_raw(self, value: ParamRawDataType) -> None: - smu = cast("KeysightB1517A", self.instrument) + smu = self.instrument if smu._source_config["output_range"] is None: smu._source_config["output_range"] = constants.IOutputRange.AUTO @@ -793,7 +793,7 @@ def set_raw(self, value: ParamRawDataType) -> None: ) def get_raw(self) -> ParamRawDataType: - smu = cast("KeysightB1517A", self.instrument) + smu = self.instrument msg = MessageBuilder().ti( chnum=smu.channels[0], diff --git a/src/qcodes/instrument_drivers/Keysight/keysightb1500/message_builder.py b/src/qcodes/instrument_drivers/Keysight/keysightb1500/message_builder.py index 10efae3fa635..d60ef45f4ca0 100644 --- a/src/qcodes/instrument_drivers/Keysight/keysightb1500/message_builder.py +++ b/src/qcodes/instrument_drivers/Keysight/keysightb1500/message_builder.py @@ -1,41 +1,39 @@ -from collections.abc import Callable from functools import wraps from operator import xor -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar from . import constants if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable -def as_csv(comps: "Iterable[Any]", sep: str = ",") -> str: +def as_csv(comps: "Iterable[object]", sep: str = ",") -> str: """Returns items in iterable ls as comma-separated string""" return sep.join(format(x) for x in comps) -MessageBuilderMethodT = TypeVar( - "MessageBuilderMethodT", bound=Callable[..., "MessageBuilder"] -) +P = ParamSpec("P") +T = TypeVar("T") -def final_command(f: MessageBuilderMethodT) -> MessageBuilderMethodT: +def final_command(f: "Callable[P, MessageBuilder]") -> "Callable[P, MessageBuilder]": @wraps(f) - def wrapper(*args: Any, **kwargs: Any) -> "MessageBuilder": + def wrapper(*args: P.args, **kwargs: P.kwargs) -> "MessageBuilder": res: MessageBuilder = f(*args, **kwargs) res._msg.set_final() return res - return cast("MessageBuilderMethodT", wrapper) + return wrapper -class CommandList(list[Any]): +class CommandList(list[T], Generic[T]): def __init__(self) -> None: super().__init__() self.is_final = False - def append(self, obj: Any) -> None: + def append(self, obj: T) -> None: if self.is_final: raise ValueError( f"Cannot add commands after `{self[-1]}`. " @@ -67,7 +65,7 @@ class MessageBuilder: """ def __init__(self) -> None: - self._msg = CommandList() + self._msg: CommandList[str] = CommandList() @property def message(self) -> str: diff --git a/src/qcodes/instrument_drivers/QuantumDesign/DynaCoolPPMS/DynaCool.py b/src/qcodes/instrument_drivers/QuantumDesign/DynaCoolPPMS/DynaCool.py index 657f15af54bd..b4f7b38d1e15 100644 --- a/src/qcodes/instrument_drivers/QuantumDesign/DynaCoolPPMS/DynaCool.py +++ b/src/qcodes/instrument_drivers/QuantumDesign/DynaCoolPPMS/DynaCool.py @@ -1,13 +1,7 @@ import warnings from functools import partial from time import sleep -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Literal, - cast, -) +from typing import TYPE_CHECKING, ClassVar, Literal, TypeVar, cast import numpy as np from pyvisa import VisaIOError @@ -22,6 +16,8 @@ from qcodes.parameters import Parameter +_T = TypeVar("_T") + class DynaCool(VisaInstrument): """ @@ -258,7 +254,7 @@ def error_code(self) -> int: return self._error_code @staticmethod - def _pick_one(which_one: int, parser: type, resp: str) -> Any: + def _pick_one(which_one: int, parser: "Callable[[str], _T]", resp: str) -> _T: """ Since most of the API calls return several values in a comma-separated string, here's a convenience function to pick out the substring of @@ -344,7 +340,7 @@ def _field_ramp_setter(self, target: float) -> None: def _measured_field_getter(self) -> float: resp = self.ask("FELD?") - number_in_oersted = cast("float", DynaCool._pick_one(1, float, resp)) + number_in_oersted = DynaCool._pick_one(1, float, resp) number_in_tesla = number_in_oersted * 1e-4 return number_in_tesla diff --git a/src/qcodes/instrument_drivers/rigol/Rigol_DG1062.py b/src/qcodes/instrument_drivers/rigol/Rigol_DG1062.py index 3991d8c8bba0..90a26baea441 100644 --- a/src/qcodes/instrument_drivers/rigol/Rigol_DG1062.py +++ b/src/qcodes/instrument_drivers/rigol/Rigol_DG1062.py @@ -414,9 +414,13 @@ def __init__( ): super().__init__(name, address, **kwargs) - self.ch1 = self.add_submodule("ch1", RigolDG1062Channel(self, "ch1", 1)) + self.ch1: RigolDG1062Channel = self.add_submodule( + "ch1", RigolDG1062Channel(self, "ch1", 1) + ) """Channel 1 submodule""" - self.ch2 = self.add_submodule("ch2", RigolDG1062Channel(self, "ch2", 2)) + self.ch2: RigolDG1062Channel = self.add_submodule( + "ch2", RigolDG1062Channel(self, "ch2", 2) + ) """Channel 2 submodule""" self.channels: ChannelTuple[RigolDG1062Channel] = self.add_submodule( diff --git a/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py b/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py index 23cc98998dc1..1529846de6c1 100644 --- a/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py +++ b/src/qcodes/instrument_drivers/rohde_schwarz/ZNB.py @@ -8,7 +8,6 @@ import qcodes.validators as vals from qcodes.instrument import ( ChannelList, - Instrument, InstrumentBaseKWArgs, InstrumentChannel, VisaInstrument, @@ -19,7 +18,6 @@ ManualParameter, MultiParameter, Parameter, - ParamRawDataType, create_on_off_val_mapping, ) @@ -29,7 +27,12 @@ log = logging.getLogger(__name__) -class FixedFrequencyTraceIQ(MultiParameter): +class FixedFrequencyTraceIQ( + MultiParameter[ + tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]], + "RohdeSchwarzZNBChannel", + ] +): """ Parameter for sweep that returns the real (I) and imaginary (Q) parts of the VNA response. @@ -94,12 +97,16 @@ def get_raw(self) -> tuple[npt.NDArray, npt.NDArray]: `cw_check_sweep_first` is set to `True` then at the cost of a few ms overhead checks if the vna is setup correctly. """ - assert isinstance(self.instrument, RohdeSchwarzZNBChannel) i, q = self.instrument._get_cw_data() return i, q -class FixedFrequencyPointIQ(MultiParameter): +class FixedFrequencyPointIQ( + MultiParameter[ + tuple[float, float], + "RohdeSchwarzZNBChannel", + ] +): """ Parameter for sweep that returns the mean of the real (I) and imaginary (Q) parts of the VNA response. @@ -142,12 +149,16 @@ def get_raw(self) -> tuple[float, float]: parameter `cw_check_sweep_first` is set to `True` then at the cost of a few ms overhead checks if the vna is setup correctly. """ - assert isinstance(self.instrument, RohdeSchwarzZNBChannel) i, q = self.instrument._get_cw_data() return float(np.mean(i)), float(np.mean(q)) -class FixedFrequencyPointMagPhase(MultiParameter): +class FixedFrequencyPointMagPhase( + MultiParameter[ + tuple[float, float], + "RohdeSchwarzZNBChannel", + ] +): """ Parameter for sweep that returns the magnitude of mean of the real (I) and imaginary (Q) parts of the VNA response and it's phase. @@ -185,20 +196,24 @@ def __init__( **kwargs, ) - def get_raw(self) -> tuple[float, ...]: + def get_raw(self) -> tuple[float, float]: """ Gets the magnitude and phase of the mean of the raw real and imaginary part of the data. If the parameter `cw_check_sweep_first` is set to `True` for the instrument then at the cost of a few ms overhead checks if the vna is setup correctly. """ - assert isinstance(self.instrument, RohdeSchwarzZNBChannel) i, q = self.instrument._get_cw_data() s = np.mean(i) + 1j * np.mean(q) return float(np.abs(s)), float(np.angle(s)) -class FrequencySweepMagPhase(MultiParameter): +class FrequencySweepMagPhase( + MultiParameter[ + tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]], + "RohdeSchwarzZNBChannel", + ] +): """ Sweep that return magnitude and phase. """ @@ -247,14 +262,18 @@ def set_sweep(self, start: float, stop: float, npts: int) -> None: self.setpoints = ((f,), (f,)) self.shapes = ((npts,), (npts,)) - def get_raw(self) -> tuple[ParamRawDataType, ...]: - assert isinstance(self.instrument, RohdeSchwarzZNBChannel) + def get_raw(self) -> tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]: with self.instrument.format.set_to("Complex"): data = self.instrument._get_sweep_data(force_polar=True) return abs(data), np.angle(data) -class FrequencySweepDBPhase(MultiParameter): +class FrequencySweepDBPhase( + MultiParameter[ + tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]], + "RohdeSchwarzZNBChannel", + ] +): """ Sweep that return magnitude in decibel (dB) and phase in radians. """ @@ -303,14 +322,18 @@ def set_sweep(self, start: float, stop: float, npts: int) -> None: self.setpoints = ((f,), (f,)) self.shapes = ((npts,), (npts,)) - def get_raw(self) -> tuple[ParamRawDataType, ...]: - assert isinstance(self.instrument, RohdeSchwarzZNBChannel) + def get_raw(self) -> tuple[npt.NDArray[np.floating], npt.NDArray[np.floating]]: with self.instrument.format.set_to("Complex"): data = self.instrument._get_sweep_data(force_polar=True) return 20 * np.log10(np.abs(data)), np.angle(data) -class FrequencySweep(ArrayParameter): +class FrequencySweep( + ArrayParameter[ + npt.NDArray[np.floating], + "RohdeSchwarzZNBChannel", + ] +): """ Hardware controlled parameter class for Rohde Schwarz ZNB trace. @@ -332,7 +355,7 @@ class FrequencySweep(ArrayParameter): def __init__( self, name: str, - instrument: Instrument, + instrument: "RohdeSchwarzZNBChannel", start: float, stop: float, npts: int, @@ -370,8 +393,7 @@ def set_sweep(self, start: float, stop: float, npts: int) -> None: self.setpoints = (f,) self.shape = (npts,) - def get_raw(self) -> ParamRawDataType: - assert isinstance(self.instrument, RohdeSchwarzZNBChannel) + def get_raw(self) -> npt.NDArray[np.floating]: return self.instrument._get_sweep_data() @@ -1087,7 +1109,7 @@ def __init__( self._max_freq: float self._min_freq, self._max_freq = m_frequency[model] - self.num_ports: Parameter = self.add_parameter( + self.num_ports: Parameter[int, RohdeSchwarzZNBBase] = self.add_parameter( name="num_ports", get_cmd="INST:PORT:COUN?", get_parser=int ) """Parameter num_ports""" diff --git a/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py b/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py index 57929e2c0c58..0a278a53fffc 100644 --- a/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py +++ b/src/qcodes/instrument_drivers/tektronix/DPO7200xx.py @@ -7,7 +7,7 @@ import textwrap import time from functools import partial -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar import numpy as np import numpy.typing as npt @@ -767,7 +767,7 @@ def _trigger_type(self, value: str) -> None: self.write(f"TRIGger:{self._identifier}:TYPE {value}") -class TektronixDPOMeasurementParameter(Parameter): +class TektronixDPOMeasurementParameter(Parameter[Any, "TektronixDPOMeasurement"]): """ A measurement parameter does not only return the instantaneous value of a measurement, but can also return some statistics. The accumulation @@ -778,7 +778,7 @@ class TektronixDPOMeasurementParameter(Parameter): """ def _get(self, metric: str) -> float: - measurement_channel = cast("TektronixDPOMeasurement", self.instrument) + measurement_channel = self.instrument if measurement_channel.type.get_latest() != self.name: measurement_channel.type(self.name) diff --git a/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py b/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py index aabb4e981c86..eaeae9b0ebd3 100644 --- a/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py +++ b/src/qcodes/instrument_drivers/yokogawa/Yokogawa_GS200.py @@ -1,5 +1,5 @@ from functools import partial -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal from qcodes.instrument import ( InstrumentBaseKWArgs, @@ -328,12 +328,14 @@ def __init__( ) """Parameter output""" - self.source_mode: Parameter = self.add_parameter( - "source_mode", - label="Source Mode", - get_cmd=":SOUR:FUNC?", - set_cmd=self._set_source_mode, - vals=Enum("VOLT", "CURR"), + self.source_mode: Parameter[Literal["VOLT", "CURR"], YokogawaGS200] = ( + self.add_parameter( + "source_mode", + label="Source Mode", + get_cmd=":SOUR:FUNC?", + set_cmd=self._set_source_mode, + vals=Enum("VOLT", "CURR"), + ) ) """Parameter source_mode""" @@ -410,12 +412,13 @@ def __init__( # We need to pass the source parameter for delegate parameters # (range and output_level) here according to the present # source_mode. - if self.source_mode() == "VOLT": - self.range.source = self.voltage_range - self.output_level.source = self.voltage - else: - self.range.source = self.current_range - self.output_level.source = self.current + match self.source_mode(): + case "VOLT": + self.range.source = self.voltage_range + self.output_level.source = self.voltage + case "CURR": + self.range.source = self.current_range + self.output_level.source = self.current self.voltage_limit: Parameter = self.add_parameter( "voltage_limit", @@ -429,7 +432,7 @@ def __init__( ) """Parameter voltage_limit""" - self.current_limit: Parameter = self.add_parameter( + self.current_limit: Parameter[float, YokogawaGS200] = self.add_parameter( "current_limit", label="Current Protection Limit", unit="I", @@ -621,10 +624,11 @@ def _set_output(self, output_level: float) -> None: ) else: mode = self.source_mode.get_latest() - if mode == "CURR": - self_range = 200e-3 - else: - self_range = 30.0 + match mode: + case "CURR": + self_range = 200e-3 + case "VOLT": + self_range = 30.0 # Check we are not trying to set an out of range value if self.range() is None or abs(output_level) > abs(self_range): @@ -670,7 +674,7 @@ def _update_measurement_module( # since the parameter is not generic in the data type this cannot # narrow None to ModeType even if that is the only valid values # for source_mode. - source_mode = cast("ModeType", self.source_mode.get_latest()) + source_mode = self.source_mode.get_latest() # Get source range if auto-range is off if source_range is None and not self.auto_range(): source_range = self.range() diff --git a/src/qcodes/parameters/array_parameter.py b/src/qcodes/parameters/array_parameter.py index c0c1bc3475c9..640916ba1820 100644 --- a/src/qcodes/parameters/array_parameter.py +++ b/src/qcodes/parameters/array_parameter.py @@ -12,15 +12,14 @@ has_loop = True except ImportError: has_loop = False +from typing import Generic -from .parameter_base import ParameterBase +from .parameter_base import ParameterBase, _InstrumentType_co, _ParameterDataTypeVar from .sequence_helpers import is_sequence_of if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from qcodes.instrument import InstrumentBase - try: from qcodes_loop.data.data_array import DataArray @@ -41,7 +40,10 @@ ) -class ArrayParameter(ParameterBase): +class ArrayParameter( + ParameterBase[_ParameterDataTypeVar, _InstrumentType_co], + Generic[_ParameterDataTypeVar, _InstrumentType_co], +): """ A gettable parameter that returns an array of values. Not necessarily part of an instrument. @@ -131,7 +133,7 @@ def __init__( self, name: str, shape: Sequence[int], - instrument: InstrumentBase | None = None, + instrument: _InstrumentType_co = None, label: str | None = None, unit: str | None = None, setpoints: Sequence[Any] | None = None, diff --git a/src/qcodes/parameters/cache.py b/src/qcodes/parameters/cache.py index 21dd6ade7220..85ac61ebda4d 100644 --- a/src/qcodes/parameters/cache.py +++ b/src/qcodes/parameters/cache.py @@ -1,14 +1,21 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Generic, Literal, Protocol, overload + +from .parameter_base import ( + _ParameterDataTypeVar, +) if TYPE_CHECKING: - from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType + from .parameter_base import ( + ParameterBase, + ParamRawDataType, + ) # The protocol is private to qcodes but used elsewhere in the codebase -class _CacheProtocol(Protocol): # noqa: PYI046 +class _CacheProtocol(Protocol, Generic[_ParameterDataTypeVar]): # noqa: PYI046 """ This protocol defines the interface that a Parameter Cache implementation must implement. This is currently used for 2 implementations, one in @@ -29,24 +36,33 @@ def valid(self) -> bool: ... def invalidate(self) -> None: ... - def set(self, value: ParamDataType) -> None: ... + def set(self, value: _ParameterDataTypeVar) -> None: ... def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: ... - def get(self, get_if_invalid: bool = True) -> ParamDataType: ... + @overload + def get(self, get_if_invalid: Literal[True]) -> _ParameterDataTypeVar: ... + + @overload + def get(self) -> _ParameterDataTypeVar: ... + + @overload + def get(self, get_if_invalid: Literal[False]) -> _ParameterDataTypeVar | None: ... + + def get(self, get_if_invalid: bool = True) -> _ParameterDataTypeVar | None: ... def _update_with( self, *, - value: ParamDataType, + value: _ParameterDataTypeVar, raw_value: ParamRawDataType, timestamp: datetime | None = None, ) -> None: ... - def __call__(self) -> ParamDataType: ... + def __call__(self) -> _ParameterDataTypeVar: ... -class _Cache: +class _Cache(Generic[_ParameterDataTypeVar]): """ Cache object for parameter to hold its value and raw value @@ -66,9 +82,11 @@ class _Cache: """ - def __init__(self, parameter: ParameterBase, max_val_age: float | None = None): + def __init__( + self, parameter: ParameterBase, max_val_age: float | None = None + ) -> None: self._parameter = parameter - self._value: ParamDataType = None + self._value: _ParameterDataTypeVar | None = None self._raw_value: ParamRawDataType = None self._timestamp: datetime | None = None self._max_val_age = max_val_age @@ -115,7 +133,7 @@ def invalidate(self) -> None: """ self._marked_valid = False - def set(self, value: ParamDataType) -> None: + def set(self, value: _ParameterDataTypeVar) -> None: """ Set the cached value of the parameter without invoking the ``set_cmd`` of the parameter (if it has one). For example, in case of @@ -146,7 +164,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: def _update_with( self, *, - value: ParamDataType, + value: _ParameterDataTypeVar, raw_value: ParamRawDataType, timestamp: datetime | None = None, ) -> None: @@ -187,7 +205,16 @@ def _timestamp_expired(self) -> bool: # parameter is still valid return False - def get(self, get_if_invalid: bool = True) -> ParamDataType: + @overload + def get(self, get_if_invalid: Literal[True]) -> _ParameterDataTypeVar: ... + + @overload + def get(self) -> _ParameterDataTypeVar: ... + + @overload + def get(self, get_if_invalid: Literal[False]) -> _ParameterDataTypeVar | None: ... + + def get(self, get_if_invalid: bool = True) -> _ParameterDataTypeVar | None: """ Return cached value if time since get was less than ``max_val_age``, or the parameter was explicitly marked invalid. @@ -246,7 +273,7 @@ def _construct_error_msg(self) -> str: ) return error_msg - def __call__(self) -> ParamDataType: + def __call__(self) -> _ParameterDataTypeVar: """ Same as :meth:`get` but always call ``get`` on parameter if the cache is not valid diff --git a/src/qcodes/parameters/delegate_parameter.py b/src/qcodes/parameters/delegate_parameter.py index f49d695f10e7..019089f604f5 100644 --- a/src/qcodes/parameters/delegate_parameter.py +++ b/src/qcodes/parameters/delegate_parameter.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic -from .parameter import Parameter +from .parameter import Parameter, _InstrumentType_co, _ParameterDataTypeVar if TYPE_CHECKING: from collections.abc import Sequence @@ -10,10 +10,16 @@ from qcodes.validators.validators import Validator - from .parameter_base import ParamDataType, ParamRawDataType + from .parameter_base import ( + ParamDataType, + ParamRawDataType, + ) -class DelegateParameter(Parameter): +class DelegateParameter( + Parameter[_ParameterDataTypeVar, _InstrumentType_co], + Generic[_ParameterDataTypeVar, _InstrumentType_co], +): """ The :class:`.DelegateParameter` wraps a given `source` :class:`Parameter`. Setting/getting it results in a set/get of the source parameter with @@ -52,7 +58,10 @@ class DelegateParameter(Parameter): """ class _DelegateCache: - def __init__(self, parameter: DelegateParameter): + def __init__( + self, + parameter: DelegateParameter[_ParameterDataTypeVar, _InstrumentType_co], + ): self._parameter = parameter self._marked_valid: bool = False @@ -99,7 +108,7 @@ def invalidate(self) -> None: if self._parameter.source is not None: self._parameter.source.cache.invalidate() - def get(self, get_if_invalid: bool = True) -> ParamDataType: + def get(self, get_if_invalid: bool = True) -> _ParameterDataTypeVar: if self._parameter.source is None: raise TypeError( "Cannot get the cache of a DelegateParameter that delegates to None" @@ -108,7 +117,7 @@ def get(self, get_if_invalid: bool = True) -> ParamDataType: self._parameter.source.cache.get(get_if_invalid=get_if_invalid) ) - def set(self, value: ParamDataType) -> None: + def set(self, value: _ParameterDataTypeVar) -> None: if self._parameter.source is None: raise TypeError( "Cannot set the cache of a DelegateParameter that delegates to None" @@ -128,7 +137,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: def _update_with( self, *, - value: ParamDataType, + value: _ParameterDataTypeVar, raw_value: ParamRawDataType, timestamp: datetime | None = None, ) -> None: @@ -142,7 +151,7 @@ def _update_with( """ pass - def __call__(self) -> ParamDataType: + def __call__(self) -> _ParameterDataTypeVar: return self.get(get_if_invalid=True) def __init__( diff --git a/src/qcodes/parameters/multi_parameter.py b/src/qcodes/parameters/multi_parameter.py index 7c850edf19bf..947c3c05eb40 100644 --- a/src/qcodes/parameters/multi_parameter.py +++ b/src/qcodes/parameters/multi_parameter.py @@ -2,16 +2,13 @@ import os from collections.abc import Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import Any, Generic import numpy as np -from .parameter_base import ParameterBase +from .parameter_base import ParameterBase, _InstrumentType_co, _ParameterDataTypeVar from .sequence_helpers import is_sequence_of -if TYPE_CHECKING: - from qcodes.instrument import InstrumentBase - try: from qcodes_loop.data.data_array import DataArray @@ -50,7 +47,10 @@ def _is_nested_sequence_or_none( return True -class MultiParameter(ParameterBase): +class MultiParameter( + ParameterBase[_ParameterDataTypeVar, _InstrumentType_co], + Generic[_ParameterDataTypeVar, _InstrumentType_co], +): """ A gettable parameter that returns multiple values with separate names, each of arbitrary shape. Not necessarily part of an instrument. @@ -141,7 +141,7 @@ def __init__( name: str, names: Sequence[str], shapes: Sequence[Sequence[int]], - instrument: InstrumentBase | None = None, + instrument: _InstrumentType_co = None, labels: Sequence[str] | None = None, units: Sequence[str] | None = None, setpoints: Sequence[Sequence[Any]] | None = None, diff --git a/src/qcodes/parameters/parameter.py b/src/qcodes/parameters/parameter.py index 1cfe55248471..311ce9ea5c59 100644 --- a/src/qcodes/parameters/parameter.py +++ b/src/qcodes/parameters/parameter.py @@ -6,10 +6,12 @@ import logging import os from types import MethodType -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Generic, Literal + +from typing_extensions import TypeVar from .command import Command -from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType +from .parameter_base import ParameterBase, ParamRawDataType from .sweep_values import SweepFixedValues if TYPE_CHECKING: @@ -23,8 +25,19 @@ log = logging.getLogger(__name__) +_ParameterDataTypeVar = TypeVar("_ParameterDataTypeVar", default=Any) +_InstrumentType_co = TypeVar( + "_InstrumentType_co", + bound="InstrumentBase | None", + default="InstrumentBase | None", + covariant=True, +) + -class Parameter(ParameterBase): +class Parameter( + ParameterBase[_ParameterDataTypeVar, _InstrumentType_co], + Generic[_ParameterDataTypeVar, _InstrumentType_co], +): """ A parameter represents a single degree of freedom. Most often, this is the standard parameter for Instruments, though it can also be @@ -172,16 +185,18 @@ class Parameter(ParameterBase): def __init__( self, name: str, - instrument: InstrumentBase | None = None, + # mypy seems to be confused here. The bound and default for _InstrumentType_co + # contains None but mypy will now allow it as a default as of v 1.19.0 + instrument: _InstrumentType_co = None, # type: ignore[assignment] label: str | None = None, unit: str | None = None, get_cmd: str | Callable[..., Any] | Literal[False] | None = None, set_cmd: str | Callable[..., Any] | Literal[False] | None = False, - initial_value: float | str | None = None, + initial_value: _ParameterDataTypeVar | None = None, max_val_age: float | None = None, vals: Validator[Any] | None = None, docstring: str | None = None, - initial_cache_value: float | str | None = None, + initial_cache_value: _ParameterDataTypeVar | None = None, bind_to_instrument: bool = True, **kwargs: Any, ) -> None: @@ -396,14 +411,16 @@ def __getitem__(self, keys: Any) -> SweepFixedValues: """ return SweepFixedValues(self, keys) - def increment(self, value: ParamDataType) -> None: + def increment(self, value: _ParameterDataTypeVar) -> None: """Increment the parameter with a value Args: value: Value to be added to the parameter. """ - self.set(self.get() + value) + # this method only works with parameters that support addition + # however we don't currently enforce that via typing + self.set(self.get() + value) # type: ignore[operator] def sweep( self, diff --git a/src/qcodes/parameters/parameter_base.py b/src/qcodes/parameters/parameter_base.py index ecf02fd4a22b..32a434a7d335 100644 --- a/src/qcodes/parameters/parameter_base.py +++ b/src/qcodes/parameters/parameter_base.py @@ -8,9 +8,10 @@ from contextlib import contextmanager from datetime import datetime from functools import cached_property, wraps -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, overload import numpy as np +from typing_extensions import TypeVar from qcodes.metadatable import Metadatable, MetadatableWithName from qcodes.parameters import ParamSpecBase @@ -41,6 +42,13 @@ from qcodes.dataset.data_set_protocol import ValuesType from qcodes.instrument import InstrumentBase from qcodes.logger.instrument_logger import InstrumentLoggerAdapter +_ParameterDataTypeVar = TypeVar("_ParameterDataTypeVar", default=Any) +_InstrumentType_co = TypeVar( + "_InstrumentType_co", + bound="InstrumentBase | None", + default="InstrumentBase | None", + covariant=True, +) LOG = logging.getLogger(__name__) @@ -109,7 +117,9 @@ def invert_val_mapping(val_mapping: Mapping[Any, Any]) -> dict[Any, Any]: return {v: k for k, v in val_mapping.items()} -class ParameterBase(MetadatableWithName): +class ParameterBase( + MetadatableWithName, Generic[_ParameterDataTypeVar, _InstrumentType_co] +): """ Shared behavior for all parameters. Not intended to be used directly, normally you should use ``Parameter``, ``ArrayParameter``, @@ -212,7 +222,7 @@ class ParameterBase(MetadatableWithName): def __init__( self, name: str, - instrument: InstrumentBase | None, + instrument: _InstrumentType_co = None, # type: ignore[assignment] snapshot_get: bool = True, metadata: Mapping[Any, Any] | None = None, step: float | None = None, @@ -230,7 +240,8 @@ def __init__( abstract: bool | None = False, bind_to_instrument: bool = True, register_name: str | None = None, - on_set_callback: Callable[[ParameterBase, ParamDataType], None] | None = None, + on_set_callback: Callable[[ParameterBase, _ParameterDataTypeVar], None] + | None = None, ) -> None: super().__init__(metadata) if not str(name).isidentifier(): @@ -279,15 +290,17 @@ def __init__( # ``_Cache`` stores "latest" value (and raw value) and timestamp # when it was set or measured - self.cache: _CacheProtocol = _Cache(self, max_val_age=max_val_age) + self.cache: _CacheProtocol[_ParameterDataTypeVar] = _Cache[ + _ParameterDataTypeVar + ](self, max_val_age=max_val_age) # ``GetLatest`` is left from previous versions where it would # implement a subset of features which ``_Cache`` has. # It is left for now for backwards compatibility reasons and shall # be deprecated and removed in the future versions. - self.get_latest: GetLatest - self.get_latest = GetLatest(self) + self.get_latest: GetLatest[_ParameterDataTypeVar] + self.get_latest = GetLatest[_ParameterDataTypeVar](self) - self.get: Callable[..., ParamDataType] + self.get: Callable[..., _ParameterDataTypeVar] self._gettable = False if self._implements_get_raw: self.get = self._wrap_get(self.get_raw) @@ -527,14 +540,14 @@ def __repr__(self) -> str: return named_repr(self) @overload - def __call__(self) -> ParamDataType: + def __call__(self) -> _ParameterDataTypeVar: pass @overload - def __call__(self, value: ParamDataType, **kwargs: Any) -> None: + def __call__(self, value: _ParameterDataTypeVar, **kwargs: Any) -> None: pass - def __call__(self, *args: Any, **kwargs: Any) -> ParamDataType | None: + def __call__(self, *args: Any, **kwargs: Any) -> _ParameterDataTypeVar | None: if len(args) == 0 and len(kwargs) == 0: if self.gettable: return self.get() @@ -606,7 +619,7 @@ def snapshot_base( state["ts"] = dttime.strftime("%Y-%m-%d %H:%M:%S") for attr in set(self._meta_attrs): - if attr == "instrument" and self._instrument: + if attr == "instrument" and self._instrument is not None: state.update( { "instrument": full_class(self._instrument), @@ -635,7 +648,9 @@ def snapshot_value(self) -> bool: """ return self._snapshot_value - def _from_value_to_raw_value(self, value: ParamDataType) -> ParamRawDataType: + def _from_value_to_raw_value( + self, value: _ParameterDataTypeVar + ) -> ParamRawDataType: raw_value: ParamRawDataType if self.val_mapping is not None: @@ -673,28 +688,35 @@ def _from_value_to_raw_value(self, value: ParamDataType) -> ParamRawDataType: return raw_value - def _from_raw_value_to_value(self, raw_value: ParamRawDataType) -> ParamDataType: - value: ParamDataType + def _from_raw_value_to_value( + self, raw_value: ParamRawDataType + ) -> _ParameterDataTypeVar: + value: _ParameterDataTypeVar if self.get_parser is not None: value = self.get_parser(raw_value) else: value = raw_value + # the code below is not very type safe but relies on duck typing / try except + # and assumes the user does not set scale/offset unless the datatype is numeric + # this should probably be rewritten but for now we ignore type errors # apply offset first (native scale) + if self.offset is not None and value is not None: # offset values try: - value = value - self.offset + value = value - self.offset # type: ignore[operator,assignment] except TypeError: if isinstance(self.offset, collections.abc.Iterable): # offset contains multiple elements, one for each value - value = tuple( - val - offset for val, offset in zip(value, self.offset) + value = tuple( # type: ignore[assignment] + val - offset + for val, offset in zip(value, self.offset) # type: ignore[call-overload] ) elif isinstance(value, collections.abc.Iterable): # Use single offset for all values - value = tuple(val - self.offset for val in value) + value = tuple(val - self.offset for val in value) # type: ignore[assignment] else: raise @@ -702,14 +724,14 @@ def _from_raw_value_to_value(self, raw_value: ParamRawDataType) -> ParamDataType if self.scale is not None and value is not None: # Scale values try: - value = value / self.scale + value = value / self.scale # type: ignore[assignment,operator] except TypeError: if isinstance(self.scale, collections.abc.Iterable): # Scale contains multiple elements, one for each value - value = tuple(val / scale for val, scale in zip(value, self.scale)) + value = tuple(val / scale for val, scale in zip(value, self.scale)) # type: ignore[call-overload,assignment] elif isinstance(value, collections.abc.Iterable): # Use single scale for all values - value = tuple(val / self.scale for val in value) + value = tuple(val / self.scale for val in value) # type: ignore[assignment] else: raise @@ -718,17 +740,17 @@ def _from_raw_value_to_value(self, raw_value: ParamRawDataType) -> ParamDataType value = self.inverse_val_mapping[value] else: try: - value = self.inverse_val_mapping[int(value)] + value = self.inverse_val_mapping[int(value)] # type: ignore[call-overload] except (ValueError, KeyError): raise KeyError(f"'{value}' not in val_mapping") - return value + return value # pyright: ignore[reportReturnType] def _wrap_get( self, get_function: Callable[..., ParamRawDataType] - ) -> Callable[..., ParamDataType]: + ) -> Callable[..., _ParameterDataTypeVar]: @wraps(get_function) - def get_wrapper(*args: Any, **kwargs: Any) -> ParamDataType: + def get_wrapper(*args: Any, **kwargs: Any) -> _ParameterDataTypeVar: if not self.gettable: raise TypeError("Trying to get a parameter that is not gettable.") if self.abstract: @@ -756,7 +778,7 @@ def get_wrapper(*args: Any, **kwargs: Any) -> ParamDataType: def _wrap_set(self, set_function: Callable[..., None]) -> Callable[..., None]: @wraps(set_function) - def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: + def set_wrapper(value: _ParameterDataTypeVar, **kwargs: Any) -> None: try: if not self.settable: raise TypeError("Trying to set a parameter that is not settable.") @@ -766,17 +788,20 @@ def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: ) self.validate(value) + # the code below is written in a duck typed way that assumes that + # the user has correctly set step size etc. This could be rewritten + # In some cases intermediate sweep values must be used. # Unless `self.step` is defined, get_sweep_values will return # a list containing only `value`. - steps = self.get_ramp_values(value, step=self.step) + steps = self.get_ramp_values(value, step=self.step) # type: ignore[arg-type] for val_step in steps: # even if the final value is valid we may be generating # steps that are not so validate them too - self.validate(val_step) + self.validate(val_step) # type: ignore[arg-type] - raw_val_step = self._from_value_to_raw_value(val_step) + raw_val_step = self._from_value_to_raw_value(val_step) # type: ignore[arg-type] # Check if delay between set operations is required t_elapsed = time.perf_counter() - self._t_last_set @@ -799,9 +824,9 @@ def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: # Sleep until total time is larger than self.post_delay time.sleep(self.post_delay - t_elapsed) - self.cache._update_with(value=val_step, raw_value=raw_val_step) + self.cache._update_with(value=val_step, raw_value=raw_val_step) # type: ignore[arg-type] - self._call_on_set_callback(val_step) + self._call_on_set_callback(val_step) # type: ignore[arg-type] except Exception as e: e.args = (*e.args, f"setting {self} to {value}") @@ -809,7 +834,7 @@ def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: return set_wrapper - def _call_on_set_callback(self, value: ParamDataType) -> None: + def _call_on_set_callback(self, value: _ParameterDataTypeVar) -> None: try: if self.on_set_callback is not None: self.on_set_callback(self, value) @@ -880,7 +905,7 @@ def _validate_context(self) -> str: context = self.name return "Parameter: " + context - def validate(self, value: ParamDataType) -> None: + def validate(self, value: _ParameterDataTypeVar) -> None: """ Validate the value supplied. @@ -1033,7 +1058,7 @@ def register_name(self) -> str: return self._register_name or self.full_name @property - def instrument(self) -> InstrumentBase | None: + def instrument(self) -> _InstrumentType_co: """ Return the first instrument that this parameter is bound to. E.g if this is bound to a channel it will return the channel @@ -1056,7 +1081,7 @@ def root_instrument(self) -> InstrumentBase | None: return None def set_to( - self, value: ParamDataType, allow_changes: bool = False + self, value: _ParameterDataTypeVar, allow_changes: bool = False ) -> _SetParamContext: """ Use a context manager to temporarily set a parameter to a value. By @@ -1245,7 +1270,7 @@ def unpack_self(self, value: ValuesType) -> list[tuple[ParameterBase, ValuesType return [(self, value)] -class GetLatest(DelegateAttributes): +class GetLatest(DelegateAttributes, Generic[_ParameterDataTypeVar]): """ Wrapper for a class:`.Parameter` that just returns the last set or measured value stored in the class:`.Parameter` itself. If get has never been called @@ -1277,7 +1302,7 @@ def __init__(self, parameter: ParameterBase): delegate_attr_objects: ClassVar[list[str]] = ["parameter"] omit_delegate_attrs: ClassVar[list[str]] = ["set"] - def get(self) -> ParamDataType: + def get(self) -> _ParameterDataTypeVar: """ Return latest value if time since get was less than `max_val_age`, otherwise perform `get()` and @@ -1304,7 +1329,7 @@ def get_raw_value(self) -> ParamRawDataType | None: """ return self.cache._raw_value - def __call__(self) -> ParamDataType: + def __call__(self) -> _ParameterDataTypeVar: """ Same as ``get()`` diff --git a/src/qcodes/parameters/parameter_with_setpoints.py b/src/qcodes/parameters/parameter_with_setpoints.py index d9ef9badf4dd..36effbbb0543 100644 --- a/src/qcodes/parameters/parameter_with_setpoints.py +++ b/src/qcodes/parameters/parameter_with_setpoints.py @@ -1,11 +1,15 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic import numpy as np -from qcodes.parameters.parameter import Parameter +from qcodes.parameters.parameter import ( + Parameter, + _InstrumentType_co, + _ParameterDataTypeVar, +) from qcodes.parameters.parameter_base import ParameterBase, ParameterSet from qcodes.validators import Arrays, Validator @@ -18,7 +22,10 @@ LOG = logging.getLogger(__name__) -class ParameterWithSetpoints(Parameter): +class ParameterWithSetpoints( + Parameter[_ParameterDataTypeVar, _InstrumentType_co], + Generic[_ParameterDataTypeVar, _InstrumentType_co], +): """ A parameter that has associated setpoints. The setpoints is nothing more than a list of other parameters that describe the values, names