@@ -64,6 +64,7 @@ class Config(ConfigBase):
6464import copy
6565import dataclasses
6666import enum
67+ import importlib
6768import inspect
6869import re
6970import types
@@ -255,28 +256,28 @@ def validate_config_field_name(name: str) -> None:
255256 match_fn = lambda v : not isinstance (v , type ) and hasattr (v , "from_pretrained" ),
256257 validate_fn = lambda v : validate_config_field_value (v .to_dict ()),
257258)
258- # Register other validators for convenience and backwards compat.
259- # We allow these to be optional to avoid a hard dependency.
260- try :
261- import numpy as np
262259
263- register_validator (
264- match_fn = lambda v : isinstance (v , np .dtype ),
265- validate_fn = lambda _ : None ,
266- )
267- except ImportError :
268- pass
269260
270- try :
271- # As of 0.6.1, PartitionSpec is not longer a tuple.
272- import jax
261+ def _maybe_register_optional_type (module : str , attribute : str ):
262+ """Attempts to register a valid type specified via `module` path and `attribute` name.
273263
274- register_validator (
275- match_fn = lambda v : isinstance (v , jax .sharding .PartitionSpec ),
276- validate_fn = lambda _ : None ,
277- )
278- except ImportError :
279- pass
264+ This is used to register optional types so that we can avoid a hard dependency on the module.
265+ """
266+ try :
267+ module = importlib .import_module (module )
268+ register_validator (
269+ match_fn = lambda v : isinstance (v , getattr (module , attribute )),
270+ validate_fn = lambda _ : None ,
271+ )
272+ except (ImportError , AttributeError ):
273+ pass
274+
275+
276+ # Register other validators for convenience and backwards compat.
277+ # We allow these to be optional to avoid a hard dependency.
278+ _maybe_register_optional_type ("numpy" , "dtype" )
279+ # As of 0.6.1, PartitionSpec is not longer a tuple.
280+ _maybe_register_optional_type ("jax.sharding" , "PartitionSpec" )
280281
281282
282283def validate_config_field_value (value : Any ) -> None :
0 commit comments