diff --git a/nums/api.py b/nums/api.py index 5d7bae84..e13ffd96 100644 --- a/nums/api.py +++ b/nums/api.py @@ -25,10 +25,14 @@ def init( address: Optional[str] = None, num_cpus: Optional[int] = None, cluster_shape: Optional[tuple] = None, + backend: Optional[str] = None, ): # pylint: disable = import-outside-toplevel import nums.core.settings as settings + if backend is not None: + assert backend in {"serial", "ray", "dask", "mpi"} + settings.backend_name = backend if cluster_shape is not None: settings.cluster_shape = cluster_shape settings.num_cpus = num_cpus diff --git a/tests/test_api.py b/tests/test_api.py index abb5cee1..8f459e98 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -24,6 +24,23 @@ # pylint: disable=import-outside-toplevel +def test_init(): + import nums + from nums.core import application_manager + from nums.core.backends.serial import SerialBackend + from nums.core.backends.ray import RayBackend + + nums.init(backend="serial") + instance = application_manager.instance() + assert isinstance(instance.km.backend, SerialBackend) + application_manager.destroy() + + nums.init(backend="ray") + instance = application_manager.instance() + assert isinstance(instance.km.backend, RayBackend) + application_manager.destroy() + + def test_rwd(): import nums from nums.core import application_manager