Skip to content

Commit 8d1a9b2

Browse files
authored
Merge pull request #7065 from jenshnielsen/refactor_doxd
Replace do0d and do1 with a wrapper around dond
2 parents 6fdb4f7 + ce875c0 commit 8d1a9b2

File tree

4 files changed

+66
-198
lines changed

4 files changed

+66
-198
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
The implementation of ``do0d`` and ``do1d`` have been replaced with a wrapper around `dond`.
2+
This aligns the keyword arguments with ``dond`` and ensures that these function support
3+
the same features as ``dond``. The same change is planned for ``do2d`` in the future.

src/qcodes/dataset/dond/do_0d.py

Lines changed: 19 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,88 +5,46 @@
55

66
from opentelemetry import trace
77

8-
from qcodes import config
9-
from qcodes.parameters import ParameterBase
8+
from .do_nd import DondKWargs, dond
109

11-
from ..descriptions.detect_shapes import detect_shape_of_measurement
12-
from ..measurements import Measurement
13-
from ..threading import process_params_meas
14-
from .do_nd_utils import _handle_plotting, _register_parameters, _set_write_period
10+
if TYPE_CHECKING:
11+
from typing_extensions import Unpack
12+
13+
from .do_nd_utils import (
14+
AxesTupleListWithDataSet,
15+
ParamMeasT,
16+
)
1517

1618
LOG = logging.getLogger(__name__)
1719
TRACER = trace.get_tracer(__name__)
1820

19-
if TYPE_CHECKING:
20-
from ..descriptions.versioning.rundescribertypes import Shapes
21-
from ..experiment_container import Experiment
22-
from .do_nd_utils import AxesTupleListWithDataSet, ParamMeasT
23-
2421

2522
@TRACER.start_as_current_span("qcodes.dataset.do0d")
2623
def do0d(
27-
*param_meas: ParamMeasT,
28-
write_period: float | None = None,
29-
measurement_name: str = "",
30-
exp: Experiment | None = None,
31-
do_plot: bool | None = None,
32-
use_threads: bool | None = None,
33-
log_info: str | None = None,
24+
*param_meas: ParamMeasT, **kwargs: Unpack[DondKWargs]
3425
) -> AxesTupleListWithDataSet:
3526
"""
3627
Perform a measurement of a single parameter. This is probably most
37-
useful for an ArrayParameter that already returns an array of data points
28+
useful for a ParameterWithSetpoints that already returns an array of data points.
3829
3930
Args:
4031
*param_meas: Parameter(s) to measure at each step or functions that
4132
will be called at each step. The function should take no arguments.
4233
The parameters and functions are called in the order they are
4334
supplied.
44-
write_period: The time after which the data is actually written to the
45-
database.
46-
measurement_name: Name of the measurement. This will be passed down to
47-
the dataset produced by the measurement. If not given, a default
48-
value of 'results' is used for the dataset.
49-
exp: The experiment to use for this measurement.
50-
do_plot: should png and pdf versions of the images be saved after the
51-
run. If None the setting will be read from ``qcodesrc.json``
52-
use_threads: If True measurements from each instrument will be done on
53-
separate threads. If you are measuring from several instruments
54-
this may give a significant speedup.
55-
log_info: Message that is logged during the measurement. If None a default
56-
message is used.
35+
**kwargs: kwargs are the same as for :func:`dond` and forwarded directly to :func:`dond`.
5736
5837
Returns:
5938
The QCoDeS dataset.
6039
6140
"""
62-
if do_plot is None:
63-
do_plot = cast("bool", config.dataset.dond_plot)
64-
meas = Measurement(name=measurement_name, exp=exp)
65-
if log_info is not None:
66-
meas._extra_log_info = log_info
67-
else:
68-
meas._extra_log_info = "Using 'qcodes.dataset.do0d'"
6941

70-
measured_parameters = tuple(
71-
param for param in param_meas if isinstance(param, ParameterBase)
72-
)
73-
74-
try:
75-
shapes: Shapes | None = detect_shape_of_measurement(
76-
measured_parameters,
77-
)
78-
except TypeError:
79-
LOG.exception(
80-
f"Could not detect shape of {measured_parameters} "
81-
f"falling back to unknown shape."
82-
)
83-
shapes = None
84-
85-
_register_parameters(meas, param_meas, shapes=shapes)
86-
_set_write_period(meas, write_period)
42+
kwargs.setdefault("log_info", "Using 'qcodes.dataset.do0d'")
8743

88-
with meas.run() as datasaver:
89-
datasaver.add_result(*process_params_meas(param_meas, use_threads=use_threads))
90-
dataset = datasaver.dataset
91-
92-
return _handle_plotting(dataset, do_plot)
44+
# since we only support entering parameters
45+
# as a simple list or args we are sure to always
46+
# get back a AxesTupleListWithDataSet and cast is safe
47+
return cast(
48+
"AxesTupleListWithDataSet",
49+
dond(*param_meas, **kwargs, squeeze=True),
50+
)

src/qcodes/dataset/dond/do_1d.py

Lines changed: 26 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,24 @@
11
from __future__ import annotations
22

33
import logging
4-
import sys
5-
import time
64
from typing import TYPE_CHECKING, cast
75

8-
import numpy as np
96
from opentelemetry import trace
10-
from tqdm.auto import tqdm
117

12-
from qcodes import config
13-
from qcodes.dataset.descriptions.detect_shapes import detect_shape_of_measurement
14-
from qcodes.dataset.dond.do_nd_utils import (
15-
BreakConditionInterrupt,
16-
_handle_plotting,
17-
_register_actions,
18-
_register_parameters,
19-
_set_write_period,
20-
catch_interrupts,
21-
)
22-
from qcodes.dataset.measurements import Measurement
23-
from qcodes.dataset.threading import (
24-
SequentialParamsCaller,
25-
ThreadPoolParamsCaller,
26-
process_params_meas,
27-
)
28-
from qcodes.parameters import ParameterBase
29-
30-
LOG = logging.getLogger(__name__)
31-
TRACER = trace.get_tracer(__name__)
8+
from .do_nd import DondKWargs, dond
9+
from .sweeps import LinSweep
3210

3311
if TYPE_CHECKING:
34-
from collections.abc import Sequence
12+
from typing_extensions import Unpack
3513

36-
from qcodes.dataset.descriptions.versioning.rundescribertypes import Shapes
3714
from qcodes.dataset.dond.do_nd_utils import (
38-
ActionsT,
3915
AxesTupleListWithDataSet,
40-
BreakConditionT,
4116
ParamMeasT,
4217
)
43-
from qcodes.dataset.experiment_container import Experiment
18+
from qcodes.parameters import ParameterBase
19+
20+
LOG = logging.getLogger(__name__)
21+
TRACER = trace.get_tracer(__name__)
4422

4523

4624
@TRACER.start_as_current_span("qcodes.dataset.do1d")
@@ -51,17 +29,7 @@ def do1d(
5129
num_points: int,
5230
delay: float,
5331
*param_meas: ParamMeasT,
54-
enter_actions: ActionsT = (),
55-
exit_actions: ActionsT = (),
56-
write_period: float | None = None,
57-
measurement_name: str = "",
58-
exp: Experiment | None = None,
59-
do_plot: bool | None = None,
60-
use_threads: bool | None = None,
61-
additional_setpoints: Sequence[ParameterBase] = tuple(),
62-
show_progress: bool | None = None,
63-
log_info: str | None = None,
64-
break_condition: BreakConditionT | None = None,
32+
**kwargs: Unpack[DondKWargs],
6533
) -> AxesTupleListWithDataSet:
6634
"""
6735
Perform a 1D scan of ``param_set`` from ``start`` to ``stop`` in
@@ -74,106 +42,30 @@ def do1d(
7442
stop: End point of sweep
7543
num_points: Number of points in sweep
7644
delay: Delay after setting parameter before measurement is performed
77-
param_meas: Parameter(s) to measure at each step or functions that
45+
*param_meas: Parameter(s) to measure at each step or functions that
7846
will be called at each step. The function should take no arguments.
7947
The parameters and functions are called in the order they are
8048
supplied.
81-
enter_actions: A list of functions taking no arguments that will be
82-
called before the measurements start
83-
exit_actions: A list of functions taking no arguments that will be
84-
called after the measurements ends
85-
write_period: The time after which the data is actually written to the
86-
database.
87-
additional_setpoints: A list of setpoint parameters to be registered in
88-
the measurement but not scanned.
89-
measurement_name: Name of the measurement. This will be passed down to
90-
the dataset produced by the measurement. If not given, a default
91-
value of 'results' is used for the dataset.
92-
exp: The experiment to use for this measurement.
93-
do_plot: should png and pdf versions of the images be saved after the
94-
run. If None the setting will be read from ``qcodesrc.json``
95-
use_threads: If True measurements from each instrument will be done on
96-
separate threads. If you are measuring from several instruments
97-
this may give a significant speedup.
98-
show_progress: should a progress bar be displayed during the
99-
measurement. If None the setting will be read from ``qcodesrc.json``
100-
log_info: Message that is logged during the measurement. If None a default
101-
message is used.
102-
break_condition: Callable that takes no arguments. If returned True,
103-
measurement is interrupted.
49+
**kwargs: kwargs are the same as for :func:`dond` and forwarded directly to :func:`dond`.
10450
10551
Returns:
10652
The QCoDeS dataset.
10753
10854
"""
109-
if do_plot is None:
110-
do_plot = cast("bool", config.dataset.dond_plot)
111-
if show_progress is None:
112-
show_progress = config.dataset.dond_show_progress
113-
114-
meas = Measurement(name=measurement_name, exp=exp)
115-
if log_info is not None:
116-
meas._extra_log_info = log_info
117-
else:
118-
meas._extra_log_info = "Using 'qcodes.dataset.do1d'"
119-
120-
all_setpoint_params = (param_set, *tuple(s for s in additional_setpoints))
121-
122-
measured_parameters = tuple(
123-
param for param in param_meas if isinstance(param, ParameterBase)
55+
kwargs.setdefault("log_info", "Using 'qcodes.dataset.do1d'")
56+
57+
return cast(
58+
"AxesTupleListWithDataSet",
59+
dond(
60+
LinSweep(
61+
param=param_set,
62+
start=start,
63+
stop=stop,
64+
delay=delay,
65+
num_points=num_points,
66+
),
67+
*param_meas,
68+
**kwargs,
69+
squeeze=True,
70+
),
12471
)
125-
try:
126-
loop_shape = (num_points, *tuple(1 for _ in additional_setpoints))
127-
shapes: Shapes | None = detect_shape_of_measurement(
128-
measured_parameters, loop_shape
129-
)
130-
except TypeError:
131-
LOG.exception(
132-
f"Could not detect shape of {measured_parameters} "
133-
f"falling back to unknown shape."
134-
)
135-
shapes = None
136-
137-
_register_parameters(meas, all_setpoint_params)
138-
_register_parameters(meas, param_meas, setpoints=all_setpoint_params, shapes=shapes)
139-
_set_write_period(meas, write_period)
140-
_register_actions(meas, enter_actions, exit_actions)
141-
142-
if use_threads is None:
143-
use_threads = config.dataset.use_threads
144-
145-
param_meas_caller = (
146-
ThreadPoolParamsCaller(*param_meas)
147-
if use_threads
148-
else SequentialParamsCaller(*param_meas)
149-
)
150-
151-
# do1D enforces a simple relationship between measured parameters
152-
# and set parameters. For anything more complicated this should be
153-
# reimplemented from scratch
154-
with (
155-
catch_interrupts() as interrupted,
156-
meas.run() as datasaver,
157-
param_meas_caller as call_param_meas,
158-
):
159-
dataset = datasaver.dataset
160-
additional_setpoints_data = process_params_meas(additional_setpoints)
161-
setpoints = np.linspace(start, stop, num_points)
162-
163-
# flush to prevent unflushed print's to visually interrupt tqdm bar
164-
# updates
165-
sys.stdout.flush()
166-
sys.stderr.flush()
167-
168-
for set_point in tqdm(setpoints, disable=not show_progress):
169-
param_set.set(set_point)
170-
time.sleep(delay)
171-
datasaver.add_result(
172-
(param_set, set_point), *call_param_meas(), *additional_setpoints_data
173-
)
174-
175-
if callable(break_condition):
176-
if break_condition():
177-
raise BreakConditionInterrupt("Break condition was met.")
178-
179-
return _handle_plotting(dataset, do_plot, interrupted())

src/qcodes/dataset/dond/do_nd.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Callable, Mapping, Sequence
77
from contextlib import ExitStack
88
from dataclasses import dataclass
9-
from typing import TYPE_CHECKING, Any, Literal, cast, overload
9+
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast, overload
1010

1111
import numpy as np
1212
from opentelemetry import trace
@@ -33,8 +33,6 @@
3333

3434
from .sweeps import AbstractSweep, TogetherSweep
3535

36-
LOG = logging.getLogger(__name__)
37-
3836
if TYPE_CHECKING:
3937
from qcodes.dataset.descriptions.versioning.rundescribertypes import Shapes
4038
from qcodes.dataset.dond.do_nd_utils import (
@@ -46,6 +44,7 @@
4644
)
4745
from qcodes.dataset.experiment_container import Experiment
4846

47+
LOG = logging.getLogger(__name__)
4948
SweepVarType = Any
5049

5150
TRACER = trace.get_tracer(__name__)
@@ -567,6 +566,22 @@ def parameters(self) -> tuple[ParameterBase, ...]:
567566
return self._parameters
568567

569568

569+
class DondKWargs(TypedDict):
570+
write_period: NotRequired[float | None]
571+
measurement_name: NotRequired[str | Sequence[str]]
572+
exp: NotRequired[Experiment | Sequence[Experiment] | None]
573+
enter_actions: NotRequired[ActionsT]
574+
exit_actions: NotRequired[ActionsT]
575+
do_plot: NotRequired[bool | None]
576+
show_progress: NotRequired[bool | None]
577+
use_threads: NotRequired[bool | None]
578+
additional_setpoints: NotRequired[Sequence[ParameterBase]]
579+
log_info: NotRequired[str | None]
580+
break_condition: NotRequired[BreakConditionT | None]
581+
dataset_dependencies: NotRequired[Mapping[str, Sequence[ParamMeasT]]]
582+
in_memory_cache: NotRequired[bool | None]
583+
584+
570585
@overload
571586
def dond(
572587
*params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT],

0 commit comments

Comments
 (0)