diff --git a/.gitignore b/.gitignore index 68bc17f..486fe26 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env +.envrc .venv env/ venv/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3491c8a..62accfe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,32 +4,37 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 23.1.0 + - id: trailing-whitespace + - id: check-added-large-files + args: ['--maxkb=1500'] + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-merge-conflict + - id: check-toml + - id: detect-private-key + - id: end-of-file-fixer + - id: fix-encoding-pragma + - id: name-tests-test + args: ['--pytest-test-first'] + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.9 hooks: - - id: black - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.6.1 - hooks: - - id: nbqa-black - - id: nbqa-isort - args: ["--float-to-top"] + # Run the linter. + - id: ruff + types_or: [ python, pyi, jupyter ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: - id: nbstripout - repo: local hooks: - - id: unittest - name: unittest - entry: python3 -m pytest -vx -m "not slow" + - id: pytest + name: pytest + entry: python3 -m pytest -x -m "not slow" pass_filenames: false language: system types: [python] diff --git a/Cargo.lock b/Cargo.lock index 32241ac..262e556 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,272 +2,1573 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "argminmax" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "202108b46429b765ef483f8a24d5c46f48c14acfdacc086dd4ab6dddf6bcdbd2" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + +[[package]] +name = "arrow-format" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07884ea216994cdc32a2d5f8274a8bee979cfe90274b83f86f440866ee3132c7" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + +[[package]] +name = "bumpalo" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" + +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "jobserver", + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "windows-targets 0.48.5", +] + +[[package]] +name = "comfy-table" +version = "7.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" +dependencies = [ + "crossterm", + "strum", + "strum_macros", + "unicode-width", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crossbeam-channel" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.4.1", + "crossterm_winapi", + "libc", + "parking_lot", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "dyn-clone" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "enum_dispatch" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f33313078bb8d4d05a2733a94ac4c2d8a0df9a2b84424ebf4f33bfc224a890e" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + +[[package]] +name = "getrandom" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", + "rayon", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "indexmap" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "lz4" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e9e2dd86df36ce760a60f6ff6ad526f7ba1f14ba0356f8254fb6905e6494df1" +dependencies = [ + "libc", + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "memchr" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" + +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "multiversion" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2c7b9d7fe61760ce5ea19532ead98541f6b4c495d87247aff9826445cf6872a" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26a83d8500ed06d68877e9de1dde76c1dbb83885dcdbda4ef44ccbc3fbda2ac8" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "target-features", +] + +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.48.5", +] + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pkg-config" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" + +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + +[[package]] +name = "polars" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938048fcda6a8e2ace6eb168bee1b415a92423ce51e418b853bf08fc40349b6b" +dependencies = [ + "getrandom", + "polars-core", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-sql", + "polars-time", + "version_check", +] + +[[package]] +name = "polars-arrow" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce68a02f698ff7787c261aea1b4c040a8fe183a8fb200e2436d7f35d95a1b86f" +dependencies = [ + "ahash", + "arrow-format", + "atoi_simd", + "bytemuck", + "chrono", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "foreign_vec", + "getrandom", + "hashbrown", + "itoa", + "lz4", + "multiversion", + "num-traits", + "polars-error", + "polars-utils", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "version_check", + "zstd", +] + +[[package]] +name = "polars-compute" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14fbc5f141b29b656a4cec4802632e5bff10bf801c6809c6bbfbd4078a044dd" +dependencies = [ + "bytemuck", + "num-traits", + "polars-arrow", + "polars-utils", + "version_check", +] + +[[package]] +name = "polars-core" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0f5efe734b6cbe5f97ea769be8360df5324fade396f1f3f5ad7fe9360ca4a23" +dependencies = [ + "ahash", + "bitflags 2.4.1", + "bytemuck", + "chrono", + "comfy-table", + "either", + "hashbrown", + "indexmap", + "num-traits", + "once_cell", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-utils", + "rand", + "rand_distr", + "rayon", + "regex", + "smartstring", + "thiserror", + "version_check", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6396de788f99ebfc9968e7b6f523e23000506cde4ba6dfc62ae4ce949002a886" +dependencies = [ + "arrow-format", + "regex", + "simdutf8", + "thiserror", +] + +[[package]] +name = "polars-io" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d0458efe8946f4718fd352f230c0db5a37926bd0d2bd25af79dc24746abaaea" +dependencies = [ + "ahash", + "atoi_simd", + "bytes", + "chrono", + "fast-float", + "home", + "itoa", + "memchr", + "memmap2", + "num-traits", + "once_cell", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-error", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", +] + +[[package]] +name = "polars-lazy" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d7105b40905bb38e8fc4a7fd736594b7491baa12fad3ac492969ca221a1b5d5" +dependencies = [ + "ahash", + "bitflags 2.4.1", + "glob", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-ops" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e09afc456ab11e75e5dcb43e00a01c71f3a46a2781e450054acb6bb096ca78e" +dependencies = [ + "ahash", + "argminmax", + "bytemuck", + "either", + "hashbrown", + "indexmap", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-pipe" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b7ead073cc3917027d77b59861a9f071db47125de9314f8907db1a0a3e4100" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-plan" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384a175624d050c31c473ee11df9d7af5d729ae626375e522158cfb3d150acd0" +dependencies = [ + "ahash", + "bytemuck", + "once_cell", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-time", + "polars-utils", + "rayon", + "regex", + "smartstring", + "strum_macros", + "version_check", +] + +[[package]] +name = "polars-row" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32322f7acbb83db3e9c7697dc821be73d06238da89c817dcc8bc1549a5e9c72f" +dependencies = [ + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f0b4c6ddffdfd0453e84bc3918572c633014d661d166654399cf93752aa95b5" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-plan", + "rand", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dee2649fc96bd1b6584e0e4a4b3ca7d22ed3d117a990e63ad438ecb26f7544d0" +dependencies = [ + "atoi", + "chrono", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b174ca4a77ad47d7b91a0460aaae65bbf874c8bfbaaa5308675dadef3976bbda" +dependencies = [ + "ahash", + "bytemuck", + "hashbrown", + "indexmap", + "num-traits", + "once_cell", + "polars-error", + "rayon", + "smartstring", + "sysinfo", + "version_check", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "pyo3-polars" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e983cb07cf665ea6e645ae9263c358062580f23a9aee41618a5706d4a7cc21" +dependencies = [ + "polars", + "polars-core", + "pyo3", + "thiserror", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] [[package]] -name = "bitflags" -version = "1.3.2" +name = "rand_distr" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] [[package]] -name = "cfg-if" -version = "1.0.0" +name = "random-tree-models" +version = "0.6.2" +dependencies = [ + "polars", + "pyo3", + "pyo3-polars", + "rand", + "rand_chacha", + "uuid", +] + +[[package]] +name = "rayon" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] [[package]] -name = "indoc" -version = "1.0.9" +name = "rayon-core" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] [[package]] -name = "libc" -version = "0.2.146" +name = "redox_syscall" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] [[package]] -name = "lock_api" -version = "0.4.10" +name = "regex" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ - "autocfg", - "scopeguard", + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", ] [[package]] -name = "memoffset" -version = "0.8.0" +name = "regex-automata" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ - "autocfg", + "aho-corasick", + "memchr", + "regex-syntax", ] [[package]] -name = "once_cell" -version = "1.18.0" +name = "regex-syntax" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] -name = "parking_lot" -version = "0.12.1" +name = "rustversion" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.195" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" dependencies = [ - "lock_api", - "parking_lot_core", + "serde_derive", ] [[package]] -name = "parking_lot_core" -version = "0.9.8" +name = "serde_derive" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", + "proc-macro2", + "quote", + "syn 2.0.48", ] [[package]] -name = "proc-macro2" -version = "1.0.60" +name = "serde_json" +version = "1.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" +checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" dependencies = [ - "unicode-ident", + "itoa", + "ryu", + "serde", ] [[package]] -name = "pyo3" -version = "0.18.3" +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + +[[package]] +name = "smallvec" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", + "autocfg", + "static_assertions", + "version_check", ] [[package]] -name = "pyo3-build-config" -version = "0.18.3" +name = "sqlparser" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3" +checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" dependencies = [ - "once_cell", - "target-lexicon", + "log", ] [[package]] -name = "pyo3-ffi" -version = "0.18.3" +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ - "libc", - "pyo3-build-config", + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.48", ] [[package]] -name = "pyo3-macros" -version = "0.18.3" +name = "syn" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "pyo3-macros-backend", "quote", - "syn", + "unicode-ident", ] [[package]] -name = "pyo3-macros-backend" -version = "0.18.3" +name = "syn" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", - "syn", + "unicode-ident", ] [[package]] -name = "quote" -version = "1.0.28" +name = "sysinfo" +version = "0.30.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "windows", +] + +[[package]] +name = "target-features" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" + +[[package]] +name = "target-lexicon" +version = "0.12.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" + +[[package]] +name = "thiserror" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", + "quote", + "syn 2.0.48", ] [[package]] -name = "random-tree-models" -version = "0.3.1" +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "uuid" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ - "pyo3", + "getrandom", ] [[package]] -name = "redox_syscall" -version = "0.3.5" +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ - "bitflags", + "cfg-if", + "wasm-bindgen-macro", ] [[package]] -name = "scopeguard" -version = "1.1.0" +name = "wasm-bindgen-backend" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.48", + "wasm-bindgen-shared", +] [[package]] -name = "smallvec" -version = "1.10.0" +name = "wasm-bindgen-macro" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] [[package]] -name = "syn" -version = "1.0.109" +name = "wasm-bindgen-macro-support" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "unicode-ident", + "syn 2.0.48", + "wasm-bindgen-backend", + "wasm-bindgen-shared", ] [[package]] -name = "target-lexicon" -version = "0.12.7" +name = "wasm-bindgen-shared" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] -name = "unicode-ident" -version = "1.0.9" +name = "winapi" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] [[package]] -name = "unindent" -version = "0.1.11" +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + +[[package]] +name = "xxhash-rust" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.9+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 90d0601..b719ae4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,12 @@ name = "random_tree_models" crate-type = ["cdylib"] [dependencies] -pyo3 = "0.18.3" +polars = { version = "0.36.2", features = ["lazy", "dtype-struct"] } +pyo3 = "0.20.0" +pyo3-polars = "0.10.0" +rand = "0.8.5" +rand_chacha = "0.3.1" +uuid = { version = "1.6.1", features = ["v4"] } [features] extension-module = ["pyo3/extension-module"] diff --git a/Makefile b/Makefile index e16e215..7a3396e 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ help: @echo "compile : update the environment requirements after changes to dependencies in pyproject.toml." @echo "update : pip install new requriements into the virtual environment." @echo "test : run pytests." + @echo "format : format rust code." # create a virtual environment .PHONY: venv @@ -59,3 +60,11 @@ update: test: source .venv/bin/activate && \ pytest -vx . + +# ============================================================================== +# format code +# ============================================================================== + +.PHONY: format +format: + cargo fmt diff --git a/config/requirements.txt b/config/requirements.txt index ffb0ac6..576fb2b 100644 --- a/config/requirements.txt +++ b/config/requirements.txt @@ -234,6 +234,7 @@ numpy==1.24.3 # contourpy # matplotlib # pandas + # pyarrow # scikit-learn # scipy # seaborn @@ -272,6 +273,8 @@ platformdirs==3.5.1 # virtualenv pluggy==1.0.0 # via pytest +polars==0.20.0 + # via random-tree-models (pyproject.toml) pre-commit==3.3.2 # via random-tree-models (pyproject.toml) prometheus-client==0.17.0 @@ -289,6 +292,8 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data +pyarrow==14.0.2 + # via random-tree-models (pyproject.toml) pycparser==2.21 # via cffi pydantic==1.10.8 diff --git a/nbs/decision-tree-rs.ipynb b/nbs/decision-tree-rs.ipynb new file mode 100644 index 0000000..d22721c --- /dev/null +++ b/nbs/decision-tree-rs.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision tree" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "* https://medium.com/@penggongting/implementing-decision-tree-from-scratch-in-python-c732e7c69aea\n", + "* https://www.kdnuggets.com/2020/01/decision-tree-algorithm-explained.html" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The core algorithm aka the CART algorithm\n", + "\n", + "CART = Classification And Regression Tree\n", + "\n", + "Starting with a tabular dataset we have columns / features and rows / observations. Each row has a target value, of which either all are continuous or categorical. \n", + "\n", + "Taking a subset of the observations as a training set, the algorithm iterates:\n", + "\n", + "1. select a feature\n", + "2. select a range of thresholds (e.g. the feature values in the taining set) \n", + "3. for each threshold\n", + " * create two groups of observations, one below the threshold and one above and \n", + " * evaluate the split score\n", + "4. select the threshold with the optimal split score (here that always means largest)\n", + "5. select the related group split \n", + "6. continue from 1. for each group whose target values are not yet homogeneous (e.g. not all the same class, or the standard variation is greater than zero)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "import sklearn.datasets as sk_datasets\n", + "\n", + "# import random_tree_models.decisiontree as dtree\n", + "import random_tree_models._rust as rust\n", + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.RandomState(42)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Classification\n", + "\n", + "split score:\n", + "* gini\n", + "* entropy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X, y = sk_datasets.make_classification(\n", + " n_samples=1_000,\n", + " n_features=2,\n", + " n_classes=2,\n", + " n_redundant=0,\n", + " class_sep=2,\n", + " random_state=rng,\n", + ")\n", + "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, alpha=0.3);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = rust.DecisionTreeClassifier(max_depth=4) # measure_name=\"gini\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X = pl.from_numpy(X)\n", + "y = pl.from_numpy(y).to_series()\n", + "display(X.head(2), y.head(2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def fit(X: pl.DataFrame, y: pl.Series):\n", + " model = rust.DecisionTreeClassifier(max_depth=4) # measure_name=\"gini\"\n", + " model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%timeit fit(X,y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# dtree.show_tree(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def predict(X: pl.DataFrame):\n", + " model.predict(X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%timeit predict(X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_prob = model.predict_proba(X)\n", + "y_prob.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x0 = np.linspace(X[:, 0].min(), X[:, 0].max(), 100)\n", + "x1 = np.linspace(X[:, 1].min(), X[:, 1].max(), 100)\n", + "X0, X1 = np.meshgrid(x0, x1)\n", + "X_plot = np.array([X0.ravel(), X1.ravel()]).T\n", + "X_plot = pl.from_numpy(X_plot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_prob = model.predict_proba(X_plot)\n", + "y_prob = y_prob.select(\"class_1\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# polars series to numpy array\n", + "y_prob = y_prob.to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_prob" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO - CONTINUE HERE: why is column_1 not used for prediction?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "im = ax.pcolormesh(X0, X1, y_prob.reshape(X0.shape), alpha=0.2)\n", + "fig.colorbar(im, ax=ax)\n", + "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, ax=ax, alpha=0.3)\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Regression\n", + "\n", + "split score:\n", + "\n", + "* variance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X, y, coefs = sk_datasets.make_regression(\n", + " n_samples=1_000, n_features=2, n_targets=1, coef=True, random_state=rng\n", + ")\n", + "sns.scatterplot(x=X[:, 0], y=y, alpha=0.3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = rust.DecisionTreeRegressor(max_depth=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X = pl.from_numpy(X)\n", + "y = pl.from_numpy(y).to_series()\n", + "display(X.head(2), y.head(2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# dtree.show_tree(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x0 = np.linspace(X[:, 0].min(), X[:, 0].max(), 100)\n", + "x1 = np.linspace(X[:, 1].min(), X[:, 1].max(), 100)\n", + "X0, X1 = np.meshgrid(x0, x1)\n", + "X_plot = np.array([X0.ravel(), X1.ravel()]).T\n", + "X_plot = pl.from_numpy(X_plot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = model.predict(X_plot)\n", + "y_pred[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(nrows=2, figsize=(12, 6))\n", + "\n", + "ax = axs[0]\n", + "sns.scatterplot(x=X_plot[:, 0], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "\n", + "ax = axs[1]\n", + "sns.scatterplot(x=X_plot[:, 1], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "im = ax.pcolormesh(X0, X1, y_pred.to_numpy().reshape(X0.shape), alpha=0.2)\n", + "fig.colorbar(im, ax=ax)\n", + "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, ax=ax, alpha=0.3)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = model.predict(X).to_numpy()\n", + "\n", + "fig, axs = plt.subplots(nrows=2, figsize=(12, 6))\n", + "\n", + "ax = axs[0]\n", + "sns.scatterplot(x=X[:, 0], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "sns.scatterplot(x=X[:, 0], y=y, ax=ax, alpha=0.1, label=\"actual\")\n", + "\n", + "ax = axs[1]\n", + "sns.scatterplot(x=X[:, 1], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "sns.scatterplot(x=X[:, 1], y=y, ax=ax, alpha=0.1, label=\"actual\")\n", + "\n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index e6ade6a..1218fe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,9 @@ dependencies = [ "pandas", "pydantic", "snakeviz", - "maturin" + "maturin", + "polars", + "pyarrow", ] [tool.maturin] @@ -38,14 +40,65 @@ bindings = "pyo3" features = ["pyo3/extension-module"] module-name = "random_tree_models._rust" -[tool.black] +# [tool.black] +# line-length = 80 + +# [tool.isort] +# multi_line_output = 3 +# line_length = 80 +# include_trailing_comma = true +# profile = "black" + +[tool.ruff] +# https://docs.astral.sh/ruff/configuration/ line-length = 80 +indent-width = 4 +target-version = "py310" +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "src", +] +[tool.ruff.lint] +fixable = ["ALL"] +unfixable = [] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false -[tool.isort] -multi_line_output = 3 -line_length = 80 -include_trailing_comma = true -profile = "black" +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" # [tool.setuptools.packages.find] # where = ["."] # list of folders that contain the packages (["."] by default) diff --git a/random_tree_models/decisiontree.py b/random_tree_models/decisiontree.py index 8b0ea1d..3fa90e1 100644 --- a/random_tree_models/decisiontree.py +++ b/random_tree_models/decisiontree.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import typing as T import uuid @@ -21,6 +22,7 @@ import random_tree_models.leafweights as leafweights import random_tree_models.scoring as scoring import random_tree_models.utils as utils +import abc logger = utils.logger @@ -199,12 +201,11 @@ def get_column( def find_best_split( X: np.ndarray, y: np.ndarray, - measure_name: str, - yhat: np.ndarray = None, + growth_params: utils.TreeGrowthParameters, g: np.ndarray = None, h: np.ndarray = None, - growth_params: utils.TreeGrowthParameters = None, rng: np.random.RandomState = np.random.RandomState(42), + incrementing_score: scoring.IncrementingScore = None, ) -> BestSplit: """Find the best split, detecting the "default direction" with missing data.""" @@ -225,13 +226,13 @@ def find_best_split( ) in get_thresholds_and_target_groups( feature_values, growth_params.threshold_params, rng ): - split_score = scoring.SplitScoreMetrics[measure_name]( + split_score = scoring.calc_score( y, target_groups, - yhat=yhat, + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, + incrementing_score=incrementing_score, ) if best is None or split_score > best.score: @@ -271,23 +272,21 @@ def check_if_split_sensible( def calc_leaf_weight_and_split_score( y: np.ndarray, - measure_name: str, growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, + incrementing_score: scoring.IncrementingScore = None, ) -> T.Tuple[float]: - leaf_weight = leafweights.calc_leaf_weight( - y, measure_name, growth_params, g=g, h=h - ) + leaf_weight = leafweights.calc_leaf_weight(y, growth_params, g=g, h=h) - yhat = leaf_weight * np.ones_like(y) - score = scoring.SplitScoreMetrics[measure_name]( - y, - np.ones_like(y, dtype=bool), - yhat=yhat, + # yhat = leaf_weight * np.ones_like(y) + score = scoring.calc_score( + y=y, + target_groups=np.ones_like(y, dtype=bool), + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, + incrementing_score=incrementing_score, ) return leaf_weight, score @@ -312,24 +311,22 @@ def select_arrays_for_child_node( def grow_tree( X: np.ndarray, y: np.ndarray, - measure_name: str, + growth_params: utils.TreeGrowthParameters, parent_node: Node = None, depth: int = 0, - growth_params: utils.TreeGrowthParameters = None, g: np.ndarray = None, h: np.ndarray = None, random_state: int = 42, - **kwargs, + incrementing_score: scoring.IncrementingScore = None, ) -> Node: """Implementation of the Classification And Regression Tree (CART) algorithm Args: X (np.ndarray): Input feature values to do thresholding on. y (np.ndarray): Target values. - measure_name (str): Values indicating which functions in scoring.SplitScoreMetrics and leafweights.LeafWeightSchemes to call. + growth_params (utils.TreeGrowthParameters, optional): Parameters controlling tree growth. parent_node (Node, optional): Parent node in tree. Defaults to None. depth (int, optional): Current tree depth. Defaults to 0. - growth_params (utils.TreeGrowthParameters, optional): Parameters controlling tree growth. Defaults to None. g (np.ndarray, optional): Boosting and loss specific precomputed 1st order derivative dloss/dyhat. Defaults to None. h (np.ndarray, optional): Boosting and loss specific precomputed 2nd order derivative d^2loss/dyhat^2. Defaults to None. @@ -340,7 +337,7 @@ def grow_tree( Node: Tree node with leaf weight, node score and potential child nodes. Note: - Currently measure_name controls how the split score and the leaf weights are computed. + Currently growth_params.split_score_name controls how the split score and the leaf weights are computed. But only the decision tree algorithm directly uses y for that and can predict y using the leaf weight values directly. @@ -361,9 +358,15 @@ def grow_tree( # compute leaf weight (for prediction) and node score (for split gain check) leaf_weight, score = calc_leaf_weight_and_split_score( - y, measure_name, growth_params, g, h + y, + growth_params, + g, + h, + incrementing_score=incrementing_score, ) + measure_name = growth_params.split_score_metric.name + if is_baselevel: # end of the line buddy return Node( prediction=leaf_weight, @@ -377,7 +380,13 @@ def grow_tree( rng = np.random.RandomState(random_state) best = find_best_split( - X, y, measure_name, g=g, h=h, growth_params=growth_params, rng=rng + X, + y, + g=g, + h=h, + growth_params=growth_params, + rng=rng, + incrementing_score=incrementing_score, ) # check if improvement due to split is below minimum requirement @@ -414,13 +423,13 @@ def grow_tree( new_node.left = grow_tree( _X, _y, - measure_name=measure_name, + growth_params=growth_params, parent_node=new_node, depth=depth + 1, - growth_params=growth_params, g=_g, h=_h, random_state=random_state_left, + incrementing_score=incrementing_score, ) # descend right @@ -428,13 +437,13 @@ def grow_tree( new_node.right = grow_tree( _X, _y, - measure_name=measure_name, + growth_params=growth_params, parent_node=new_node, depth=depth + 1, - growth_params=growth_params, g=_g, h=_h, random_state=random_state_right, + incrementing_score=incrementing_score, ) return new_node @@ -480,7 +489,7 @@ def predict_with_tree(tree: Node, X: np.ndarray) -> np.ndarray: return predictions -class DecisionTreeTemplate(base.BaseEstimator): +class DecisionTreeTemplate(abc.ABC, base.BaseEstimator): """Template for DecisionTree classes Based on: https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator @@ -514,24 +523,36 @@ def __init__( self.column_method = column_method self.n_columns_to_try = n_columns_to_try + def _sanity_check_measure_name(self) -> utils.SplitScoreMetrics: + try: + return utils.SplitScoreMetrics[self.measure_name] + except KeyError as ex: + raise KeyError( + f"Unknown measure_name: {self.measure_name}. " + f"Valid options: {', '.join(list(utils.SplitScoreMetrics.__members__.keys()))}. {ex=}" + ) + def _organize_growth_parameters(self): + threshold_params = utils.ThresholdSelectionParameters( + method=self.threshold_method, + quantile=self.threshold_quantile, + n_thresholds=self.n_thresholds, + random_state=int(self.random_state), + ) + column_params = utils.ColumnSelectionParameters( + method=self.column_method, + n_trials=self.n_columns_to_try, + ) self.growth_params_ = utils.TreeGrowthParameters( max_depth=self.max_depth, + split_score_metric=self._sanity_check_measure_name(), min_improvement=self.min_improvement, lam=-abs(self.lam), frac_subsamples=float(self.frac_subsamples), frac_features=float(self.frac_features), random_state=int(self.random_state), - threshold_params=utils.ThresholdSelectionParameters( - method=self.threshold_method, - quantile=self.threshold_quantile, - n_thresholds=self.n_thresholds, - random_state=int(self.random_state), - ), - column_params=utils.ColumnSelectionParameters( - method=self.column_method, - n_trials=self.n_columns_to_try, - ), + threshold_params=threshold_params, + column_params=column_params, ) def _select_samples_and_features( @@ -539,7 +560,7 @@ def _select_samples_and_features( ) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: "Sub-samples rows and columns from X and y" if not hasattr(self, "growth_params_"): - raise ValueError(f"Try calling `fit` first.") + raise ValueError("Try calling `fit` first.") ix = np.arange(len(X)) rng = np.random.RandomState(self.growth_params_.random_state) @@ -570,15 +591,17 @@ def _select_features( ) -> np.ndarray: return X[:, ix_features] + @abc.abstractmethod def fit( self, X: T.Union[pd.DataFrame, np.ndarray], y: T.Union[pd.Series, np.ndarray], ) -> "DecisionTreeTemplate": - raise NotImplementedError() + ... + @abc.abstractmethod def predict(self, X: T.Union[pd.DataFrame, np.ndarray]) -> np.ndarray: - raise NotImplementedError() + ... class DecisionTreeRegressor(base.RegressorMixin, DecisionTreeTemplate): @@ -611,7 +634,6 @@ def fit( self.tree_ = grow_tree( _X, _y, - measure_name=self.measure_name, growth_params=self.growth_params_, random_state=self.random_state, **kwargs, @@ -675,7 +697,6 @@ def fit( self.tree_ = grow_tree( _X, _y, - measure_name=self.measure_name, growth_params=self.growth_params_, random_state=self.random_state, ) diff --git a/random_tree_models/isolationforest.py b/random_tree_models/isolationforest.py index 577aeac..d1c7276 100644 --- a/random_tree_models/isolationforest.py +++ b/random_tree_models/isolationforest.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import typing as T import numpy as np @@ -7,6 +8,7 @@ from sklearn.utils.validation import check_array, check_is_fitted import random_tree_models.decisiontree as dtree +import random_tree_models.scoring as scoring # TODO: add tests @@ -60,13 +62,14 @@ def fit( _X, _y, self.ix_features_ = self._select_samples_and_features( X, dummy_y ) + self.incrementing_score_ = scoring.IncrementingScore() self.tree_ = dtree.grow_tree( _X, _y, - measure_name=self.measure_name, growth_params=self.growth_params_, random_state=self.random_state, + incrementing_score=self.incrementing_score_, **kwargs, ) diff --git a/random_tree_models/leafweights.py b/random_tree_models/leafweights.py index 351a3b0..62ac87b 100644 --- a/random_tree_models/leafweights.py +++ b/random_tree_models/leafweights.py @@ -1,18 +1,15 @@ -from enum import Enum -from functools import partial - +# -*- coding: utf-8 -*- import numpy as np import random_tree_models.utils as utils -def leaf_weight_mean(y: np.ndarray, **kwargs) -> float: +def leaf_weight_mean(y: np.ndarray) -> float: return np.mean(y) def leaf_weight_binary_classification_friedman2001( g: np.ndarray, - **kwargs, ) -> float: "Computes optimal leaf weight as in Friedman et al. 2001 Algorithm 5" @@ -25,7 +22,6 @@ def leaf_weight_xgboost( growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, - **kwargs, ) -> float: "Computes optimal leaf weight as in Chen et al. 2016 equation 5" @@ -33,32 +29,8 @@ def leaf_weight_xgboost( return w -class LeafWeightSchemes(Enum): - # https://stackoverflow.com/questions/40338652/how-to-define-enum-values-that-are-functions - friedman_binary_classification = partial( - leaf_weight_binary_classification_friedman2001 - ) - variance = partial(leaf_weight_mean) - entropy = partial(leaf_weight_mean) - entropy_rs = partial(leaf_weight_mean) - gini = partial(leaf_weight_mean) - gini_rs = partial(leaf_weight_mean) - xgboost = partial(leaf_weight_xgboost) - incrementing = partial(leaf_weight_mean) - - def __call__( - self, - y: np.ndarray, - growth_params: utils.TreeGrowthParameters, - g: np.ndarray = None, - h: np.ndarray = None, - ) -> float: - return self.value(y=y, growth_params=growth_params, g=g, h=h) - - def calc_leaf_weight( y: np.ndarray, - measure_name: str, growth_params: utils.TreeGrowthParameters, g: np.ndarray = None, h: np.ndarray = None, @@ -71,7 +43,25 @@ def calc_leaf_weight( if len(y) == 0: return None - weight_func = LeafWeightSchemes[measure_name] - leaf_weight = weight_func(y=y, growth_params=growth_params, g=g, h=h) + measure_name = growth_params.split_score_metric + + match measure_name: + case ( + utils.SplitScoreMetrics.variance + | utils.SplitScoreMetrics.entropy + | utils.SplitScoreMetrics.entropy_rs + | utils.SplitScoreMetrics.gini + | utils.SplitScoreMetrics.gini_rs + | utils.SplitScoreMetrics.incrementing + ): + leaf_weight = leaf_weight_mean(y) + case utils.SplitScoreMetrics.friedman_binary_classification: + leaf_weight = leaf_weight_binary_classification_friedman2001(g) + case utils.SplitScoreMetrics.xgboost: + leaf_weight = leaf_weight_xgboost(growth_params, g, h) + case _: + raise KeyError( + f"Unknown measure_name: {measure_name}, expected one of {', '.join(list(utils.SplitScoreMetrics.__members__.keys()))}" + ) return leaf_weight diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index 4caabec..d9f59ab 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -1,6 +1,4 @@ -from enum import Enum -from functools import partial - +# -*- coding: utf-8 -*- import numpy as np import random_tree_models.utils as utils @@ -16,7 +14,7 @@ def check_y_and_target_groups(y: np.ndarray, target_groups: np.ndarray = None): raise ValueError(f"{y.shape=} != {target_groups.shape=}") -def calc_variance(y: np.ndarray, target_groups: np.ndarray, **kwargs) -> float: +def calc_variance(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the variance of a split""" check_y_and_target_groups(y, target_groups=target_groups) @@ -58,7 +56,7 @@ def entropy(y: np.ndarray) -> float: return h -def calc_entropy(y: np.ndarray, target_groups: np.ndarray, **kwargs) -> float: +def calc_entropy(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the entropy of a split""" check_y_and_target_groups(y, target_groups=target_groups) @@ -73,9 +71,7 @@ def calc_entropy(y: np.ndarray, target_groups: np.ndarray, **kwargs) -> float: return h -def calc_entropy_rs( - y: np.ndarray, target_groups: np.ndarray, **kwargs -) -> float: +def calc_entropy_rs(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the entropy of a split""" check_y_and_target_groups(y, target_groups=target_groups) @@ -113,9 +109,7 @@ def gini_impurity(y: np.ndarray) -> float: return -g -def calc_gini_impurity( - y: np.ndarray, target_groups: np.ndarray, **kwargs -) -> float: +def calc_gini_impurity(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the gini impurity of a split Based on: https://scikit-learn.org/stable/modules/tree.html#classification-criteria @@ -133,9 +127,7 @@ def calc_gini_impurity( return g -def calc_gini_impurity_rs( - y: np.ndarray, target_groups: np.ndarray, **kwargs -) -> float: +def calc_gini_impurity_rs(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the gini impurity of a split Based on: https://scikit-learn.org/stable/modules/tree.html#classification-criteria @@ -172,12 +164,10 @@ def xgboost_split_score( def calc_xgboost_split_score( - y: np.ndarray, target_groups: np.ndarray, g: np.ndarray, h: np.ndarray, growth_params: utils.TreeGrowthParameters, - **kwargs, ) -> float: """Calculates the xgboost general version score of a split with loss specifics in g and h. @@ -206,36 +196,44 @@ def calc_xgboost_split_score( class IncrementingScore: - score = 0 + score: int = 0 - def __call__(self, *args, **kwargs) -> float: + def update(self) -> float: """Calculates the random cut score of a split""" self.score += 1 return self.score -class SplitScoreMetrics(Enum): - # https://stackoverflow.com/questions/40338652/how-to-define-enum-values-that-are-functions - variance = partial(calc_variance) - entropy = partial(calc_entropy) - entropy_rs = partial(calc_entropy_rs) - gini = partial(calc_gini_impurity) - gini_rs = partial(calc_gini_impurity_rs) - # variance for split score because Friedman et al. 2001 in Algorithm 1 - # step 4 minimize the squared error between actual and predicted dloss/dyhat - friedman_binary_classification = partial(calc_variance) - xgboost = partial(calc_xgboost_split_score) - incrementing = partial(IncrementingScore()) - - def __call__( - self, - y: np.ndarray, - target_groups: np.ndarray, - yhat: np.ndarray = None, - g: np.ndarray = None, - h: np.ndarray = None, - growth_params: utils.TreeGrowthParameters = None, - ) -> float: - return self.value( - y, target_groups, yhat=yhat, g=g, h=h, growth_params=growth_params - ) +def calc_score( + y: np.ndarray, + target_groups: np.ndarray, + growth_params: utils.TreeGrowthParameters, + g: np.ndarray = None, + h: np.ndarray = None, + incrementing_score: IncrementingScore = None, +) -> float: + measure_name = growth_params.split_score_metric + + match measure_name: + case utils.SplitScoreMetrics.variance: + return calc_variance(y, target_groups) + case utils.SplitScoreMetrics.entropy: + return calc_entropy(y, target_groups) + case utils.SplitScoreMetrics.entropy_rs: + return calc_entropy_rs(y, target_groups) + case utils.SplitScoreMetrics.gini: + return calc_gini_impurity(y, target_groups) + case utils.SplitScoreMetrics.gini_rs: + return calc_gini_impurity_rs(y, target_groups) + case utils.SplitScoreMetrics.friedman_binary_classification: + return calc_variance(y, target_groups) + case utils.SplitScoreMetrics.xgboost: + return calc_xgboost_split_score(target_groups, g, h, growth_params) + case utils.SplitScoreMetrics.incrementing: + if incrementing_score is None: + raise ValueError( + f"{incrementing_score=} must be provided as an instance of scoring.IncrementingScore {measure_name=}" + ) + return incrementing_score.update() + case _: + raise ValueError(f"{measure_name=} not supported") diff --git a/random_tree_models/utils.py b/random_tree_models/utils.py index 0da4b7e..4feeccc 100644 --- a/random_tree_models/utils.py +++ b/random_tree_models/utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import logging from enum import Enum @@ -63,10 +64,22 @@ class ColumnSelectionParameters: n_trials: StrictInt = None +class SplitScoreMetrics(Enum): + variance = "variance" + entropy = "entropy" + entropy_rs = "entropy_rs" + gini = "gini" + gini_rs = "gini_rs" + friedman_binary_classification = "friedman_binary_classification" + xgboost = "xgboost" + incrementing = "incrementing" + + @dataclass class TreeGrowthParameters: max_depth: StrictInt min_improvement: StrictFloat = 0.0 + split_score_metric: SplitScoreMetrics = SplitScoreMetrics.variance # xgboost lambda - multiplied with sum of squares of leaf weights # see Chen et al. 2016 equation 2 lam: StrictFloat = 0.0 diff --git a/src/decisiontree.rs b/src/decisiontree.rs new file mode 100644 index 0000000..d9ebeb2 --- /dev/null +++ b/src/decisiontree.rs @@ -0,0 +1,942 @@ +use polars::{lazy::dsl::GetOutput, prelude::*}; +// use rand::SeedableRng; +// use rand_chacha::ChaCha20Rng; +use uuid::Uuid; + +use crate::{ + scoring, + utils::{SplitScoreMetrics, TreeGrowthParameters}, +}; + +#[derive(PartialEq, Debug, Clone)] +pub struct SplitScore { + pub name: String, + pub score: f64, +} + +impl SplitScore { + pub fn new(name: String, score: f64) -> Self { + SplitScore { name, score } + } +} + +#[derive(PartialEq, Debug, Clone)] +pub struct Node { + pub column: Option, + pub column_idx: Option, + pub threshold: Option, + pub prediction: Option, + pub default_is_left: Option, + + // descendants + pub left: Option>, + pub right: Option>, + + // misc + pub measure: Option, + + pub n_obs: usize, + pub reason: String, + pub depth: usize, + pub node_id: Uuid, +} + +impl Node { + pub fn new( + column: Option, + column_idx: Option, + threshold: Option, + prediction: Option, + default_is_left: Option, + left: Option>, + right: Option>, + measure: Option, + n_obs: usize, + reason: String, + depth: usize, + ) -> Self { + let node_id = Uuid::new_v4(); + Node { + column, + column_idx, + threshold, + prediction, + default_is_left, + left, + right, + measure, + n_obs, + reason, + depth, + node_id, + } + } + + pub fn is_leaf(&self) -> bool { + self.left.is_none() && self.right.is_none() + } + + pub fn insert(&mut self, new_node: Node, insert_left: bool) { + if insert_left { + match self.left { + Some(ref mut _left) => { + panic!("Something went wrong. The left node is already occupied.") + } // left.insert(new_node) + None => self.left = Some(Box::new(new_node)), + } + } else { + match self.right { + Some(ref mut _right) => { + panic!("Something went wrong. The right node is already occupied.") + } + //right.insert(new_node), + None => self.right = Some(Box::new(new_node)), + } + } + } +} + +pub fn check_is_baselevel( + y: &Series, + depth: usize, + growth_params: &TreeGrowthParameters, +) -> (bool, String) { + let n_obs = y.len(); + let n_unique = y + .n_unique() + .expect("Something went wrong. Could not get n_unique."); + let max_depth = growth_params.max_depth; + + if max_depth.is_some() && depth >= max_depth.unwrap() { + return (true, "max depth reached".to_string()); + } else if n_unique == 1 { + return (true, "homogenous group".to_string()); + } else if n_obs <= 1 { + return (true, "<= 1 data point in group".to_string()); + } else { + (false, "".to_string()) + } +} + +pub fn calc_leaf_weight( + y: &Series, + growth_params: &TreeGrowthParameters, + _g: Option<&Series>, + _h: Option<&Series>, +) -> f64 { + match growth_params.split_score_metric { + Some(SplitScoreMetrics::NegEntropy) => y.mean().unwrap(), + Some(SplitScoreMetrics::NegVariance) => y.mean().unwrap(), + _ => panic!("Something went wrong. The split_score_metric is not defined."), + } +} + +pub fn calc_leaf_weight_and_split_score( + y: &Series, + growth_params: &TreeGrowthParameters, + g: Option<&Series>, + h: Option<&Series>, + incrementing_score: Option, +) -> (f64, SplitScore) { + let leaf_weight = calc_leaf_weight(y, growth_params, g, h); + + let target_groups: Series = Series::new("target_groups", vec![true; y.len()]); + let score = scoring::calc_score(y, &target_groups, growth_params, g, h, incrementing_score); + let name = match &growth_params.split_score_metric { + Some(metric) => metric.to_string(), + _ => panic!("Something went wrong. The split_score_metric is not defined."), + }; + + let score = SplitScore::new(name, score); + + (leaf_weight, score) +} + +#[derive(PartialEq, Debug, Clone)] +pub struct BestSplit { + pub score: f64, + pub column: String, + pub column_idx: usize, + pub threshold: f64, + pub target_groups: Series, + pub default_is_left: Option, +} + +impl BestSplit { + pub fn new( + score: f64, + column: String, + column_idx: usize, + threshold: f64, + target_groups: Series, + default_is_left: Option, + ) -> Self { + BestSplit { + score, + column, + column_idx, + threshold, + target_groups, + default_is_left, + } + } +} + +pub fn find_best_split( + x: &DataFrame, + y: &Series, + growth_params: &TreeGrowthParameters, + g: Option<&Series>, + h: Option<&Series>, + incrementing_score: Option, +) -> BestSplit { + if y.len() <= 1 { + panic!("Something went wrong. The parent_node handed down less than two data points.") + } + // TODO: handle case where there are duplicates + // TODO: handle case where there is only one unique y value but multiple non-duplicate x values + let mut best_split: Option = None; + // println!("finding best split"); + for (idx, col) in x.get_column_names().iter().enumerate() { + // println!("col {:?}, idx {:?}", col, idx); + let feature_values = x.select_series(&[col]).unwrap()[0] + .clone() + .cast(&DataType::Float64) + .unwrap(); + + let unique_values = feature_values.unique().unwrap().sort(false); + + // skip the below if there is only one unique value + if unique_values.len() == 1 { + continue; + } + + let mut unique_iter = unique_values.iter(); + unique_iter.next().unwrap(); // skipping the first value + for value in unique_iter { + let value: f64 = value.try_extract().unwrap(); + // println!("value {:?}", value); + let target_groups = feature_values.lt(value).unwrap(); + let target_groups = Series::new("target_groups", target_groups); + + let score = + scoring::calc_score(y, &target_groups, growth_params, g, h, incrementing_score); + // println!("score {:?}", score); + match best_split { + Some(ref mut best_split) => { + if score > best_split.score { + best_split.score = score; + best_split.column_idx = idx; + best_split.column = col.to_string(); + best_split.threshold = value; + best_split.target_groups = target_groups; + best_split.default_is_left = None; + } + } + None => { + best_split = Some(BestSplit::new( + score, + col.to_string(), + idx, + value, + target_groups, + None, + )); + } + } + // println!("best_split {:?}", best_split); + } + } + + best_split.unwrap() +} + +pub fn select_arrays_for_child_node( + go_left: bool, + best: &BestSplit, + x: DataFrame, + y: Series, +) -> (DataFrame, Series) { + if go_left { + let x_child = x + .clone() + .filter(&best.target_groups.bool().unwrap()) + .unwrap(); + let y_child = y + .clone() + .filter(&best.target_groups.bool().unwrap()) + .unwrap(); + return (x_child, y_child); + } else { + let x_child = x + .clone() + .filter(&!best.target_groups.bool().unwrap()) + .unwrap(); + let y_child = y + .clone() + .filter(&!best.target_groups.bool().unwrap()) + .unwrap(); + return (x_child, y_child); + } +} + +// Inspirations: +// * https://rusty-ferris.pages.dev/blog/binary-tree-sum-of-values/ +// * https://gist.github.com/aidanhs/5ac9088ca0f6bdd4a370 +pub fn grow_tree( + x: &DataFrame, + y: &Series, + growth_params: &TreeGrowthParameters, + _parent_node: Option<&Node>, + depth: usize, +) -> Node { + let n_obs = x.height(); + // println!("\nn_obs {:?}", n_obs); + // println!("depth {:?}", depth); + if n_obs == 0 { + panic!("Something went wrong. The parent_node handed down an empty set of data points.") + } + + let (is_baselevel, reason) = check_is_baselevel(y, depth, growth_params); + + let (leaf_weight, score) = calc_leaf_weight_and_split_score(y, growth_params, None, None, None); + // println!("leaf_weight {:?}", leaf_weight); + if is_baselevel { + let new_node = Node::new( + None, + None, + None, + Some(leaf_weight), + None, + None, + None, + Some(score), + n_obs, + reason, + depth, + ); + return new_node; + } + + // find best split + let best = find_best_split(x, y, growth_params, None, None, None); + // println!("column {:?}", best.column); + // println!("column_idx {:?}", best.column_idx); + // println!("threshold {:?}", best.threshold); + // println!("target_groups {:?}", best.target_groups); + + // let mut rng = ChaCha20Rng::seed_from_u64(42); + + let best_ = best.clone(); + let mut new_node = Node::new( + Some(best_.column), + Some(best_.column_idx), + Some(best_.threshold), + Some(leaf_weight), + match best_.default_is_left { + Some(default_is_left) => Some(default_is_left), + None => None, + }, + None, + None, + Some(SplitScore::new("neg_entropy".to_string(), best_.score)), + n_obs, + "leaf node".to_string(), + depth, + ); + + // check if improvement due to split is below minimum requirement + + // descend left + let (x_left, y_left) = select_arrays_for_child_node(true, &best, x.clone(), y.clone()); + // println!("x_left {:?}", x_left); + // println!("y_left {:?}", y_left); + let new_left_node = grow_tree(&x_left, &y_left, growth_params, Some(&new_node), &depth + 1); // mut new_node, + new_node.insert(new_left_node, true); + + // descend right + let (x_right, y_right) = select_arrays_for_child_node(false, &best, x.clone(), y.clone()); + // println!("x_right {:?}", x_right); + // println!("y_right {:?}", y_right); + let new_right_node = grow_tree( + &x_right, + &y_right, + growth_params, + Some(&new_node), + depth + 1, + ); // mut new_node, + new_node.insert(new_right_node, false); + + return new_node; +} + +pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { + let mut node = tree; + + let row_f64 = (*row).cast(&DataType::Float64).unwrap(); + let row = row_f64.f64().unwrap(); + + while !node.is_leaf() { + let idx = node.column_idx.unwrap(); + let value: f64 = row.get(idx).expect("Accessing failed."); + + let threshold = node.threshold.unwrap(); + let is_left = value < threshold; + // println!("idx {:?} value {:?} threshold {:?} is_left {:?}", idx, value, threshold, is_left); + // let is_left = if value < threshold { + // node.default_is_left.unwrap() + // } else { + // !node.default_is_left.unwrap() + // }; + if is_left { + node = node.left.as_ref().unwrap(); + } else { + node = node.right.as_ref().unwrap(); + } + } + node.prediction.unwrap() +} + +pub fn udf<'a, 'b>( + s: Series, + n_cols: &'a usize, + tree: &'b Node, +) -> Result, PolarsError> { + let mut preds: Vec = vec![]; + + for struct_ in s.iter() { + let mut row: Vec = vec![]; + let mut iter = struct_._iter_struct_av(); + for _ in 0..*n_cols { + let value = iter.next().unwrap().try_extract::().unwrap(); + row.push(value); + } + let row = Series::new("", row); + // println!("\nrow {:?}", row); + let prediction = predict_for_row_with_tree(&row, tree); + // println!("prediction {:?}", prediction); + preds.push(prediction); + } + + Ok(Some(Series::new("predictions", preds))) +} + +pub fn predict_with_tree(x: DataFrame, tree: Node) -> Series { + // use polars to apply predict_for_row_with_tree to get one prediction per row + + let mut columns: Vec = vec![]; + let column_names = x.get_column_names(); + for v in column_names { + columns.push(col(v)); + } + let n_cols: usize = columns.len(); + + let predictions = x + .lazy() + .select([as_struct(columns) + .apply( + move |s| udf(s, &n_cols, &tree), + GetOutput::from_type(DataType::Float64), + ) + .alias("predictions")]) + .collect() + .unwrap(); + + predictions.select_series(&["predictions"]).unwrap()[0].clone() +} + +pub struct DecisionTreeCore { + pub growth_params: TreeGrowthParameters, + tree: Option, +} + +impl DecisionTreeCore { + pub fn new(max_depth: usize, split_score_metric: SplitScoreMetrics) -> Self { + let growth_params = TreeGrowthParameters { + max_depth: Some(max_depth), + split_score_metric: Some(split_score_metric), + }; + DecisionTreeCore { + growth_params, + tree: None, + } + } + + pub fn fit(&mut self, x: &DataFrame, y: &Series) { + self.tree = Some(grow_tree(x, y, &self.growth_params, None, 0)); + } + + pub fn predict(&self, x: &DataFrame) -> Series { + let x = x.clone(); + let tree_ = self.tree.clone(); + match tree_ { + Some(tree) => predict_with_tree(x, tree), + None => panic!("Something went wrong. The tree is not initialized."), + } + } +} + +pub struct DecisionTreeClassifier { + decision_tree_core: DecisionTreeCore, +} + +impl DecisionTreeClassifier { + pub fn new(max_depth: usize) -> Self { + DecisionTreeClassifier { + decision_tree_core: DecisionTreeCore::new(max_depth, SplitScoreMetrics::NegEntropy), + } + } + + pub fn fit(&mut self, x: &DataFrame, y: &Series) { + self.decision_tree_core.fit(x, y); + } + + pub fn predict_proba(&self, x: &DataFrame) -> DataFrame { + // println!("predict_proba for {:?}", x.shape()); + let class1 = self.decision_tree_core.predict(x); + // println!("class1 {:?}", class1.len()); + let y_proba: DataFrame = df!("class_1" => &class1) + .unwrap() + .lazy() + .with_columns([(lit(1.) - col("class_1")).alias("class_0")]) + .collect() + .unwrap(); + let y_proba = y_proba.select(&["class_0", "class_1"]).unwrap(); + y_proba + } + + pub fn predict(&self, x: &DataFrame) -> Series { + let y_proba = self.predict_proba(x); + // define "y" as a Series that contains the index of the maximum value column per row + let y = y_proba + .lazy() + .select([(col("class_1").gt(0.5)).alias("y")]) + .collect() + .unwrap(); + + y.select_series(&["y"]).unwrap()[0].clone() + } +} + +pub struct DecisionTreeRegressor { + decision_tree_core: DecisionTreeCore, +} + +impl DecisionTreeRegressor { + pub fn new(max_depth: usize) -> Self { + DecisionTreeRegressor { + decision_tree_core: DecisionTreeCore::new(max_depth, SplitScoreMetrics::NegVariance), + } + } + + pub fn fit(&mut self, x: &DataFrame, y: &Series) { + self.decision_tree_core.fit(x, y); + } + + pub fn predict(&self, x: &DataFrame) -> Series { + let y_pred = Series::new("y", self.decision_tree_core.predict(x)); + + y_pred + } +} + +#[cfg(test)] +mod tests { + // use rand_chacha::ChaCha20Rng; + // use rand::SeedableRng; + + use super::*; + + #[test] + fn test_split_score() { + let split_score = SplitScore::new("test".to_string(), 0.5); + assert_eq!(split_score.name, "test"); + assert_eq!(split_score.score, 0.5); + } + + #[test] + fn test_node_init() { + let node = Node::new( + Some("column".to_string()), + Some(0), + Some(0.0), + Some(1.0), + Some(true), + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), + 10, + "leaf node".to_string(), + 0, + ); + assert_eq!(node.column.unwrap(), "column".to_string()); + assert_eq!(node.column_idx.unwrap(), 0); + assert_eq!(node.threshold.unwrap(), 0.0); + assert_eq!(node.prediction.unwrap(), 1.0); + assert_eq!(node.default_is_left.unwrap(), true); + assert_eq!(node.left, None); + assert_eq!(node.right, None); + let m = node.measure.unwrap(); + assert_eq!(m.name, "score"); + assert_eq!(m.score, 0.5); + assert_eq!(node.n_obs, 10); + assert_eq!(node.reason, "leaf node".to_string()); + assert_eq!(node.depth, 0); + } + + #[test] + fn test_child_node_assignment() { + let mut node = Node::new( + Some("column".to_string()), + Some(0), + Some(0.0), + Some(1.0), + Some(true), + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), + 10, + "leaf node".to_string(), + 0, + ); + let child_node = Node::new( + Some("column".to_string()), + Some(0), + Some(0.0), + Some(1.0), + Some(true), + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), + 10, + "leaf node".to_string(), + 0, + ); + node.left = Some(Box::new(child_node)); + assert_eq!(node.left.is_some(), true); + assert_eq!(node.right.is_none(), true); + } + + #[test] + fn test_grandchild_node_assignment() { + let mut node = Node::new( + Some("column".to_string()), + Some(0), + Some(0.0), + Some(1.0), + Some(true), + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), + 10, + "leaf node".to_string(), + 0, + ); + let child_node = Node::new( + Some("column".to_string()), + Some(0), + Some(0.0), + Some(1.0), + Some(true), + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), + 10, + "leaf node".to_string(), + 0, + ); + let grandchild_node = Node::new( + Some("column".to_string()), + Some(0), + Some(0.0), + Some(1.0), + Some(true), + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), + 10, + "leaf node".to_string(), + 0, + ); + node.left = Some(Box::new(child_node)); + node.left.as_mut().unwrap().left = Some(Box::new(grandchild_node)); + assert_eq!(node.left.is_some(), true); + assert_eq!(node.right.is_none(), true); + assert_eq!(node.left.as_ref().unwrap().left.is_some(), true); + assert_eq!(node.left.as_ref().unwrap().right.is_none(), true); + } + + #[test] + fn test_node_is_leaf() { + let node = Node { + column: Some("column".to_string()), + column_idx: Some(0), + threshold: Some(0.0), + prediction: Some(1.0), + default_is_left: Some(true), + left: None, + right: None, + measure: Some(SplitScore::new("score".to_string(), 0.5)), + n_obs: 10, + reason: "leaf node".to_string(), + depth: 1, + node_id: Uuid::new_v4(), + }; + assert_eq!(node.is_leaf(), true); + } + + // test calc_leaf_weight_and_split_score + #[test] + fn test_calc_leaf_weight_and_split_score() { + let y = Series::new("y", &[1, 1, 1]); + let growth_params = TreeGrowthParameters { + max_depth: Some(2), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let (leaf_weight, score) = + calc_leaf_weight_and_split_score(&y, &growth_params, None, None, None); + assert_eq!(leaf_weight, 1.0); + assert_eq!(score.name, "NegEntropy"); + assert_eq!(score.score, 0.0); + } + + #[test] + fn test_grow_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 2]); + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + + let tree = grow_tree(&df, &y, &growth_params, None, 0); + + assert!(tree.is_leaf() == false); + assert_eq!(tree.left.is_some(), true); + assert_eq!(tree.right.is_some(), true); + assert_eq!(tree.left.as_ref().unwrap().is_leaf(), true); + assert_eq!(tree.right.as_ref().unwrap().is_leaf(), true); + } + + #[test] + fn test_predict_for_row_with_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1]); + let growth_params = TreeGrowthParameters { + max_depth: Some(2), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + + let tree = grow_tree(&df, &y, &growth_params, None, 0); + + let row = df.select_at_idx(0).unwrap(); + let prediction = predict_for_row_with_tree(&row, &tree); + assert_eq!(prediction, 1.0); + } + + #[test] + fn test_predict_with_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3, 4]), + Series::new("b", &[1, 2, 3, 4]), + Series::new("c", &[1, 2, 3, 4]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1, 1]); + let growth_params = TreeGrowthParameters { + max_depth: Some(2), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let tree = grow_tree(&df, &y, &growth_params, None, 0); + + let predictions = predict_with_tree(df, tree); + assert_eq!( + predictions, + Series::new("predictions", &[1.0, 1.0, 1.0, 1.0]) + ); + } + + #[test] + fn test_decision_tree_core() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1]); + + let mut dtree = DecisionTreeCore::new(2, SplitScoreMetrics::NegEntropy); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_classifier() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[1, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (3, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_classifier_1d() { + let df = DataFrame::new(vec![Series::new("a", &[1, 2, 3])]).unwrap(); + let y = Series::new("y", &[0, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (3, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_classifier_2d_case1() { + // is given two columns but only needs one + let df = DataFrame::new(vec![ + Series::new("a", &[1, 1, 1]), + Series::new("b", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[0, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (3, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_classifier_2d_case2() { + // is given two columns and needs both + let df = DataFrame::new(vec![ + Series::new("a", &[-1, 1, -1, 1]), + Series::new("b", &[-1, -1, 1, 1]), + ]) + .unwrap(); + let y = Series::new("y", &[0, 1, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (4, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_regressor() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[1, 1, 1])); + } + + #[test] + fn test_decision_tree_regressor_1d() { + let df = DataFrame::new(vec![Series::new("a", &[1, 2, 3])]).unwrap(); + let y = Series::new("y", &[-1, 1, 1]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[-1, 1, 1])); + } + + #[test] + fn test_decision_tree_regressor_2d_case1() { + // is given two columns but only needs one + let df = DataFrame::new(vec![ + Series::new("a", &[1, 1, 1]), + Series::new("b", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[-1, 1, 1]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[-1, 1, 1])); + } + + #[test] + fn test_decision_tree_regressor_2d_case2() { + // is given two columns and needs both + let df = DataFrame::new(vec![ + Series::new("a", &[-1, 1, -1, 1]), + Series::new("b", &[-1, -1, 1, 1]), + ]) + .unwrap(); + let y = Series::new("y", &[0, 1, 1, 2]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1, 2])); + } +} diff --git a/src/lib.rs b/src/lib.rs index 756e91d..8da91e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,16 @@ +use polars::{frame::DataFrame, series::Series}; use pyo3::prelude::*; +mod decisiontree; mod scoring; +use pyo3_polars::{PyDataFrame, PySeries}; +mod utils; #[pymodule] #[pyo3(name = "_rust")] fn random_tree_models(py: Python<'_>, m: &PyModule) -> PyResult<()> { register_scoring_module(py, m)?; + m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -16,3 +22,76 @@ fn register_scoring_module(py: Python<'_>, parent_module: &PyModule) -> PyResult parent_module.add_submodule(child_module)?; Ok(()) } + +#[pyclass] +struct DecisionTreeClassifier { + max_depth: usize, + tree_: Option, +} + +#[pymethods] +impl DecisionTreeClassifier { + #[new] + fn new(max_depth: usize) -> Self { + DecisionTreeClassifier { + max_depth, + tree_: None, + } + } + + fn fit(&mut self, x: PyDataFrame, y: PySeries) -> PyResult<()> { + let mut tree = decisiontree::DecisionTreeClassifier::new(self.max_depth); + let x: DataFrame = x.into(); + let y: Series = y.into(); + tree.fit(&x, &y); + self.tree_ = Some(tree); + Ok(()) + } + + fn predict(&self, x: PyDataFrame) -> PyResult { + let x: DataFrame = x.into(); + let y_pred = self.tree_.as_ref().unwrap().predict(&x); + + Ok(PySeries(y_pred)) + } + + fn predict_proba(&self, x: PyDataFrame) -> PyResult { + let x: DataFrame = x.into(); + let y_pred = self.tree_.as_ref().unwrap().predict_proba(&x); + + Ok(PyDataFrame(y_pred)) + } +} + +#[pyclass] +struct DecisionTreeRegressor { + max_depth: usize, + tree_: Option, +} + +#[pymethods] +impl DecisionTreeRegressor { + #[new] + fn new(max_depth: usize) -> Self { + DecisionTreeRegressor { + max_depth, + tree_: None, + } + } + + fn fit(&mut self, x: PyDataFrame, y: PySeries) -> PyResult<()> { + let mut tree = decisiontree::DecisionTreeRegressor::new(self.max_depth); + let x: DataFrame = x.into(); + let y: Series = y.into(); + tree.fit(&x, &y); + self.tree_ = Some(tree); + Ok(()) + } + + fn predict(&self, x: PyDataFrame) -> PyResult { + let x: DataFrame = x.into(); + let y_pred = self.tree_.as_ref().unwrap().predict(&x); + + Ok(PySeries(y_pred)) + } +} diff --git a/src/scoring.rs b/src/scoring.rs index dae0432..172e0ac 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; +use polars::prelude::*; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use crate::utils::{SplitScoreMetrics, TreeGrowthParameters}; + // compute gini impurity of an array of discrete values #[pyfunction(name = "gini_impurity")] pub fn gini_impurity_py(values: Vec) -> PyResult { @@ -57,25 +60,391 @@ fn entropy(values: Vec) -> f64 { entropy } +pub fn count_y_values(y: &Series) -> Series { + let df = y.value_counts(false, false).unwrap(); + let counts: Series = df.select_at_idx(1).unwrap().clone(); + counts +} + +pub fn calc_probabilities(y: &Series) -> Series { + let msg = "Could not cast to f64"; + let counts = count_y_values(y); + let counts = counts.cast(&DataType::Float64).expect(msg); + let total: f64 = counts.sum().unwrap(); + let ps = Series::new("probs", counts / total); + ps +} + +pub fn calc_neg_entropy_series(ps: &Series) -> f64 { + let neg_entropy = ps + .f64() + .expect("not f64 dtype") + .into_iter() + .map(|x| x.unwrap() * x.unwrap().log2()) + .sum(); + neg_entropy +} + +pub fn neg_entropy_rs(y: &Series, target_groups: &Series) -> f64 { + let msg = "Could not cast to f64"; + let w_left: f64 = (*target_groups) + .cast(&polars::datatypes::DataType::Float64) + .expect(msg) + .sum::() + .unwrap() + / y.len() as f64; + let w_right: f64 = 1.0 - w_left; + + // generate boolean chunked array of target_groups + let trues = Series::new("", vec![true; target_groups.len()]); + let target_groups = target_groups.equal(&trues).unwrap(); + + let entropy_left: f64; + let entropy_right: f64; + if w_left > 0. { + let y_left = y.filter(&target_groups).unwrap(); + let probs = calc_probabilities(&y_left); + entropy_left = calc_neg_entropy_series(&probs); + } else { + entropy_left = 0.0; + } + if w_right > 0. { + let y_right = y.filter(&!target_groups).unwrap(); + let probs = calc_probabilities(&y_right); + entropy_right = calc_neg_entropy_series(&probs); + } else { + entropy_right = 0.0; + } + let score = (w_left * entropy_left) + (w_right * entropy_right); + score +} + +pub fn series_neg_variance(y: Series) -> f64 { + if y.len() == 1 { + return 0.; + } + + let var_series = y.var_as_series(1).unwrap(); + + let variance = var_series + .f64() + .expect("not float64") + .get(0) + .expect("was null"); + + -variance +} + +pub fn neg_variance_rs(y: &Series, target_groups: &Series) -> f64 { + let msg = "Could not cast to f64"; + let w_left: f64 = (*target_groups) + .cast(&polars::datatypes::DataType::Float64) + .expect(msg) + .sum::() + .unwrap() + / y.len() as f64; + let w_right: f64 = 1.0 - w_left; + + // generate boolean chunked array of target_groups + let trues = Series::new("", vec![true; target_groups.len()]); + let target_groups = target_groups.equal(&trues).unwrap(); + + let variance_left: f64; + let variance_right: f64; + if w_left == 1. || w_right == 1. { + return series_neg_variance(y.clone()); + } + if w_left > 0. { + let y_left = y.filter(&target_groups).unwrap(); + variance_left = series_neg_variance(y_left); + } else { + variance_left = 0.0; + } + if w_right > 0. { + let y_right = y.filter(&!target_groups).unwrap(); + variance_right = series_neg_variance(y_right); + } else { + variance_right = 0.0; + } + let score = (w_left * variance_left) + (w_right * variance_right); + score +} + +pub fn calc_score( + y: &Series, + target_groups: &Series, + growth_params: &TreeGrowthParameters, + _g: Option<&Series>, + _h: Option<&Series>, + _incrementing_score: Option, +) -> f64 { + match growth_params.split_score_metric { + Some(SplitScoreMetrics::NegEntropy) => neg_entropy_rs(y, target_groups), + Some(SplitScoreMetrics::NegVariance) => neg_variance_rs(y, target_groups), + _ => panic!( + "split_score_metric {:?} not supported", + growth_params.split_score_metric + ), + } +} + +#[cfg(test)] mod tests { + + use super::*; // test that gini impurity correctly computes values smaller than zero for a couple of vectors + #[test] fn test_gini_impurity() { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; - assert_eq!(super::gini_impurity(values), 0.0); + assert_eq!(gini_impurity(values), 0.0); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; - assert_eq!(super::gini_impurity(values), -0.5); + assert_eq!(gini_impurity(values), -0.5); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - assert_eq!(super::gini_impurity(values), -0.875); + assert_eq!(gini_impurity(values), -0.875); } // test that entropy correctly computes values smaller than zero for a couple of vectors #[test] fn test_entropy() { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; - assert_eq!(super::entropy(values), 0.0); + assert_eq!(entropy(values), 0.0); + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + assert_eq!(entropy(values), -1.0); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + assert_eq!(entropy(values), -3.0); + } + // test count_y_values + #[test] + fn test_count_y_values() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let counts = count_y_values(&s); + // assert that counts is a series with one value + let exp: Vec = vec![8]; + assert_eq!(counts, Series::new("count", exp)); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let counts = count_y_values(&s); + let exp: Vec = vec![4, 4]; + assert_eq!(counts, Series::new("count", exp)); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let counts = count_y_values(&s); + let exp: Vec = vec![1, 1, 1, 1, 1, 1, 1, 1]; + assert_eq!(counts, Series::new("count", exp)); + } + + // test calc_probabilities + #[test] + fn test_calc_probabilities() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + // assert that counts is a series with one value + let exp: Vec = vec![1.0]; + assert_eq!(probs, Series::new("probs", exp)); + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; - assert_eq!(super::entropy(values), -1.0); + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let exp: Vec = vec![0.5, 0.5]; + assert_eq!(probs, Series::new("probs", exp)); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - assert_eq!(super::entropy(values), -3.0); + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let exp: Vec = vec![0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]; + assert_eq!(probs, Series::new("probs", exp)); + } + + // test calc_neg_entropy_series + #[test] + fn test_calc_neg_entropy_series() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let neg_entropy = calc_neg_entropy_series(&probs); + assert_eq!(neg_entropy, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let neg_entropy = calc_neg_entropy_series(&probs); + assert_eq!(neg_entropy, -1.0); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let neg_entropy = calc_neg_entropy_series(&probs); + assert_eq!(neg_entropy, -3.0); + } + + // test calc_score + #[test] + fn test_calc_score() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, -1.0); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, -3.0); + + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), + }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, -1.0); + } + + // test entropy_rs + #[test] + fn test_entropy_rs() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_entropy_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_entropy_rs(&s, &target_groups); + assert_eq!(score, -1.0); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_entropy_rs(&s, &target_groups); + assert_eq!(score, -3.0); + + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_entropy_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_entropy_rs(&s, &target_groups); + assert_eq!(score, -1.0); + } + + #[test] + fn test_series_neg_variance_rs() { + let s = Series::new("y", vec![1]); + let score = series_neg_variance(s); + assert_eq!(score, 0.0); + + let s = Series::new("y", vec![0, 0, 0, 0, 0, 0, 0, 0]); + let score = series_neg_variance(s); + assert_eq!(score, 0.0); + + let s = Series::new("y", vec![0, 1, 0, 1, 0, 1, 0, 1]); + let score = series_neg_variance(s); + assert_eq!(score, -0.2857142857142857); + + let s = Series::new("y", vec![0, 1, 2, 3, 4, 5, 6, 7]); + let score = series_neg_variance(s); + assert_eq!(score, -6.); + } + + #[test] + fn test_neg_variance_rs() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -0.2857142857142857); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -6.); + + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -0.3333333333333333); + + let values = vec![0., 1., 0., 1., 0., 1., 0., 1.]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -0.3333333333333333); } } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..d3907ef --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,20 @@ +use std::fmt; + +#[derive(Debug)] +pub enum SplitScoreMetrics { + NegVariance, + NegEntropy, +} + +impl fmt::Display for SplitScoreMetrics { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + // or, alternatively: + // fmt::Debug::fmt(self, f) + } +} + +pub struct TreeGrowthParameters { + pub max_depth: Option, + pub split_score_metric: Option, +} diff --git a/tests/test_decisiontree.py b/tests/test_decisiontree.py index 5e569e6..aeb47da 100644 --- a/tests/test_decisiontree.py +++ b/tests/test_decisiontree.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import types from unittest.mock import patch @@ -149,8 +150,6 @@ def test_Node(int_val, float_val, node_val, str_val, bool_val): ], ) def test_check_is_baselevel(y, depths): - node = dtree.Node() - y, is_baselevel_exp_y = y depth, max_depth, is_baselevel_exp_depth = depths is_baselevel_exp = is_baselevel_exp_depth or is_baselevel_exp_y @@ -586,13 +585,15 @@ def test_1d( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - grow_params = utils.TreeGrowthParameters(max_depth=2) + grow_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( self.X_1D, y, - measure_name=measure_name, g=g, h=h, growth_params=grow_params, @@ -638,13 +639,15 @@ def test_1d_missing( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - grow_params = utils.TreeGrowthParameters(max_depth=2) + grow_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( self.X_1D_missing, y, - measure_name=measure_name, g=g, h=h, growth_params=grow_params, @@ -690,16 +693,18 @@ def test_2d( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - growth_params = utils.TreeGrowthParameters(max_depth=2) + growth_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( - self.X_2D, - y, - measure_name, + X=self.X_2D, + y=y, + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, ) except ValueError as ex: if is_homogenous: @@ -743,16 +748,18 @@ def test_2d_missing( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - growth_params = utils.TreeGrowthParameters(max_depth=2) + growth_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( - self.X_2D_missing, - y, - measure_name, + X=self.X_2D_missing, + y=y, + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, ) except ValueError as ex: if is_homogenous: @@ -855,35 +862,54 @@ def test_check_if_split_sensible( assert gain is None -def test_calc_leaf_weight_and_split_score(): - # calls leafweights.calc_leaf_weight and scoreing.SplitScoreMetrics - # and returns two floats - y = np.array([True, True, False]) - measure_name = "gini" - growth_params = utils.TreeGrowthParameters(max_depth=2) - g = np.array([1, 2, 3]) - h = np.array([4, 5, 6]) - leaf_weight_exp = 1.0 - score_exp = 42.0 - with ( - patch( - "random_tree_models.decisiontree.leafweights.calc_leaf_weight", - return_value=leaf_weight_exp, - ) as mock_calc_leaf_weight, - patch( - "random_tree_models.decisiontree.scoring.SplitScoreMetrics.__call__", - return_value=score_exp, - ) as mock_SplitScoreMetrics, - ): - # line to test - leaf_weight, split_score = dtree.calc_leaf_weight_and_split_score( - y, measure_name, growth_params, g, h - ) - - assert leaf_weight == leaf_weight_exp - assert split_score == score_exp - assert mock_calc_leaf_weight.call_count == 1 - assert mock_SplitScoreMetrics.call_count == 1 +# write tests for calc_leaf_weight_and_split_score +@pytest.mark.parametrize( + "y,growth_params,g,h", + [ + (y, growth_params, g, h) + for y in [ + np.array([True, True, False]), + np.array([True, True, True]), + np.array([False, False, False]), + ] + for growth_params in [ + utils.TreeGrowthParameters( + max_depth=2, split_score_metric=utils.SplitScoreMetrics.gini + ), + utils.TreeGrowthParameters( + max_depth=2, split_score_metric=utils.SplitScoreMetrics.entropy + ), + utils.TreeGrowthParameters( + max_depth=2, split_score_metric=utils.SplitScoreMetrics.variance + ), + utils.TreeGrowthParameters( + max_depth=2, + min_improvement=0.2, + split_score_metric=utils.SplitScoreMetrics.gini, + ), + utils.TreeGrowthParameters( + max_depth=2, + min_improvement=0.2, + split_score_metric=utils.SplitScoreMetrics.entropy, + ), + utils.TreeGrowthParameters( + max_depth=2, + min_improvement=0.2, + split_score_metric=utils.SplitScoreMetrics.variance, + ), + ] + for g in [np.array([1, 2, 3])] + for h in [np.array([4, 5, 6])] + ], +) +def test_calc_leaf_weight_and_split_score( + y: np.ndarray, + growth_params: utils.TreeGrowthParameters, + g: np.ndarray, + h: np.ndarray, +): + # line to test + _, _ = dtree.calc_leaf_weight_and_split_score(y, growth_params, g, h) @pytest.mark.parametrize("go_left", [True, False]) @@ -925,12 +951,14 @@ class Test_grow_tree: X = np.array([[1], [2], [3]]) y = np.array([True, True, False]) target_groups = np.array([True, True, False]) - measure_name = "gini" + measure_name = utils.SplitScoreMetrics["gini"] depth_dummy = 0 def test_baselevel(self): # test returned leaf node - growth_params = utils.TreeGrowthParameters(max_depth=2) + growth_params = utils.TreeGrowthParameters( + max_depth=2, split_score_metric=self.measure_name + ) parent_node = None is_baselevel = True reason = "very custom leaf node comment" @@ -942,20 +970,21 @@ def test_baselevel(self): leaf_node = dtree.grow_tree( self.X, self.y, - self.measure_name, + growth_params=growth_params, parent_node=parent_node, depth=self.depth_dummy, - growth_params=growth_params, ) mock_check_is_baselevel.assert_called_once() - assert leaf_node.is_leaf == True + assert leaf_node.is_leaf is True assert leaf_node.reason == reason def test_split_improvement_insufficient(self): # test split improvement below minimum growth_params = utils.TreeGrowthParameters( - max_depth=2, min_improvement=0.2 + max_depth=2, + min_improvement=0.2, + split_score_metric=self.measure_name, ) parent_score = -1.0 new_score = -0.9 @@ -965,7 +994,7 @@ def test_split_improvement_insufficient(self): threshold=3.0, target_groups=self.target_groups, ) - measure = dtree.SplitScore(self.measure_name, parent_score) + measure = dtree.SplitScore(self.measure_name.name, parent_score) parent_node = dtree.Node( array_column=0, threshold=1.0, @@ -994,10 +1023,9 @@ def test_split_improvement_insufficient(self): node = dtree.grow_tree( self.X, self.y, - self.measure_name, + growth_params=growth_params, parent_node=parent_node, depth=self.depth_dummy, - growth_params=growth_params, ) mock_check_is_baselevel.assert_called_once() @@ -1009,7 +1037,9 @@ def test_split_improvement_insufficient(self): def test_split_improvement_sufficient(self): # test split improvement above minumum, leading to two leaf nodes growth_params = utils.TreeGrowthParameters( - max_depth=2, min_improvement=0.0 + max_depth=2, + split_score_metric=self.measure_name, + min_improvement=0.0, ) parent_score = -1.0 new_score = -0.9 @@ -1019,7 +1049,7 @@ def test_split_improvement_sufficient(self): threshold=3.0, target_groups=self.target_groups, ) - measure = dtree.SplitScore(self.measure_name, parent_score) + measure = dtree.SplitScore(self.measure_name.name, parent_score) parent_node = dtree.Node( array_column=0, threshold=1.0, @@ -1051,10 +1081,9 @@ def test_split_improvement_sufficient(self): tree = dtree.grow_tree( self.X, self.y, - self.measure_name, + growth_params=growth_params, parent_node=parent_node, depth=self.depth_dummy, - growth_params=growth_params, ) assert mock_check_is_baselevel.call_count == 3 @@ -1064,19 +1093,19 @@ def test_split_improvement_sufficient(self): assert tree.reason == "" assert tree.prediction == np.mean(self.y) assert tree.n_obs == len(self.y) - assert tree.is_leaf == False + assert tree.is_leaf is False # left leaf assert tree.left.reason == leaf_reason assert tree.left.prediction == 1.0 assert tree.left.n_obs == 2 - assert tree.left.is_leaf == True + assert tree.left.is_leaf is True # right leaf assert tree.right.reason == leaf_reason assert tree.right.prediction == 0.0 assert tree.right.n_obs == 1 - assert tree.right.is_leaf == True + assert tree.right.is_leaf is True @pytest.mark.parametrize( @@ -1143,8 +1172,18 @@ def test_predict_with_tree(): assert np.allclose(predictions, np.arange(0, 4, 1)) +class DecisionTreeTemplateTestClass(dtree.DecisionTreeTemplate): + "Class to test abstract class DecisionTreeTemplate" + + def fit(self): + pass + + def predict(self): + pass + + class TestDecisionTreeTemplate: - model = dtree.DecisionTreeTemplate() + model = DecisionTreeTemplateTestClass(measure_name="gini") X = np.random.normal(size=(100, 10)) y = np.random.normal(size=(100,)) @@ -1157,18 +1196,6 @@ def test_growth_params_(self): self.model._organize_growth_parameters() assert isinstance(self.model.growth_params_, utils.TreeGrowthParameters) - def test_fit(self): - try: - self.model.fit(None, None) - except NotImplementedError as ex: - pytest.xfail("DecisionTreeTemplate.fit expectedly refused call") - - def test_predict(self): - try: - self.model.predict(None) - except NotImplementedError as ex: - pytest.xfail("DecisionTreeTemplate.predict expectedly refused call") - def test_select_samples_and_features_no_sampling(self): self.model.frac_features = 1.0 self.model.frac_samples = 1.0 @@ -1325,7 +1352,10 @@ def test_predict(self): @pytest.mark.slow @parametrize_with_checks( - [dtree.DecisionTreeRegressor(), dtree.DecisionTreeClassifier()] + [ + dtree.DecisionTreeRegressor(measure_name="variance"), + dtree.DecisionTreeClassifier(measure_name="gini"), + ] ) def test_dtree_estimators_with_sklearn_checks(estimator, check): """Test of estimators using scikit-learn test suite diff --git a/tests/test_isolationforest.py b/tests/test_isolationforest.py index 10bd305..7ec43b1 100644 --- a/tests/test_isolationforest.py +++ b/tests/test_isolationforest.py @@ -1,8 +1,6 @@ +# -*- coding: utf-8 -*- import numpy as np -import pytest -from sklearn.utils.estimator_checks import parametrize_with_checks -import random_tree_models.decisiontree as dtree import random_tree_models.isolationforest as iforest rng = np.random.RandomState(42) @@ -17,6 +15,7 @@ def test_fit(self): model.fit(self.X_inlier) assert hasattr(model, "tree_") assert hasattr(model, "growth_params_") + assert model.incrementing_score_.score > 0 def test_predict(self): model = iforest.IsolationTree() diff --git a/tests/test_leafweights.py b/tests/test_leafweights.py index 26dbd1b..e0ee3a8 100644 --- a/tests/test_leafweights.py +++ b/tests/test_leafweights.py @@ -1,5 +1,7 @@ +# -*- coding: utf-8 -*- import numpy as np import pytest +import pydantic import random_tree_models.leafweights as leafweights import random_tree_models.utils as utils @@ -7,90 +9,48 @@ def test_leaf_weight_mean(): y = np.array([1, 2, 3]) - g = np.array([1, 2, 3]) * 2 - assert leafweights.leaf_weight_mean(y=y, g=g) == 2.0 + assert leafweights.leaf_weight_mean(y=y) == 2.0 def test_leaf_weight_binary_classification_friedman2001(): - y = np.array([1, 2, 3]) g = np.array([1, 2, 3]) * 2 assert ( - leafweights.leaf_weight_binary_classification_friedman2001(y=y, g=g) + leafweights.leaf_weight_binary_classification_friedman2001(g=g) == -0.375 ) def test_leaf_weight_xgboost(): - y = np.array([1, 2, 3]) g = np.array([1, 2, 3]) * 2 h = np.array([1, 2, 3]) * 4 params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) assert ( - leafweights.leaf_weight_xgboost(y=y, g=g, h=h, growth_params=params) - == -0.5 + leafweights.leaf_weight_xgboost(growth_params=params, g=g, h=h) == -0.5 ) -class TestLeafWeightSchemes: - def test_leaf_weight_mean_references(self): - mean_schemes = [ - "variance", - "entropy", - "entropy_rs", - "gini", - "gini_rs", - "incrementing", - ] - - for scheme in mean_schemes: - assert ( - leafweights.LeafWeightSchemes[scheme].value.func - is leafweights.leaf_weight_mean - ) - - def test_leaf_weight_xgboost_references(self): - assert ( - leafweights.LeafWeightSchemes["xgboost"].value.func - is leafweights.leaf_weight_xgboost - ) - - def test_leaf_weight_friedman_references(self): - assert ( - leafweights.LeafWeightSchemes[ - "friedman_binary_classification" - ].value.func - is leafweights.leaf_weight_binary_classification_friedman2001 - ) - - class Test_calc_leaf_weight: def test_error_for_unknown_scheme(self): - y = np.array([1, 2, 3]) - growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) try: - leafweights.calc_leaf_weight( - y=y, growth_params=growth_params, measure_name="not_a_scheme" + _ = utils.TreeGrowthParameters( + max_depth=2, split_score_metric="not_a_scheme", lam=0.0 ) - except KeyError as ex: + except pydantic.ValidationError: pytest.xfail("ValueError correctly raised for unknown scheme") else: pytest.fail("ValueError not raised for unknown scheme") - def test_leaf_weight_none_if_y_empty(self): - y = np.array([]) - growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) - - weight = leafweights.calc_leaf_weight( - y=y, growth_params=growth_params, measure_name="not_a_scheme" - ) - assert weight is None - # returns a float if y is not empty def test_leaf_weight_float_if_y_not_empty(self): y = np.array([1, 2, 3]) - growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) + growth_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics["variance"], + lam=0.0, + ) weight = leafweights.calc_leaf_weight( - y=y, growth_params=growth_params, measure_name="variance" + y=y, + growth_params=growth_params, ) assert isinstance(weight, float) diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 4028281..8e7f781 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import numpy as np import pytest @@ -359,11 +360,11 @@ def test_calc_xgboost_split_score( g: np.ndarray, h: np.ndarray, target_groups: np.ndarray, score_exp: float ): growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) - y = None + try: # line to test score = scoring.calc_xgboost_split_score( - y, target_groups, g, h, growth_params + target_groups, g, h, growth_params ) except ValueError as ex: if score_exp is None: @@ -380,6 +381,7 @@ def test_calc_xgboost_split_score( class TestSplitScoreMetrics: "Redudancy test - calling calc_xgboost_split_score etc via SplitScoreMetrics needs to yield the same values as in the test above." + y = np.array([1, 1, 2, 2]) target_groups = np.array([False, True, False, True]) @@ -388,29 +390,39 @@ class TestSplitScoreMetrics: var_exp = -0.25 def test_gini(self): - g = scoring.SplitScoreMetrics["gini"](self.y, self.target_groups) + measure = utils.SplitScoreMetrics["gini"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + g = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert g == self.g_exp def test_gini_rs(self): - g = scoring.SplitScoreMetrics["gini_rs"](self.y, self.target_groups) + measure = utils.SplitScoreMetrics["gini_rs"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + g = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert g == self.g_exp def test_entropy(self): - h = scoring.SplitScoreMetrics["entropy"](self.y, self.target_groups) + measure = utils.SplitScoreMetrics["entropy"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + h = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert h == self.h_exp - def test_entropy(self): - h = scoring.SplitScoreMetrics["entropy_rs"](self.y, self.target_groups) + def test_entropy_rs(self): + measure = utils.SplitScoreMetrics["entropy_rs"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + h = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert h == self.h_exp def test_variance(self): - var = scoring.SplitScoreMetrics["variance"](self.y, self.target_groups) + measure = utils.SplitScoreMetrics["variance"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + var = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert var == self.var_exp def test_friedman_binary_classification(self): - var = scoring.SplitScoreMetrics["friedman_binary_classification"]( - self.y, self.target_groups - ) + measure = utils.SplitScoreMetrics["friedman_binary_classification"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + var = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert var == self.var_exp @pytest.mark.parametrize( @@ -452,11 +464,37 @@ def test_xgboost( score_exp: float, ): growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) - y = None # line to test score = scoring.calc_xgboost_split_score( - y, target_groups, g, h, growth_params + target_groups, g, h, growth_params ) assert score == score_exp + + def test_incrementing(self): + incrementing_score = scoring.IncrementingScore() + score_metric = utils.SplitScoreMetrics["incrementing"] + gp = utils.TreeGrowthParameters(1, split_score_metric=score_metric) + + # line to test + score = scoring.calc_score( + self.y, + self.target_groups, + growth_params=gp, + incrementing_score=incrementing_score, + ) + + assert score == 1 + assert incrementing_score.score == 1 + + # line to test + score = scoring.calc_score( + self.y, + self.target_groups, + growth_params=gp, + incrementing_score=incrementing_score, + ) + + assert score == 2 + assert incrementing_score.score == 2