Skip to content

Commit 0af5ee9

Browse files
authored
Merge pull request #284 from scipp/type-hints
Fix some type hints and run mypy
2 parents 2a8009f + 71169a9 commit 0af5ee9

36 files changed

+862
-685
lines changed

.github/workflows/test.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,18 @@ jobs:
7777
with:
7878
name: CoverageReport
7979
path: coverage_html/
80+
81+
mypy:
82+
runs-on: ${{ inputs.os-variant }}
83+
steps:
84+
- uses: actions/checkout@v4
85+
with:
86+
ref: ${{ inputs.checkout_ref }}
87+
- uses: actions/setup-python@v5
88+
with:
89+
python-version: ${{ inputs.python-version }}
90+
- run: python -m pip install --upgrade pip
91+
# Use mypy manually because it does not respect the exclusion patterns
92+
# when called through tox.
93+
- run: python -m pip install -r requirements/mypy.txt -e .
94+
- run: python -m mypy .

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@
217217
# relative to this directory. They are copied after the builtin static files,
218218
# so a file named "default.css" will overwrite the builtin "default.css".
219219
html_static_path = ['_static']
220-
html_css_files = []
220+
html_css_files: list[str] = []
221221
html_js_files = ["anaconda-icon.js"]
222222

223223
# -- Options for HTMLHelp output ------------------------------------------

docs/user-guide/nexus-classes.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@
270270
"metadata": {},
271271
"source": [
272272
"In some cases the event data fields may be contained directly within an [NXdetector](#NXdetector).\n",
273-
"The event data can also be accessed from there:"
273+
"The event data can also be accessed from there:\n",
274+
"(This example only shows one of several datasets that comprise event data.)"
274275
]
275276
},
276277
{
@@ -280,7 +281,7 @@
280281
"metadata": {},
281282
"outputs": [],
282283
"source": [
283-
"f['entry/instrument/bank102']['events'][...]"
284+
"f['entry/instrument/bank102']['event_time_offset'][:1000]"
284285
]
285286
},
286287
{

pyproject.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,24 @@ enable_error_code = [
117117
]
118118
warn_unreachable = true
119119

120+
exclude = [
121+
# not yet updated for type hints:
122+
'^tests/',
123+
'^src/scippnexus/nxlog',
124+
'^src/scippnexus/nxsample',
125+
'^src/scippnexus/nxoff_geometry',
126+
'^src/scippnexus/nxevent_data',
127+
'^src/scippnexus/event_field',
128+
'^src/scippnexus/nxcylindrical_geometry',
129+
'^src/scippnexus/nxdata',
130+
'^src/scippnexus/_load',
131+
'^src/scippnexus/nxmonitor',
132+
'^src/scippnexus/nxdetector',
133+
'^src/scippnexus/application_definitions/',
134+
'^src/scippnexus/nxtransformations',
135+
'^src/scippnexus/base',
136+
]
137+
120138
[tool.codespell]
121139
ignore-words-list = [
122140
# Codespell wants "socioeconomic" which seems to be the standard spelling.

src/scippnexus/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from . import typing
1313
from ._load import load
14+
from .attrs import Attrs
1415
from .base import (
1516
Group,
1617
NexusStructureError,
@@ -19,7 +20,7 @@
1920
create_class,
2021
create_field,
2122
)
22-
from .field import Attrs, DependsOn, Field
23+
from .field import DependsOn, Field
2324
from .file import File
2425
from .nexus_classes import * # noqa: F403
2526
from .nxtransformations import TransformationChain, compute_positions, zip_pixel_offsets

src/scippnexus/_cache.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99
"""
1010

1111
# flake8: noqa: E501
12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
1215
from types import GenericAlias
16+
from typing import Generic, TypeVar
1317

1418
_NOT_FOUND = object()
19+
R = TypeVar("R")
1520

1621

17-
class cached_property:
18-
def __init__(self, func):
22+
class cached_property(Generic[R]):
23+
def __init__(self, func: Callable[..., R]) -> None:
1924
self.func = func
20-
self.attrname = None
25+
self.attrname: str | None = None
2126
self.__doc__ = func.__doc__
2227

23-
def __set_name__(self, owner, name):
28+
def __set_name__(self, owner: object, name: str) -> None:
2429
if self.attrname is None:
2530
self.attrname = name
2631
elif name != self.attrname:
@@ -29,9 +34,9 @@ def __set_name__(self, owner, name):
2934
f"({self.attrname!r} and {name!r})."
3035
)
3136

32-
def __get__(self, instance, owner=None):
37+
def __get__(self, instance: object, owner: object = None) -> R:
3338
if instance is None:
34-
return self
39+
return self # type: ignore[return-value]
3540
if self.attrname is None:
3641
raise TypeError(
3742
"Cannot use cached_property instance without calling __set_name__ on it."
@@ -46,7 +51,7 @@ def __get__(self, instance, owner=None):
4651
f"instance to cache {self.attrname!r} property."
4752
)
4853
raise TypeError(msg) from None
49-
val = cache.get(self.attrname, _NOT_FOUND)
54+
val: R = cache.get(self.attrname, _NOT_FOUND)
5055
if val is _NOT_FOUND:
5156
val = self.func(instance)
5257
try:
@@ -59,4 +64,4 @@ def __get__(self, instance, owner=None):
5964
raise TypeError(msg) from None
6065
return val
6166

62-
__class_getitem__ = classmethod(GenericAlias)
67+
__class_getitem__ = classmethod(GenericAlias) # type: ignore[var-annotated]

src/scippnexus/_common.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
33
# @author Simon Heybrock
44

5+
from collections.abc import Sequence
6+
57
import numpy as np
68
import scipp as sc
79

@@ -10,8 +12,8 @@
1012

1113
def convert_time_to_datetime64(
1214
raw_times: sc.Variable,
13-
start: str | None = None,
14-
scaling_factor: float | np.float64 = None,
15+
start: sc.Variable,
16+
scaling_factor: float | np.float64 | None = None,
1517
) -> sc.Variable:
1618
"""
1719
The nexus standard allows an arbitrary scaling factor to be inserted
@@ -24,14 +26,16 @@ def convert_time_to_datetime64(
2426
2527
See https://manual.nexusformat.org/classes/base_classes/NXlog.html
2628
27-
Args:
28-
raw_times: The raw time data from a nexus file.
29-
start: Optional, the start time of the log in an ISO8601
30-
string. If not provided, defaults to the beginning of the
31-
unix epoch (1970-01-01T00:00:00).
32-
scaling_factor: Optional, the scaling factor between the provided
33-
time series data and the unit of the raw_times Variable. If
34-
not provided, defaults to 1 (a no-op scaling factor).
29+
Parameters
30+
----------
31+
raw_times:
32+
The raw time data from a nexus file.
33+
start:
34+
The start time of the log.
35+
scaling_factor:
36+
Optional, the scaling factor between the provided
37+
time series data and the unit of the raw_times Variable. If
38+
not provided, defaults to 1 (a no-op scaling factor).
3539
"""
3640
if (
3741
raw_times.dtype in (sc.DType.float64, sc.DType.float32)
@@ -53,10 +57,18 @@ def convert_time_to_datetime64(
5357
)
5458

5559

56-
def _to_canonical_select(dims: list[str], select: ScippIndex) -> dict[str, int | slice]:
60+
def has_time_unit(obj: sc.Variable) -> bool:
61+
if (unit := obj.unit) is None:
62+
return False
63+
return unit.to_dict().get('powers') == {'s': 1}
64+
65+
66+
def to_canonical_select(
67+
dims: Sequence[str], select: ScippIndex
68+
) -> dict[str, int | slice]:
5769
"""Return selection as dict with explicit dim labels"""
5870

59-
def check_1d():
71+
def check_1d() -> None:
6072
if len(dims) != 1:
6173
raise sc.DimensionError(
6274
f"Dataset has multiple dimensions {dims}, "
@@ -67,8 +79,8 @@ def check_1d():
6779
return {}
6880
if isinstance(select, tuple) and len(select) == 0:
6981
return {}
70-
if isinstance(select, tuple) and isinstance(select[0], str):
71-
key, sel = select
82+
if isinstance(select, tuple) and isinstance(select[0], str): # type: ignore[misc] # incorrect narrowing
83+
key, sel = select # type: ignore[misc] # incorrect narrowing
7284
return {key: sel}
7385
if isinstance(select, tuple):
7486
check_1d()
@@ -77,7 +89,7 @@ def check_1d():
7789
f"Dataset has single dimension {dims}, "
7890
"but multiple indices {select} were specified."
7991
)
80-
return {dims[0]: select[0]}
92+
return {dims[0]: select[0]} # type: ignore[unreachable] # incorrect narrowing
8193
elif isinstance(select, int | sc.Variable) or isinstance(select, slice):
8294
check_1d()
8395
return {dims[0]: select}
@@ -86,12 +98,14 @@ def check_1d():
8698
return select.copy()
8799

88100

89-
def to_plain_index(dims: list[str], select: ScippIndex) -> int | slice | tuple:
101+
def to_plain_index(
102+
dims: Sequence[str], select: ScippIndex
103+
) -> int | slice | tuple[int | slice, ...]:
90104
"""
91105
Given a valid "scipp" index 'select', return an equivalent plain numpy-style index.
92106
"""
93-
select = _to_canonical_select(dims, select)
94-
index = [slice(None)] * len(dims)
107+
select = to_canonical_select(dims, select)
108+
index: list[int | slice] = [slice(None)] * len(dims)
95109
for key, sel in select.items():
96110
if key not in dims:
97111
raise sc.DimensionError(
@@ -104,8 +118,8 @@ def to_plain_index(dims: list[str], select: ScippIndex) -> int | slice | tuple:
104118

105119

106120
def to_child_select(
107-
dims: list[str],
108-
child_dims: list[str],
121+
dims: Sequence[str],
122+
child_dims: Sequence[str],
109123
select: ScippIndex,
110124
bin_edge_dim: str | None = None,
111125
) -> ScippIndex:
@@ -115,7 +129,7 @@ def to_child_select(
115129
116130
This removes any selections that apply to the parent but not the child.
117131
"""
118-
select = _to_canonical_select(dims, select)
132+
select = to_canonical_select(dims, select)
119133
for d in dims:
120134
if d not in child_dims and d in select:
121135
del select[d]

src/scippnexus/_load.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
33
import contextlib
44
import io
5+
from contextlib import AbstractContextManager
56
from os import PathLike
7+
from typing import Any
68

79
import h5py as h5
810
import scipp as sc
@@ -23,7 +25,11 @@ def load(
2325
root: str | None = None,
2426
select: ScippIndex = (),
2527
definitions: Definitions | DefaultDefinitionsType = DefaultDefinitions,
26-
) -> sc.DataGroup | sc.DataArray | sc.Dataset:
28+
) -> (
29+
sc.DataGroup[sc.DataGroup[Any] | sc.DataArray | sc.Dataset]
30+
| sc.DataArray
31+
| sc.Dataset
32+
):
2733
"""Load a NeXus file.
2834
2935
This function is a shorthand for opening a file manually.
@@ -98,7 +104,7 @@ def load(
98104
def _open(
99105
filename: str | PathLike[str] | io.BytesIO | h5.Group | Group,
100106
definitions: Definitions | DefaultDefinitionsType = DefaultDefinitions,
101-
):
107+
) -> AbstractContextManager[Group | File]:
102108
if isinstance(filename, h5.Group):
103109
return contextlib.nullcontext(
104110
Group(
Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
33
# @author Matthew Jones
44
import warnings
5-
from typing import Any
65

76
import h5py
8-
import numpy as np
97

108

11-
def _cset_to_encoding(cset: int) -> str:
9+
def cset_to_encoding(cset: int) -> str:
1210
"""
1311
Converts a HDF5 cset into a python encoding. Allowed values for cset are
1412
h5py.h5t.CSET_ASCII and h5py.h5t.CSET_UTF8.
@@ -32,9 +30,9 @@ def _cset_to_encoding(cset: int) -> str:
3230
)
3331

3432

35-
def _warn_latin1_decode(obj, decoded, error):
33+
def warn_latin1_decode(obj: object, decoded: str, error: Exception) -> None:
3634
warnings.warn(
37-
f"Encoding for bytes '{obj}' declared as ascii, "
35+
f"Encoding for bytes '{obj!r}' declared as ascii, "
3836
f"but contains characters in extended ascii range. Assuming "
3937
f"extended ASCII (latin-1), but this behavior is not "
4038
f"specified by the HDF5 or nexus standards and may therefore "
@@ -44,7 +42,7 @@ def _warn_latin1_decode(obj, decoded, error):
4442
)
4543

4644

47-
def _ensure_str(str_or_bytes: str | bytes, encoding: str) -> str:
45+
def ensure_str(str_or_bytes: str | bytes, encoding: str) -> str:
4846
"""
4947
See https://docs.h5py.org/en/stable/strings.html for justification about some of
5048
the operations performed in this method. In particular, variable-length strings
@@ -66,21 +64,7 @@ def _ensure_str(str_or_bytes: str | bytes, encoding: str) -> str:
6664
return str(str_or_bytes, encoding="ascii")
6765
except UnicodeDecodeError as e:
6866
decoded = str(str_or_bytes, encoding="latin-1")
69-
_warn_latin1_decode(str_or_bytes, decoded, str(e))
67+
warn_latin1_decode(str_or_bytes, decoded, e)
7068
return decoded
7169
else:
7270
return str(str_or_bytes, encoding)
73-
74-
75-
_map_to_supported_type = {
76-
'int8': np.int32,
77-
'int16': np.int32,
78-
'uint8': np.int32,
79-
'uint16': np.int32,
80-
'uint32': np.int32,
81-
'uint64': np.int64,
82-
}
83-
84-
85-
def _ensure_supported_int_type(dataset_type: Any):
86-
return _map_to_supported_type.get(dataset_type, dataset_type)

src/scippnexus/attrs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from collections.abc import Iterator, Mapping
55
from typing import Any
66

7-
from ._hdf5_nexus import _cset_to_encoding, _ensure_str
7+
import h5py as h5
88

9+
from ._string import cset_to_encoding, ensure_str
910

10-
class Attrs(Mapping):
11-
def __init__(self, attrs: Mapping):
11+
12+
class Attrs(Mapping[str, Any]):
13+
def __init__(self, attrs: h5.AttributeManager) -> None:
1214
self._base_attrs = attrs
1315
self._attrs = dict(attrs)
1416

@@ -17,7 +19,7 @@ def __getitem__(self, name: str) -> Any:
1719
# Is this check for string attributes sufficient? Is there a better way?
1820
if isinstance(attr, str | bytes):
1921
cset = self._base_attrs.get_id(name.encode("utf-8")).get_type().get_cset()
20-
return _ensure_str(attr, _cset_to_encoding(cset))
22+
return ensure_str(attr, cset_to_encoding(cset))
2123
return attr
2224

2325
def __iter__(self) -> Iterator[str]:

0 commit comments

Comments
 (0)