diff --git a/pyproject.toml b/pyproject.toml index 5a283d0..2b29b3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,12 +9,8 @@ dependencies = [ "keyrings-google-artifactregistry-auth>=1.1.2", "equinox>=0.11.4", "jaxtyping>=0.2.29", - "pytest>=8.2.2", "plum-dispatch>=2.4.1", - "jaxlib[cuda12_pip]>=0.4.29", "ipdb>=0.13.13", - "jax[cuda12]>=0.4.29", - "marimo>=0.6.19", "altair>=5.3.0", "polars>=0.20.31", "pyarrow>=16.1.0", @@ -22,16 +18,32 @@ dependencies = [ "huggingface-hub>=0.24.6", "safetensors>=0.4.5", ] + readme = "README.md" requires-python = ">= 3.8" +[project.optional-dependencies] +cpu = [ + "jax[cpu]>=0.4.34", +] +metal = [ + "jax-metal>=0.1.1", +] +cuda = [ + "jax[cuda12]>=0.4.34", + "jaxlib[cuda12]>=0.4.34", +] + + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.rye] managed = true -dev-dependencies = [] +dev-dependencies = [ + "pytest>=8.3.3", +] [tool.hatch.metadata] allow-direct-references = true diff --git a/requirements-dev.lock b/requirements-dev.lock index 909957d..7722380 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,55 +10,43 @@ # universal: false -e file:. -altair==5.4.0 +altair==5.4.1 # via genspn -anyio==4.4.0 - # via starlette +appnope==0.1.4 + # via ipykernel asttokens==2.4.1 # via stack-data attrs==24.2.0 # via jsonschema # via referencing -beartype==0.18.5 +beartype==0.19.0 # via plum-dispatch cachetools==5.5.0 # via google-auth -certifi==2024.7.4 +certifi==2024.8.30 # via requests -cffi==1.17.0 - # via cryptography -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests -click==8.1.7 - # via marimo - # via uvicorn comm==0.2.2 # via ipykernel -cryptography==43.0.0 - # via secretstorage -debugpy==1.8.5 +debugpy==1.8.7 # via ipykernel decorator==5.1.1 # via ipdb # via ipython -docutils==0.21.2 - # via marimo -equinox==0.11.5 +equinox==0.11.7 # via genspn -executing==2.0.1 +executing==2.1.0 # via stack-data -filelock==3.15.4 +filelock==3.16.1 # via huggingface-hub -fsspec==2024.6.1 +fsspec==2024.9.0 # via huggingface-hub -google-auth==2.34.0 +google-auth==2.35.0 # via keyrings-google-artifactregistry-auth -h11==0.14.0 - # via uvicorn -huggingface-hub==0.24.6 +huggingface-hub==0.25.2 # via genspn -idna==3.7 - # via anyio +idna==3.10 # via requests iniconfig==2.0.0 # via pytest @@ -66,161 +54,107 @@ ipdb==0.13.13 # via genspn ipykernel==6.29.5 # via genspn -ipython==8.26.0 +ipython==8.28.0 # via ipdb # via ipykernel -itsdangerous==2.2.0 - # via marimo jaraco-classes==3.4.0 # via keyring jaraco-context==6.0.1 # via keyring -jaraco-functools==4.0.2 +jaraco-functools==4.1.0 # via keyring -jax==0.4.31 +jax==0.4.34 # via equinox - # via genspn -jax-cuda12-pjrt==0.4.31 - # via jax-cuda12-plugin -jax-cuda12-plugin==0.4.31 - # via jax -jaxlib==0.4.31 - # via genspn +jaxlib==0.4.34 # via jax -jaxtyping==0.2.33 +jaxtyping==0.2.34 # via equinox # via genspn jedi==0.19.1 # via ipython - # via marimo -jeepney==0.8.0 - # via keyring - # via secretstorage jinja2==3.1.4 # via altair jsonschema==4.23.0 # via altair -jsonschema-specifications==2023.12.1 +jsonschema-specifications==2024.10.1 # via jsonschema -jupyter-client==8.6.2 +jupyter-client==8.6.3 # via ipykernel jupyter-core==5.7.2 # via ipykernel # via jupyter-client -keyring==25.3.0 +keyring==25.4.1 # via keyrings-google-artifactregistry-auth keyrings-google-artifactregistry-auth==1.1.2 # via genspn -marimo==0.8.0 - # via genspn -markdown==3.7 - # via marimo - # via pymdown-extensions markdown-it-py==3.0.0 # via rich -markupsafe==2.1.5 +markupsafe==3.0.1 # via jinja2 matplotlib-inline==0.1.7 # via ipykernel # via ipython mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.4.0 +ml-dtypes==0.5.0 # via jax # via jaxlib -more-itertools==10.4.0 +more-itertools==10.5.0 # via jaraco-classes # via jaraco-functools -narwhals==1.5.2 +narwhals==1.9.3 # via altair nest-asyncio==1.6.0 # via ipykernel -numpy==2.1.0 +numpy==2.1.2 # via jax # via jaxlib # via ml-dtypes - # via opt-einsum # via pyarrow # via scipy -nvidia-cublas-cu12==12.6.0.22 - # via jax-cuda12-plugin - # via nvidia-cudnn-cu12 - # via nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.6.37 - # via jax-cuda12-plugin -nvidia-cuda-nvcc-cu12==12.6.20 - # via jax-cuda12-plugin -nvidia-cuda-runtime-cu12==12.6.37 - # via jax-cuda12-plugin -nvidia-cudnn-cu12==9.3.0.75 - # via jax-cuda12-plugin -nvidia-cufft-cu12==11.2.6.28 - # via jax-cuda12-plugin -nvidia-cusolver-cu12==11.6.4.38 - # via jax-cuda12-plugin -nvidia-cusparse-cu12==12.5.2.23 - # via jax-cuda12-plugin - # via nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.22.3 - # via jax-cuda12-plugin -nvidia-nvjitlink-cu12==12.6.20 - # via jax-cuda12-plugin - # via nvidia-cufft-cu12 - # via nvidia-cusolver-cu12 - # via nvidia-cusparse-cu12 -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via jax packaging==24.1 # via altair # via huggingface-hub # via ipykernel - # via marimo # via pytest parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.2.2 +platformdirs==4.3.6 # via jupyter-core pluggy==1.5.0 # via keyrings-google-artifactregistry-auth # via pytest plum-dispatch==2.5.2 # via genspn -polars==1.5.0 +polars==1.9.0 # via genspn -prompt-toolkit==3.0.47 +prompt-toolkit==3.0.48 # via ipython psutil==6.0.0 # via ipykernel - # via marimo ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pyarrow==17.0.0 # via genspn -pyasn1==0.6.0 +pyasn1==0.6.1 # via pyasn1-modules # via rsa -pyasn1-modules==0.4.0 +pyasn1-modules==0.4.1 # via google-auth -pycparser==2.22 - # via cffi pygments==2.18.0 # via ipython - # via marimo # via rich -pymdown-extensions==10.9 - # via marimo -pytest==8.3.2 - # via genspn +pytest==8.3.3 python-dateutil==2.9.0.post0 # via jupyter-client pyyaml==6.0.2 # via huggingface-hub - # via marimo - # via pymdown-extensions pyzmq==26.2.0 # via ipykernel # via jupyter-client @@ -230,33 +164,23 @@ referencing==0.35.1 requests==2.32.3 # via huggingface-hub # via keyrings-google-artifactregistry-auth -rich==13.7.1 +rich==13.9.2 # via plum-dispatch rpds-py==0.20.0 # via jsonschema # via referencing rsa==4.9 # via google-auth -ruff==0.6.2 - # via marimo safetensors==0.4.5 # via genspn scipy==1.14.1 # via jax # via jaxlib -secretstorage==3.3.3 - # via keyring six==1.16.0 # via asttokens # via python-dateutil -sniffio==1.3.1 - # via anyio stack-data==0.6.3 # via ipython -starlette==0.38.2 - # via marimo -tomlkit==0.13.2 - # via marimo tornado==6.4.1 # via ipykernel # via jupyter-client @@ -276,11 +200,7 @@ typing-extensions==4.12.2 # via equinox # via huggingface-hub # via plum-dispatch -urllib3==2.2.2 +urllib3==2.2.3 # via requests -uvicorn==0.30.6 - # via marimo wcwidth==0.2.13 # via prompt-toolkit -websockets==12.0 - # via marimo diff --git a/requirements.lock b/requirements.lock index 909957d..14b3009 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,217 +10,146 @@ # universal: false -e file:. -altair==5.4.0 +altair==5.4.1 # via genspn -anyio==4.4.0 - # via starlette +appnope==0.1.4 + # via ipykernel asttokens==2.4.1 # via stack-data attrs==24.2.0 # via jsonschema # via referencing -beartype==0.18.5 +beartype==0.19.0 # via plum-dispatch cachetools==5.5.0 # via google-auth -certifi==2024.7.4 +certifi==2024.8.30 # via requests -cffi==1.17.0 - # via cryptography -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests -click==8.1.7 - # via marimo - # via uvicorn comm==0.2.2 # via ipykernel -cryptography==43.0.0 - # via secretstorage -debugpy==1.8.5 +debugpy==1.8.7 # via ipykernel decorator==5.1.1 # via ipdb # via ipython -docutils==0.21.2 - # via marimo -equinox==0.11.5 +equinox==0.11.7 # via genspn -executing==2.0.1 +executing==2.1.0 # via stack-data -filelock==3.15.4 +filelock==3.16.1 # via huggingface-hub -fsspec==2024.6.1 +fsspec==2024.9.0 # via huggingface-hub -google-auth==2.34.0 +google-auth==2.35.0 # via keyrings-google-artifactregistry-auth -h11==0.14.0 - # via uvicorn -huggingface-hub==0.24.6 +huggingface-hub==0.25.2 # via genspn -idna==3.7 - # via anyio +idna==3.10 # via requests -iniconfig==2.0.0 - # via pytest ipdb==0.13.13 # via genspn ipykernel==6.29.5 # via genspn -ipython==8.26.0 +ipython==8.28.0 # via ipdb # via ipykernel -itsdangerous==2.2.0 - # via marimo jaraco-classes==3.4.0 # via keyring jaraco-context==6.0.1 # via keyring -jaraco-functools==4.0.2 +jaraco-functools==4.1.0 # via keyring -jax==0.4.31 +jax==0.4.34 # via equinox - # via genspn -jax-cuda12-pjrt==0.4.31 - # via jax-cuda12-plugin -jax-cuda12-plugin==0.4.31 - # via jax -jaxlib==0.4.31 - # via genspn +jaxlib==0.4.34 # via jax -jaxtyping==0.2.33 +jaxtyping==0.2.34 # via equinox # via genspn jedi==0.19.1 # via ipython - # via marimo -jeepney==0.8.0 - # via keyring - # via secretstorage jinja2==3.1.4 # via altair jsonschema==4.23.0 # via altair -jsonschema-specifications==2023.12.1 +jsonschema-specifications==2024.10.1 # via jsonschema -jupyter-client==8.6.2 +jupyter-client==8.6.3 # via ipykernel jupyter-core==5.7.2 # via ipykernel # via jupyter-client -keyring==25.3.0 +keyring==25.4.1 # via keyrings-google-artifactregistry-auth keyrings-google-artifactregistry-auth==1.1.2 # via genspn -marimo==0.8.0 - # via genspn -markdown==3.7 - # via marimo - # via pymdown-extensions markdown-it-py==3.0.0 # via rich -markupsafe==2.1.5 +markupsafe==3.0.1 # via jinja2 matplotlib-inline==0.1.7 # via ipykernel # via ipython mdurl==0.1.2 # via markdown-it-py -ml-dtypes==0.4.0 +ml-dtypes==0.5.0 # via jax # via jaxlib -more-itertools==10.4.0 +more-itertools==10.5.0 # via jaraco-classes # via jaraco-functools -narwhals==1.5.2 +narwhals==1.9.3 # via altair nest-asyncio==1.6.0 # via ipykernel -numpy==2.1.0 +numpy==2.1.2 # via jax # via jaxlib # via ml-dtypes - # via opt-einsum # via pyarrow # via scipy -nvidia-cublas-cu12==12.6.0.22 - # via jax-cuda12-plugin - # via nvidia-cudnn-cu12 - # via nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.6.37 - # via jax-cuda12-plugin -nvidia-cuda-nvcc-cu12==12.6.20 - # via jax-cuda12-plugin -nvidia-cuda-runtime-cu12==12.6.37 - # via jax-cuda12-plugin -nvidia-cudnn-cu12==9.3.0.75 - # via jax-cuda12-plugin -nvidia-cufft-cu12==11.2.6.28 - # via jax-cuda12-plugin -nvidia-cusolver-cu12==11.6.4.38 - # via jax-cuda12-plugin -nvidia-cusparse-cu12==12.5.2.23 - # via jax-cuda12-plugin - # via nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.22.3 - # via jax-cuda12-plugin -nvidia-nvjitlink-cu12==12.6.20 - # via jax-cuda12-plugin - # via nvidia-cufft-cu12 - # via nvidia-cusolver-cu12 - # via nvidia-cusparse-cu12 -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via jax packaging==24.1 # via altair # via huggingface-hub # via ipykernel - # via marimo - # via pytest parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.2.2 +platformdirs==4.3.6 # via jupyter-core pluggy==1.5.0 # via keyrings-google-artifactregistry-auth - # via pytest plum-dispatch==2.5.2 # via genspn -polars==1.5.0 +polars==1.9.0 # via genspn -prompt-toolkit==3.0.47 +prompt-toolkit==3.0.48 # via ipython psutil==6.0.0 # via ipykernel - # via marimo ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pyarrow==17.0.0 # via genspn -pyasn1==0.6.0 +pyasn1==0.6.1 # via pyasn1-modules # via rsa -pyasn1-modules==0.4.0 +pyasn1-modules==0.4.1 # via google-auth -pycparser==2.22 - # via cffi pygments==2.18.0 # via ipython - # via marimo # via rich -pymdown-extensions==10.9 - # via marimo -pytest==8.3.2 - # via genspn python-dateutil==2.9.0.post0 # via jupyter-client pyyaml==6.0.2 # via huggingface-hub - # via marimo - # via pymdown-extensions pyzmq==26.2.0 # via ipykernel # via jupyter-client @@ -230,33 +159,23 @@ referencing==0.35.1 requests==2.32.3 # via huggingface-hub # via keyrings-google-artifactregistry-auth -rich==13.7.1 +rich==13.9.2 # via plum-dispatch rpds-py==0.20.0 # via jsonschema # via referencing rsa==4.9 # via google-auth -ruff==0.6.2 - # via marimo safetensors==0.4.5 # via genspn scipy==1.14.1 # via jax # via jaxlib -secretstorage==3.3.3 - # via keyring six==1.16.0 # via asttokens # via python-dateutil -sniffio==1.3.1 - # via anyio stack-data==0.6.3 # via ipython -starlette==0.38.2 - # via marimo -tomlkit==0.13.2 - # via marimo tornado==6.4.1 # via ipykernel # via jupyter-client @@ -276,11 +195,7 @@ typing-extensions==4.12.2 # via equinox # via huggingface-hub # via plum-dispatch -urllib3==2.2.2 +urllib3==2.2.3 # via requests -uvicorn==0.30.6 - # via marimo wcwidth==0.2.13 # via prompt-toolkit -websockets==12.0 - # via marimo diff --git a/src/genspn/__init__.py b/src/genspn/__init__.py deleted file mode 100644 index 5eb6009..0000000 --- a/src/genspn/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -def hello() -> str: - return "Hello from genspn!" diff --git a/src/genspn/distributions.py b/src/genspn/distributions.py index 17b411a..7faac99 100644 --- a/src/genspn/distributions.py +++ b/src/genspn/distributions.py @@ -191,8 +191,11 @@ def logpdf(dist: Mixed, x: Datapoint) -> Float[Array, ""]: @dispatch def logpdf(dist: GEM, pi: Float[Array, "n"], K: Integer[Array, ""]) -> Float[Array, ""]: - betas = jax.vmap(lambda i: 1 - pi[i] / pi[i-1])(jnp.arange(len(pi))) - betas = betas.at[0].set(pi[0]) + def unfold(carry, pi): + beta = pi / carry + return carry * (1-beta), beta + + _, betas = jax.lax.scan(unfold, 1.0, pi) logprobs = jax.vmap(jax.scipy.stats.beta.logpdf, in_axes=(0, None, 0))(betas, 1-dist.d, dist.alpha + (1 + jnp.arange(len(pi))) * dist.d) idx = jnp.arange(logprobs.shape[0]) logprobs = jnp.where(idx < K, logprobs, 0) diff --git a/test/test_score.py b/test/test_score.py index 1b41f1c..6306598 100644 --- a/test/test_score.py +++ b/test/test_score.py @@ -24,15 +24,12 @@ def test_score_stick_breaking(): logp0 = logpdf(gem, pis0, K) logp1 = logpdf(gem, pis1, K) - assert logp0 == ( - jax.scipy.stats.beta.logpdf(pis0[0], 1 - d, alpha + d) + - jax.scipy.stats.beta.logpdf(1 - pis0[1]/pis0[0], 1 - d, alpha + 2 * d) - ) - - assert logp1 == ( - jax.scipy.stats.beta.logpdf(pis1[0], 1 - d, alpha + d) + - jax.scipy.stats.beta.logpdf(1 - pis1[1]/pis1[0], 1 - d, alpha + 2 * d) - ) + logp0_expected = jax.scipy.stats.beta.logpdf(pis0[0], 1 - d, alpha + d) + jax.scipy.stats.beta.logpdf(pis0[1]/(1-pis0[0]), 1 - d, alpha + 2 * d) + assert logp0 == logp0_expected + + logp1_expected = jax.scipy.stats.beta.logpdf(pis1[0], 1 - d, alpha + d) + jax.scipy.stats.beta.logpdf(pis1[1]/(1-pis1[0]), 1 - d, alpha + 2 * d) + + assert logp1 == logp1_expected def test_score_data(): x = jnp.zeros((5, 6))