Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changes/newsfragments/7730.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
The QCoDeS Parameter classes ``ParameterBase``, ``Parameter``, ``ParameterWithSetpoints``, ``DelegateParameter``, ``ArrayParameter`` and ``MultiParameter`` now
takes two Optional Generic arguments to allow the data type and the type of the instrument the parameter is bound to to be fixed statically. This enables
the type of the output of ``parameter.get()``, input of ``parameter.set()`` and value of ``parameter.instrument`` to be known statically such that type
checkers and IDE's can make use of this information.
14 changes: 9 additions & 5 deletions src/qcodes/parameters/array_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
has_loop = True
except ImportError:
has_loop = False
from typing import Generic

from .parameter_base import ParameterBase
from .parameter_base import InstrumentType_co, ParameterBase, ParameterDataTypeVar
from .sequence_helpers import is_sequence_of

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

from qcodes.instrument import InstrumentBase


try:
from qcodes_loop.data.data_array import DataArray
Expand All @@ -41,7 +40,10 @@
)


class ArrayParameter(ParameterBase):
class ArrayParameter(
ParameterBase[ParameterDataTypeVar, InstrumentType_co],
Generic[ParameterDataTypeVar, InstrumentType_co],
):
"""
A gettable parameter that returns an array of values.
Not necessarily part of an instrument.
Expand Down Expand Up @@ -131,7 +133,9 @@ def __init__(
self,
name: str,
shape: Sequence[int],
instrument: InstrumentBase | None = None,
# mypy seems to be confused here. The bound and default for InstrumentType_co
# contains None but mypy will now allow it as a default as of v 1.19.0
instrument: InstrumentType_co = None, # type: ignore[assignment]
label: str | None = None,
unit: str | None = None,
setpoints: Sequence[Any] | None = None,
Expand Down
62 changes: 48 additions & 14 deletions src/qcodes/parameters/cache.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from __future__ import annotations

from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Protocol
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, overload

from typing_extensions import TypeVar

# due to circular imports we cannot import the TypeVar from parameter_base
ParameterDataTypeVar = TypeVar("ParameterDataTypeVar", default=Any)

if TYPE_CHECKING:
from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType
from .parameter_base import (
ParameterBase,
ParamRawDataType,
)


# The protocol is private to qcodes but used elsewhere in the codebase
class _CacheProtocol(Protocol): # noqa: PYI046
class _CacheProtocol(Protocol, Generic[ParameterDataTypeVar]): # noqa: PYI046
"""
This protocol defines the interface that a Parameter Cache implementation
must implement. This is currently used for 2 implementations, one in
Expand All @@ -29,24 +37,36 @@ def valid(self) -> bool: ...

def invalidate(self) -> None: ...

def set(self, value: ParamDataType) -> None: ...
def set(self, value: ParameterDataTypeVar) -> None: ...

def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: ...

def get(self, get_if_invalid: bool = True) -> ParamDataType: ...
@overload
def get(self, get_if_invalid: Literal[True]) -> ParameterDataTypeVar: ...

@overload
def get(self) -> ParameterDataTypeVar: ...

@overload
def get(self, get_if_invalid: Literal[False]) -> ParameterDataTypeVar | None: ...

@overload
def get(self, get_if_invalid: bool) -> ParameterDataTypeVar | None: ...

def get(self, get_if_invalid: bool = True) -> ParameterDataTypeVar | None: ...

def _update_with(
self,
*,
value: ParamDataType,
value: ParameterDataTypeVar,
raw_value: ParamRawDataType,
timestamp: datetime | None = None,
) -> None: ...

def __call__(self) -> ParamDataType: ...
def __call__(self) -> ParameterDataTypeVar: ...


class _Cache:
class _Cache(Generic[ParameterDataTypeVar]):
"""
Cache object for parameter to hold its value and raw value

Expand All @@ -66,9 +86,11 @@ class _Cache:

"""

def __init__(self, parameter: ParameterBase, max_val_age: float | None = None):
def __init__(
self, parameter: ParameterBase, max_val_age: float | None = None
) -> None:
self._parameter = parameter
self._value: ParamDataType = None
self._value: ParameterDataTypeVar | None = None
self._raw_value: ParamRawDataType = None
self._timestamp: datetime | None = None
self._max_val_age = max_val_age
Expand Down Expand Up @@ -115,7 +137,7 @@ def invalidate(self) -> None:
"""
self._marked_valid = False

def set(self, value: ParamDataType) -> None:
def set(self, value: ParameterDataTypeVar) -> None:
"""
Set the cached value of the parameter without invoking the
``set_cmd`` of the parameter (if it has one). For example, in case of
Expand Down Expand Up @@ -146,7 +168,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None:
def _update_with(
self,
*,
value: ParamDataType,
value: ParameterDataTypeVar,
raw_value: ParamRawDataType,
timestamp: datetime | None = None,
) -> None:
Expand Down Expand Up @@ -187,7 +209,19 @@ def _timestamp_expired(self) -> bool:
# parameter is still valid
return False

def get(self, get_if_invalid: bool = True) -> ParamDataType:
@overload
def get(self, get_if_invalid: Literal[True]) -> ParameterDataTypeVar: ...

@overload
def get(self) -> ParameterDataTypeVar: ...

@overload
def get(self, get_if_invalid: Literal[False]) -> ParameterDataTypeVar | None: ...

@overload
def get(self, get_if_invalid: bool) -> ParameterDataTypeVar | None: ...

def get(self, get_if_invalid: bool = True) -> ParameterDataTypeVar | None:
"""
Return cached value if time since get was less than ``max_val_age``,
or the parameter was explicitly marked invalid.
Expand Down Expand Up @@ -246,7 +280,7 @@ def _construct_error_msg(self) -> str:
)
return error_msg

def __call__(self) -> ParamDataType:
def __call__(self) -> ParameterDataTypeVar:
"""
Same as :meth:`get` but always call ``get`` on parameter if the
cache is not valid
Expand Down
51 changes: 39 additions & 12 deletions src/qcodes/parameters/delegate_parameter.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Generic

from typing_extensions import TypeVar

from .parameter import Parameter
from .parameter_base import InstrumentType_co, ParameterDataTypeVar

if TYPE_CHECKING:
from collections.abc import Sequence
from datetime import datetime

from qcodes.instrument import InstrumentBase
from qcodes.validators.validators import Validator

from .parameter_base import ParamDataType, ParamRawDataType


class DelegateParameter(Parameter):
from .parameter_base import (
ParamDataType,
ParamRawDataType,
)

# Generic type variables for inner cache class
# these need to be different variables such that both classes can be generic
_local_ParameterDataTypeVar = TypeVar("_local_ParameterDataTypeVar", default=Any)
_local_InstrumentType_co = TypeVar(
"_local_InstrumentType_co",
bound="InstrumentBase | None",
default="InstrumentBase | None",
covariant=True,
)


class DelegateParameter(
Parameter[ParameterDataTypeVar, InstrumentType_co],
Generic[ParameterDataTypeVar, InstrumentType_co],
):
"""
The :class:`.DelegateParameter` wraps a given `source` :class:`Parameter`.
Setting/getting it results in a set/get of the source parameter with
Expand Down Expand Up @@ -51,8 +71,15 @@ class DelegateParameter(Parameter):

"""

class _DelegateCache:
def __init__(self, parameter: DelegateParameter):
class _DelegateCache(
Generic[_local_ParameterDataTypeVar, _local_InstrumentType_co]
):
def __init__(
self,
parameter: DelegateParameter[
_local_ParameterDataTypeVar, _local_InstrumentType_co
],
):
self._parameter = parameter
self._marked_valid: bool = False

Expand Down Expand Up @@ -99,7 +126,7 @@ def invalidate(self) -> None:
if self._parameter.source is not None:
self._parameter.source.cache.invalidate()

def get(self, get_if_invalid: bool = True) -> ParamDataType:
def get(self, get_if_invalid: bool = True) -> _local_ParameterDataTypeVar:
if self._parameter.source is None:
raise TypeError(
"Cannot get the cache of a DelegateParameter that delegates to None"
Expand All @@ -108,7 +135,7 @@ def get(self, get_if_invalid: bool = True) -> ParamDataType:
self._parameter.source.cache.get(get_if_invalid=get_if_invalid)
)

def set(self, value: ParamDataType) -> None:
def set(self, value: _local_ParameterDataTypeVar) -> None:
if self._parameter.source is None:
raise TypeError(
"Cannot set the cache of a DelegateParameter that delegates to None"
Expand All @@ -128,7 +155,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None:
def _update_with(
self,
*,
value: ParamDataType,
value: _local_ParameterDataTypeVar,
raw_value: ParamRawDataType,
timestamp: datetime | None = None,
) -> None:
Expand All @@ -142,7 +169,7 @@ def _update_with(
"""
pass

def __call__(self) -> ParamDataType:
def __call__(self) -> _local_ParameterDataTypeVar:
return self.get(get_if_invalid=True)

def __init__(
Expand Down Expand Up @@ -183,7 +210,7 @@ def __init__(
# i.e. _SetParamContext overrides it
self._settable = True

self.cache = self._DelegateCache(self)
self.cache = self._DelegateCache[ParameterDataTypeVar, InstrumentType_co](self)
if initial_cache_value is not None:
self.cache.set(initial_cache_value)

Expand Down
16 changes: 9 additions & 7 deletions src/qcodes/parameters/multi_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

import os
from collections.abc import Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from typing import Any, Generic

import numpy as np

from .parameter_base import ParameterBase
from .parameter_base import InstrumentType_co, ParameterBase, ParameterDataTypeVar
from .sequence_helpers import is_sequence_of

if TYPE_CHECKING:
from qcodes.instrument import InstrumentBase

try:
from qcodes_loop.data.data_array import DataArray

Expand Down Expand Up @@ -50,7 +47,10 @@ def _is_nested_sequence_or_none(
return True


class MultiParameter(ParameterBase):
class MultiParameter(
ParameterBase[ParameterDataTypeVar, InstrumentType_co],
Generic[ParameterDataTypeVar, InstrumentType_co],
):
"""
A gettable parameter that returns multiple values with separate names,
each of arbitrary shape. Not necessarily part of an instrument.
Expand Down Expand Up @@ -141,7 +141,9 @@ def __init__(
name: str,
names: Sequence[str],
shapes: Sequence[Sequence[int]],
instrument: InstrumentBase | None = None,
# mypy seems to be confused here. The bound and default for InstrumentType_co
# contains None but mypy will now allow it as a default as of v 1.19.0
instrument: InstrumentType_co = None, # type: ignore[assignment]
labels: Sequence[str] | None = None,
units: Sequence[str] | None = None,
setpoints: Sequence[Sequence[Any]] | None = None,
Expand Down
28 changes: 20 additions & 8 deletions src/qcodes/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import logging
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Generic, Literal

from .command import Command
from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType
from .parameter_base import (
InstrumentType_co,
ParameterBase,
ParameterDataTypeVar,
ParamRawDataType,
)
from .sweep_values import SweepFixedValues

if TYPE_CHECKING:
Expand All @@ -24,7 +29,10 @@
log = logging.getLogger(__name__)


class Parameter(ParameterBase):
class Parameter(
ParameterBase[ParameterDataTypeVar, InstrumentType_co],
Generic[ParameterDataTypeVar, InstrumentType_co],
):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to agree on this order of generic arguments before merging since changing this would be a breaking change for all Parameter usage and implementations.

Roughly there are four relevant generic parameters for an instrument.

  1. ParameterData type (input/output from get/set)
  2. InstrumentType (what self.instrument returns)
  3. RootInstrumentType (what self.root_instrument) returns
  4. RawParameterData. What is passed to self.write/ask etc

Here I propose that we add 1 and 2 but omit 3 and 4 for the time being.
That is a tradeoff between the features and the amount of parameters
that needs to be passed to each generic Parameter since currently if
one generic parameter is supplied to a Generic function/class they must all be supplied.

This, however, also means that we cannot add these in a non-breaking way before we drop 3.12 support.
In 3.13 onwards it becomes possible to omit values for type parameters that have default values so we can simply append these and omit them when they don't add value

"""
A parameter represents a single degree of freedom. Most often,
this is the standard parameter for Instruments, though it can also be
Expand Down Expand Up @@ -172,16 +180,18 @@ class Parameter(ParameterBase):
def __init__(
self,
name: str,
instrument: InstrumentBase | None = None,
# mypy seems to be confused here. The bound and default for InstrumentType_co
# contains None but mypy will now allow it as a default as of v 1.19.0
instrument: InstrumentType_co = None, # type: ignore[assignment]
label: str | None = None,
unit: str | None = None,
get_cmd: str | Callable[..., Any] | Literal[False] | None = None,
set_cmd: str | Callable[..., Any] | Literal[False] | None = False,
initial_value: float | str | None = None,
initial_value: ParameterDataTypeVar | None = None,
max_val_age: float | None = None,
vals: Validator[Any] | None = None,
docstring: str | None = None,
initial_cache_value: float | str | None = None,
initial_cache_value: ParameterDataTypeVar | None = None,
bind_to_instrument: bool = True,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -396,14 +406,16 @@ def __getitem__(self, keys: Any) -> SweepFixedValues:
"""
return SweepFixedValues(self, keys)

def increment(self, value: ParamDataType) -> None:
def increment(self, value: ParameterDataTypeVar) -> None:
"""Increment the parameter with a value

Args:
value: Value to be added to the parameter.

"""
self.set(self.get() + value)
# this method only works with parameters that support addition
# however we don't currently enforce that via typing
self.set(self.get() + value) # type: ignore[operator]

def sweep(
self,
Expand Down
Loading
Loading