Skip to content

Commit 3b42d08

Browse files
author
mark
committed
Adds a helper function.
1 parent 6c02a68 commit 3b42d08

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

axlearn/common/config.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Config(ConfigBase):
6464
import copy
6565
import dataclasses
6666
import enum
67+
import importlib
6768
import inspect
6869
import re
6970
import types
@@ -255,28 +256,28 @@ def validate_config_field_name(name: str) -> None:
255256
match_fn=lambda v: not isinstance(v, type) and hasattr(v, "from_pretrained"),
256257
validate_fn=lambda v: validate_config_field_value(v.to_dict()),
257258
)
258-
# Register other validators for convenience and backwards compat.
259-
# We allow these to be optional to avoid a hard dependency.
260-
try:
261-
import numpy as np
262259

263-
register_validator(
264-
match_fn=lambda v: isinstance(v, np.dtype),
265-
validate_fn=lambda _: None,
266-
)
267-
except ImportError:
268-
pass
269260

270-
try:
271-
# As of 0.6.1, PartitionSpec is not longer a tuple.
272-
import jax
261+
def _maybe_register_optional_type(module: str, attribute: str):
262+
"""Attempts to register a valid type specified via `module` path and `attribute` name.
273263
274-
register_validator(
275-
match_fn=lambda v: isinstance(v, jax.sharding.PartitionSpec),
276-
validate_fn=lambda _: None,
277-
)
278-
except ImportError:
279-
pass
264+
This is used to register optional types so that we can avoid a hard dependency on the module.
265+
"""
266+
try:
267+
module = importlib.import_module(module)
268+
register_validator(
269+
match_fn=lambda v: isinstance(v, getattr(module, attribute)),
270+
validate_fn=lambda _: None,
271+
)
272+
except (ImportError, AttributeError):
273+
pass
274+
275+
276+
# Register other validators for convenience and backwards compat.
277+
# We allow these to be optional to avoid a hard dependency.
278+
_maybe_register_optional_type("numpy", "dtype")
279+
# As of 0.6.1, PartitionSpec is not longer a tuple.
280+
_maybe_register_optional_type("jax.sharding", "PartitionSpec")
280281

281282

282283
def validate_config_field_value(value: Any) -> None:

axlearn/common/config_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ class Config(Configurable.Config):
10271027

10281028
def test_custom_validators(self):
10291029
try:
1030+
# pytype: disable=invalid-annotation
10301031
# pylint: disable-next=import-outside-toplevel
10311032
import jax
10321033

@@ -1037,6 +1038,7 @@ class Test(ConfigBase):
10371038
spec = jax.sharding.PartitionSpec("test")
10381039
cfg = Test(partition_spec=spec)
10391040
self.assertEqual(cfg.partition_spec, spec)
1041+
# pytype: enable=invalid-annotation
10401042

10411043
except ImportError:
10421044
pass

0 commit comments

Comments
 (0)