Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"ghcr.io/schlich/devcontainer-features/powerlevel10k:1": {},
"ghcr.io/nils-geistmann/devcontainers-features/zsh:0": {
"setLocale": true,
"theme": "agnoster",
"theme": "robbyrussell",
"plugins": "git docker",
"desiredLocale": "en_US.UTF-8 UTF-8"
},
Expand Down
8 changes: 7 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def integrators(request):
return request.param


@pytest.fixture(scope='function', params=[None] if "torch" in available_backends() else [])
def pytorch_only(request):
return request.param


def pytest_generate_tests(metafunc: pytest.Metafunc):
autodiff_needed = "requires_autodiff" in metafunc.fixturenames

Expand All @@ -58,7 +63,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
argnames = sorted([key for key in argvalues_map if key in metafunc.fixturenames])
argvalues = list(itertools.product(*[argvalues_map[key] for key in argnames]))

if "dtype_var" and "backend_var" in metafunc.fixturenames:
if "dtype_var" in metafunc.fixturenames and "backend_var" in metafunc.fixturenames:
if np.finfo(np.longdouble).bits > np.finfo(np.float64).bits:
expansion_map = {
"dtype_var": ["longdouble"],
Expand All @@ -82,4 +87,5 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
raise TypeError("Test configuration requests autodiff, but no dynamic backend specified!")
argnames.append("requires_autodiff")
argvalues = [(*aval, True) for aval in argvalues if len(aval) > 1 and aval[1] not in ["numpy"]]

metafunc.parametrize(argnames, argvalues)
5 changes: 4 additions & 1 deletion desolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@

from desolver.differential_system import *

from desolver.integrators import available_methods
from desolver.integrators import available_methods

if backend.is_backend_available("torch"):
from desolver import torch_ext
31 changes: 12 additions & 19 deletions desolver/backend/autoray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,27 @@

@lru_cache(maxsize=32, typed=False)
def epsilon(dtype: str|np.dtype):
if isinstance(dtype, str) and 'numpy' in dtype:
if isinstance(dtype, str) and 'torch' not in dtype:
dtype = np.dtype(dtype)
if dtype in (np.half, np.single, np.double, np.longdouble):
try:
return np.finfo(dtype).eps*4
elif 'torch' in str(dtype):
import torch
return torch.finfo(dtype).eps*4
else:
return 4e-14
except:
if 'torch' in str(dtype):
import torch
return torch.finfo(dtype).eps*4
else:
return 4e-14


@lru_cache(maxsize=32, typed=False)
def tol_epsilon(dtype: str|np.dtype):
if isinstance(dtype, str) and 'numpy' in dtype:
dtype = np.dtype(dtype)
if dtype in (np.half, np.single, np.double, np.longdouble):
return np.finfo(dtype).eps*32
elif 'torch' in str(dtype):
import torch
return torch.finfo(dtype).eps*32
else:
return 32e-14
return 8*epsilon(dtype)


@lru_cache(maxsize=32, typed=False)
def backend_like_dtype(dtype: str|np.dtype):
if (isinstance(dtype, str) and 'numpy' in dtype) or isinstance(dtype, np.dtype):
return 'numpy'
elif 'torch' in str(dtype):
if 'torch' in str(dtype):
return 'torch'
else:
return 'numpy'

7 changes: 7 additions & 0 deletions desolver/backend/load_backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import sys
import einops

__AVAILABLE_BACKENDS__ = ["numpy"]

from desolver.backend.common import *
from desolver.backend.autoray_backend import *
from desolver.backend.numpy_backend import *
try:
from desolver.backend.torch_backend import *
__AVAILABLE_BACKENDS__.append("torch")
except ImportError:
pass

Expand Down Expand Up @@ -65,3 +68,7 @@ def contract_first_ndims(a, b, n=1):
estr3 = "..."
einsum_str = einsum_str.format(estr1, estr2, estr3)
return einops.einsum(a, b, einsum_str)


def is_backend_available(backend_name):
return backend_name.lower().strip() in __AVAILABLE_BACKENDS__
6 changes: 5 additions & 1 deletion desolver/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scipy.special
import scipy.sparse
import scipy.sparse.linalg
import scipy.linalg
import autoray
import contextlib

Expand All @@ -13,7 +14,10 @@ def __solve_linear_system(A,b,overwrite_a=False,overwrite_b=False,check_finite=F
if sparse and A.dtype not in (numpy.half, numpy.longdouble) and b.dtype not in (numpy.half, numpy.longdouble):
return scipy.sparse.linalg.spsolve(scipy.sparse.csc_matrix(A),b)
else:
return scipy.linalg.solve(A,b,overwrite_a=overwrite_a,overwrite_b=overwrite_b,check_finite=check_finite)
try:
return scipy.linalg.solve(A,b,overwrite_a=overwrite_a,overwrite_b=overwrite_b,check_finite=check_finite)
except numpy.linalg.LinAlgError:
return scipy.linalg.lstsq(A,b,overwrite_a=overwrite_a,overwrite_b=overwrite_b,check_finite=check_finite)[0]


autoray.register_function("numpy", "solve_linear_system", __solve_linear_system)
Expand Down
35 changes: 27 additions & 8 deletions desolver/backend/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,33 @@

linear_algebra_exceptions.append(torch._C._LinAlgError)

def __solve_linear_system(A, b, sparse=False):
__A = A
__b = b
if __A.dtype in (torch.float16, torch.bfloat16):
__A = __A.float()
if __b.dtype in (torch.float16, torch.bfloat16):
__b = __b.float()
return torch.linalg.solve(__A, __b).to(A.dtype)

def __solve_linear_system(A:torch.Tensor, b:torch.Tensor, sparse=False):
"""Solves a linear system either exactly when A is invertible, or
approximately when A is not invertible"""
if b.dtype in {torch.float16, torch.bfloat16}:
return __solve_linear_system(A.to(torch.float32), b.to(torch.float32), sparse=sparse).to(b.dtype)
eps_threshold = torch.finfo(b.dtype).eps**0.5
soln = torch.empty_like(A[...,0,:,None])
is_square = A.shape[-2] == A.shape[-1]
if is_square:
use_solve = torch.linalg.det(A).abs() > eps_threshold
else:
use_solve = torch.zeros_like(soln[...,0,0], dtype=torch.bool)
info = torch.ones_like(use_solve, dtype=torch.int)
soln, info = torch.linalg.solve_ex(A, b, check_errors=False)
use_solve = use_solve & ((info == 0) | torch.all(torch.isfinite(soln[...,0]), dim=-1))
use_svd = ~use_solve
U,S,Vh = torch.linalg.svd(A, full_matrices=is_square)
if A.dim() == 2:
soln = (Vh.mT @ torch.linalg.pinv(torch.diag_embed(S)) @ U.mT @ b)
else:
soln = torch.where(
use_svd[...,None,None],
torch.bmm(torch.bmm(torch.bmm(Vh.mT, torch.linalg.pinv(torch.diag_embed(S))), U.mT), b),
soln,
)
return soln


def to_cpu_wrapper(fn):
Expand Down
Loading
Loading