diff --git a/docs/changes/newsfragments/7730.improved b/docs/changes/newsfragments/7730.improved new file mode 100644 index 00000000000..cf1eab68ef3 --- /dev/null +++ b/docs/changes/newsfragments/7730.improved @@ -0,0 +1,4 @@ +The QCoDeS Parameter classes ``ParameterBase``, ``Parameter``, ``ParameterWithSetpoints``, ``DelegateParameter``, ``ArrayParameter`` and ``MultiParameter`` now +takes two Optional Generic arguments to allow the data type and the type of the instrument the parameter is bound to to be fixed statically. This enables +the type of the output of ``parameter.get()``, input of ``parameter.set()`` and value of ``parameter.instrument`` to be known statically such that type +checkers and IDE's can make use of this information. diff --git a/src/qcodes/parameters/array_parameter.py b/src/qcodes/parameters/array_parameter.py index c0c1bc3475c..595e19fb7f6 100644 --- a/src/qcodes/parameters/array_parameter.py +++ b/src/qcodes/parameters/array_parameter.py @@ -12,15 +12,14 @@ has_loop = True except ImportError: has_loop = False +from typing import Generic -from .parameter_base import ParameterBase +from .parameter_base import InstrumentType_co, ParameterBase, ParameterDataTypeVar from .sequence_helpers import is_sequence_of if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from qcodes.instrument import InstrumentBase - try: from qcodes_loop.data.data_array import DataArray @@ -41,7 +40,10 @@ ) -class ArrayParameter(ParameterBase): +class ArrayParameter( + ParameterBase[ParameterDataTypeVar, InstrumentType_co], + Generic[ParameterDataTypeVar, InstrumentType_co], +): """ A gettable parameter that returns an array of values. Not necessarily part of an instrument. @@ -131,7 +133,9 @@ def __init__( self, name: str, shape: Sequence[int], - instrument: InstrumentBase | None = None, + # mypy seems to be confused here. The bound and default for InstrumentType_co + # contains None but mypy will now allow it as a default as of v 1.19.0 + instrument: InstrumentType_co = None, # type: ignore[assignment] label: str | None = None, unit: str | None = None, setpoints: Sequence[Any] | None = None, diff --git a/src/qcodes/parameters/cache.py b/src/qcodes/parameters/cache.py index 21dd6ade722..28c8f55ae6b 100644 --- a/src/qcodes/parameters/cache.py +++ b/src/qcodes/parameters/cache.py @@ -1,14 +1,22 @@ from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, overload + +from typing_extensions import TypeVar + +# due to circular imports we cannot import the TypeVar from parameter_base +ParameterDataTypeVar = TypeVar("ParameterDataTypeVar", default=Any) if TYPE_CHECKING: - from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType + from .parameter_base import ( + ParameterBase, + ParamRawDataType, + ) # The protocol is private to qcodes but used elsewhere in the codebase -class _CacheProtocol(Protocol): # noqa: PYI046 +class _CacheProtocol(Protocol, Generic[ParameterDataTypeVar]): # noqa: PYI046 """ This protocol defines the interface that a Parameter Cache implementation must implement. This is currently used for 2 implementations, one in @@ -29,24 +37,36 @@ def valid(self) -> bool: ... def invalidate(self) -> None: ... - def set(self, value: ParamDataType) -> None: ... + def set(self, value: ParameterDataTypeVar) -> None: ... def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: ... - def get(self, get_if_invalid: bool = True) -> ParamDataType: ... + @overload + def get(self, get_if_invalid: Literal[True]) -> ParameterDataTypeVar: ... + + @overload + def get(self) -> ParameterDataTypeVar: ... + + @overload + def get(self, get_if_invalid: Literal[False]) -> ParameterDataTypeVar | None: ... + + @overload + def get(self, get_if_invalid: bool) -> ParameterDataTypeVar | None: ... + + def get(self, get_if_invalid: bool = True) -> ParameterDataTypeVar | None: ... def _update_with( self, *, - value: ParamDataType, + value: ParameterDataTypeVar, raw_value: ParamRawDataType, timestamp: datetime | None = None, ) -> None: ... - def __call__(self) -> ParamDataType: ... + def __call__(self) -> ParameterDataTypeVar: ... -class _Cache: +class _Cache(Generic[ParameterDataTypeVar]): """ Cache object for parameter to hold its value and raw value @@ -66,9 +86,11 @@ class _Cache: """ - def __init__(self, parameter: ParameterBase, max_val_age: float | None = None): + def __init__( + self, parameter: ParameterBase, max_val_age: float | None = None + ) -> None: self._parameter = parameter - self._value: ParamDataType = None + self._value: ParameterDataTypeVar | None = None self._raw_value: ParamRawDataType = None self._timestamp: datetime | None = None self._max_val_age = max_val_age @@ -115,7 +137,7 @@ def invalidate(self) -> None: """ self._marked_valid = False - def set(self, value: ParamDataType) -> None: + def set(self, value: ParameterDataTypeVar) -> None: """ Set the cached value of the parameter without invoking the ``set_cmd`` of the parameter (if it has one). For example, in case of @@ -146,7 +168,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: def _update_with( self, *, - value: ParamDataType, + value: ParameterDataTypeVar, raw_value: ParamRawDataType, timestamp: datetime | None = None, ) -> None: @@ -187,7 +209,19 @@ def _timestamp_expired(self) -> bool: # parameter is still valid return False - def get(self, get_if_invalid: bool = True) -> ParamDataType: + @overload + def get(self, get_if_invalid: Literal[True]) -> ParameterDataTypeVar: ... + + @overload + def get(self) -> ParameterDataTypeVar: ... + + @overload + def get(self, get_if_invalid: Literal[False]) -> ParameterDataTypeVar | None: ... + + @overload + def get(self, get_if_invalid: bool) -> ParameterDataTypeVar | None: ... + + def get(self, get_if_invalid: bool = True) -> ParameterDataTypeVar | None: """ Return cached value if time since get was less than ``max_val_age``, or the parameter was explicitly marked invalid. @@ -246,7 +280,7 @@ def _construct_error_msg(self) -> str: ) return error_msg - def __call__(self) -> ParamDataType: + def __call__(self) -> ParameterDataTypeVar: """ Same as :meth:`get` but always call ``get`` on parameter if the cache is not valid diff --git a/src/qcodes/parameters/delegate_parameter.py b/src/qcodes/parameters/delegate_parameter.py index f49d695f10e..d8b43d988fc 100644 --- a/src/qcodes/parameters/delegate_parameter.py +++ b/src/qcodes/parameters/delegate_parameter.py @@ -1,19 +1,39 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic + +from typing_extensions import TypeVar from .parameter import Parameter +from .parameter_base import InstrumentType_co, ParameterDataTypeVar if TYPE_CHECKING: from collections.abc import Sequence from datetime import datetime + from qcodes.instrument import InstrumentBase from qcodes.validators.validators import Validator - from .parameter_base import ParamDataType, ParamRawDataType - - -class DelegateParameter(Parameter): + from .parameter_base import ( + ParamDataType, + ParamRawDataType, + ) + +# Generic type variables for inner cache class +# these need to be different variables such that both classes can be generic +_local_ParameterDataTypeVar = TypeVar("_local_ParameterDataTypeVar", default=Any) +_local_InstrumentType_co = TypeVar( + "_local_InstrumentType_co", + bound="InstrumentBase | None", + default="InstrumentBase | None", + covariant=True, +) + + +class DelegateParameter( + Parameter[ParameterDataTypeVar, InstrumentType_co], + Generic[ParameterDataTypeVar, InstrumentType_co], +): """ The :class:`.DelegateParameter` wraps a given `source` :class:`Parameter`. Setting/getting it results in a set/get of the source parameter with @@ -51,8 +71,15 @@ class DelegateParameter(Parameter): """ - class _DelegateCache: - def __init__(self, parameter: DelegateParameter): + class _DelegateCache( + Generic[_local_ParameterDataTypeVar, _local_InstrumentType_co] + ): + def __init__( + self, + parameter: DelegateParameter[ + _local_ParameterDataTypeVar, _local_InstrumentType_co + ], + ): self._parameter = parameter self._marked_valid: bool = False @@ -99,7 +126,7 @@ def invalidate(self) -> None: if self._parameter.source is not None: self._parameter.source.cache.invalidate() - def get(self, get_if_invalid: bool = True) -> ParamDataType: + def get(self, get_if_invalid: bool = True) -> _local_ParameterDataTypeVar: if self._parameter.source is None: raise TypeError( "Cannot get the cache of a DelegateParameter that delegates to None" @@ -108,7 +135,7 @@ def get(self, get_if_invalid: bool = True) -> ParamDataType: self._parameter.source.cache.get(get_if_invalid=get_if_invalid) ) - def set(self, value: ParamDataType) -> None: + def set(self, value: _local_ParameterDataTypeVar) -> None: if self._parameter.source is None: raise TypeError( "Cannot set the cache of a DelegateParameter that delegates to None" @@ -128,7 +155,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: def _update_with( self, *, - value: ParamDataType, + value: _local_ParameterDataTypeVar, raw_value: ParamRawDataType, timestamp: datetime | None = None, ) -> None: @@ -142,7 +169,7 @@ def _update_with( """ pass - def __call__(self) -> ParamDataType: + def __call__(self) -> _local_ParameterDataTypeVar: return self.get(get_if_invalid=True) def __init__( @@ -183,7 +210,7 @@ def __init__( # i.e. _SetParamContext overrides it self._settable = True - self.cache = self._DelegateCache(self) + self.cache = self._DelegateCache[ParameterDataTypeVar, InstrumentType_co](self) if initial_cache_value is not None: self.cache.set(initial_cache_value) diff --git a/src/qcodes/parameters/multi_parameter.py b/src/qcodes/parameters/multi_parameter.py index 7c850edf19b..501cd128660 100644 --- a/src/qcodes/parameters/multi_parameter.py +++ b/src/qcodes/parameters/multi_parameter.py @@ -2,16 +2,13 @@ import os from collections.abc import Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import Any, Generic import numpy as np -from .parameter_base import ParameterBase +from .parameter_base import InstrumentType_co, ParameterBase, ParameterDataTypeVar from .sequence_helpers import is_sequence_of -if TYPE_CHECKING: - from qcodes.instrument import InstrumentBase - try: from qcodes_loop.data.data_array import DataArray @@ -50,7 +47,10 @@ def _is_nested_sequence_or_none( return True -class MultiParameter(ParameterBase): +class MultiParameter( + ParameterBase[ParameterDataTypeVar, InstrumentType_co], + Generic[ParameterDataTypeVar, InstrumentType_co], +): """ A gettable parameter that returns multiple values with separate names, each of arbitrary shape. Not necessarily part of an instrument. @@ -141,7 +141,9 @@ def __init__( name: str, names: Sequence[str], shapes: Sequence[Sequence[int]], - instrument: InstrumentBase | None = None, + # mypy seems to be confused here. The bound and default for InstrumentType_co + # contains None but mypy will now allow it as a default as of v 1.19.0 + instrument: InstrumentType_co = None, # type: ignore[assignment] labels: Sequence[str] | None = None, units: Sequence[str] | None = None, setpoints: Sequence[Sequence[Any]] | None = None, diff --git a/src/qcodes/parameters/parameter.py b/src/qcodes/parameters/parameter.py index 1cfe5524847..52cb5945eb9 100644 --- a/src/qcodes/parameters/parameter.py +++ b/src/qcodes/parameters/parameter.py @@ -6,10 +6,15 @@ import logging import os from types import MethodType -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Generic, Literal from .command import Command -from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType +from .parameter_base import ( + InstrumentType_co, + ParameterBase, + ParameterDataTypeVar, + ParamRawDataType, +) from .sweep_values import SweepFixedValues if TYPE_CHECKING: @@ -24,7 +29,10 @@ log = logging.getLogger(__name__) -class Parameter(ParameterBase): +class Parameter( + ParameterBase[ParameterDataTypeVar, InstrumentType_co], + Generic[ParameterDataTypeVar, InstrumentType_co], +): """ A parameter represents a single degree of freedom. Most often, this is the standard parameter for Instruments, though it can also be @@ -172,16 +180,18 @@ class Parameter(ParameterBase): def __init__( self, name: str, - instrument: InstrumentBase | None = None, + # mypy seems to be confused here. The bound and default for InstrumentType_co + # contains None but mypy will now allow it as a default as of v 1.19.0 + instrument: InstrumentType_co = None, # type: ignore[assignment] label: str | None = None, unit: str | None = None, get_cmd: str | Callable[..., Any] | Literal[False] | None = None, set_cmd: str | Callable[..., Any] | Literal[False] | None = False, - initial_value: float | str | None = None, + initial_value: ParameterDataTypeVar | None = None, max_val_age: float | None = None, vals: Validator[Any] | None = None, docstring: str | None = None, - initial_cache_value: float | str | None = None, + initial_cache_value: ParameterDataTypeVar | None = None, bind_to_instrument: bool = True, **kwargs: Any, ) -> None: @@ -396,14 +406,16 @@ def __getitem__(self, keys: Any) -> SweepFixedValues: """ return SweepFixedValues(self, keys) - def increment(self, value: ParamDataType) -> None: + def increment(self, value: ParameterDataTypeVar) -> None: """Increment the parameter with a value Args: value: Value to be added to the parameter. """ - self.set(self.get() + value) + # this method only works with parameters that support addition + # however we don't currently enforce that via typing + self.set(self.get() + value) # type: ignore[operator] def sweep( self, diff --git a/src/qcodes/parameters/parameter_base.py b/src/qcodes/parameters/parameter_base.py index ecf02fd4a22..16710c3d7c9 100644 --- a/src/qcodes/parameters/parameter_base.py +++ b/src/qcodes/parameters/parameter_base.py @@ -8,9 +8,10 @@ from contextlib import contextmanager from datetime import datetime from functools import cached_property, wraps -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, overload import numpy as np +from typing_extensions import TypeVar from qcodes.metadatable import Metadatable, MetadatableWithName from qcodes.parameters import ParamSpecBase @@ -30,7 +31,9 @@ from .named_repr import named_repr from .permissive_range import permissive_range -# for now the type the parameter may contain is not restricted at all +# ParamDataType is legacy and has been replaced by a generic type variable +# ParamRawDataType reprecents the raw data type used in get_raw and set_raw +# and may be replaced by a generic type variable in the future ParamDataType = Any ParamRawDataType = Any @@ -41,6 +44,18 @@ from qcodes.dataset.data_set_protocol import ValuesType from qcodes.instrument import InstrumentBase from qcodes.logger.instrument_logger import InstrumentLoggerAdapter +ParameterDataTypeVar = TypeVar("ParameterDataTypeVar", default=Any) +# InstrumentType_co is a covariant type variable representing the instrument +# type associated with the parameter. It needs to be covariant to allow passing +# a Parameter bound to None or a specific instrument where the default is used in the type hint. +# Otherwise we see errors such as +# Type parameter "InstrumentType@ParameterBase" is invariant, but "None" is not the same as "InstrumentBase | None" +InstrumentType_co = TypeVar( + "InstrumentType_co", + bound="InstrumentBase | None", + default="InstrumentBase | None", + covariant=True, +) LOG = logging.getLogger(__name__) @@ -109,7 +124,9 @@ def invert_val_mapping(val_mapping: Mapping[Any, Any]) -> dict[Any, Any]: return {v: k for k, v in val_mapping.items()} -class ParameterBase(MetadatableWithName): +class ParameterBase( + MetadatableWithName, Generic[ParameterDataTypeVar, InstrumentType_co] +): """ Shared behavior for all parameters. Not intended to be used directly, normally you should use ``Parameter``, ``ArrayParameter``, @@ -212,7 +229,7 @@ class ParameterBase(MetadatableWithName): def __init__( self, name: str, - instrument: InstrumentBase | None, + instrument: InstrumentType_co = None, # type: ignore[assignment] snapshot_get: bool = True, metadata: Mapping[Any, Any] | None = None, step: float | None = None, @@ -230,7 +247,8 @@ def __init__( abstract: bool | None = False, bind_to_instrument: bool = True, register_name: str | None = None, - on_set_callback: Callable[[ParameterBase, ParamDataType], None] | None = None, + on_set_callback: Callable[[ParameterBase, ParameterDataTypeVar], None] + | None = None, ) -> None: super().__init__(metadata) if not str(name).isidentifier(): @@ -279,15 +297,17 @@ def __init__( # ``_Cache`` stores "latest" value (and raw value) and timestamp # when it was set or measured - self.cache: _CacheProtocol = _Cache(self, max_val_age=max_val_age) + self.cache: _CacheProtocol[ParameterDataTypeVar] = _Cache[ParameterDataTypeVar]( + self, max_val_age=max_val_age + ) # ``GetLatest`` is left from previous versions where it would # implement a subset of features which ``_Cache`` has. # It is left for now for backwards compatibility reasons and shall # be deprecated and removed in the future versions. - self.get_latest: GetLatest - self.get_latest = GetLatest(self) + self.get_latest: GetLatest[ParameterDataTypeVar] + self.get_latest = GetLatest[ParameterDataTypeVar](self) - self.get: Callable[..., ParamDataType] + self.get: Callable[..., ParameterDataTypeVar] self._gettable = False if self._implements_get_raw: self.get = self._wrap_get(self.get_raw) @@ -527,14 +547,14 @@ def __repr__(self) -> str: return named_repr(self) @overload - def __call__(self) -> ParamDataType: + def __call__(self) -> ParameterDataTypeVar: pass @overload - def __call__(self, value: ParamDataType, **kwargs: Any) -> None: + def __call__(self, value: ParameterDataTypeVar, **kwargs: Any) -> None: pass - def __call__(self, *args: Any, **kwargs: Any) -> ParamDataType | None: + def __call__(self, *args: Any, **kwargs: Any) -> ParameterDataTypeVar | None: if len(args) == 0 and len(kwargs) == 0: if self.gettable: return self.get() @@ -606,7 +626,7 @@ def snapshot_base( state["ts"] = dttime.strftime("%Y-%m-%d %H:%M:%S") for attr in set(self._meta_attrs): - if attr == "instrument" and self._instrument: + if attr == "instrument" and self._instrument is not None: state.update( { "instrument": full_class(self._instrument), @@ -635,7 +655,7 @@ def snapshot_value(self) -> bool: """ return self._snapshot_value - def _from_value_to_raw_value(self, value: ParamDataType) -> ParamRawDataType: + def _from_value_to_raw_value(self, value: ParameterDataTypeVar) -> ParamRawDataType: raw_value: ParamRawDataType if self.val_mapping is not None: @@ -673,28 +693,35 @@ def _from_value_to_raw_value(self, value: ParamDataType) -> ParamRawDataType: return raw_value - def _from_raw_value_to_value(self, raw_value: ParamRawDataType) -> ParamDataType: - value: ParamDataType + def _from_raw_value_to_value( + self, raw_value: ParamRawDataType + ) -> ParameterDataTypeVar: + value: ParameterDataTypeVar if self.get_parser is not None: value = self.get_parser(raw_value) else: value = raw_value + # the code below is not very type safe but relies on duck typing / try except + # and assumes the user does not set scale/offset unless the datatype is numeric + # this should probably be rewritten but for now we ignore type errors # apply offset first (native scale) + if self.offset is not None and value is not None: # offset values try: - value = value - self.offset + value = value - self.offset # type: ignore[operator,assignment] except TypeError: if isinstance(self.offset, collections.abc.Iterable): # offset contains multiple elements, one for each value - value = tuple( - val - offset for val, offset in zip(value, self.offset) + value = tuple( # type: ignore[assignment] + val - offset + for val, offset in zip(value, self.offset) # type: ignore[call-overload] ) elif isinstance(value, collections.abc.Iterable): # Use single offset for all values - value = tuple(val - self.offset for val in value) + value = tuple(val - self.offset for val in value) # type: ignore[assignment] else: raise @@ -702,14 +729,14 @@ def _from_raw_value_to_value(self, raw_value: ParamRawDataType) -> ParamDataType if self.scale is not None and value is not None: # Scale values try: - value = value / self.scale + value = value / self.scale # type: ignore[assignment,operator] except TypeError: if isinstance(self.scale, collections.abc.Iterable): # Scale contains multiple elements, one for each value - value = tuple(val / scale for val, scale in zip(value, self.scale)) + value = tuple(val / scale for val, scale in zip(value, self.scale)) # type: ignore[call-overload,assignment] elif isinstance(value, collections.abc.Iterable): # Use single scale for all values - value = tuple(val / self.scale for val in value) + value = tuple(val / self.scale for val in value) # type: ignore[assignment] else: raise @@ -718,17 +745,17 @@ def _from_raw_value_to_value(self, raw_value: ParamRawDataType) -> ParamDataType value = self.inverse_val_mapping[value] else: try: - value = self.inverse_val_mapping[int(value)] + value = self.inverse_val_mapping[int(value)] # type: ignore[call-overload] except (ValueError, KeyError): raise KeyError(f"'{value}' not in val_mapping") - return value + return value # pyright: ignore[reportReturnType] def _wrap_get( self, get_function: Callable[..., ParamRawDataType] - ) -> Callable[..., ParamDataType]: + ) -> Callable[..., ParameterDataTypeVar]: @wraps(get_function) - def get_wrapper(*args: Any, **kwargs: Any) -> ParamDataType: + def get_wrapper(*args: Any, **kwargs: Any) -> ParameterDataTypeVar: if not self.gettable: raise TypeError("Trying to get a parameter that is not gettable.") if self.abstract: @@ -756,7 +783,7 @@ def get_wrapper(*args: Any, **kwargs: Any) -> ParamDataType: def _wrap_set(self, set_function: Callable[..., None]) -> Callable[..., None]: @wraps(set_function) - def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: + def set_wrapper(value: ParameterDataTypeVar, **kwargs: Any) -> None: try: if not self.settable: raise TypeError("Trying to set a parameter that is not settable.") @@ -766,17 +793,20 @@ def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: ) self.validate(value) + # the code below is written in a duck typed way that assumes that + # the user has correctly set step size etc. This could be rewritten + # In some cases intermediate sweep values must be used. # Unless `self.step` is defined, get_sweep_values will return # a list containing only `value`. - steps = self.get_ramp_values(value, step=self.step) + steps = self.get_ramp_values(value, step=self.step) # type: ignore[arg-type] for val_step in steps: # even if the final value is valid we may be generating # steps that are not so validate them too - self.validate(val_step) + self.validate(val_step) # type: ignore[arg-type] - raw_val_step = self._from_value_to_raw_value(val_step) + raw_val_step = self._from_value_to_raw_value(val_step) # type: ignore[arg-type] # Check if delay between set operations is required t_elapsed = time.perf_counter() - self._t_last_set @@ -799,9 +829,9 @@ def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: # Sleep until total time is larger than self.post_delay time.sleep(self.post_delay - t_elapsed) - self.cache._update_with(value=val_step, raw_value=raw_val_step) + self.cache._update_with(value=val_step, raw_value=raw_val_step) # type: ignore[arg-type] - self._call_on_set_callback(val_step) + self._call_on_set_callback(val_step) # type: ignore[arg-type] except Exception as e: e.args = (*e.args, f"setting {self} to {value}") @@ -809,7 +839,7 @@ def set_wrapper(value: ParamDataType, **kwargs: Any) -> None: return set_wrapper - def _call_on_set_callback(self, value: ParamDataType) -> None: + def _call_on_set_callback(self, value: ParameterDataTypeVar) -> None: try: if self.on_set_callback is not None: self.on_set_callback(self, value) @@ -880,7 +910,7 @@ def _validate_context(self) -> str: context = self.name return "Parameter: " + context - def validate(self, value: ParamDataType) -> None: + def validate(self, value: ParameterDataTypeVar) -> None: """ Validate the value supplied. @@ -1033,7 +1063,7 @@ def register_name(self) -> str: return self._register_name or self.full_name @property - def instrument(self) -> InstrumentBase | None: + def instrument(self) -> InstrumentType_co: """ Return the first instrument that this parameter is bound to. E.g if this is bound to a channel it will return the channel @@ -1056,7 +1086,7 @@ def root_instrument(self) -> InstrumentBase | None: return None def set_to( - self, value: ParamDataType, allow_changes: bool = False + self, value: ParameterDataTypeVar, allow_changes: bool = False ) -> _SetParamContext: """ Use a context manager to temporarily set a parameter to a value. By @@ -1245,7 +1275,7 @@ def unpack_self(self, value: ValuesType) -> list[tuple[ParameterBase, ValuesType return [(self, value)] -class GetLatest(DelegateAttributes): +class GetLatest(DelegateAttributes, Generic[ParameterDataTypeVar]): """ Wrapper for a class:`.Parameter` that just returns the last set or measured value stored in the class:`.Parameter` itself. If get has never been called @@ -1277,7 +1307,7 @@ def __init__(self, parameter: ParameterBase): delegate_attr_objects: ClassVar[list[str]] = ["parameter"] omit_delegate_attrs: ClassVar[list[str]] = ["set"] - def get(self) -> ParamDataType: + def get(self) -> ParameterDataTypeVar: """ Return latest value if time since get was less than `max_val_age`, otherwise perform `get()` and @@ -1304,7 +1334,7 @@ def get_raw_value(self) -> ParamRawDataType | None: """ return self.cache._raw_value - def __call__(self) -> ParamDataType: + def __call__(self) -> ParameterDataTypeVar: """ Same as ``get()`` diff --git a/src/qcodes/parameters/parameter_with_setpoints.py b/src/qcodes/parameters/parameter_with_setpoints.py index d9ef9badf4d..1110425fefd 100644 --- a/src/qcodes/parameters/parameter_with_setpoints.py +++ b/src/qcodes/parameters/parameter_with_setpoints.py @@ -1,12 +1,19 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic import numpy as np -from qcodes.parameters.parameter import Parameter -from qcodes.parameters.parameter_base import ParameterBase, ParameterSet +from qcodes.parameters.parameter import ( + Parameter, +) +from qcodes.parameters.parameter_base import ( + InstrumentType_co, + ParameterBase, + ParameterDataTypeVar, + ParameterSet, +) from qcodes.validators import Arrays, Validator if TYPE_CHECKING: @@ -18,7 +25,10 @@ LOG = logging.getLogger(__name__) -class ParameterWithSetpoints(Parameter): +class ParameterWithSetpoints( + Parameter[ParameterDataTypeVar, InstrumentType_co], + Generic[ParameterDataTypeVar, InstrumentType_co], +): """ A parameter that has associated setpoints. The setpoints is nothing more than a list of other parameters that describe the values, names