From c44959e7a9f7e826787ffbeca32eac20075490b9 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 15:44:03 +0100 Subject: [PATCH 01/28] chore: updated dependencies --- config/requirements.txt | 5 +++++ pyproject.toml | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/config/requirements.txt b/config/requirements.txt index ffb0ac6..576fb2b 100644 --- a/config/requirements.txt +++ b/config/requirements.txt @@ -234,6 +234,7 @@ numpy==1.24.3 # contourpy # matplotlib # pandas + # pyarrow # scikit-learn # scipy # seaborn @@ -272,6 +273,8 @@ platformdirs==3.5.1 # virtualenv pluggy==1.0.0 # via pytest +polars==0.20.0 + # via random-tree-models (pyproject.toml) pre-commit==3.3.2 # via random-tree-models (pyproject.toml) prometheus-client==0.17.0 @@ -289,6 +292,8 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data +pyarrow==14.0.2 + # via random-tree-models (pyproject.toml) pycparser==2.21 # via cffi pydantic==1.10.8 diff --git a/pyproject.toml b/pyproject.toml index e6ade6a..70251eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,9 @@ dependencies = [ "pandas", "pydantic", "snakeviz", - "maturin" + "maturin", + "polars", + "pyarrow", ] [tool.maturin] From 79191801f55f5869c7accbbc81666b2ebce10bb8 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 17:56:39 +0100 Subject: [PATCH 02/28] chore: replaced enum + callable values with enum + match logic, more verbose but easier to understand --- .pre-commit-config.yaml | 30 ++++---- Cargo.lock | 45 ++++++------ Cargo.toml | 2 +- pyproject.toml | 63 +++++++++++++++-- random_tree_models/decisiontree.py | 40 +++++++---- random_tree_models/isolationforest.py | 2 +- random_tree_models/leafweights.py | 49 ++++++-------- random_tree_models/scoring.py | 70 ++++++++++++------- tests/test_decisiontree.py | 98 ++++++++++++++------------- tests/test_leafweights.py | 39 ++--------- tests/test_scoring.py | 25 ++++--- 11 files changed, 261 insertions(+), 202 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3491c8a..daf9a50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,31 +5,25 @@ repos: rev: v4.4.0 hooks: - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 23.1.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.9 hooks: - - id: black - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - - repo: https://github.com/nbQA-dev/nbQA - rev: 1.6.1 - hooks: - - id: nbqa-black - - id: nbqa-isort - args: ["--float-to-top"] + # Run the linter. + - id: ruff + types_or: [ python, pyi, jupyter ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: - id: nbstripout - repo: local hooks: - - id: unittest - name: unittest - entry: python3 -m pytest -vx -m "not slow" + - id: pytest + name: pytest + entry: python3 -m pytest -x -m "not slow" pass_filenames: false language: system types: [python] diff --git a/Cargo.lock b/Cargo.lock index 32241ac..a4e82fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,11 +20,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "indoc" -version = "1.0.9" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" [[package]] name = "libc" @@ -44,9 +50,9 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] @@ -91,9 +97,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.18.3" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" dependencies = [ "cfg-if", "indoc", @@ -108,9 +114,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.18.3" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" dependencies = [ "once_cell", "target-lexicon", @@ -118,9 +124,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.18.3" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" dependencies = [ "libc", "pyo3-build-config", @@ -128,9 +134,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.18.3" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -140,10 +146,11 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.18.3" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" dependencies = [ + "heck", "proc-macro2", "quote", "syn", @@ -160,7 +167,7 @@ dependencies = [ [[package]] name = "random-tree-models" -version = "0.3.1" +version = "0.6.2" dependencies = [ "pyo3", ] @@ -188,9 +195,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "syn" -version = "1.0.109" +version = "2.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +checksum = "fcb8d4cebc40aa517dfb69618fa647a346562e67228e2236ae0042ee6ac14775" dependencies = [ "proc-macro2", "quote", @@ -211,9 +218,9 @@ checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" [[package]] name = "unindent" -version = "0.1.11" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "windows-targets" diff --git a/Cargo.toml b/Cargo.toml index 90d0601..d209f42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ name = "random_tree_models" crate-type = ["cdylib"] [dependencies] -pyo3 = "0.18.3" +pyo3 = "0.20.0" [features] extension-module = ["pyo3/extension-module"] diff --git a/pyproject.toml b/pyproject.toml index 70251eb..1218fe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,14 +40,65 @@ bindings = "pyo3" features = ["pyo3/extension-module"] module-name = "random_tree_models._rust" -[tool.black] +# [tool.black] +# line-length = 80 + +# [tool.isort] +# multi_line_output = 3 +# line_length = 80 +# include_trailing_comma = true +# profile = "black" + +[tool.ruff] +# https://docs.astral.sh/ruff/configuration/ line-length = 80 +indent-width = 4 +target-version = "py310" +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "src", +] +[tool.ruff.lint] +fixable = ["ALL"] +unfixable = [] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false -[tool.isort] -multi_line_output = 3 -line_length = 80 -include_trailing_comma = true -profile = "black" +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" # [tool.setuptools.packages.find] # where = ["."] # list of folders that contain the packages (["."] by default) diff --git a/random_tree_models/decisiontree.py b/random_tree_models/decisiontree.py index 8b0ea1d..c688805 100644 --- a/random_tree_models/decisiontree.py +++ b/random_tree_models/decisiontree.py @@ -199,7 +199,7 @@ def get_column( def find_best_split( X: np.ndarray, y: np.ndarray, - measure_name: str, + measure_name: scoring.SplitScoreMetrics, yhat: np.ndarray = None, g: np.ndarray = None, h: np.ndarray = None, @@ -225,13 +225,13 @@ def find_best_split( ) in get_thresholds_and_target_groups( feature_values, growth_params.threshold_params, rng ): - split_score = scoring.SplitScoreMetrics[measure_name]( + split_score = scoring.calc_score( y, target_groups, - yhat=yhat, g=g, h=h, growth_params=growth_params, + score_metric=measure_name, ) if best is None or split_score > best.score: @@ -271,7 +271,7 @@ def check_if_split_sensible( def calc_leaf_weight_and_split_score( y: np.ndarray, - measure_name: str, + measure_name: scoring.SplitScoreMetrics, growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, @@ -280,14 +280,14 @@ def calc_leaf_weight_and_split_score( y, measure_name, growth_params, g=g, h=h ) - yhat = leaf_weight * np.ones_like(y) - score = scoring.SplitScoreMetrics[measure_name]( + # yhat = leaf_weight * np.ones_like(y) + score = scoring.calc_score( y, np.ones_like(y, dtype=bool), - yhat=yhat, g=g, h=h, growth_params=growth_params, + score_metric=measure_name, ) return leaf_weight, score @@ -312,7 +312,7 @@ def select_arrays_for_child_node( def grow_tree( X: np.ndarray, y: np.ndarray, - measure_name: str, + measure_name: scoring.SplitScoreMetrics, parent_node: Node = None, depth: int = 0, growth_params: utils.TreeGrowthParameters = None, @@ -367,7 +367,7 @@ def grow_tree( if is_baselevel: # end of the line buddy return Node( prediction=leaf_weight, - measure=SplitScore(measure_name, score=score), + measure=SplitScore(measure_name.name, score=score), n_obs=n_obs, reason=reason, depth=depth, @@ -389,7 +389,7 @@ def grow_tree( reason = f"gain due split ({gain=}) lower than {growth_params.min_improvement=} or all data points assigned to one side (is left {best.target_groups.mean()=:.2%})" leaf_node = Node( prediction=leaf_weight, - measure=SplitScore(measure_name, score=score), + measure=SplitScore(measure_name.name, score=score), n_obs=n_obs, reason=reason, depth=depth, @@ -402,7 +402,7 @@ def grow_tree( threshold=best.threshold, prediction=leaf_weight, default_is_left=best.default_is_left, - measure=SplitScore(measure_name, best.score), + measure=SplitScore(measure_name.name, best.score), n_obs=n_obs, reason="", depth=depth, @@ -514,7 +514,19 @@ def __init__( self.column_method = column_method self.n_columns_to_try = n_columns_to_try + def _sanity_check_measure_name(self): + try: + self.measure_name_enum_ = scoring.SplitScoreMetrics[ + self.measure_name + ] + except KeyError as ex: + raise KeyError( + f"Unknown measure_name: {self.measure_name}. " + f"Valid options: {', '.join(list(scoring.SplitScoreMetrics.__members__.keys()))}. {ex=}" + ) + def _organize_growth_parameters(self): + self._sanity_check_measure_name() self.growth_params_ = utils.TreeGrowthParameters( max_depth=self.max_depth, min_improvement=self.min_improvement, @@ -539,7 +551,7 @@ def _select_samples_and_features( ) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: "Sub-samples rows and columns from X and y" if not hasattr(self, "growth_params_"): - raise ValueError(f"Try calling `fit` first.") + raise ValueError("Try calling `fit` first.") ix = np.arange(len(X)) rng = np.random.RandomState(self.growth_params_.random_state) @@ -611,7 +623,7 @@ def fit( self.tree_ = grow_tree( _X, _y, - measure_name=self.measure_name, + measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, **kwargs, @@ -675,7 +687,7 @@ def fit( self.tree_ = grow_tree( _X, _y, - measure_name=self.measure_name, + measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, ) diff --git a/random_tree_models/isolationforest.py b/random_tree_models/isolationforest.py index 577aeac..dcd26c8 100644 --- a/random_tree_models/isolationforest.py +++ b/random_tree_models/isolationforest.py @@ -64,7 +64,7 @@ def fit( self.tree_ = dtree.grow_tree( _X, _y, - measure_name=self.measure_name, + measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, **kwargs, diff --git a/random_tree_models/leafweights.py b/random_tree_models/leafweights.py index 351a3b0..0fffd7d 100644 --- a/random_tree_models/leafweights.py +++ b/random_tree_models/leafweights.py @@ -1,8 +1,6 @@ -from enum import Enum -from functools import partial - import numpy as np +import random_tree_models.scoring as scoring import random_tree_models.utils as utils @@ -33,32 +31,9 @@ def leaf_weight_xgboost( return w -class LeafWeightSchemes(Enum): - # https://stackoverflow.com/questions/40338652/how-to-define-enum-values-that-are-functions - friedman_binary_classification = partial( - leaf_weight_binary_classification_friedman2001 - ) - variance = partial(leaf_weight_mean) - entropy = partial(leaf_weight_mean) - entropy_rs = partial(leaf_weight_mean) - gini = partial(leaf_weight_mean) - gini_rs = partial(leaf_weight_mean) - xgboost = partial(leaf_weight_xgboost) - incrementing = partial(leaf_weight_mean) - - def __call__( - self, - y: np.ndarray, - growth_params: utils.TreeGrowthParameters, - g: np.ndarray = None, - h: np.ndarray = None, - ) -> float: - return self.value(y=y, growth_params=growth_params, g=g, h=h) - - def calc_leaf_weight( y: np.ndarray, - measure_name: str, + measure_name: scoring.SplitScoreMetrics, growth_params: utils.TreeGrowthParameters, g: np.ndarray = None, h: np.ndarray = None, @@ -71,7 +46,23 @@ def calc_leaf_weight( if len(y) == 0: return None - weight_func = LeafWeightSchemes[measure_name] - leaf_weight = weight_func(y=y, growth_params=growth_params, g=g, h=h) + match measure_name: + case ( + scoring.SplitScoreMetrics.variance + | scoring.SplitScoreMetrics.entropy + | scoring.SplitScoreMetrics.entropy_rs + | scoring.SplitScoreMetrics.gini + | scoring.SplitScoreMetrics.gini_rs + | scoring.SplitScoreMetrics.incrementing + ): + leaf_weight = leaf_weight_mean(y) + case scoring.SplitScoreMetrics.friedman_binary_classification: + leaf_weight = leaf_weight_binary_classification_friedman2001(g) + case scoring.SplitScoreMetrics.xgboost: + leaf_weight = leaf_weight_xgboost(growth_params, g, h) + case _: + raise KeyError( + f"Unknown measure_name: {measure_name}, expected one of {', '.join(list(scoring.SplitScoreMetrics.__members__.keys()))}" + ) return leaf_weight diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index 4caabec..c6c252a 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -1,5 +1,4 @@ from enum import Enum -from functools import partial import numpy as np @@ -214,28 +213,47 @@ def __call__(self, *args, **kwargs) -> float: return self.score -class SplitScoreMetrics(Enum): - # https://stackoverflow.com/questions/40338652/how-to-define-enum-values-that-are-functions - variance = partial(calc_variance) - entropy = partial(calc_entropy) - entropy_rs = partial(calc_entropy_rs) - gini = partial(calc_gini_impurity) - gini_rs = partial(calc_gini_impurity_rs) - # variance for split score because Friedman et al. 2001 in Algorithm 1 - # step 4 minimize the squared error between actual and predicted dloss/dyhat - friedman_binary_classification = partial(calc_variance) - xgboost = partial(calc_xgboost_split_score) - incrementing = partial(IncrementingScore()) - - def __call__( - self, - y: np.ndarray, - target_groups: np.ndarray, - yhat: np.ndarray = None, - g: np.ndarray = None, - h: np.ndarray = None, - growth_params: utils.TreeGrowthParameters = None, - ) -> float: - return self.value( - y, target_groups, yhat=yhat, g=g, h=h, growth_params=growth_params - ) +SplitScoreMetrics = Enum( + "SplitScoreMetrics", + [ + "variance", + "entropy", + "entropy_rs", + "gini", + "gini_rs", + "friedman_binary_classification", + "xgboost", + "incrementing", + ], +) + + +def calc_score( + y: np.ndarray, + target_groups: np.ndarray, + g: np.ndarray = None, + h: np.ndarray = None, + growth_params: utils.TreeGrowthParameters = None, + score_metric: SplitScoreMetrics = SplitScoreMetrics.variance, +) -> float: + match score_metric: + case SplitScoreMetrics.variance: + return calc_variance(y, target_groups) + case SplitScoreMetrics.entropy: + return calc_entropy(y, target_groups) + case SplitScoreMetrics.entropy_rs: + return calc_entropy_rs(y, target_groups) + case SplitScoreMetrics.gini: + return calc_gini_impurity(y, target_groups) + case SplitScoreMetrics.gini_rs: + return calc_gini_impurity_rs(y, target_groups) + case SplitScoreMetrics.friedman_binary_classification: + return calc_variance(y, target_groups) + case SplitScoreMetrics.xgboost: + return calc_xgboost_split_score( + y, target_groups, g, h, growth_params + ) + case SplitScoreMetrics.incrementing: + return IncrementingScore()() + case _: + raise ValueError(f"{score_metric=} not supported") diff --git a/tests/test_decisiontree.py b/tests/test_decisiontree.py index 5e569e6..0438a1d 100644 --- a/tests/test_decisiontree.py +++ b/tests/test_decisiontree.py @@ -8,6 +8,7 @@ from sklearn.utils.estimator_checks import parametrize_with_checks import random_tree_models.decisiontree as dtree +import random_tree_models.scoring as scoring import random_tree_models.utils as utils # first value in each tuple is the value to test and the second is the flag indicating if this should work @@ -149,8 +150,6 @@ def test_Node(int_val, float_val, node_val, str_val, bool_val): ], ) def test_check_is_baselevel(y, depths): - node = dtree.Node() - y, is_baselevel_exp_y = y depth, max_depth, is_baselevel_exp_depth = depths is_baselevel_exp = is_baselevel_exp_depth or is_baselevel_exp_y @@ -592,7 +591,7 @@ def test_1d( best = dtree.find_best_split( self.X_1D, y, - measure_name=measure_name, + measure_name=scoring.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=grow_params, @@ -644,7 +643,7 @@ def test_1d_missing( best = dtree.find_best_split( self.X_1D_missing, y, - measure_name=measure_name, + measure_name=scoring.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=grow_params, @@ -696,7 +695,7 @@ def test_2d( best = dtree.find_best_split( self.X_2D, y, - measure_name, + scoring.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=growth_params, @@ -749,7 +748,7 @@ def test_2d_missing( best = dtree.find_best_split( self.X_2D_missing, y, - measure_name, + scoring.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=growth_params, @@ -855,35 +854,39 @@ def test_check_if_split_sensible( assert gain is None -def test_calc_leaf_weight_and_split_score(): - # calls leafweights.calc_leaf_weight and scoreing.SplitScoreMetrics - # and returns two floats - y = np.array([True, True, False]) - measure_name = "gini" - growth_params = utils.TreeGrowthParameters(max_depth=2) - g = np.array([1, 2, 3]) - h = np.array([4, 5, 6]) - leaf_weight_exp = 1.0 - score_exp = 42.0 - with ( - patch( - "random_tree_models.decisiontree.leafweights.calc_leaf_weight", - return_value=leaf_weight_exp, - ) as mock_calc_leaf_weight, - patch( - "random_tree_models.decisiontree.scoring.SplitScoreMetrics.__call__", - return_value=score_exp, - ) as mock_SplitScoreMetrics, - ): - # line to test - leaf_weight, split_score = dtree.calc_leaf_weight_and_split_score( - y, measure_name, growth_params, g, h - ) - - assert leaf_weight == leaf_weight_exp - assert split_score == score_exp - assert mock_calc_leaf_weight.call_count == 1 - assert mock_SplitScoreMetrics.call_count == 1 +# write tests for calc_leaf_weight_and_split_score +@pytest.mark.parametrize( + "y,measure_name,growth_params,g,h", + [ + (y, measure_name, growth_params, g, h) + for y in [ + np.array([True, True, False]), + np.array([True, True, True]), + np.array([False, False, False]), + ] + for measure_name in [ + scoring.SplitScoreMetrics.gini, + scoring.SplitScoreMetrics.entropy, + ] + for growth_params in [ + utils.TreeGrowthParameters(max_depth=2), + utils.TreeGrowthParameters(max_depth=2, min_improvement=0.2), + ] + for g in [np.array([1, 2, 3])] + for h in [np.array([4, 5, 6])] + ], +) +def test_calc_leaf_weight_and_split_score( + y: np.ndarray, + measure_name: str, + growth_params: utils.TreeGrowthParameters, + g: np.ndarray, + h: np.ndarray, +): + # line to test + _, _ = dtree.calc_leaf_weight_and_split_score( + y, measure_name, growth_params, g, h + ) @pytest.mark.parametrize("go_left", [True, False]) @@ -925,7 +928,7 @@ class Test_grow_tree: X = np.array([[1], [2], [3]]) y = np.array([True, True, False]) target_groups = np.array([True, True, False]) - measure_name = "gini" + measure_name = scoring.SplitScoreMetrics["gini"] depth_dummy = 0 def test_baselevel(self): @@ -949,7 +952,7 @@ def test_baselevel(self): ) mock_check_is_baselevel.assert_called_once() - assert leaf_node.is_leaf == True + assert leaf_node.is_leaf is True assert leaf_node.reason == reason def test_split_improvement_insufficient(self): @@ -965,7 +968,7 @@ def test_split_improvement_insufficient(self): threshold=3.0, target_groups=self.target_groups, ) - measure = dtree.SplitScore(self.measure_name, parent_score) + measure = dtree.SplitScore(self.measure_name.name, parent_score) parent_node = dtree.Node( array_column=0, threshold=1.0, @@ -1019,7 +1022,7 @@ def test_split_improvement_sufficient(self): threshold=3.0, target_groups=self.target_groups, ) - measure = dtree.SplitScore(self.measure_name, parent_score) + measure = dtree.SplitScore(self.measure_name.name, parent_score) parent_node = dtree.Node( array_column=0, threshold=1.0, @@ -1064,19 +1067,19 @@ def test_split_improvement_sufficient(self): assert tree.reason == "" assert tree.prediction == np.mean(self.y) assert tree.n_obs == len(self.y) - assert tree.is_leaf == False + assert tree.is_leaf is False # left leaf assert tree.left.reason == leaf_reason assert tree.left.prediction == 1.0 assert tree.left.n_obs == 2 - assert tree.left.is_leaf == True + assert tree.left.is_leaf is True # right leaf assert tree.right.reason == leaf_reason assert tree.right.prediction == 0.0 assert tree.right.n_obs == 1 - assert tree.right.is_leaf == True + assert tree.right.is_leaf is True @pytest.mark.parametrize( @@ -1144,7 +1147,7 @@ def test_predict_with_tree(): class TestDecisionTreeTemplate: - model = dtree.DecisionTreeTemplate() + model = dtree.DecisionTreeTemplate(measure_name="gini") X = np.random.normal(size=(100, 10)) y = np.random.normal(size=(100,)) @@ -1160,13 +1163,13 @@ def test_growth_params_(self): def test_fit(self): try: self.model.fit(None, None) - except NotImplementedError as ex: + except NotImplementedError: pytest.xfail("DecisionTreeTemplate.fit expectedly refused call") def test_predict(self): try: self.model.predict(None) - except NotImplementedError as ex: + except NotImplementedError: pytest.xfail("DecisionTreeTemplate.predict expectedly refused call") def test_select_samples_and_features_no_sampling(self): @@ -1325,7 +1328,10 @@ def test_predict(self): @pytest.mark.slow @parametrize_with_checks( - [dtree.DecisionTreeRegressor(), dtree.DecisionTreeClassifier()] + [ + dtree.DecisionTreeRegressor(measure_name="variance"), + dtree.DecisionTreeClassifier(measure_name="gini"), + ] ) def test_dtree_estimators_with_sklearn_checks(estimator, check): """Test of estimators using scikit-learn test suite diff --git a/tests/test_leafweights.py b/tests/test_leafweights.py index 26dbd1b..d3b6a55 100644 --- a/tests/test_leafweights.py +++ b/tests/test_leafweights.py @@ -2,6 +2,7 @@ import pytest import random_tree_models.leafweights as leafweights +import random_tree_models.scoring as scoring import random_tree_models.utils as utils @@ -31,38 +32,6 @@ def test_leaf_weight_xgboost(): ) -class TestLeafWeightSchemes: - def test_leaf_weight_mean_references(self): - mean_schemes = [ - "variance", - "entropy", - "entropy_rs", - "gini", - "gini_rs", - "incrementing", - ] - - for scheme in mean_schemes: - assert ( - leafweights.LeafWeightSchemes[scheme].value.func - is leafweights.leaf_weight_mean - ) - - def test_leaf_weight_xgboost_references(self): - assert ( - leafweights.LeafWeightSchemes["xgboost"].value.func - is leafweights.leaf_weight_xgboost - ) - - def test_leaf_weight_friedman_references(self): - assert ( - leafweights.LeafWeightSchemes[ - "friedman_binary_classification" - ].value.func - is leafweights.leaf_weight_binary_classification_friedman2001 - ) - - class Test_calc_leaf_weight: def test_error_for_unknown_scheme(self): y = np.array([1, 2, 3]) @@ -71,7 +40,7 @@ def test_error_for_unknown_scheme(self): leafweights.calc_leaf_weight( y=y, growth_params=growth_params, measure_name="not_a_scheme" ) - except KeyError as ex: + except KeyError: pytest.xfail("ValueError correctly raised for unknown scheme") else: pytest.fail("ValueError not raised for unknown scheme") @@ -91,6 +60,8 @@ def test_leaf_weight_float_if_y_not_empty(self): growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) weight = leafweights.calc_leaf_weight( - y=y, growth_params=growth_params, measure_name="variance" + y=y, + growth_params=growth_params, + measure_name=scoring.SplitScoreMetrics["variance"], ) assert isinstance(weight, float) diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 4028281..e927401 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -380,6 +380,7 @@ def test_calc_xgboost_split_score( class TestSplitScoreMetrics: "Redudancy test - calling calc_xgboost_split_score etc via SplitScoreMetrics needs to yield the same values as in the test above." + y = np.array([1, 1, 2, 2]) target_groups = np.array([False, True, False, True]) @@ -388,28 +389,36 @@ class TestSplitScoreMetrics: var_exp = -0.25 def test_gini(self): - g = scoring.SplitScoreMetrics["gini"](self.y, self.target_groups) + measure = scoring.SplitScoreMetrics["gini"] + g = scoring.calc_score(self.y, self.target_groups, score_metric=measure) assert g == self.g_exp def test_gini_rs(self): - g = scoring.SplitScoreMetrics["gini_rs"](self.y, self.target_groups) + measure = scoring.SplitScoreMetrics["gini_rs"] + g = scoring.calc_score(self.y, self.target_groups, score_metric=measure) assert g == self.g_exp def test_entropy(self): - h = scoring.SplitScoreMetrics["entropy"](self.y, self.target_groups) + measure = scoring.SplitScoreMetrics["entropy"] + h = scoring.calc_score(self.y, self.target_groups, score_metric=measure) assert h == self.h_exp - def test_entropy(self): - h = scoring.SplitScoreMetrics["entropy_rs"](self.y, self.target_groups) + def test_entropy_rs(self): + measure = scoring.SplitScoreMetrics["entropy_rs"] + h = scoring.calc_score(self.y, self.target_groups, score_metric=measure) assert h == self.h_exp def test_variance(self): - var = scoring.SplitScoreMetrics["variance"](self.y, self.target_groups) + measure = scoring.SplitScoreMetrics["variance"] + var = scoring.calc_score( + self.y, self.target_groups, score_metric=measure + ) assert var == self.var_exp def test_friedman_binary_classification(self): - var = scoring.SplitScoreMetrics["friedman_binary_classification"]( - self.y, self.target_groups + measure = scoring.SplitScoreMetrics["friedman_binary_classification"] + var = scoring.calc_score( + self.y, self.target_groups, score_metric=measure ) assert var == self.var_exp From d14328de6eb7ecbda0f7758e9cfded6b0945fdfa Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 17:58:58 +0100 Subject: [PATCH 03/28] chore: extended pre-commit hooks --- .pre-commit-config.yaml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index daf9a50..62accfe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,18 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: check-yaml + - id: trailing-whitespace + - id: check-added-large-files + args: ['--maxkb=1500'] + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-merge-conflict + - id: check-toml + - id: detect-private-key + - id: end-of-file-fixer + - id: fix-encoding-pragma + - id: name-tests-test + args: ['--pytest-test-first'] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.1.9 From fa75898c269eeff6807c353e0ab2688deb51072b Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 18:11:19 +0100 Subject: [PATCH 04/28] chore: removed superfluous **kwargs --- random_tree_models/leafweights.py | 5 ++--- random_tree_models/scoring.py | 23 +++++++---------------- tests/test_leafweights.py | 11 ++++------- tests/test_scoring.py | 8 ++++---- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/random_tree_models/leafweights.py b/random_tree_models/leafweights.py index 0fffd7d..0f98d45 100644 --- a/random_tree_models/leafweights.py +++ b/random_tree_models/leafweights.py @@ -1,16 +1,16 @@ +# -*- coding: utf-8 -*- import numpy as np import random_tree_models.scoring as scoring import random_tree_models.utils as utils -def leaf_weight_mean(y: np.ndarray, **kwargs) -> float: +def leaf_weight_mean(y: np.ndarray) -> float: return np.mean(y) def leaf_weight_binary_classification_friedman2001( g: np.ndarray, - **kwargs, ) -> float: "Computes optimal leaf weight as in Friedman et al. 2001 Algorithm 5" @@ -23,7 +23,6 @@ def leaf_weight_xgboost( growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, - **kwargs, ) -> float: "Computes optimal leaf weight as in Chen et al. 2016 equation 5" diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index c6c252a..c9c161a 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from enum import Enum import numpy as np @@ -15,7 +16,7 @@ def check_y_and_target_groups(y: np.ndarray, target_groups: np.ndarray = None): raise ValueError(f"{y.shape=} != {target_groups.shape=}") -def calc_variance(y: np.ndarray, target_groups: np.ndarray, **kwargs) -> float: +def calc_variance(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the variance of a split""" check_y_and_target_groups(y, target_groups=target_groups) @@ -57,7 +58,7 @@ def entropy(y: np.ndarray) -> float: return h -def calc_entropy(y: np.ndarray, target_groups: np.ndarray, **kwargs) -> float: +def calc_entropy(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the entropy of a split""" check_y_and_target_groups(y, target_groups=target_groups) @@ -72,9 +73,7 @@ def calc_entropy(y: np.ndarray, target_groups: np.ndarray, **kwargs) -> float: return h -def calc_entropy_rs( - y: np.ndarray, target_groups: np.ndarray, **kwargs -) -> float: +def calc_entropy_rs(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the entropy of a split""" check_y_and_target_groups(y, target_groups=target_groups) @@ -112,9 +111,7 @@ def gini_impurity(y: np.ndarray) -> float: return -g -def calc_gini_impurity( - y: np.ndarray, target_groups: np.ndarray, **kwargs -) -> float: +def calc_gini_impurity(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the gini impurity of a split Based on: https://scikit-learn.org/stable/modules/tree.html#classification-criteria @@ -132,9 +129,7 @@ def calc_gini_impurity( return g -def calc_gini_impurity_rs( - y: np.ndarray, target_groups: np.ndarray, **kwargs -) -> float: +def calc_gini_impurity_rs(y: np.ndarray, target_groups: np.ndarray) -> float: """Calculates the gini impurity of a split Based on: https://scikit-learn.org/stable/modules/tree.html#classification-criteria @@ -171,12 +166,10 @@ def xgboost_split_score( def calc_xgboost_split_score( - y: np.ndarray, target_groups: np.ndarray, g: np.ndarray, h: np.ndarray, growth_params: utils.TreeGrowthParameters, - **kwargs, ) -> float: """Calculates the xgboost general version score of a split with loss specifics in g and h. @@ -250,9 +243,7 @@ def calc_score( case SplitScoreMetrics.friedman_binary_classification: return calc_variance(y, target_groups) case SplitScoreMetrics.xgboost: - return calc_xgboost_split_score( - y, target_groups, g, h, growth_params - ) + return calc_xgboost_split_score(target_groups, g, h, growth_params) case SplitScoreMetrics.incrementing: return IncrementingScore()() case _: diff --git a/tests/test_leafweights.py b/tests/test_leafweights.py index d3b6a55..8b900f0 100644 --- a/tests/test_leafweights.py +++ b/tests/test_leafweights.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import numpy as np import pytest @@ -8,27 +9,23 @@ def test_leaf_weight_mean(): y = np.array([1, 2, 3]) - g = np.array([1, 2, 3]) * 2 - assert leafweights.leaf_weight_mean(y=y, g=g) == 2.0 + assert leafweights.leaf_weight_mean(y=y) == 2.0 def test_leaf_weight_binary_classification_friedman2001(): - y = np.array([1, 2, 3]) g = np.array([1, 2, 3]) * 2 assert ( - leafweights.leaf_weight_binary_classification_friedman2001(y=y, g=g) + leafweights.leaf_weight_binary_classification_friedman2001(g=g) == -0.375 ) def test_leaf_weight_xgboost(): - y = np.array([1, 2, 3]) g = np.array([1, 2, 3]) * 2 h = np.array([1, 2, 3]) * 4 params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) assert ( - leafweights.leaf_weight_xgboost(y=y, g=g, h=h, growth_params=params) - == -0.5 + leafweights.leaf_weight_xgboost(growth_params=params, g=g, h=h) == -0.5 ) diff --git a/tests/test_scoring.py b/tests/test_scoring.py index e927401..1eeef2d 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import numpy as np import pytest @@ -359,11 +360,11 @@ def test_calc_xgboost_split_score( g: np.ndarray, h: np.ndarray, target_groups: np.ndarray, score_exp: float ): growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) - y = None + try: # line to test score = scoring.calc_xgboost_split_score( - y, target_groups, g, h, growth_params + target_groups, g, h, growth_params ) except ValueError as ex: if score_exp is None: @@ -461,11 +462,10 @@ def test_xgboost( score_exp: float, ): growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) - y = None # line to test score = scoring.calc_xgboost_split_score( - y, target_groups, g, h, growth_params + target_groups, g, h, growth_params ) assert score == score_exp From 35f939405320031e120073a0e233efed82d8ea52 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 18:21:52 +0100 Subject: [PATCH 05/28] chore: SplitScoreMetrics Enum as a class to enable autocomplete of members --- random_tree_models/scoring.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index c9c161a..f05d625 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -206,19 +206,15 @@ def __call__(self, *args, **kwargs) -> float: return self.score -SplitScoreMetrics = Enum( - "SplitScoreMetrics", - [ - "variance", - "entropy", - "entropy_rs", - "gini", - "gini_rs", - "friedman_binary_classification", - "xgboost", - "incrementing", - ], -) +class SplitScoreMetrics(Enum): + variance = "variance" + entropy = "entropy" + entropy_rs = "entropy_rs" + gini = "gini" + gini_rs = "gini_rs" + friedman_binary_classification = "friedman_binary_classification" + xgboost = "xgboost" + incrementing = "incrementing" def calc_score( From 738e85aeaffaeef6300f6195675fd96b89320b3e Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 18:55:26 +0100 Subject: [PATCH 06/28] fix: passed incrementing_score --- random_tree_models/decisiontree.py | 24 ++++++++++++++++++++++-- random_tree_models/isolationforest.py | 4 ++++ random_tree_models/scoring.py | 11 ++++++++--- tests/test_isolationforest.py | 5 ++--- tests/test_scoring.py | 26 ++++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 8 deletions(-) diff --git a/random_tree_models/decisiontree.py b/random_tree_models/decisiontree.py index c688805..eaf0042 100644 --- a/random_tree_models/decisiontree.py +++ b/random_tree_models/decisiontree.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import typing as T import uuid @@ -205,6 +206,7 @@ def find_best_split( h: np.ndarray = None, growth_params: utils.TreeGrowthParameters = None, rng: np.random.RandomState = np.random.RandomState(42), + incrementing_score: scoring.IncrementingScore = None, ) -> BestSplit: """Find the best split, detecting the "default direction" with missing data.""" @@ -232,6 +234,7 @@ def find_best_split( h=h, growth_params=growth_params, score_metric=measure_name, + incrementing_score=incrementing_score, ) if best is None or split_score > best.score: @@ -275,6 +278,7 @@ def calc_leaf_weight_and_split_score( growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, + incrementing_score: scoring.IncrementingScore = None, ) -> T.Tuple[float]: leaf_weight = leafweights.calc_leaf_weight( y, measure_name, growth_params, g=g, h=h @@ -288,6 +292,7 @@ def calc_leaf_weight_and_split_score( h=h, growth_params=growth_params, score_metric=measure_name, + incrementing_score=incrementing_score, ) return leaf_weight, score @@ -319,6 +324,7 @@ def grow_tree( g: np.ndarray = None, h: np.ndarray = None, random_state: int = 42, + incrementing_score: scoring.IncrementingScore = None, **kwargs, ) -> Node: """Implementation of the Classification And Regression Tree (CART) algorithm @@ -361,7 +367,12 @@ def grow_tree( # compute leaf weight (for prediction) and node score (for split gain check) leaf_weight, score = calc_leaf_weight_and_split_score( - y, measure_name, growth_params, g, h + y, + measure_name, + growth_params, + g, + h, + incrementing_score=incrementing_score, ) if is_baselevel: # end of the line buddy @@ -377,7 +388,14 @@ def grow_tree( rng = np.random.RandomState(random_state) best = find_best_split( - X, y, measure_name, g=g, h=h, growth_params=growth_params, rng=rng + X, + y, + measure_name, + g=g, + h=h, + growth_params=growth_params, + rng=rng, + incrementing_score=incrementing_score, ) # check if improvement due to split is below minimum requirement @@ -421,6 +439,7 @@ def grow_tree( g=_g, h=_h, random_state=random_state_left, + incrementing_score=incrementing_score, ) # descend right @@ -435,6 +454,7 @@ def grow_tree( g=_g, h=_h, random_state=random_state_right, + incrementing_score=incrementing_score, ) return new_node diff --git a/random_tree_models/isolationforest.py b/random_tree_models/isolationforest.py index dcd26c8..e4a66ef 100644 --- a/random_tree_models/isolationforest.py +++ b/random_tree_models/isolationforest.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import typing as T import numpy as np @@ -7,6 +8,7 @@ from sklearn.utils.validation import check_array, check_is_fitted import random_tree_models.decisiontree as dtree +import random_tree_models.scoring as scoring # TODO: add tests @@ -60,6 +62,7 @@ def fit( _X, _y, self.ix_features_ = self._select_samples_and_features( X, dummy_y ) + self.incrementing_score_ = scoring.IncrementingScore() self.tree_ = dtree.grow_tree( _X, @@ -67,6 +70,7 @@ def fit( measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, + incrementing_score=self.incrementing_score_, **kwargs, ) diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index f05d625..358e913 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -198,9 +198,9 @@ def calc_xgboost_split_score( class IncrementingScore: - score = 0 + score: int = 0 - def __call__(self, *args, **kwargs) -> float: + def update(self) -> float: """Calculates the random cut score of a split""" self.score += 1 return self.score @@ -224,6 +224,7 @@ def calc_score( h: np.ndarray = None, growth_params: utils.TreeGrowthParameters = None, score_metric: SplitScoreMetrics = SplitScoreMetrics.variance, + incrementing_score: IncrementingScore = None, ) -> float: match score_metric: case SplitScoreMetrics.variance: @@ -241,6 +242,10 @@ def calc_score( case SplitScoreMetrics.xgboost: return calc_xgboost_split_score(target_groups, g, h, growth_params) case SplitScoreMetrics.incrementing: - return IncrementingScore()() + if incrementing_score is None: + raise ValueError( + f"{incrementing_score=} must be provided as an instance of scoring.IncrementingScore {score_metric=}" + ) + return incrementing_score.update() case _: raise ValueError(f"{score_metric=} not supported") diff --git a/tests/test_isolationforest.py b/tests/test_isolationforest.py index 10bd305..7ec43b1 100644 --- a/tests/test_isolationforest.py +++ b/tests/test_isolationforest.py @@ -1,8 +1,6 @@ +# -*- coding: utf-8 -*- import numpy as np -import pytest -from sklearn.utils.estimator_checks import parametrize_with_checks -import random_tree_models.decisiontree as dtree import random_tree_models.isolationforest as iforest rng = np.random.RandomState(42) @@ -17,6 +15,7 @@ def test_fit(self): model.fit(self.X_inlier) assert hasattr(model, "tree_") assert hasattr(model, "growth_params_") + assert model.incrementing_score_.score > 0 def test_predict(self): model = iforest.IsolationTree() diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 1eeef2d..275a785 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -469,3 +469,29 @@ def test_xgboost( ) assert score == score_exp + + def test_incrementing(self): + incrementing_score = scoring.IncrementingScore() + score_metric = scoring.SplitScoreMetrics["incrementing"] + + # line to test + score = scoring.calc_score( + self.y, + self.target_groups, + score_metric=score_metric, + incrementing_score=incrementing_score, + ) + + assert score == 1 + assert incrementing_score.score == 1 + + # line to test + score = scoring.calc_score( + self.y, + self.target_groups, + score_metric=score_metric, + incrementing_score=incrementing_score, + ) + + assert score == 2 + assert incrementing_score.score == 2 From 589830f605bf5693609c70c38e1f5a14e8c7370a Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 19:48:42 +0100 Subject: [PATCH 07/28] chore: moved SplitScoreMetrics to utils to make split_score_metric an attribute of that type in TreeGrowthParameters --- random_tree_models/decisiontree.py | 57 +++++++++------------ random_tree_models/isolationforest.py | 1 - random_tree_models/leafweights.py | 22 ++++---- random_tree_models/scoring.py | 38 +++++--------- random_tree_models/utils.py | 13 +++++ tests/test_decisiontree.py | 74 +++++++++++++++++---------- tests/test_leafweights.py | 26 ++++------ tests/test_scoring.py | 41 ++++++++------- 8 files changed, 139 insertions(+), 133 deletions(-) diff --git a/random_tree_models/decisiontree.py b/random_tree_models/decisiontree.py index eaf0042..85cddcb 100644 --- a/random_tree_models/decisiontree.py +++ b/random_tree_models/decisiontree.py @@ -200,7 +200,6 @@ def get_column( def find_best_split( X: np.ndarray, y: np.ndarray, - measure_name: scoring.SplitScoreMetrics, yhat: np.ndarray = None, g: np.ndarray = None, h: np.ndarray = None, @@ -233,7 +232,6 @@ def find_best_split( g=g, h=h, growth_params=growth_params, - score_metric=measure_name, incrementing_score=incrementing_score, ) @@ -274,15 +272,12 @@ def check_if_split_sensible( def calc_leaf_weight_and_split_score( y: np.ndarray, - measure_name: scoring.SplitScoreMetrics, growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, incrementing_score: scoring.IncrementingScore = None, ) -> T.Tuple[float]: - leaf_weight = leafweights.calc_leaf_weight( - y, measure_name, growth_params, g=g, h=h - ) + leaf_weight = leafweights.calc_leaf_weight(y, growth_params, g=g, h=h) # yhat = leaf_weight * np.ones_like(y) score = scoring.calc_score( @@ -291,7 +286,6 @@ def calc_leaf_weight_and_split_score( g=g, h=h, growth_params=growth_params, - score_metric=measure_name, incrementing_score=incrementing_score, ) @@ -317,7 +311,6 @@ def select_arrays_for_child_node( def grow_tree( X: np.ndarray, y: np.ndarray, - measure_name: scoring.SplitScoreMetrics, parent_node: Node = None, depth: int = 0, growth_params: utils.TreeGrowthParameters = None, @@ -332,7 +325,7 @@ def grow_tree( Args: X (np.ndarray): Input feature values to do thresholding on. y (np.ndarray): Target values. - measure_name (str): Values indicating which functions in scoring.SplitScoreMetrics and leafweights.LeafWeightSchemes to call. + measure_name (str): Values indicating which functions in utils.SplitScoreMetrics and leafweights.LeafWeightSchemes to call. parent_node (Node, optional): Parent node in tree. Defaults to None. depth (int, optional): Current tree depth. Defaults to 0. growth_params (utils.TreeGrowthParameters, optional): Parameters controlling tree growth. Defaults to None. @@ -368,17 +361,18 @@ def grow_tree( # compute leaf weight (for prediction) and node score (for split gain check) leaf_weight, score = calc_leaf_weight_and_split_score( y, - measure_name, growth_params, g, h, incrementing_score=incrementing_score, ) + measure_name = growth_params.split_score_metric.name + if is_baselevel: # end of the line buddy return Node( prediction=leaf_weight, - measure=SplitScore(measure_name.name, score=score), + measure=SplitScore(measure_name, score=score), n_obs=n_obs, reason=reason, depth=depth, @@ -390,7 +384,6 @@ def grow_tree( best = find_best_split( X, y, - measure_name, g=g, h=h, growth_params=growth_params, @@ -407,7 +400,7 @@ def grow_tree( reason = f"gain due split ({gain=}) lower than {growth_params.min_improvement=} or all data points assigned to one side (is left {best.target_groups.mean()=:.2%})" leaf_node = Node( prediction=leaf_weight, - measure=SplitScore(measure_name.name, score=score), + measure=SplitScore(measure_name, score=score), n_obs=n_obs, reason=reason, depth=depth, @@ -420,7 +413,7 @@ def grow_tree( threshold=best.threshold, prediction=leaf_weight, default_is_left=best.default_is_left, - measure=SplitScore(measure_name.name, best.score), + measure=SplitScore(measure_name, best.score), n_obs=n_obs, reason="", depth=depth, @@ -432,7 +425,6 @@ def grow_tree( new_node.left = grow_tree( _X, _y, - measure_name=measure_name, parent_node=new_node, depth=depth + 1, growth_params=growth_params, @@ -447,7 +439,6 @@ def grow_tree( new_node.right = grow_tree( _X, _y, - measure_name=measure_name, parent_node=new_node, depth=depth + 1, growth_params=growth_params, @@ -534,36 +525,36 @@ def __init__( self.column_method = column_method self.n_columns_to_try = n_columns_to_try - def _sanity_check_measure_name(self): + def _sanity_check_measure_name(self) -> utils.SplitScoreMetrics: try: - self.measure_name_enum_ = scoring.SplitScoreMetrics[ - self.measure_name - ] + return utils.SplitScoreMetrics[self.measure_name] except KeyError as ex: raise KeyError( f"Unknown measure_name: {self.measure_name}. " - f"Valid options: {', '.join(list(scoring.SplitScoreMetrics.__members__.keys()))}. {ex=}" + f"Valid options: {', '.join(list(utils.SplitScoreMetrics.__members__.keys()))}. {ex=}" ) def _organize_growth_parameters(self): - self._sanity_check_measure_name() + threshold_params = utils.ThresholdSelectionParameters( + method=self.threshold_method, + quantile=self.threshold_quantile, + n_thresholds=self.n_thresholds, + random_state=int(self.random_state), + ) + column_params = utils.ColumnSelectionParameters( + method=self.column_method, + n_trials=self.n_columns_to_try, + ) self.growth_params_ = utils.TreeGrowthParameters( max_depth=self.max_depth, + split_score_metric=self._sanity_check_measure_name(), min_improvement=self.min_improvement, lam=-abs(self.lam), frac_subsamples=float(self.frac_subsamples), frac_features=float(self.frac_features), random_state=int(self.random_state), - threshold_params=utils.ThresholdSelectionParameters( - method=self.threshold_method, - quantile=self.threshold_quantile, - n_thresholds=self.n_thresholds, - random_state=int(self.random_state), - ), - column_params=utils.ColumnSelectionParameters( - method=self.column_method, - n_trials=self.n_columns_to_try, - ), + threshold_params=threshold_params, + column_params=column_params, ) def _select_samples_and_features( @@ -643,7 +634,6 @@ def fit( self.tree_ = grow_tree( _X, _y, - measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, **kwargs, @@ -707,7 +697,6 @@ def fit( self.tree_ = grow_tree( _X, _y, - measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, ) diff --git a/random_tree_models/isolationforest.py b/random_tree_models/isolationforest.py index e4a66ef..d1c7276 100644 --- a/random_tree_models/isolationforest.py +++ b/random_tree_models/isolationforest.py @@ -67,7 +67,6 @@ def fit( self.tree_ = dtree.grow_tree( _X, _y, - measure_name=self.measure_name_enum_, growth_params=self.growth_params_, random_state=self.random_state, incrementing_score=self.incrementing_score_, diff --git a/random_tree_models/leafweights.py b/random_tree_models/leafweights.py index 0f98d45..62ac87b 100644 --- a/random_tree_models/leafweights.py +++ b/random_tree_models/leafweights.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import numpy as np -import random_tree_models.scoring as scoring import random_tree_models.utils as utils @@ -32,7 +31,6 @@ def leaf_weight_xgboost( def calc_leaf_weight( y: np.ndarray, - measure_name: scoring.SplitScoreMetrics, growth_params: utils.TreeGrowthParameters, g: np.ndarray = None, h: np.ndarray = None, @@ -45,23 +43,25 @@ def calc_leaf_weight( if len(y) == 0: return None + measure_name = growth_params.split_score_metric + match measure_name: case ( - scoring.SplitScoreMetrics.variance - | scoring.SplitScoreMetrics.entropy - | scoring.SplitScoreMetrics.entropy_rs - | scoring.SplitScoreMetrics.gini - | scoring.SplitScoreMetrics.gini_rs - | scoring.SplitScoreMetrics.incrementing + utils.SplitScoreMetrics.variance + | utils.SplitScoreMetrics.entropy + | utils.SplitScoreMetrics.entropy_rs + | utils.SplitScoreMetrics.gini + | utils.SplitScoreMetrics.gini_rs + | utils.SplitScoreMetrics.incrementing ): leaf_weight = leaf_weight_mean(y) - case scoring.SplitScoreMetrics.friedman_binary_classification: + case utils.SplitScoreMetrics.friedman_binary_classification: leaf_weight = leaf_weight_binary_classification_friedman2001(g) - case scoring.SplitScoreMetrics.xgboost: + case utils.SplitScoreMetrics.xgboost: leaf_weight = leaf_weight_xgboost(growth_params, g, h) case _: raise KeyError( - f"Unknown measure_name: {measure_name}, expected one of {', '.join(list(scoring.SplitScoreMetrics.__members__.keys()))}" + f"Unknown measure_name: {measure_name}, expected one of {', '.join(list(utils.SplitScoreMetrics.__members__.keys()))}" ) return leaf_weight diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index 358e913..a0ae640 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -from enum import Enum - import numpy as np import random_tree_models.utils as utils @@ -206,46 +204,36 @@ def update(self) -> float: return self.score -class SplitScoreMetrics(Enum): - variance = "variance" - entropy = "entropy" - entropy_rs = "entropy_rs" - gini = "gini" - gini_rs = "gini_rs" - friedman_binary_classification = "friedman_binary_classification" - xgboost = "xgboost" - incrementing = "incrementing" - - def calc_score( y: np.ndarray, target_groups: np.ndarray, g: np.ndarray = None, h: np.ndarray = None, growth_params: utils.TreeGrowthParameters = None, - score_metric: SplitScoreMetrics = SplitScoreMetrics.variance, incrementing_score: IncrementingScore = None, ) -> float: - match score_metric: - case SplitScoreMetrics.variance: + measure_name = growth_params.split_score_metric + + match measure_name: + case utils.SplitScoreMetrics.variance: return calc_variance(y, target_groups) - case SplitScoreMetrics.entropy: + case utils.SplitScoreMetrics.entropy: return calc_entropy(y, target_groups) - case SplitScoreMetrics.entropy_rs: + case utils.SplitScoreMetrics.entropy_rs: return calc_entropy_rs(y, target_groups) - case SplitScoreMetrics.gini: + case utils.SplitScoreMetrics.gini: return calc_gini_impurity(y, target_groups) - case SplitScoreMetrics.gini_rs: + case utils.SplitScoreMetrics.gini_rs: return calc_gini_impurity_rs(y, target_groups) - case SplitScoreMetrics.friedman_binary_classification: + case utils.SplitScoreMetrics.friedman_binary_classification: return calc_variance(y, target_groups) - case SplitScoreMetrics.xgboost: + case utils.SplitScoreMetrics.xgboost: return calc_xgboost_split_score(target_groups, g, h, growth_params) - case SplitScoreMetrics.incrementing: + case utils.SplitScoreMetrics.incrementing: if incrementing_score is None: raise ValueError( - f"{incrementing_score=} must be provided as an instance of scoring.IncrementingScore {score_metric=}" + f"{incrementing_score=} must be provided as an instance of scoring.IncrementingScore {measure_name=}" ) return incrementing_score.update() case _: - raise ValueError(f"{score_metric=} not supported") + raise ValueError(f"{measure_name=} not supported") diff --git a/random_tree_models/utils.py b/random_tree_models/utils.py index 0da4b7e..4feeccc 100644 --- a/random_tree_models/utils.py +++ b/random_tree_models/utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import logging from enum import Enum @@ -63,10 +64,22 @@ class ColumnSelectionParameters: n_trials: StrictInt = None +class SplitScoreMetrics(Enum): + variance = "variance" + entropy = "entropy" + entropy_rs = "entropy_rs" + gini = "gini" + gini_rs = "gini_rs" + friedman_binary_classification = "friedman_binary_classification" + xgboost = "xgboost" + incrementing = "incrementing" + + @dataclass class TreeGrowthParameters: max_depth: StrictInt min_improvement: StrictFloat = 0.0 + split_score_metric: SplitScoreMetrics = SplitScoreMetrics.variance # xgboost lambda - multiplied with sum of squares of leaf weights # see Chen et al. 2016 equation 2 lam: StrictFloat = 0.0 diff --git a/tests/test_decisiontree.py b/tests/test_decisiontree.py index 0438a1d..a85a27f 100644 --- a/tests/test_decisiontree.py +++ b/tests/test_decisiontree.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import types from unittest.mock import patch @@ -8,7 +9,6 @@ from sklearn.utils.estimator_checks import parametrize_with_checks import random_tree_models.decisiontree as dtree -import random_tree_models.scoring as scoring import random_tree_models.utils as utils # first value in each tuple is the value to test and the second is the flag indicating if this should work @@ -585,13 +585,15 @@ def test_1d( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - grow_params = utils.TreeGrowthParameters(max_depth=2) + grow_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( self.X_1D, y, - measure_name=scoring.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=grow_params, @@ -637,13 +639,15 @@ def test_1d_missing( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - grow_params = utils.TreeGrowthParameters(max_depth=2) + grow_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( self.X_1D_missing, y, - measure_name=scoring.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=grow_params, @@ -695,7 +699,7 @@ def test_2d( best = dtree.find_best_split( self.X_2D, y, - scoring.SplitScoreMetrics[measure_name], + utils.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=growth_params, @@ -748,7 +752,7 @@ def test_2d_missing( best = dtree.find_best_split( self.X_2D_missing, y, - scoring.SplitScoreMetrics[measure_name], + utils.SplitScoreMetrics[measure_name], g=g, h=h, growth_params=growth_params, @@ -856,21 +860,39 @@ def test_check_if_split_sensible( # write tests for calc_leaf_weight_and_split_score @pytest.mark.parametrize( - "y,measure_name,growth_params,g,h", + "y,growth_params,g,h", [ - (y, measure_name, growth_params, g, h) + (y, growth_params, g, h) for y in [ np.array([True, True, False]), np.array([True, True, True]), np.array([False, False, False]), ] - for measure_name in [ - scoring.SplitScoreMetrics.gini, - scoring.SplitScoreMetrics.entropy, - ] for growth_params in [ - utils.TreeGrowthParameters(max_depth=2), - utils.TreeGrowthParameters(max_depth=2, min_improvement=0.2), + utils.TreeGrowthParameters( + max_depth=2, split_score_metric=utils.SplitScoreMetrics.gini + ), + utils.TreeGrowthParameters( + max_depth=2, split_score_metric=utils.SplitScoreMetrics.entropy + ), + utils.TreeGrowthParameters( + max_depth=2, split_score_metric=utils.SplitScoreMetrics.variance + ), + utils.TreeGrowthParameters( + max_depth=2, + min_improvement=0.2, + split_score_metric=utils.SplitScoreMetrics.gini, + ), + utils.TreeGrowthParameters( + max_depth=2, + min_improvement=0.2, + split_score_metric=utils.SplitScoreMetrics.entropy, + ), + utils.TreeGrowthParameters( + max_depth=2, + min_improvement=0.2, + split_score_metric=utils.SplitScoreMetrics.variance, + ), ] for g in [np.array([1, 2, 3])] for h in [np.array([4, 5, 6])] @@ -878,15 +900,12 @@ def test_check_if_split_sensible( ) def test_calc_leaf_weight_and_split_score( y: np.ndarray, - measure_name: str, growth_params: utils.TreeGrowthParameters, g: np.ndarray, h: np.ndarray, ): # line to test - _, _ = dtree.calc_leaf_weight_and_split_score( - y, measure_name, growth_params, g, h - ) + _, _ = dtree.calc_leaf_weight_and_split_score(y, growth_params, g, h) @pytest.mark.parametrize("go_left", [True, False]) @@ -928,12 +947,14 @@ class Test_grow_tree: X = np.array([[1], [2], [3]]) y = np.array([True, True, False]) target_groups = np.array([True, True, False]) - measure_name = scoring.SplitScoreMetrics["gini"] + measure_name = utils.SplitScoreMetrics["gini"] depth_dummy = 0 def test_baselevel(self): # test returned leaf node - growth_params = utils.TreeGrowthParameters(max_depth=2) + growth_params = utils.TreeGrowthParameters( + max_depth=2, split_score_metric=self.measure_name + ) parent_node = None is_baselevel = True reason = "very custom leaf node comment" @@ -945,7 +966,6 @@ def test_baselevel(self): leaf_node = dtree.grow_tree( self.X, self.y, - self.measure_name, parent_node=parent_node, depth=self.depth_dummy, growth_params=growth_params, @@ -958,7 +978,9 @@ def test_baselevel(self): def test_split_improvement_insufficient(self): # test split improvement below minimum growth_params = utils.TreeGrowthParameters( - max_depth=2, min_improvement=0.2 + max_depth=2, + min_improvement=0.2, + split_score_metric=self.measure_name, ) parent_score = -1.0 new_score = -0.9 @@ -997,7 +1019,6 @@ def test_split_improvement_insufficient(self): node = dtree.grow_tree( self.X, self.y, - self.measure_name, parent_node=parent_node, depth=self.depth_dummy, growth_params=growth_params, @@ -1012,7 +1033,9 @@ def test_split_improvement_insufficient(self): def test_split_improvement_sufficient(self): # test split improvement above minumum, leading to two leaf nodes growth_params = utils.TreeGrowthParameters( - max_depth=2, min_improvement=0.0 + max_depth=2, + split_score_metric=self.measure_name, + min_improvement=0.0, ) parent_score = -1.0 new_score = -0.9 @@ -1054,7 +1077,6 @@ def test_split_improvement_sufficient(self): tree = dtree.grow_tree( self.X, self.y, - self.measure_name, parent_node=parent_node, depth=self.depth_dummy, growth_params=growth_params, diff --git a/tests/test_leafweights.py b/tests/test_leafweights.py index 8b900f0..e0ee3a8 100644 --- a/tests/test_leafweights.py +++ b/tests/test_leafweights.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import numpy as np import pytest +import pydantic import random_tree_models.leafweights as leafweights -import random_tree_models.scoring as scoring import random_tree_models.utils as utils @@ -31,34 +31,26 @@ def test_leaf_weight_xgboost(): class Test_calc_leaf_weight: def test_error_for_unknown_scheme(self): - y = np.array([1, 2, 3]) - growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) try: - leafweights.calc_leaf_weight( - y=y, growth_params=growth_params, measure_name="not_a_scheme" + _ = utils.TreeGrowthParameters( + max_depth=2, split_score_metric="not_a_scheme", lam=0.0 ) - except KeyError: + except pydantic.ValidationError: pytest.xfail("ValueError correctly raised for unknown scheme") else: pytest.fail("ValueError not raised for unknown scheme") - def test_leaf_weight_none_if_y_empty(self): - y = np.array([]) - growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) - - weight = leafweights.calc_leaf_weight( - y=y, growth_params=growth_params, measure_name="not_a_scheme" - ) - assert weight is None - # returns a float if y is not empty def test_leaf_weight_float_if_y_not_empty(self): y = np.array([1, 2, 3]) - growth_params = utils.TreeGrowthParameters(max_depth=2, lam=0.0) + growth_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics["variance"], + lam=0.0, + ) weight = leafweights.calc_leaf_weight( y=y, growth_params=growth_params, - measure_name=scoring.SplitScoreMetrics["variance"], ) assert isinstance(weight, float) diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 275a785..8e7f781 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -390,37 +390,39 @@ class TestSplitScoreMetrics: var_exp = -0.25 def test_gini(self): - measure = scoring.SplitScoreMetrics["gini"] - g = scoring.calc_score(self.y, self.target_groups, score_metric=measure) + measure = utils.SplitScoreMetrics["gini"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + g = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert g == self.g_exp def test_gini_rs(self): - measure = scoring.SplitScoreMetrics["gini_rs"] - g = scoring.calc_score(self.y, self.target_groups, score_metric=measure) + measure = utils.SplitScoreMetrics["gini_rs"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + g = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert g == self.g_exp def test_entropy(self): - measure = scoring.SplitScoreMetrics["entropy"] - h = scoring.calc_score(self.y, self.target_groups, score_metric=measure) + measure = utils.SplitScoreMetrics["entropy"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + h = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert h == self.h_exp def test_entropy_rs(self): - measure = scoring.SplitScoreMetrics["entropy_rs"] - h = scoring.calc_score(self.y, self.target_groups, score_metric=measure) + measure = utils.SplitScoreMetrics["entropy_rs"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + h = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert h == self.h_exp def test_variance(self): - measure = scoring.SplitScoreMetrics["variance"] - var = scoring.calc_score( - self.y, self.target_groups, score_metric=measure - ) + measure = utils.SplitScoreMetrics["variance"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + var = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert var == self.var_exp def test_friedman_binary_classification(self): - measure = scoring.SplitScoreMetrics["friedman_binary_classification"] - var = scoring.calc_score( - self.y, self.target_groups, score_metric=measure - ) + measure = utils.SplitScoreMetrics["friedman_binary_classification"] + gp = utils.TreeGrowthParameters(1, split_score_metric=measure) + var = scoring.calc_score(self.y, self.target_groups, growth_params=gp) assert var == self.var_exp @pytest.mark.parametrize( @@ -472,13 +474,14 @@ def test_xgboost( def test_incrementing(self): incrementing_score = scoring.IncrementingScore() - score_metric = scoring.SplitScoreMetrics["incrementing"] + score_metric = utils.SplitScoreMetrics["incrementing"] + gp = utils.TreeGrowthParameters(1, split_score_metric=score_metric) # line to test score = scoring.calc_score( self.y, self.target_groups, - score_metric=score_metric, + growth_params=gp, incrementing_score=incrementing_score, ) @@ -489,7 +492,7 @@ def test_incrementing(self): score = scoring.calc_score( self.y, self.target_groups, - score_metric=score_metric, + growth_params=gp, incrementing_score=incrementing_score, ) From 18c33d5f4496cdf05e1998a1e70a9fb5f9af9b21 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 26 Dec 2023 20:04:45 +0100 Subject: [PATCH 08/28] docs: updated grow_tree docstring --- random_tree_models/decisiontree.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/random_tree_models/decisiontree.py b/random_tree_models/decisiontree.py index 85cddcb..88bd5c7 100644 --- a/random_tree_models/decisiontree.py +++ b/random_tree_models/decisiontree.py @@ -325,7 +325,6 @@ def grow_tree( Args: X (np.ndarray): Input feature values to do thresholding on. y (np.ndarray): Target values. - measure_name (str): Values indicating which functions in utils.SplitScoreMetrics and leafweights.LeafWeightSchemes to call. parent_node (Node, optional): Parent node in tree. Defaults to None. depth (int, optional): Current tree depth. Defaults to 0. growth_params (utils.TreeGrowthParameters, optional): Parameters controlling tree growth. Defaults to None. @@ -339,7 +338,7 @@ def grow_tree( Node: Tree node with leaf weight, node score and potential child nodes. Note: - Currently measure_name controls how the split score and the leaf weights are computed. + Currently growth_params.split_score_name controls how the split score and the leaf weights are computed. But only the decision tree algorithm directly uses y for that and can predict y using the leaf weight values directly. From 5f85918137d29a8bfebd53f7e7a1703fb35116dc Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Fri, 29 Dec 2023 16:22:45 +0100 Subject: [PATCH 09/28] chore: added abc module for decisiontreetremplate and rearranged some grow tree and split find arguments and their optionality --- random_tree_models/decisiontree.py | 29 ++++++++-------- random_tree_models/scoring.py | 2 +- tests/test_decisiontree.py | 54 ++++++++++++++++-------------- 3 files changed, 44 insertions(+), 41 deletions(-) diff --git a/random_tree_models/decisiontree.py b/random_tree_models/decisiontree.py index 88bd5c7..3fa90e1 100644 --- a/random_tree_models/decisiontree.py +++ b/random_tree_models/decisiontree.py @@ -22,6 +22,7 @@ import random_tree_models.leafweights as leafweights import random_tree_models.scoring as scoring import random_tree_models.utils as utils +import abc logger = utils.logger @@ -200,10 +201,9 @@ def get_column( def find_best_split( X: np.ndarray, y: np.ndarray, - yhat: np.ndarray = None, + growth_params: utils.TreeGrowthParameters, g: np.ndarray = None, h: np.ndarray = None, - growth_params: utils.TreeGrowthParameters = None, rng: np.random.RandomState = np.random.RandomState(42), incrementing_score: scoring.IncrementingScore = None, ) -> BestSplit: @@ -229,9 +229,9 @@ def find_best_split( split_score = scoring.calc_score( y, target_groups, + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, incrementing_score=incrementing_score, ) @@ -281,11 +281,11 @@ def calc_leaf_weight_and_split_score( # yhat = leaf_weight * np.ones_like(y) score = scoring.calc_score( - y, - np.ones_like(y, dtype=bool), + y=y, + target_groups=np.ones_like(y, dtype=bool), + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, incrementing_score=incrementing_score, ) @@ -311,23 +311,22 @@ def select_arrays_for_child_node( def grow_tree( X: np.ndarray, y: np.ndarray, + growth_params: utils.TreeGrowthParameters, parent_node: Node = None, depth: int = 0, - growth_params: utils.TreeGrowthParameters = None, g: np.ndarray = None, h: np.ndarray = None, random_state: int = 42, incrementing_score: scoring.IncrementingScore = None, - **kwargs, ) -> Node: """Implementation of the Classification And Regression Tree (CART) algorithm Args: X (np.ndarray): Input feature values to do thresholding on. y (np.ndarray): Target values. + growth_params (utils.TreeGrowthParameters, optional): Parameters controlling tree growth. parent_node (Node, optional): Parent node in tree. Defaults to None. depth (int, optional): Current tree depth. Defaults to 0. - growth_params (utils.TreeGrowthParameters, optional): Parameters controlling tree growth. Defaults to None. g (np.ndarray, optional): Boosting and loss specific precomputed 1st order derivative dloss/dyhat. Defaults to None. h (np.ndarray, optional): Boosting and loss specific precomputed 2nd order derivative d^2loss/dyhat^2. Defaults to None. @@ -424,9 +423,9 @@ def grow_tree( new_node.left = grow_tree( _X, _y, + growth_params=growth_params, parent_node=new_node, depth=depth + 1, - growth_params=growth_params, g=_g, h=_h, random_state=random_state_left, @@ -438,9 +437,9 @@ def grow_tree( new_node.right = grow_tree( _X, _y, + growth_params=growth_params, parent_node=new_node, depth=depth + 1, - growth_params=growth_params, g=_g, h=_h, random_state=random_state_right, @@ -490,7 +489,7 @@ def predict_with_tree(tree: Node, X: np.ndarray) -> np.ndarray: return predictions -class DecisionTreeTemplate(base.BaseEstimator): +class DecisionTreeTemplate(abc.ABC, base.BaseEstimator): """Template for DecisionTree classes Based on: https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator @@ -592,15 +591,17 @@ def _select_features( ) -> np.ndarray: return X[:, ix_features] + @abc.abstractmethod def fit( self, X: T.Union[pd.DataFrame, np.ndarray], y: T.Union[pd.Series, np.ndarray], ) -> "DecisionTreeTemplate": - raise NotImplementedError() + ... + @abc.abstractmethod def predict(self, X: T.Union[pd.DataFrame, np.ndarray]) -> np.ndarray: - raise NotImplementedError() + ... class DecisionTreeRegressor(base.RegressorMixin, DecisionTreeTemplate): diff --git a/random_tree_models/scoring.py b/random_tree_models/scoring.py index a0ae640..d9f59ab 100644 --- a/random_tree_models/scoring.py +++ b/random_tree_models/scoring.py @@ -207,9 +207,9 @@ def update(self) -> float: def calc_score( y: np.ndarray, target_groups: np.ndarray, + growth_params: utils.TreeGrowthParameters, g: np.ndarray = None, h: np.ndarray = None, - growth_params: utils.TreeGrowthParameters = None, incrementing_score: IncrementingScore = None, ) -> float: measure_name = growth_params.split_score_metric diff --git a/tests/test_decisiontree.py b/tests/test_decisiontree.py index a85a27f..aeb47da 100644 --- a/tests/test_decisiontree.py +++ b/tests/test_decisiontree.py @@ -693,16 +693,18 @@ def test_2d( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - growth_params = utils.TreeGrowthParameters(max_depth=2) + growth_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( - self.X_2D, - y, - utils.SplitScoreMetrics[measure_name], + X=self.X_2D, + y=y, + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, ) except ValueError as ex: if is_homogenous: @@ -746,16 +748,18 @@ def test_2d_missing( h: np.ndarray, ): is_homogenous = len(np.unique(y)) == 1 - growth_params = utils.TreeGrowthParameters(max_depth=2) + growth_params = utils.TreeGrowthParameters( + max_depth=2, + split_score_metric=utils.SplitScoreMetrics[measure_name], + ) try: # line to test best = dtree.find_best_split( - self.X_2D_missing, - y, - utils.SplitScoreMetrics[measure_name], + X=self.X_2D_missing, + y=y, + growth_params=growth_params, g=g, h=h, - growth_params=growth_params, ) except ValueError as ex: if is_homogenous: @@ -966,9 +970,9 @@ def test_baselevel(self): leaf_node = dtree.grow_tree( self.X, self.y, + growth_params=growth_params, parent_node=parent_node, depth=self.depth_dummy, - growth_params=growth_params, ) mock_check_is_baselevel.assert_called_once() @@ -1019,9 +1023,9 @@ def test_split_improvement_insufficient(self): node = dtree.grow_tree( self.X, self.y, + growth_params=growth_params, parent_node=parent_node, depth=self.depth_dummy, - growth_params=growth_params, ) mock_check_is_baselevel.assert_called_once() @@ -1077,9 +1081,9 @@ def test_split_improvement_sufficient(self): tree = dtree.grow_tree( self.X, self.y, + growth_params=growth_params, parent_node=parent_node, depth=self.depth_dummy, - growth_params=growth_params, ) assert mock_check_is_baselevel.call_count == 3 @@ -1168,8 +1172,18 @@ def test_predict_with_tree(): assert np.allclose(predictions, np.arange(0, 4, 1)) +class DecisionTreeTemplateTestClass(dtree.DecisionTreeTemplate): + "Class to test abstract class DecisionTreeTemplate" + + def fit(self): + pass + + def predict(self): + pass + + class TestDecisionTreeTemplate: - model = dtree.DecisionTreeTemplate(measure_name="gini") + model = DecisionTreeTemplateTestClass(measure_name="gini") X = np.random.normal(size=(100, 10)) y = np.random.normal(size=(100,)) @@ -1182,18 +1196,6 @@ def test_growth_params_(self): self.model._organize_growth_parameters() assert isinstance(self.model.growth_params_, utils.TreeGrowthParameters) - def test_fit(self): - try: - self.model.fit(None, None) - except NotImplementedError: - pytest.xfail("DecisionTreeTemplate.fit expectedly refused call") - - def test_predict(self): - try: - self.model.predict(None) - except NotImplementedError: - pytest.xfail("DecisionTreeTemplate.predict expectedly refused call") - def test_select_samples_and_features_no_sampling(self): self.model.frac_features = 1.0 self.model.frac_samples = 1.0 From 837c13b6528cb6581e2afc55cd1f4127f6a41f10 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Sun, 31 Dec 2023 15:00:42 +0100 Subject: [PATCH 10/28] chore: added lazy polars, rand and uuid as dependencies --- Cargo.lock | 1510 ++++++++++++++++++++++++++++++++++++++++++++++++---- Cargo.toml | 4 + 2 files changed, 1398 insertions(+), 116 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a4e82fb..5f9feaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,225 +2,1392 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "argminmax" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "202108b46429b765ef483f8a24d5c46f48c14acfdacc086dd4ab6dddf6bcdbd2" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + +[[package]] +name = "arrow-format" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07884ea216994cdc32a2d5f8274a8bee979cfe90274b83f86f440866ee3132c7" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + +[[package]] +name = "atoi_simd" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccfc14f5c3e34de57539a7ba9c18ecde3d9bbde48d232ea1da3e468adb307fd0" + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + +[[package]] +name = "bumpalo" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" + +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", +] + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "jobserver", + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "windows-targets 0.48.0", +] + +[[package]] +name = "comfy-table" +version = "7.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" +dependencies = [ + "crossterm", + "strum", + "strum_macros", + "unicode-width", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crossbeam-channel" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82a9b73a36529d9c47029b9fb3a6f0ea3cc916a261195352ba19e770fc1748b2" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc6598521bb5a83d491e8c1fe51db7296019d2ca3cb93cc6c2a20369a4d78a2" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.4.1", + "crossterm_winapi", + "libc", + "parking_lot", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "dyn-clone" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "enum_dispatch" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f33313078bb8d4d05a2733a94ac4c2d8a0df9a2b84424ebf4f33bfc224a890e" +dependencies = [ + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.43", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + +[[package]] +name = "getrandom" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", + "rayon", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "indexmap" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.151" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "lock_api" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "lz4" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e9e2dd86df36ce760a60f6ff6ad526f7ba1f14ba0356f8254fb6905e6494df1" +dependencies = [ + "libc", + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "memchr" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" + +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "multiversion" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2c7b9d7fe61760ce5ea19532ead98541f6b4c495d87247aff9826445cf6872a" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26a83d8500ed06d68877e9de1dde76c1dbb83885dcdbda4ef44ccbc3fbda2ac8" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "target-features", +] + +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.48.0", +] + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pkg-config" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" + +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + +[[package]] +name = "polars" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df8e52f9236eb722da0990a70bbb1216dcc7a77bcb00c63439d2d982823e90d5" +dependencies = [ + "getrandom", + "polars-core", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-sql", + "polars-time", + "version_check", +] + +[[package]] +name = "polars-arrow" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd503430a6d9779b07915d858865fe998317ef3cfef8973881f578ac5d4baae7" +dependencies = [ + "ahash", + "arrow-format", + "atoi_simd", + "bytemuck", + "chrono", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "foreign_vec", + "getrandom", + "hashbrown", + "itoa", + "lz4", + "multiversion", + "num-traits", + "polars-error", + "polars-utils", + "rustc_version", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "zstd", +] + +[[package]] +name = "polars-core" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae73d5b8e55decde670caba1cc82b61f14bfb9a72503198f0997d657a98dcfd6" +dependencies = [ + "ahash", + "bitflags 2.4.1", + "bytemuck", + "chrono", + "comfy-table", + "either", + "hashbrown", + "indexmap", + "num-traits", + "once_cell", + "polars-arrow", + "polars-error", + "polars-row", + "polars-utils", + "rand", + "rand_distr", + "rayon", + "regex", + "smartstring", + "thiserror", + "version_check", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb0520d68eaa9993ae0c741409d1526beff5b8f48e1d73e4381616f8152cf488" +dependencies = [ + "arrow-format", + "regex", + "simdutf8", + "thiserror", +] + +[[package]] +name = "polars-io" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96e10a0745acd6009db64bef0ceb9e23a70b1c27b26a0a6517c91f3e6363bc06" +dependencies = [ + "ahash", + "atoi_simd", + "bytes", + "chrono", + "fast-float", + "home", + "itoa", + "memchr", + "memmap2", + "num-traits", + "once_cell", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-error", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", +] + +[[package]] +name = "polars-lazy" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3555f759705be6dd0d3762d16a0b8787b2dc4da73b57465f3b2bf1a070ba8f20" +dependencies = [ + "ahash", + "bitflags 2.4.1", + "glob", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-ops" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a7eb218296aaa7f79945f08288ca32ca3cf25fa505649eeee689ec21eebf636" +dependencies = [ + "ahash", + "argminmax", + "bytemuck", + "either", + "hashbrown", + "indexmap", + "memchr", + "num-traits", + "polars-arrow", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-pipe" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66094e7df64c932a9a7bdfe7df0c65efdcb192096e11a6a765a9778f78b4bdec" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown", + "num-traits", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "version_check", +] + +[[package]] +name = "polars-plan" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10e32a0958ef854b132bad7f8369cb3237254635d5e864c99505bc0bc1035fbc" +dependencies = [ + "ahash", + "bytemuck", + "once_cell", + "percent-encoding", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-time", + "polars-utils", + "rayon", + "regex", + "smartstring", + "strum_macros", + "version_check", +] + +[[package]] +name = "polars-row" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135ab81cac2906ba74ea8984c7e6025d081ae5867615bcefb4d84dfdb456dac" +dependencies = [ + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dbd7786849a5e3ad1fde188bf38141632f626e3a57319b0bbf7a5f1d75519e" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-plan", + "rand", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae56f79e9cedd617773c1c8f5ca84a31a8b1d593714959d5f799e7bdd98fe51" +dependencies = [ + "atoi", + "chrono", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.35.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da6ce68169fe61d46958c8eab7447360f30f2f23f6e24a0ce703a14b0a3cfbfc" +dependencies = [ + "ahash", + "bytemuck", + "hashbrown", + "indexmap", + "num-traits", + "once_cell", + "polars-error", + "rayon", + "smartstring", + "sysinfo", + "version_check", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.43", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.43", +] + +[[package]] +name = "quote" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "random-tree-models" +version = "0.6.2" +dependencies = [ + "polars", + "pyo3", + "rand", + "rand_chacha", + "uuid", +] + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "regex" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] -name = "bitflags" -version = "1.3.2" +name = "semver" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" [[package]] -name = "cfg-if" -version = "1.0.0" +name = "serde" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] [[package]] -name = "heck" -version = "0.4.1" +name = "serde_derive" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", +] [[package]] -name = "indoc" -version = "2.0.4" +name = "serde_json" +version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] [[package]] -name = "libc" -version = "0.2.146" +name = "simdutf8" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" [[package]] -name = "lock_api" -version = "0.4.10" +name = "smallvec" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" dependencies = [ "autocfg", - "scopeguard", + "static_assertions", + "version_check", ] [[package]] -name = "memoffset" -version = "0.9.0" +name = "sqlparser" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" dependencies = [ - "autocfg", + "log", ] [[package]] -name = "once_cell" -version = "1.18.0" +name = "static_assertions" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] -name = "parking_lot" -version = "0.12.1" +name = "streaming-iterator" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ - "lock_api", - "parking_lot_core", + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.43", ] [[package]] -name = "parking_lot_core" -version = "0.9.8" +name = "syn" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", + "proc-macro2", + "quote", + "unicode-ident", ] [[package]] -name = "proc-macro2" -version = "1.0.60" +name = "syn" +version = "2.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" +checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" dependencies = [ + "proc-macro2", + "quote", "unicode-ident", ] [[package]] -name = "pyo3" -version = "0.20.0" +name = "sysinfo" +version = "0.29.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +checksum = "cd727fc423c2060f6c92d9534cef765c65a6ed3f428a03d7def74a8c4348e666" dependencies = [ "cfg-if", - "indoc", + "core-foundation-sys", "libc", - "memoffset", - "parking_lot", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", + "ntapi", + "once_cell", + "winapi", ] [[package]] -name = "pyo3-build-config" -version = "0.20.0" +name = "target-features" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" -dependencies = [ - "once_cell", - "target-lexicon", -] +checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" [[package]] -name = "pyo3-ffi" -version = "0.20.0" +name = "target-lexicon" +version = "0.12.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" + +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ - "libc", - "pyo3-build-config", + "thiserror-impl", ] [[package]] -name = "pyo3-macros" -version = "0.20.0" +name = "thiserror-impl" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", - "pyo3-macros-backend", "quote", - "syn", + "syn 2.0.43", ] [[package]] -name = "pyo3-macros-backend" -version = "0.20.0" +name = "unicode-ident" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" + +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "uuid" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", + "getrandom", ] [[package]] -name = "quote" -version = "1.0.28" +name = "version_check" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ - "proc-macro2", + "cfg-if", + "wasm-bindgen-macro", ] [[package]] -name = "random-tree-models" -version = "0.6.2" +name = "wasm-bindgen-backend" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ - "pyo3", + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.43", + "wasm-bindgen-shared", ] [[package]] -name = "redox_syscall" -version = "0.3.5" +name = "wasm-bindgen-macro" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ - "bitflags", + "quote", + "wasm-bindgen-macro-support", ] [[package]] -name = "scopeguard" -version = "1.1.0" +name = "wasm-bindgen-macro-support" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] [[package]] -name = "smallvec" -version = "1.10.0" +name = "wasm-bindgen-shared" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] -name = "syn" -version = "2.0.20" +name = "winapi" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb8d4cebc40aa517dfb69618fa647a346562e67228e2236ae0042ee6ac14775" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", ] [[package]] -name = "target-lexicon" -version = "0.12.7" +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] -name = "unicode-ident" -version = "1.0.9" +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "unindent" -version = "0.2.3" +name = "windows-core" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "af6041b3f84485c21b57acdc0fee4f4f0c93f426053dc05fa5d6fc262537bbff" +dependencies = [ + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", +] [[package]] name = "windows-targets" @@ -228,13 +1395,28 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -243,38 +1425,134 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + +[[package]] +name = "xxhash-rust" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.9+zstd.1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index d209f42..f81a95f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,11 @@ name = "random_tree_models" crate-type = ["cdylib"] [dependencies] +polars = { version = "0.35.4", features = ["lazy"] } pyo3 = "0.20.0" +rand = "0.8.5" +rand_chacha = "0.3.1" +uuid = { version = "1.6.1", features = ["v4"] } [features] extension-module = ["pyo3/extension-module"] From 8d0392f3e64c44a6c54e106e7b8399a8aa11f017 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Sun, 31 Dec 2023 15:07:34 +0100 Subject: [PATCH 11/28] feat: added basic implementation of decision tree functionality in decisiontree.rs --- src/decisiontree.rs | 408 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 20 +++ 2 files changed, 428 insertions(+) create mode 100644 src/decisiontree.rs diff --git a/src/decisiontree.rs b/src/decisiontree.rs new file mode 100644 index 0000000..f934761 --- /dev/null +++ b/src/decisiontree.rs @@ -0,0 +1,408 @@ +use polars::prelude::*; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use uuid::Uuid; + +#[derive(PartialEq, Debug)] +pub struct SplitScore { + pub name: String, + pub score: f64, +} + +impl SplitScore { + pub fn new(name: String, score: f64) -> Self { + SplitScore { name, score } + } +} + +#[derive(PartialEq, Debug)] +pub struct Node { + pub array_column: usize, + pub threshold: f64, + pub prediction: f64, + pub default_is_left: bool, + + // descendants + pub left: Option>, + pub right: Option>, + + // misc + pub measure: SplitScore, + + pub n_obs: usize, + pub reason: String, + pub depth: usize, + pub node_id: Uuid, +} + +impl Node { + pub fn new( + array_column: usize, + threshold: f64, + prediction: f64, + default_is_left: bool, + left: Option>, + right: Option>, + measure: SplitScore, + n_obs: usize, + reason: String, + depth: usize, + ) -> Self { + let node_id = Uuid::new_v4(); + Node { + array_column, + threshold, + prediction, + default_is_left, + left, + right, + measure, + n_obs, + reason, + depth, + node_id, + } + } + + pub fn is_leaf(&self) -> bool { + self.left.is_none() && self.right.is_none() + } + + pub fn insert(&mut self, new_node: Node, insert_left: bool) { + if insert_left { + match self.left { + Some(ref mut _left) => { + panic!("Something went wrong. The left node is already occupied.") + } // left.insert(new_node) + None => self.left = Some(Box::new(new_node)), + } + } else { + match self.right { + Some(ref mut _right) => { + panic!("Something went wrong. The right node is already occupied.") + } + //right.insert(new_node), + None => self.right = Some(Box::new(new_node)), + } + } + } +} +// Inspirations: +// * https://rusty-ferris.pages.dev/blog/binary-tree-sum-of-values/ +// * https://gist.github.com/aidanhs/5ac9088ca0f6bdd4a370 +pub fn grow_tree(x: &DataFrame, y: &Series, parent_node: Option<&Node>, depth: usize) -> Node { + // TODO: implement check_is_baselevel and such + let n_obs = x.height(); + if n_obs == 0 { + panic!("Something went wrong. The parent_node handed down an empty set of data points.") + } + + let is_baselevel: bool = depth == 1; + if is_baselevel { + let new_node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + return new_node; + } + + let mut rng = ChaCha20Rng::seed_from_u64(42); + + let mut new_node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + + // descend left + let new_left_node = grow_tree(x, y, Some(&new_node), &depth + 1); // mut new_node, + new_node.insert(new_left_node, true); + + // descend right + let new_right_node = grow_tree(x, y, Some(&new_node), depth + 1); // mut new_node, + new_node.insert(new_right_node, false); + + return new_node; +} + +pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { + let mut node = tree; + + let row_f64 = (*row).cast(&DataType::Float64).unwrap(); + let row = row_f64.f64().unwrap(); + + while !node.is_leaf() { + let value: f64 = row.get(node.array_column).expect("Accessing failed."); + + let is_left = if value < node.threshold { + node.default_is_left + } else { + !node.default_is_left + }; + if is_left { + node = node.left.as_ref().unwrap(); + } else { + node = node.right.as_ref().unwrap(); + } + } + node.prediction +} + +pub fn predict_with_tree(x: &DataFrame, tree: &Node) -> Series { + // use polars to apply predict_for_row_with_tree to get one prediction per row + let predictions: Series = x + .iter() + .map(|row| predict_for_row_with_tree(row, tree)) + .collect(); + + let predictions = Series::new("predictions", predictions); + + predictions +} + +struct DecisionTreeTemplate { + pub max_depth: usize, + tree: Option, +} + +impl DecisionTreeTemplate { + pub fn new(max_depth: usize) -> Self { + DecisionTreeTemplate { + max_depth, + tree: None, + } + } + + pub fn fit(&mut self, x: &DataFrame, y: &Series) { + self.tree = Some(grow_tree(x, y, None, 0)); + } + + pub fn predict(&self, x: &DataFrame) -> Series { + match &self.tree { + Some(tree) => predict_with_tree(x, tree), + None => panic!("Something went wrong. The tree is not initialized."), + } + } +} + +#[cfg(test)] +mod tests { + // use rand_chacha::ChaCha20Rng; + // use rand::SeedableRng; + + use super::*; + + #[test] + fn test_split_score() { + let split_score = SplitScore::new("test".to_string(), 0.5); + assert_eq!(split_score.name, "test"); + assert_eq!(split_score.score, 0.5); + } + + #[test] + fn test_node_init() { + let node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + assert_eq!(node.array_column, 0); + assert_eq!(node.threshold, 0.0); + assert_eq!(node.prediction, 1.0); + assert_eq!(node.default_is_left, true); + assert_eq!(node.left, None); + assert_eq!(node.right, None); + assert_eq!(node.measure.name, "score"); + assert_eq!(node.measure.score, 0.5); + assert_eq!(node.n_obs, 10); + assert_eq!(node.reason, "leaf node".to_string()); + assert_eq!(node.depth, 0); + } + + #[test] + fn test_child_node_assignment() { + let mut node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + let child_node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + node.left = Some(Box::new(child_node)); + assert_eq!(node.left.is_some(), true); + assert_eq!(node.right.is_none(), true); + } + + #[test] + fn test_grandchild_node_assignment() { + let mut node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + let child_node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + let grandchild_node = Node::new( + 0, + 0.0, + 1.0, + true, + None, + None, + SplitScore::new("score".to_string(), 0.5), + 10, + "leaf node".to_string(), + 0, + ); + node.left = Some(Box::new(child_node)); + node.left.as_mut().unwrap().left = Some(Box::new(grandchild_node)); + assert_eq!(node.left.is_some(), true); + assert_eq!(node.right.is_none(), true); + assert_eq!(node.left.as_ref().unwrap().left.is_some(), true); + assert_eq!(node.left.as_ref().unwrap().right.is_none(), true); + } + + #[test] + fn test_node_is_leaf() { + let node = Node { + array_column: 0, + threshold: 0.0, + prediction: 1.0, + default_is_left: true, + left: None, + right: None, + measure: SplitScore::new("score".to_string(), 0.5), + n_obs: 10, + reason: "leaf node".to_string(), + depth: 1, + node_id: Uuid::new_v4(), + }; + assert_eq!(node.is_leaf(), true); + } + + #[test] + fn test_grow_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 2, 3]); + + let tree = grow_tree(&df, &y, None, 0); + + assert!(tree.is_leaf() == false); + assert_eq!(tree.left.is_some(), true); + assert_eq!(tree.right.is_some(), true); + assert_eq!(tree.left.as_ref().unwrap().is_leaf(), true); + assert_eq!(tree.right.as_ref().unwrap().is_leaf(), true); + } + + #[test] + fn test_predict_for_row_with_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 2, 3]); + + let tree = grow_tree(&df, &y, None, 0); + + let row = df.select_at_idx(0).unwrap(); + let prediction = predict_for_row_with_tree(&row, &tree); + assert_eq!(prediction, 1.0); + } + + // test predict_with_tree + #[test] + fn test_predict_with_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 2, 3]); + + let tree = grow_tree(&df, &y, None, 0); + + let predictions = predict_with_tree(&df, &tree); + assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); + } + + // test DecisionTreeTemplate + #[test] + fn test_decision_tree_template() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 2, 3]); + + let mut dtree = DecisionTreeTemplate::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); + } +} diff --git a/src/lib.rs b/src/lib.rs index 756e91d..46af9d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,13 @@ use pyo3::prelude::*; +mod decisiontree; mod scoring; +// mod utils; #[pymodule] #[pyo3(name = "_rust")] fn random_tree_models(py: Python<'_>, m: &PyModule) -> PyResult<()> { register_scoring_module(py, m)?; + // m.add_class::()?; Ok(()) } @@ -16,3 +19,20 @@ fn register_scoring_module(py: Python<'_>, parent_module: &PyModule) -> PyResult parent_module.add_submodule(child_module)?; Ok(()) } + +// #[pyclass] +// struct DecisionTree { +// num: usize, +// } + +// #[pymethods] +// impl DecisionTree { +// #[new] +// fn new(num: usize) -> Self { +// DecisionTree { num } +// } + +// fn get_num(&self) -> usize { +// self.num +// } +// } From e2136dfae4e308059a4f46f49fc8d20a29129d2a Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:12:16 +0100 Subject: [PATCH 12/28] Add pyo3-polars crate as a dependency --- Cargo.lock | 13 +++++++++++++ Cargo.toml | 1 + 2 files changed, 14 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 5f9feaa..4127c3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -937,6 +937,18 @@ dependencies = [ "syn 2.0.43", ] +[[package]] +name = "pyo3-polars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e37da190c68036cb620bbde9a8933839addfa4b66e9903b9b1dc751b2b00e7d7" +dependencies = [ + "polars", + "polars-core", + "pyo3", + "thiserror", +] + [[package]] name = "quote" version = "1.0.28" @@ -992,6 +1004,7 @@ version = "0.6.2" dependencies = [ "polars", "pyo3", + "pyo3-polars", "rand", "rand_chacha", "uuid", diff --git a/Cargo.toml b/Cargo.toml index f81a95f..2b94d88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib"] [dependencies] polars = { version = "0.35.4", features = ["lazy"] } pyo3 = "0.20.0" +pyo3-polars = "0.9.0" rand = "0.8.5" rand_chacha = "0.3.1" uuid = { version = "1.6.1", features = ["v4"] } From afe610a52c9d6275e10fa2514e68527f126901aa Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 1 Jan 2024 15:00:14 +0100 Subject: [PATCH 13/28] made more Node struct fields optional in decisiontree.rs --- src/decisiontree.rs | 141 +++++++++++++++++++++++--------------------- 1 file changed, 74 insertions(+), 67 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index f934761..3928eab 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -17,17 +17,17 @@ impl SplitScore { #[derive(PartialEq, Debug)] pub struct Node { - pub array_column: usize, - pub threshold: f64, - pub prediction: f64, - pub default_is_left: bool, + pub array_column: Option, + pub threshold: Option, + pub prediction: Option, + pub default_is_left: Option, // descendants pub left: Option>, pub right: Option>, // misc - pub measure: SplitScore, + pub measure: Option, pub n_obs: usize, pub reason: String, @@ -37,13 +37,13 @@ pub struct Node { impl Node { pub fn new( - array_column: usize, - threshold: f64, - prediction: f64, - default_is_left: bool, + array_column: Option, + threshold: Option, + prediction: Option, + default_is_left: Option, left: Option>, right: Option>, - measure: SplitScore, + measure: Option, n_obs: usize, reason: String, depth: usize, @@ -100,13 +100,13 @@ pub fn grow_tree(x: &DataFrame, y: &Series, parent_node: Option<&Node>, depth: u let is_baselevel: bool = depth == 1; if is_baselevel { let new_node = Node::new( - 0, - 0.0, - 1.0, - true, None, None, - SplitScore::new("score".to_string(), 0.5), + Some(1.0), + None, + None, + None, + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, @@ -114,21 +114,25 @@ pub fn grow_tree(x: &DataFrame, y: &Series, parent_node: Option<&Node>, depth: u return new_node; } + // find best split let mut rng = ChaCha20Rng::seed_from_u64(42); + let leaf_weight = 1.0; let mut new_node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(leaf_weight), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, ); + // check if improvement due to split is below minimum requirement + // descend left let new_left_node = grow_tree(x, y, Some(&new_node), &depth + 1); // mut new_node, new_node.insert(new_left_node, true); @@ -147,12 +151,14 @@ pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { let row = row_f64.f64().unwrap(); while !node.is_leaf() { - let value: f64 = row.get(node.array_column).expect("Accessing failed."); + let col = node.array_column.unwrap(); + let value: f64 = row.get(col).expect("Accessing failed."); - let is_left = if value < node.threshold { - node.default_is_left + let threshold = node.threshold.unwrap(); + let is_left = if value < threshold { + node.default_is_left.unwrap() } else { - !node.default_is_left + !node.default_is_left.unwrap() }; if is_left { node = node.left.as_ref().unwrap(); @@ -160,7 +166,7 @@ pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { node = node.right.as_ref().unwrap(); } } - node.prediction + node.prediction.unwrap() } pub fn predict_with_tree(x: &DataFrame, tree: &Node) -> Series { @@ -175,7 +181,7 @@ pub fn predict_with_tree(x: &DataFrame, tree: &Node) -> Series { predictions } -struct DecisionTreeTemplate { +pub struct DecisionTreeTemplate { pub max_depth: usize, tree: Option, } @@ -217,25 +223,26 @@ mod tests { #[test] fn test_node_init() { let node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(1.0), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, ); - assert_eq!(node.array_column, 0); - assert_eq!(node.threshold, 0.0); - assert_eq!(node.prediction, 1.0); - assert_eq!(node.default_is_left, true); + assert_eq!(node.array_column.unwrap(), 0); + assert_eq!(node.threshold.unwrap(), 0.0); + assert_eq!(node.prediction.unwrap(), 1.0); + assert_eq!(node.default_is_left.unwrap(), true); assert_eq!(node.left, None); assert_eq!(node.right, None); - assert_eq!(node.measure.name, "score"); - assert_eq!(node.measure.score, 0.5); + let m = node.measure.unwrap(); + assert_eq!(m.name, "score"); + assert_eq!(m.score, 0.5); assert_eq!(node.n_obs, 10); assert_eq!(node.reason, "leaf node".to_string()); assert_eq!(node.depth, 0); @@ -244,25 +251,25 @@ mod tests { #[test] fn test_child_node_assignment() { let mut node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(1.0), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, ); let child_node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(1.0), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, @@ -275,37 +282,37 @@ mod tests { #[test] fn test_grandchild_node_assignment() { let mut node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(1.0), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, ); let child_node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(1.0), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, ); let grandchild_node = Node::new( - 0, - 0.0, - 1.0, - true, + Some(0), + Some(0.0), + Some(1.0), + Some(true), None, None, - SplitScore::new("score".to_string(), 0.5), + Some(SplitScore::new("score".to_string(), 0.5)), 10, "leaf node".to_string(), 0, @@ -321,13 +328,13 @@ mod tests { #[test] fn test_node_is_leaf() { let node = Node { - array_column: 0, - threshold: 0.0, - prediction: 1.0, - default_is_left: true, + array_column: Some(0), + threshold: Some(0.0), + prediction: Some(1.0), + default_is_left: Some(true), left: None, right: None, - measure: SplitScore::new("score".to_string(), 0.5), + measure: Some(SplitScore::new("score".to_string(), 0.5)), n_obs: 10, reason: "leaf node".to_string(), depth: 1, From 89bd7cafea960d830999ddc8291210acb4962321 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 1 Jan 2024 15:35:46 +0100 Subject: [PATCH 14/28] feat: added first rudimentary pyo3 class DecisionTree and the TreeGrowthParameters struct in utils.rs and check_is_baselevel --- src/decisiontree.rs | 60 ++++++++++++++++++++++++++++++++++++--------- src/lib.rs | 50 +++++++++++++++++++++++++------------ src/utils.rs | 3 +++ 3 files changed, 85 insertions(+), 28 deletions(-) create mode 100644 src/utils.rs diff --git a/src/decisiontree.rs b/src/decisiontree.rs index 3928eab..a7a3bdd 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -3,6 +3,8 @@ use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use uuid::Uuid; +use crate::utils::TreeGrowthParameters; + #[derive(PartialEq, Debug)] pub struct SplitScore { pub name: String, @@ -87,17 +89,46 @@ impl Node { } } } + +pub fn check_is_baselevel( + y: &Series, + depth: usize, + growth_params: &TreeGrowthParameters, +) -> (bool, String) { + let n_obs = y.len(); + let n_unique = y + .n_unique() + .expect("Something went wrong. Could not get n_unique."); + let max_depth = growth_params.max_depth; + + if max_depth.is_some() && depth >= max_depth.unwrap() { + return (true, "max depth reached".to_string()); + } else if n_unique == 1 { + return (true, "homogenous group".to_string()); + } else if n_obs <= 1 { + return (true, "<= 1 data point in group".to_string()); + } else { + (false, "".to_string()) + } +} + // Inspirations: // * https://rusty-ferris.pages.dev/blog/binary-tree-sum-of-values/ // * https://gist.github.com/aidanhs/5ac9088ca0f6bdd4a370 -pub fn grow_tree(x: &DataFrame, y: &Series, parent_node: Option<&Node>, depth: usize) -> Node { +pub fn grow_tree( + x: &DataFrame, + y: &Series, + growth_params: &TreeGrowthParameters, + parent_node: Option<&Node>, + depth: usize, +) -> Node { // TODO: implement check_is_baselevel and such let n_obs = x.height(); if n_obs == 0 { panic!("Something went wrong. The parent_node handed down an empty set of data points.") } - let is_baselevel: bool = depth == 1; + let (is_baselevel, reason) = check_is_baselevel(y, depth, growth_params); if is_baselevel { let new_node = Node::new( None, @@ -108,7 +139,7 @@ pub fn grow_tree(x: &DataFrame, y: &Series, parent_node: Option<&Node>, depth: u None, Some(SplitScore::new("score".to_string(), 0.5)), 10, - "leaf node".to_string(), + reason, 0, ); return new_node; @@ -134,11 +165,11 @@ pub fn grow_tree(x: &DataFrame, y: &Series, parent_node: Option<&Node>, depth: u // check if improvement due to split is below minimum requirement // descend left - let new_left_node = grow_tree(x, y, Some(&new_node), &depth + 1); // mut new_node, + let new_left_node = grow_tree(x, y, growth_params, Some(&new_node), &depth + 1); // mut new_node, new_node.insert(new_left_node, true); // descend right - let new_right_node = grow_tree(x, y, Some(&new_node), depth + 1); // mut new_node, + let new_right_node = grow_tree(x, y, growth_params, Some(&new_node), depth + 1); // mut new_node, new_node.insert(new_right_node, false); return new_node; @@ -182,20 +213,23 @@ pub fn predict_with_tree(x: &DataFrame, tree: &Node) -> Series { } pub struct DecisionTreeTemplate { - pub max_depth: usize, + pub growth_params: TreeGrowthParameters, tree: Option, } impl DecisionTreeTemplate { pub fn new(max_depth: usize) -> Self { + let growth_params = TreeGrowthParameters { + max_depth: Some(max_depth), + }; DecisionTreeTemplate { - max_depth, + growth_params, tree: None, } } pub fn fit(&mut self, x: &DataFrame, y: &Series) { - self.tree = Some(grow_tree(x, y, None, 0)); + self.tree = Some(grow_tree(x, y, &self.growth_params, None, 0)); } pub fn predict(&self, x: &DataFrame) -> Series { @@ -352,8 +386,9 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 2, 3]); + let growth_params = TreeGrowthParameters { max_depth: Some(1) }; - let tree = grow_tree(&df, &y, None, 0); + let tree = grow_tree(&df, &y, &growth_params, None, 0); assert!(tree.is_leaf() == false); assert_eq!(tree.left.is_some(), true); @@ -371,8 +406,9 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 2, 3]); + let growth_params = TreeGrowthParameters { max_depth: Some(2) }; - let tree = grow_tree(&df, &y, None, 0); + let tree = grow_tree(&df, &y, &growth_params, None, 0); let row = df.select_at_idx(0).unwrap(); let prediction = predict_for_row_with_tree(&row, &tree); @@ -389,8 +425,8 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 2, 3]); - - let tree = grow_tree(&df, &y, None, 0); + let growth_params = TreeGrowthParameters { max_depth: Some(2) }; + let tree = grow_tree(&df, &y, &growth_params, None, 0); let predictions = predict_with_tree(&df, &tree); assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); diff --git a/src/lib.rs b/src/lib.rs index 46af9d7..981e054 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,15 @@ +use polars::{frame::DataFrame, series::Series}; use pyo3::prelude::*; mod decisiontree; mod scoring; -// mod utils; +use pyo3_polars::{PyDataFrame, PySeries}; +mod utils; #[pymodule] #[pyo3(name = "_rust")] fn random_tree_models(py: Python<'_>, m: &PyModule) -> PyResult<()> { register_scoring_module(py, m)?; - // m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -20,19 +22,35 @@ fn register_scoring_module(py: Python<'_>, parent_module: &PyModule) -> PyResult Ok(()) } -// #[pyclass] -// struct DecisionTree { -// num: usize, -// } +#[pyclass] +struct DecisionTree { + max_depth: usize, + tree_: Option, +} + +#[pymethods] +impl DecisionTree { + #[new] + fn new(max_depth: usize) -> Self { + DecisionTree { + max_depth, + tree_: None, + } + } -// #[pymethods] -// impl DecisionTree { -// #[new] -// fn new(num: usize) -> Self { -// DecisionTree { num } -// } + fn fit(&mut self, X: PyDataFrame, y: PySeries) -> PyResult<()> { + let mut tree = decisiontree::DecisionTreeTemplate::new(self.max_depth); + let X: DataFrame = X.into(); + let y: Series = y.into(); + tree.fit(&X, &y); + self.tree_ = Some(tree); + Ok(()) + } -// fn get_num(&self) -> usize { -// self.num -// } -// } + fn predict(&self, X: PyDataFrame) -> PyResult { + let X: DataFrame = X.into(); + let y_pred = self.tree_.as_ref().unwrap().predict(&X); + + Ok(PySeries(y_pred)) + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..82b4485 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,3 @@ +pub struct TreeGrowthParameters { + pub max_depth: Option, +} From b14b5954ff1b85ab7958b9348ae873f0a2ee5749 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 1 Jan 2024 17:26:24 +0100 Subject: [PATCH 15/28] feat: added fn calc_leaf_weight_and_split_score --- src/decisiontree.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index a7a3bdd..8709d54 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -112,6 +112,21 @@ pub fn check_is_baselevel( } } +pub fn calc_leaf_weight_and_split_score( + y: &Series, + growth_params: &TreeGrowthParameters, + g: Option<&Series>, + h: Option<&Series>, + incrementing_score: Option, +) -> (f64, SplitScore) { + + let leaf_weight = 1.0; + + let score = SplitScore::new("score".to_string(), 0.5); + + (leaf_weight, score) +} + // Inspirations: // * https://rusty-ferris.pages.dev/blog/binary-tree-sum-of-values/ // * https://gist.github.com/aidanhs/5ac9088ca0f6bdd4a370 @@ -122,7 +137,7 @@ pub fn grow_tree( parent_node: Option<&Node>, depth: usize, ) -> Node { - // TODO: implement check_is_baselevel and such + let n_obs = x.height(); if n_obs == 0 { panic!("Something went wrong. The parent_node handed down an empty set of data points.") @@ -145,9 +160,12 @@ pub fn grow_tree( return new_node; } + // let leaf_weight = 1.0; + let (leaf_weight, score) = calc_leaf_weight_and_split_score(y, growth_params, None, None, None); + // find best split let mut rng = ChaCha20Rng::seed_from_u64(42); - let leaf_weight = 1.0; + let mut new_node = Node::new( Some(0), From 66f61a17bf65101837cbdba8d6d182bdef2ccb02 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 1 Jan 2024 17:36:47 +0100 Subject: [PATCH 16/28] feat: added basic version of fn calc_leaf_weight only computing the mean so far --- src/decisiontree.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index 8709d54..7d8713d 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -112,6 +112,18 @@ pub fn check_is_baselevel( } } +pub fn calc_leaf_weight( + y: &Series, + growth_params: &TreeGrowthParameters, + g: Option<&Series>, + h: Option<&Series>, +) -> f64 { + + let leaf_weight = y.mean().unwrap(); + + leaf_weight +} + pub fn calc_leaf_weight_and_split_score( y: &Series, growth_params: &TreeGrowthParameters, @@ -120,7 +132,7 @@ pub fn calc_leaf_weight_and_split_score( incrementing_score: Option, ) -> (f64, SplitScore) { - let leaf_weight = 1.0; + let leaf_weight = calc_leaf_weight(y, growth_params, g, h); let score = SplitScore::new("score".to_string(), 0.5); @@ -403,7 +415,7 @@ mod tests { Series::new("c", &[1, 2, 3]), ]) .unwrap(); - let y = Series::new("y", &[1, 2, 3]); + let y = Series::new("y", &[1, 1, 2]); let growth_params = TreeGrowthParameters { max_depth: Some(1) }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -423,7 +435,7 @@ mod tests { Series::new("c", &[1, 2, 3]), ]) .unwrap(); - let y = Series::new("y", &[1, 2, 3]); + let y = Series::new("y", &[1, 1, 1]); let growth_params = TreeGrowthParameters { max_depth: Some(2) }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -442,7 +454,7 @@ mod tests { Series::new("c", &[1, 2, 3]), ]) .unwrap(); - let y = Series::new("y", &[1, 2, 3]); + let y = Series::new("y", &[1, 1, 1]); let growth_params = TreeGrowthParameters { max_depth: Some(2) }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -459,7 +471,7 @@ mod tests { Series::new("c", &[1, 2, 3]), ]) .unwrap(); - let y = Series::new("y", &[1, 2, 3]); + let y = Series::new("y", &[1, 1, 1]); let mut dtree = DecisionTreeTemplate::new(2); dtree.fit(&df, &y); From 60c806f3fd35a17a1135701e3b58cd93bf38eafd Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Sat, 6 Jan 2024 10:53:47 +0100 Subject: [PATCH 17/28] added negative entropy score calculation --- src/decisiontree.rs | 22 ++-- src/scoring.rs | 240 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 250 insertions(+), 12 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index 7d8713d..81fe155 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -3,7 +3,7 @@ use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use uuid::Uuid; -use crate::utils::TreeGrowthParameters; +use crate::{scoring, utils::TreeGrowthParameters}; #[derive(PartialEq, Debug)] pub struct SplitScore { @@ -118,7 +118,6 @@ pub fn calc_leaf_weight( g: Option<&Series>, h: Option<&Series>, ) -> f64 { - let leaf_weight = y.mean().unwrap(); leaf_weight @@ -131,10 +130,11 @@ pub fn calc_leaf_weight_and_split_score( h: Option<&Series>, incrementing_score: Option, ) -> (f64, SplitScore) { - let leaf_weight = calc_leaf_weight(y, growth_params, g, h); - let score = SplitScore::new("score".to_string(), 0.5); + let target_groups: Series = Series::new("target_groups", vec![true; y.len()]); + let score = scoring::calc_score(y, &target_groups, growth_params, g, h, incrementing_score); + let score = SplitScore::new("neg_entropy".to_string(), score); (leaf_weight, score) } @@ -149,7 +149,6 @@ pub fn grow_tree( parent_node: Option<&Node>, depth: usize, ) -> Node { - let n_obs = x.height(); if n_obs == 0 { panic!("Something went wrong. The parent_node handed down an empty set of data points.") @@ -178,7 +177,6 @@ pub fn grow_tree( // find best split let mut rng = ChaCha20Rng::seed_from_u64(42); - let mut new_node = Node::new( Some(0), Some(0.0), @@ -407,6 +405,18 @@ mod tests { assert_eq!(node.is_leaf(), true); } + // test calc_leaf_weight_and_split_score + #[test] + fn test_calc_leaf_weight_and_split_score() { + let y = Series::new("y", &[1, 1, 1]); + let growth_params = TreeGrowthParameters { max_depth: Some(2) }; + let (leaf_weight, score) = + calc_leaf_weight_and_split_score(&y, &growth_params, None, None, None); + assert_eq!(leaf_weight, 1.0); + assert_eq!(score.name, "neg_entropy"); + assert_eq!(score.score, 0.0); + } + #[test] fn test_grow_tree() { let df = DataFrame::new(vec![ diff --git a/src/scoring.rs b/src/scoring.rs index dae0432..3faff29 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; +use polars::prelude::*; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use crate::utils::TreeGrowthParameters; + // compute gini impurity of an array of discrete values #[pyfunction(name = "gini_impurity")] pub fn gini_impurity_py(values: Vec) -> PyResult { @@ -57,25 +60,250 @@ fn entropy(values: Vec) -> f64 { entropy } +pub fn count_y_values(y: &Series) -> Series { + let df = y.value_counts(false, false).unwrap(); + let counts: Series = df.select_at_idx(1).unwrap().clone(); + counts +} + +pub fn calc_probabilities(y: &Series) -> Series { + let msg = "Could not cast to f64"; + let counts = count_y_values(y); + let counts = counts.cast(&DataType::Float64).expect(msg); + let total: f64 = counts.sum().unwrap(); + let ps = Series::new("probs", counts / total); + ps +} + +pub fn calc_neg_entropy_series(ps: &Series) -> f64 { + let neg_entropy = ps + .f64() + .expect("not f64 dtype") + .into_iter() + .map(|x| x.unwrap() * x.unwrap().log2()) + .sum(); + neg_entropy +} + +pub fn entropy_rs(y: &Series, target_groups: &Series) -> f64 { + let msg = "Could not cast to f64"; + let w_left: f64 = (*target_groups) + .cast(&polars::datatypes::DataType::Float64) + .expect(msg) + .sum::() + .unwrap() + / y.len() as f64; + let w_right: f64 = 1.0 - w_left; + + // generate boolean chunked array of target_groups + let trues = Series::new("", vec![true; target_groups.len()]); + let target_groups = target_groups.equal(&trues).unwrap(); + + let mut entropy_left = 0.0; + let mut entropy_right = 0.0; + if w_left > 0. { + let y_left = y.filter(&target_groups).unwrap(); + let probs = calc_probabilities(&y_left); + entropy_left = calc_neg_entropy_series(&probs); + } else { + entropy_left = 0.0; + } + if w_right > 0. { + let y_right = y.filter(&!target_groups).unwrap(); + let probs = calc_probabilities(&y_right); + entropy_right = calc_neg_entropy_series(&probs); + } else { + entropy_right = 0.0; + } + let score = (w_left * entropy_left) + (w_right * entropy_right); + score +} + +pub fn calc_score( + y: &Series, + target_groups: &Series, + growth_params: &TreeGrowthParameters, + g: Option<&Series>, + h: Option<&Series>, + incrementing_score: Option, +) -> f64 { + let score = entropy_rs(y, target_groups); + + score +} + mod tests { + use super::*; // test that gini impurity correctly computes values smaller than zero for a couple of vectors #[test] fn test_gini_impurity() { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; - assert_eq!(super::gini_impurity(values), 0.0); + assert_eq!(gini_impurity(values), 0.0); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; - assert_eq!(super::gini_impurity(values), -0.5); + assert_eq!(gini_impurity(values), -0.5); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - assert_eq!(super::gini_impurity(values), -0.875); + assert_eq!(gini_impurity(values), -0.875); } // test that entropy correctly computes values smaller than zero for a couple of vectors #[test] fn test_entropy() { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; - assert_eq!(super::entropy(values), 0.0); + assert_eq!(entropy(values), 0.0); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; - assert_eq!(super::entropy(values), -1.0); + assert_eq!(entropy(values), -1.0); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - assert_eq!(super::entropy(values), -3.0); + assert_eq!(entropy(values), -3.0); + } + // test count_y_values + #[test] + fn test_count_y_values() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let counts = count_y_values(&s); + // assert that counts is a series with one value + let exp: Vec = vec![8]; + assert_eq!(counts, Series::new("counts", exp)); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let counts = count_y_values(&s); + let exp: Vec = vec![4, 4]; + assert_eq!(counts, Series::new("counts", exp)); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let counts = count_y_values(&s); + let exp: Vec = vec![1, 1, 1, 1, 1, 1, 1, 1]; + assert_eq!(counts, Series::new("counts", exp)); + } + + // test calc_probabilities + #[test] + fn test_calc_probabilities() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + // assert that counts is a series with one value + let exp: Vec = vec![1.0]; + assert_eq!(probs, Series::new("probs", exp)); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let exp: Vec = vec![0.5, 0.5]; + assert_eq!(probs, Series::new("probs", exp)); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let exp: Vec = vec![0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]; + assert_eq!(probs, Series::new("probs", exp)); + } + + // test calc_neg_entropy_series + #[test] + fn test_calc_neg_entropy_series() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let neg_entropy = calc_neg_entropy_series(&probs); + assert_eq!(neg_entropy, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let neg_entropy = calc_neg_entropy_series(&probs); + assert_eq!(neg_entropy, -1.0); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let probs = calc_probabilities(&s); + let neg_entropy = calc_neg_entropy_series(&probs); + assert_eq!(neg_entropy, -3.0); + } + + // test calc_score + #[test] + fn test_calc_score() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, -1.0); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, -3.0); + + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let score = calc_score(&s, &target_groups, &growth_params, None, None, None); + assert_eq!(score, -1.0); + } + + // test entropy_rs + #[test] + fn test_entropy_rs() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = entropy_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = entropy_rs(&s, &target_groups); + assert_eq!(score, -1.0); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = entropy_rs(&s, &target_groups); + assert_eq!(score, -3.0); + + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = entropy_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = entropy_rs(&s, &target_groups); + assert_eq!(score, -1.0); } } From 550f217fd27fb6884e2ea8a72c6b9e0d1fd8a802 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 15 Jan 2024 03:43:17 -0600 Subject: [PATCH 18/28] chore: updated polars and pyo3-polars dependencies and added polars lazy dependency --- Cargo.lock | 309 +++++++++++++++++++++++++++-------------------------- Cargo.toml | 4 +- 2 files changed, 158 insertions(+), 155 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4127c3f..262e556 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", "getrandom", @@ -81,9 +81,9 @@ dependencies = [ [[package]] name = "atoi_simd" -version = "0.15.5" +version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccfc14f5c3e34de57539a7ba9c18ecde3d9bbde48d232ea1da3e468adb307fd0" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" [[package]] name = "autocfg" @@ -126,7 +126,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] @@ -160,7 +160,7 @@ dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets 0.48.0", + "windows-targets 0.48.5", ] [[package]] @@ -183,54 +183,46 @@ checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "crossbeam-channel" -version = "0.5.10" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82a9b73a36529d9c47029b9fb3a6f0ea3cc916a261195352ba19e770fc1748b2" +checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.17" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-queue" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc6598521bb5a83d491e8c1fe51db7296019d2ca3cb93cc6c2a20369a4d78a2" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.18" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" -dependencies = [ - "cfg-if", -] +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] name = "crossterm" @@ -275,7 +267,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] @@ -349,9 +341,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.58" +version = "0.1.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" +checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -412,9 +404,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.151" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "libm" @@ -424,9 +416,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -534,9 +526,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "parking_lot" @@ -550,15 +542,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.8" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-targets 0.48.0", + "windows-targets 0.48.5", ] [[package]] @@ -584,9 +576,9 @@ dependencies = [ [[package]] name = "polars" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8e52f9236eb722da0990a70bbb1216dcc7a77bcb00c63439d2d982823e90d5" +checksum = "938048fcda6a8e2ace6eb168bee1b415a92423ce51e418b853bf08fc40349b6b" dependencies = [ "getrandom", "polars-core", @@ -600,9 +592,9 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd503430a6d9779b07915d858865fe998317ef3cfef8973881f578ac5d4baae7" +checksum = "ce68a02f698ff7787c261aea1b4c040a8fe183a8fb200e2436d7f35d95a1b86f" dependencies = [ "ahash", "arrow-format", @@ -622,19 +614,32 @@ dependencies = [ "num-traits", "polars-error", "polars-utils", - "rustc_version", "ryu", "simdutf8", "streaming-iterator", "strength_reduce", + "version_check", "zstd", ] +[[package]] +name = "polars-compute" +version = "0.36.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14fbc5f141b29b656a4cec4802632e5bff10bf801c6809c6bbfbd4078a044dd" +dependencies = [ + "bytemuck", + "num-traits", + "polars-arrow", + "polars-utils", + "version_check", +] + [[package]] name = "polars-core" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae73d5b8e55decde670caba1cc82b61f14bfb9a72503198f0997d657a98dcfd6" +checksum = "d0f5efe734b6cbe5f97ea769be8360df5324fade396f1f3f5ad7fe9360ca4a23" dependencies = [ "ahash", "bitflags 2.4.1", @@ -647,6 +652,7 @@ dependencies = [ "num-traits", "once_cell", "polars-arrow", + "polars-compute", "polars-error", "polars-row", "polars-utils", @@ -662,9 +668,9 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb0520d68eaa9993ae0c741409d1526beff5b8f48e1d73e4381616f8152cf488" +checksum = "6396de788f99ebfc9968e7b6f523e23000506cde4ba6dfc62ae4ce949002a886" dependencies = [ "arrow-format", "regex", @@ -674,9 +680,9 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96e10a0745acd6009db64bef0ceb9e23a70b1c27b26a0a6517c91f3e6363bc06" +checksum = "7d0458efe8946f4718fd352f230c0db5a37926bd0d2bd25af79dc24746abaaea" dependencies = [ "ahash", "atoi_simd", @@ -704,9 +710,9 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3555f759705be6dd0d3762d16a0b8787b2dc4da73b57465f3b2bf1a070ba8f20" +checksum = "9d7105b40905bb38e8fc4a7fd736594b7491baa12fad3ac492969ca221a1b5d5" dependencies = [ "ahash", "bitflags 2.4.1", @@ -727,9 +733,9 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a7eb218296aaa7f79945f08288ca32ca3cf25fa505649eeee689ec21eebf636" +checksum = "2e09afc456ab11e75e5dcb43e00a01c71f3a46a2781e450054acb6bb096ca78e" dependencies = [ "ahash", "argminmax", @@ -740,6 +746,7 @@ dependencies = [ "memchr", "num-traits", "polars-arrow", + "polars-compute", "polars-core", "polars-error", "polars-utils", @@ -751,9 +758,9 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66094e7df64c932a9a7bdfe7df0c65efdcb192096e11a6a765a9778f78b4bdec" +checksum = "d9b7ead073cc3917027d77b59861a9f071db47125de9314f8907db1a0a3e4100" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -761,6 +768,7 @@ dependencies = [ "hashbrown", "num-traits", "polars-arrow", + "polars-compute", "polars-core", "polars-io", "polars-ops", @@ -774,9 +782,9 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10e32a0958ef854b132bad7f8369cb3237254635d5e864c99505bc0bc1035fbc" +checksum = "384a175624d050c31c473ee11df9d7af5d729ae626375e522158cfb3d150acd0" dependencies = [ "ahash", "bytemuck", @@ -797,9 +805,9 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d135ab81cac2906ba74ea8984c7e6025d081ae5867615bcefb4d84dfdb456dac" +checksum = "32322f7acbb83db3e9c7697dc821be73d06238da89c817dcc8bc1549a5e9c72f" dependencies = [ "polars-arrow", "polars-error", @@ -808,9 +816,9 @@ dependencies = [ [[package]] name = "polars-sql" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dbd7786849a5e3ad1fde188bf38141632f626e3a57319b0bbf7a5f1d75519e" +checksum = "9f0b4c6ddffdfd0453e84bc3918572c633014d661d166654399cf93752aa95b5" dependencies = [ "polars-arrow", "polars-core", @@ -825,9 +833,9 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aae56f79e9cedd617773c1c8f5ca84a31a8b1d593714959d5f799e7bdd98fe51" +checksum = "dee2649fc96bd1b6584e0e4a4b3ca7d22ed3d117a990e63ad438ecb26f7544d0" dependencies = [ "atoi", "chrono", @@ -844,9 +852,9 @@ dependencies = [ [[package]] name = "polars-utils" -version = "0.35.4" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da6ce68169fe61d46958c8eab7447360f30f2f23f6e24a0ce703a14b0a3cfbfc" +checksum = "b174ca4a77ad47d7b91a0460aaae65bbf874c8bfbaaa5308675dadef3976bbda" dependencies = [ "ahash", "bytemuck", @@ -869,18 +877,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.71" +version = "1.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" dependencies = [ "unicode-ident", ] [[package]] name = "pyo3" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" dependencies = [ "cfg-if", "indoc", @@ -895,9 +903,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" dependencies = [ "once_cell", "target-lexicon", @@ -905,9 +913,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" dependencies = [ "libc", "pyo3-build-config", @@ -915,33 +923,33 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] name = "pyo3-macros-backend" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] name = "pyo3-polars" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e37da190c68036cb620bbde9a8933839addfa4b66e9903b9b1dc751b2b00e7d7" +checksum = "b1e983cb07cf665ea6e645ae9263c358062580f23a9aee41618a5706d4a7cc21" dependencies = [ "polars", "polars-core", @@ -951,9 +959,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.28" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -1032,9 +1040,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags 1.3.2", ] @@ -1068,15 +1076,6 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - [[package]] name = "rustversion" version = "1.0.14" @@ -1091,41 +1090,35 @@ checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "scopeguard" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" - -[[package]] -name = "semver" -version = "1.0.20" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] name = "serde_json" -version = "1.0.108" +version = "1.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" dependencies = [ "itoa", "ryu", @@ -1140,9 +1133,9 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" [[package]] name = "smallvec" -version = "1.10.0" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "smartstring" @@ -1198,7 +1191,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] @@ -1214,9 +1207,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.43" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -1225,16 +1218,16 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.11" +version = "0.30.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd727fc423c2060f6c92d9534cef765c65a6ed3f428a03d7def74a8c4348e666" +checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", - "winapi", + "windows", ] [[package]] @@ -1245,35 +1238,35 @@ checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" [[package]] name = "target-lexicon" -version = "0.12.7" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-width" @@ -1329,7 +1322,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", "wasm-bindgen-shared", ] @@ -1351,7 +1344,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1384,13 +1377,23 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.0", +] + [[package]] name = "windows-core" -version = "0.50.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af6041b3f84485c21b57acdc0fee4f4f0c93f426053dc05fa5d6fc262537bbff" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] @@ -1404,17 +1407,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -1434,9 +1437,9 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" @@ -1446,9 +1449,9 @@ checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" @@ -1458,9 +1461,9 @@ checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" @@ -1470,9 +1473,9 @@ checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" @@ -1482,9 +1485,9 @@ checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" @@ -1494,9 +1497,9 @@ checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" @@ -1506,9 +1509,9 @@ checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" @@ -1539,7 +1542,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.43", + "syn 2.0.48", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 2b94d88..b719ae4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,9 +9,9 @@ name = "random_tree_models" crate-type = ["cdylib"] [dependencies] -polars = { version = "0.35.4", features = ["lazy"] } +polars = { version = "0.36.2", features = ["lazy", "dtype-struct"] } pyo3 = "0.20.0" -pyo3-polars = "0.9.0" +pyo3-polars = "0.10.0" rand = "0.8.5" rand_chacha = "0.3.1" uuid = { version = "1.6.1", features = ["v4"] } From e8f5965b1bacc426bf03a1baf49c4b8a8dba2b77 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 15 Jan 2024 03:45:13 -0600 Subject: [PATCH 19/28] chore: added make format, rust only so far --- Makefile | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Makefile b/Makefile index e16e215..7a3396e 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ help: @echo "compile : update the environment requirements after changes to dependencies in pyproject.toml." @echo "update : pip install new requriements into the virtual environment." @echo "test : run pytests." + @echo "format : format rust code." # create a virtual environment .PHONY: venv @@ -59,3 +60,11 @@ update: test: source .venv/bin/activate && \ pytest -vx . + +# ============================================================================== +# format code +# ============================================================================== + +.PHONY: format +format: + cargo fmt From 5365d57f88a5cb3b0c3354710990f7c2c988207c Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 15 Jan 2024 03:46:00 -0600 Subject: [PATCH 20/28] chore: updated test_count_y_values --- src/scoring.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/scoring.rs b/src/scoring.rs index 3faff29..c310b4f 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -162,19 +162,19 @@ mod tests { let counts = count_y_values(&s); // assert that counts is a series with one value let exp: Vec = vec![8]; - assert_eq!(counts, Series::new("counts", exp)); + assert_eq!(counts, Series::new("count", exp)); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; let s = Series::new("y", values); let counts = count_y_values(&s); let exp: Vec = vec![4, 4]; - assert_eq!(counts, Series::new("counts", exp)); + assert_eq!(counts, Series::new("count", exp)); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; let s = Series::new("y", values); let counts = count_y_values(&s); let exp: Vec = vec![1, 1, 1, 1, 1, 1, 1, 1]; - assert_eq!(counts, Series::new("counts", exp)); + assert_eq!(counts, Series::new("count", exp)); } // test calc_probabilities From 91189435dfea0bd8320763a2783f980b704ee8c4 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 15 Jan 2024 03:47:37 -0600 Subject: [PATCH 21/28] fix: predict_with_tree now loops rows instead of columns :P --- src/decisiontree.rs | 152 +++++++++++++++++++++++++++++++++++++------- src/lib.rs | 31 +++++---- 2 files changed, 149 insertions(+), 34 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index 81fe155..a26daaa 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -1,11 +1,14 @@ -use polars::prelude::*; +use polars::{ + lazy::dsl::GetOutput, + prelude::*, +}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use uuid::Uuid; use crate::{scoring, utils::TreeGrowthParameters}; -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Debug, Clone)] pub struct SplitScore { pub name: String, pub score: f64, @@ -17,7 +20,7 @@ impl SplitScore { } } -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Debug, Clone)] pub struct Node { pub array_column: Option, pub threshold: Option, @@ -228,29 +231,63 @@ pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { node.prediction.unwrap() } -pub fn predict_with_tree(x: &DataFrame, tree: &Node) -> Series { +pub fn udf<'a, 'b>( + s: Series, + n_cols: &'a usize, + tree: &'b Node, +) -> Result, PolarsError> { + let mut preds: Vec = vec![]; + + for struct_ in s.iter() { + let mut row: Vec = vec![]; + let mut iter = struct_._iter_struct_av(); + for _ in 0..*n_cols { + let value = iter.next().unwrap().try_extract::().unwrap(); + row.push(value); + } + let row = Series::new("", row); + let prediction = predict_for_row_with_tree(&row, tree); + preds.push(prediction); + } + + Ok(Some(Series::new("predictions", preds))) +} + +pub fn predict_with_tree(x: DataFrame, tree: Node) -> Series { // use polars to apply predict_for_row_with_tree to get one prediction per row - let predictions: Series = x - .iter() - .map(|row| predict_for_row_with_tree(row, tree)) - .collect(); - let predictions = Series::new("predictions", predictions); + let mut columns: Vec = vec![]; + let column_names = x.get_column_names(); + for v in column_names { + columns.push(col(v)); + } + let n_cols: usize = columns.len(); + + let predictions = x + .lazy() + .select([as_struct(columns) + .apply( + move |s| udf(s, &n_cols, &tree), + GetOutput::from_type(DataType::Float64), + ) + .alias("predictions")]) + .collect() + .unwrap(); - predictions + predictions.select_series(&["predictions"]).unwrap()[0].clone() } -pub struct DecisionTreeTemplate { +pub struct DecisionTreeCore { pub growth_params: TreeGrowthParameters, tree: Option, } -impl DecisionTreeTemplate { +impl DecisionTreeCore { pub fn new(max_depth: usize) -> Self { let growth_params = TreeGrowthParameters { max_depth: Some(max_depth), }; - DecisionTreeTemplate { + DecisionTreeCore { growth_params, tree: None, } @@ -261,13 +298,57 @@ impl DecisionTreeTemplate { } pub fn predict(&self, x: &DataFrame) -> Series { - match &self.tree { + let x = x.clone(); + let tree_ = self.tree.clone(); + match tree_ { Some(tree) => predict_with_tree(x, tree), None => panic!("Something went wrong. The tree is not initialized."), } } } +pub struct DecisionTreeClassifier { + decision_tree_core: DecisionTreeCore, +} + +impl DecisionTreeClassifier { + pub fn new(max_depth: usize) -> Self { + DecisionTreeClassifier { + decision_tree_core: DecisionTreeCore::new(max_depth), + } + } + + pub fn fit(&mut self, x: &DataFrame, y: &Series) { + self.decision_tree_core.fit(x, y); + } + + pub fn predict_proba(&self, x: &DataFrame) -> DataFrame { + println!("predict_proba for {:?}", x.shape()); + let class1 = self.decision_tree_core.predict(x); + println!("class1 {:?}", class1.len()); + let y_proba: DataFrame = df!("class_1" => &class1) + .unwrap() + .lazy() + .with_columns([(lit(1.) - col("class_1")).alias("class_0")]) + .collect() + .unwrap(); + let y_proba = y_proba.select(&["class_0", "class_1"]).unwrap(); + y_proba + } + + pub fn predict(&self, x: &DataFrame) -> Series { + let y_proba = self.predict_proba(x); + // define "y" as a Series that contains the index of the maximum value column per row + let y = y_proba + .lazy() + .select([(col("class_1").gt(0.5)).alias("y")]) + .collect() + .unwrap(); + + y.select_series(&["y"]).unwrap()[0].clone() + } +} + #[cfg(test)] mod tests { // use rand_chacha::ChaCha20Rng; @@ -455,9 +536,27 @@ mod tests { assert_eq!(prediction, 1.0); } - // test predict_with_tree #[test] fn test_predict_with_tree() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3, 4]), + Series::new("b", &[1, 2, 3, 4]), + Series::new("c", &[1, 2, 3, 4]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1, 1]); + let growth_params = TreeGrowthParameters { max_depth: Some(2) }; + let tree = grow_tree(&df, &y, &growth_params, None, 0); + + let predictions = predict_with_tree(df, tree); + assert_eq!( + predictions, + Series::new("predictions", &[1.0, 1.0, 1.0, 1.0]) + ); + } + + #[test] + fn test_decision_tree_core() { let df = DataFrame::new(vec![ Series::new("a", &[1, 2, 3]), Series::new("b", &[1, 2, 3]), @@ -465,16 +564,15 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 1, 1]); - let growth_params = TreeGrowthParameters { max_depth: Some(2) }; - let tree = grow_tree(&df, &y, &growth_params, None, 0); - let predictions = predict_with_tree(&df, &tree); + let mut dtree = DecisionTreeCore::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); } - // test DecisionTreeTemplate #[test] - fn test_decision_tree_template() { + fn test_decision_tree_classifier() { let df = DataFrame::new(vec![ Series::new("a", &[1, 2, 3]), Series::new("b", &[1, 2, 3]), @@ -483,9 +581,19 @@ mod tests { .unwrap(); let y = Series::new("y", &[1, 1, 1]); - let mut dtree = DecisionTreeTemplate::new(2); + let mut dtree = DecisionTreeClassifier::new(2); dtree.fit(&df, &y); let predictions = dtree.predict(&df); - assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); + assert_eq!(predictions, Series::new("y", &[1, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (3, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); } } diff --git a/src/lib.rs b/src/lib.rs index 981e054..0fde80f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ mod utils; #[pyo3(name = "_rust")] fn random_tree_models(py: Python<'_>, m: &PyModule) -> PyResult<()> { register_scoring_module(py, m)?; - m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -23,34 +23,41 @@ fn register_scoring_module(py: Python<'_>, parent_module: &PyModule) -> PyResult } #[pyclass] -struct DecisionTree { +struct DecisionTreeClassifier { max_depth: usize, - tree_: Option, + tree_: Option, } #[pymethods] -impl DecisionTree { +impl DecisionTreeClassifier { #[new] fn new(max_depth: usize) -> Self { - DecisionTree { + DecisionTreeClassifier { max_depth, tree_: None, } } - fn fit(&mut self, X: PyDataFrame, y: PySeries) -> PyResult<()> { - let mut tree = decisiontree::DecisionTreeTemplate::new(self.max_depth); - let X: DataFrame = X.into(); + fn fit(&mut self, x: PyDataFrame, y: PySeries) -> PyResult<()> { + let mut tree = decisiontree::DecisionTreeClassifier::new(self.max_depth); + let x: DataFrame = x.into(); let y: Series = y.into(); - tree.fit(&X, &y); + tree.fit(&x, &y); self.tree_ = Some(tree); Ok(()) } - fn predict(&self, X: PyDataFrame) -> PyResult { - let X: DataFrame = X.into(); - let y_pred = self.tree_.as_ref().unwrap().predict(&X); + fn predict(&self, x: PyDataFrame) -> PyResult { + let x: DataFrame = x.into(); + let y_pred = self.tree_.as_ref().unwrap().predict(&x); Ok(PySeries(y_pred)) } + + fn predict_proba(&self, x: PyDataFrame) -> PyResult { + let x: DataFrame = x.into(); + let y_pred = self.tree_.as_ref().unwrap().predict_proba(&x); + + Ok(PyDataFrame(y_pred)) + } } From d3f3adc41317aefa08e07951ff0b2ba01b95f145 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Mon, 15 Jan 2024 04:41:20 -0600 Subject: [PATCH 22/28] feat: added pub fn find_best_split --- src/decisiontree.rs | 149 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 124 insertions(+), 25 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index a26daaa..ce7a7a6 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -1,7 +1,4 @@ -use polars::{ - lazy::dsl::GetOutput, - prelude::*, -}; +use polars::{lazy::dsl::GetOutput, prelude::*}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; use uuid::Uuid; @@ -22,7 +19,8 @@ impl SplitScore { #[derive(PartialEq, Debug, Clone)] pub struct Node { - pub array_column: Option, + pub column: Option, + pub column_idx: Option, pub threshold: Option, pub prediction: Option, pub default_is_left: Option, @@ -42,7 +40,8 @@ pub struct Node { impl Node { pub fn new( - array_column: Option, + column: Option, + column_idx: Option, threshold: Option, prediction: Option, default_is_left: Option, @@ -55,7 +54,8 @@ impl Node { ) -> Self { let node_id = Uuid::new_v4(); Node { - array_column, + column, + column_idx, threshold, prediction, default_is_left, @@ -142,6 +142,91 @@ pub fn calc_leaf_weight_and_split_score( (leaf_weight, score) } +pub struct BestSplit { + pub score: f64, + pub column: String, + pub column_idx: usize, + pub threshold: f64, + pub target_groups: Series, + pub default_is_left: Option, +} + +impl BestSplit { + pub fn new( + score: f64, + column: String, + column_idx: usize, + threshold: f64, + target_groups: Series, + default_is_left: Option, + ) -> Self { + BestSplit { + score, + column, + column_idx, + threshold, + target_groups, + default_is_left, + } + } +} + +pub fn find_best_split( + x: &DataFrame, + y: &Series, + growth_params: &TreeGrowthParameters, + g: Option<&Series>, + h: Option<&Series>, + incrementing_score: Option, +) -> BestSplit { + if y.len() <= 1 { + panic!("Something went wrong. The parent_node handed down less than two data points.") + } + + let mut best_split: Option = None; + + for (idx, col) in x.get_column_names().iter().enumerate() { + let mut feature_values = x.select_series(&[col]).unwrap()[0] + .clone() + .cast(&DataType::Float64) + .unwrap(); + feature_values = feature_values.sort(false); + + for value in feature_values.iter() { + let value: f64 = value.try_extract().unwrap(); + let target_groups = feature_values.lt(value).unwrap(); + let target_groups = Series::new("target_groups", target_groups); + + let score = + scoring::calc_score(y, &target_groups, growth_params, g, h, incrementing_score); + + match best_split { + Some(ref mut best_split) => { + if score < best_split.score { + best_split.score = score; + best_split.column = col.to_string(); + best_split.threshold = value; + best_split.target_groups = target_groups; + best_split.default_is_left = None; + } + } + None => { + best_split = Some(BestSplit::new( + score, + col.to_string(), + idx, + value, + target_groups, + None, + )); + } + } + } + } + + best_split.unwrap() +} + // Inspirations: // * https://rusty-ferris.pages.dev/blog/binary-tree-sum-of-values/ // * https://gist.github.com/aidanhs/5ac9088ca0f6bdd4a370 @@ -158,39 +243,45 @@ pub fn grow_tree( } let (is_baselevel, reason) = check_is_baselevel(y, depth, growth_params); + + let (leaf_weight, score) = calc_leaf_weight_and_split_score(y, growth_params, None, None, None); + if is_baselevel { let new_node = Node::new( None, None, - Some(1.0), None, + Some(leaf_weight), None, None, - Some(SplitScore::new("score".to_string(), 0.5)), - 10, + None, + Some(score), + n_obs, reason, - 0, + depth, ); return new_node; } - // let leaf_weight = 1.0; - let (leaf_weight, score) = calc_leaf_weight_and_split_score(y, growth_params, None, None, None); - // find best split - let mut rng = ChaCha20Rng::seed_from_u64(42); + let best = find_best_split(x, y, growth_params, None, None, None); + // let mut rng = ChaCha20Rng::seed_from_u64(42); let mut new_node = Node::new( - Some(0), - Some(0.0), + Some(best.column), + Some(best.column_idx), + Some(best.threshold), Some(leaf_weight), - Some(true), + match best.default_is_left { + Some(default_is_left) => Some(default_is_left), + None => None, + }, None, None, - Some(SplitScore::new("score".to_string(), 0.5)), - 10, + Some(SplitScore::new("neg_entropy".to_string(), best.score)), + n_obs, "leaf node".to_string(), - 0, + depth, ); // check if improvement due to split is below minimum requirement @@ -213,8 +304,8 @@ pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { let row = row_f64.f64().unwrap(); while !node.is_leaf() { - let col = node.array_column.unwrap(); - let value: f64 = row.get(col).expect("Accessing failed."); + let idx = node.column_idx.unwrap(); + let value: f64 = row.get(idx).expect("Accessing failed."); let threshold = node.threshold.unwrap(); let is_left = if value < threshold { @@ -366,6 +457,7 @@ mod tests { #[test] fn test_node_init() { let node = Node::new( + Some("column".to_string()), Some(0), Some(0.0), Some(1.0), @@ -377,7 +469,8 @@ mod tests { "leaf node".to_string(), 0, ); - assert_eq!(node.array_column.unwrap(), 0); + assert_eq!(node.column.unwrap(), "column".to_string()); + assert_eq!(node.column_idx.unwrap(), 0); assert_eq!(node.threshold.unwrap(), 0.0); assert_eq!(node.prediction.unwrap(), 1.0); assert_eq!(node.default_is_left.unwrap(), true); @@ -394,6 +487,7 @@ mod tests { #[test] fn test_child_node_assignment() { let mut node = Node::new( + Some("column".to_string()), Some(0), Some(0.0), Some(1.0), @@ -406,6 +500,7 @@ mod tests { 0, ); let child_node = Node::new( + Some("column".to_string()), Some(0), Some(0.0), Some(1.0), @@ -425,6 +520,7 @@ mod tests { #[test] fn test_grandchild_node_assignment() { let mut node = Node::new( + Some("column".to_string()), Some(0), Some(0.0), Some(1.0), @@ -437,6 +533,7 @@ mod tests { 0, ); let child_node = Node::new( + Some("column".to_string()), Some(0), Some(0.0), Some(1.0), @@ -449,6 +546,7 @@ mod tests { 0, ); let grandchild_node = Node::new( + Some("column".to_string()), Some(0), Some(0.0), Some(1.0), @@ -471,7 +569,8 @@ mod tests { #[test] fn test_node_is_leaf() { let node = Node { - array_column: Some(0), + column: Some("column".to_string()), + column_idx: Some(0), threshold: Some(0.0), prediction: Some(1.0), default_is_left: Some(true), From 9eed72155dd65c63cd4fb9f1692dd39eb0ecd436 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Tue, 16 Jan 2024 13:14:20 -0600 Subject: [PATCH 23/28] fix: debugged version of rust DecisionTreeClassifier and new tests --- nbs/decision-tree-rs.ipynb | 376 +++++++++++++++++++++++++++++++++++++ src/decisiontree.rs | 196 ++++++++++++++++--- src/lib.rs | 2 + src/scoring.rs | 15 +- 4 files changed, 555 insertions(+), 34 deletions(-) create mode 100644 nbs/decision-tree-rs.ipynb diff --git a/nbs/decision-tree-rs.ipynb b/nbs/decision-tree-rs.ipynb new file mode 100644 index 0000000..96ca7cf --- /dev/null +++ b/nbs/decision-tree-rs.ipynb @@ -0,0 +1,376 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Decision tree" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "* https://medium.com/@penggongting/implementing-decision-tree-from-scratch-in-python-c732e7c69aea\n", + "* https://www.kdnuggets.com/2020/01/decision-tree-algorithm-explained.html" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The core algorithm aka the CART algorithm\n", + "\n", + "CART = Classification And Regression Tree\n", + "\n", + "Starting with a tabular dataset we have columns / features and rows / observations. Each row has a target value, of which either all are continuous or categorical. \n", + "\n", + "Taking a subset of the observations as a training set, the algorithm iterates:\n", + "\n", + "1. select a feature\n", + "2. select a range of thresholds (e.g. the feature values in the taining set) \n", + "3. for each threshold\n", + " * create two groups of observations, one below the threshold and one above and \n", + " * evaluate the split score\n", + "4. select the threshold with the optimal split score (here that always means largest)\n", + "5. select the related group split \n", + "6. continue from 1. for each group whose target values are not yet homogeneous (e.g. not all the same class, or the standard variation is greater than zero)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "import sklearn.datasets as sk_datasets\n", + "\n", + "# import random_tree_models.decisiontree as dtree\n", + "import random_tree_models._rust as rust\n", + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.RandomState(42)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Classification\n", + "\n", + "split score:\n", + "* gini\n", + "* entropy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X, y = sk_datasets.make_classification(\n", + " n_samples=1_000,\n", + " n_features=2,\n", + " n_classes=2,\n", + " n_redundant=0,\n", + " class_sep=2,\n", + " random_state=rng,\n", + ")\n", + "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, alpha=0.3);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = rust.DecisionTreeClassifier(max_depth=4) # measure_name=\"gini\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X = pl.from_numpy(X)\n", + "y = pl.from_numpy(y).to_series()\n", + "display(X.head(2), y.head(2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# dtree.show_tree(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_prob = model.predict_proba(X)\n", + "y_prob.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x0 = np.linspace(X[:, 0].min(), X[:, 0].max(), 100)\n", + "x1 = np.linspace(X[:, 1].min(), X[:, 1].max(), 100)\n", + "X0, X1 = np.meshgrid(x0, x1)\n", + "X_plot = np.array([X0.ravel(), X1.ravel()]).T\n", + "X_plot = pl.from_numpy(X_plot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_prob = model.predict_proba(X_plot)\n", + "y_prob = y_prob.select(\"class_1\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# polars series to numpy array\n", + "y_prob = y_prob.to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_prob" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO - CONTINUE HERE: why is column_1 not used for prediction?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "im = ax.pcolormesh(X0, X1, y_prob.reshape(X0.shape), alpha=0.2)\n", + "fig.colorbar(im, ax=ax)\n", + "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, ax=ax, alpha=0.3)\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Regression\n", + "\n", + "split score:\n", + "\n", + "* variance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X, y, coefs = sk_datasets.make_regression(\n", + " n_samples=1_000, n_features=2, n_targets=1, coef=True, random_state=rng\n", + ")\n", + "sns.scatterplot(x=X[:, 0], y=y, alpha=0.3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = rust.DecisionTreeRegressor(measure_name=\"variance\", max_depth=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# dtree.show_tree(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x0 = np.linspace(X[:, 0].min(), X[:, 0].max(), 100)\n", + "x1 = np.linspace(X[:, 1].min(), X[:, 1].max(), 100)\n", + "X0, X1 = np.meshgrid(x0, x1)\n", + "X_plot = np.array([X0.ravel(), X1.ravel()]).T" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = model.predict(X_plot)\n", + "y_pred[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(nrows=2, figsize=(12, 6))\n", + "\n", + "ax = axs[0]\n", + "sns.scatterplot(x=X_plot[:, 0], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "\n", + "ax = axs[1]\n", + "sns.scatterplot(x=X_plot[:, 1], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "im = ax.pcolormesh(X0, X1, y_pred.reshape(X0.shape), alpha=0.2)\n", + "fig.colorbar(im, ax=ax)\n", + "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, ax=ax, alpha=0.3)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = model.predict(X)\n", + "\n", + "fig, axs = plt.subplots(nrows=2, figsize=(12, 6))\n", + "\n", + "ax = axs[0]\n", + "sns.scatterplot(x=X[:, 0], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "sns.scatterplot(x=X[:, 0], y=y, ax=ax, alpha=0.1, label=\"actual\")\n", + "\n", + "ax = axs[1]\n", + "sns.scatterplot(x=X[:, 1], y=y_pred, ax=ax, alpha=0.1, label=\"prediction\")\n", + "sns.scatterplot(x=X[:, 1], y=y, ax=ax, alpha=0.1, label=\"actual\")\n", + "\n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/decisiontree.rs b/src/decisiontree.rs index ce7a7a6..2965577 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -1,6 +1,6 @@ use polars::{lazy::dsl::GetOutput, prelude::*}; -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; +// use rand::SeedableRng; +// use rand_chacha::ChaCha20Rng; use uuid::Uuid; use crate::{scoring, utils::TreeGrowthParameters}; @@ -117,9 +117,9 @@ pub fn check_is_baselevel( pub fn calc_leaf_weight( y: &Series, - growth_params: &TreeGrowthParameters, - g: Option<&Series>, - h: Option<&Series>, + _growth_params: &TreeGrowthParameters, + _g: Option<&Series>, + _h: Option<&Series>, ) -> f64 { let leaf_weight = y.mean().unwrap(); @@ -142,6 +142,7 @@ pub fn calc_leaf_weight_and_split_score( (leaf_weight, score) } +#[derive(PartialEq, Debug, Clone)] pub struct BestSplit { pub score: f64, pub column: String, @@ -182,28 +183,40 @@ pub fn find_best_split( if y.len() <= 1 { panic!("Something went wrong. The parent_node handed down less than two data points.") } - + // TODO: handle case where there are duplicates + // TODO: handle case where there is only one unique y value but multiple non-duplicate x values let mut best_split: Option = None; - + // println!("finding best split"); for (idx, col) in x.get_column_names().iter().enumerate() { - let mut feature_values = x.select_series(&[col]).unwrap()[0] + // println!("col {:?}, idx {:?}", col, idx); + let feature_values = x.select_series(&[col]).unwrap()[0] .clone() .cast(&DataType::Float64) .unwrap(); - feature_values = feature_values.sort(false); - for value in feature_values.iter() { + let unique_values = feature_values.unique().unwrap().sort(false); + + // skip the below if there is only one unique value + if unique_values.len() == 1 { + continue; + } + + let mut unique_iter = unique_values.iter(); + unique_iter.next().unwrap(); // skipping the first value + for value in unique_iter { let value: f64 = value.try_extract().unwrap(); + // println!("value {:?}", value); let target_groups = feature_values.lt(value).unwrap(); let target_groups = Series::new("target_groups", target_groups); let score = scoring::calc_score(y, &target_groups, growth_params, g, h, incrementing_score); - + // println!("score {:?}", score); match best_split { Some(ref mut best_split) => { - if score < best_split.score { + if score > best_split.score { best_split.score = score; + best_split.column_idx = idx; best_split.column = col.to_string(); best_split.threshold = value; best_split.target_groups = target_groups; @@ -221,12 +234,42 @@ pub fn find_best_split( )); } } + // println!("best_split {:?}", best_split); } } best_split.unwrap() } +pub fn select_arrays_for_child_node( + go_left: bool, + best: &BestSplit, + x: DataFrame, + y: Series, +) -> (DataFrame, Series) { + if go_left { + let x_child = x + .clone() + .filter(&best.target_groups.bool().unwrap()) + .unwrap(); + let y_child = y + .clone() + .filter(&best.target_groups.bool().unwrap()) + .unwrap(); + return (x_child, y_child); + } else { + let x_child = x + .clone() + .filter(&!best.target_groups.bool().unwrap()) + .unwrap(); + let y_child = y + .clone() + .filter(&!best.target_groups.bool().unwrap()) + .unwrap(); + return (x_child, y_child); + } +} + // Inspirations: // * https://rusty-ferris.pages.dev/blog/binary-tree-sum-of-values/ // * https://gist.github.com/aidanhs/5ac9088ca0f6bdd4a370 @@ -234,10 +277,12 @@ pub fn grow_tree( x: &DataFrame, y: &Series, growth_params: &TreeGrowthParameters, - parent_node: Option<&Node>, + _parent_node: Option<&Node>, depth: usize, ) -> Node { let n_obs = x.height(); + // println!("\nn_obs {:?}", n_obs); + // println!("depth {:?}", depth); if n_obs == 0 { panic!("Something went wrong. The parent_node handed down an empty set of data points.") } @@ -245,7 +290,7 @@ pub fn grow_tree( let (is_baselevel, reason) = check_is_baselevel(y, depth, growth_params); let (leaf_weight, score) = calc_leaf_weight_and_split_score(y, growth_params, None, None, None); - + // println!("leaf_weight {:?}", leaf_weight); if is_baselevel { let new_node = Node::new( None, @@ -265,20 +310,26 @@ pub fn grow_tree( // find best split let best = find_best_split(x, y, growth_params, None, None, None); + // println!("column {:?}", best.column); + // println!("column_idx {:?}", best.column_idx); + // println!("threshold {:?}", best.threshold); + // println!("target_groups {:?}", best.target_groups); + // let mut rng = ChaCha20Rng::seed_from_u64(42); + let best_ = best.clone(); let mut new_node = Node::new( - Some(best.column), - Some(best.column_idx), - Some(best.threshold), + Some(best_.column), + Some(best_.column_idx), + Some(best_.threshold), Some(leaf_weight), - match best.default_is_left { + match best_.default_is_left { Some(default_is_left) => Some(default_is_left), None => None, }, None, None, - Some(SplitScore::new("neg_entropy".to_string(), best.score)), + Some(SplitScore::new("neg_entropy".to_string(), best_.score)), n_obs, "leaf node".to_string(), depth, @@ -287,11 +338,23 @@ pub fn grow_tree( // check if improvement due to split is below minimum requirement // descend left - let new_left_node = grow_tree(x, y, growth_params, Some(&new_node), &depth + 1); // mut new_node, + let (x_left, y_left) = select_arrays_for_child_node(true, &best, x.clone(), y.clone()); + // println!("x_left {:?}", x_left); + // println!("y_left {:?}", y_left); + let new_left_node = grow_tree(&x_left, &y_left, growth_params, Some(&new_node), &depth + 1); // mut new_node, new_node.insert(new_left_node, true); // descend right - let new_right_node = grow_tree(x, y, growth_params, Some(&new_node), depth + 1); // mut new_node, + let (x_right, y_right) = select_arrays_for_child_node(false, &best, x.clone(), y.clone()); + // println!("x_right {:?}", x_right); + // println!("y_right {:?}", y_right); + let new_right_node = grow_tree( + &x_right, + &y_right, + growth_params, + Some(&new_node), + depth + 1, + ); // mut new_node, new_node.insert(new_right_node, false); return new_node; @@ -308,11 +371,13 @@ pub fn predict_for_row_with_tree(row: &Series, tree: &Node) -> f64 { let value: f64 = row.get(idx).expect("Accessing failed."); let threshold = node.threshold.unwrap(); - let is_left = if value < threshold { - node.default_is_left.unwrap() - } else { - !node.default_is_left.unwrap() - }; + let is_left = value < threshold; + // println!("idx {:?} value {:?} threshold {:?} is_left {:?}", idx, value, threshold, is_left); + // let is_left = if value < threshold { + // node.default_is_left.unwrap() + // } else { + // !node.default_is_left.unwrap() + // }; if is_left { node = node.left.as_ref().unwrap(); } else { @@ -337,7 +402,9 @@ pub fn udf<'a, 'b>( row.push(value); } let row = Series::new("", row); + // println!("\nrow {:?}", row); let prediction = predict_for_row_with_tree(&row, tree); + // println!("prediction {:?}", prediction); preds.push(prediction); } @@ -414,9 +481,9 @@ impl DecisionTreeClassifier { } pub fn predict_proba(&self, x: &DataFrame) -> DataFrame { - println!("predict_proba for {:?}", x.shape()); + // println!("predict_proba for {:?}", x.shape()); let class1 = self.decision_tree_core.predict(x); - println!("class1 {:?}", class1.len()); + // println!("class1 {:?}", class1.len()); let y_proba: DataFrame = df!("class_1" => &class1) .unwrap() .lazy() @@ -695,4 +762,77 @@ mod tests { .unwrap(); assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); } + + #[test] + fn test_decision_tree_classifier_1d() { + let df = DataFrame::new(vec![Series::new("a", &[1, 2, 3])]).unwrap(); + let y = Series::new("y", &[0, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (3, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_classifier_2d_case1() { + // is given two columns but only needs one + let df = DataFrame::new(vec![ + Series::new("a", &[1, 1, 1]), + Series::new("b", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[0, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (3, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0])); + } + + #[test] + fn test_decision_tree_classifier_2d_case2() { + // is given two columns and needs both + let df = DataFrame::new(vec![ + Series::new("a", &[-1, 1, -1, 1]), + Series::new("b", &[-1, -1, 1, 1]), + ]) + .unwrap(); + let y = Series::new("y", &[0, 1, 1, 1]); + + let mut dtree = DecisionTreeClassifier::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1, 1])); + + let y_proba = dtree.predict_proba(&df); + assert_eq!(y_proba.shape(), (4, 2)); + assert_eq!(y_proba.get_column_names(), &["class_0", "class_1"]); + // assert that y_proba sums to 1 per row + let y_proba_sum = y_proba + .sum_horizontal(polars::frame::NullStrategy::Propagate) + .unwrap() + .unwrap(); + assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0, 1.0])); + } } diff --git a/src/lib.rs b/src/lib.rs index 0fde80f..fe296d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,3 +61,5 @@ impl DecisionTreeClassifier { Ok(PyDataFrame(y_pred)) } } + +// TODO: implement DecisionTreeRegressor diff --git a/src/scoring.rs b/src/scoring.rs index c310b4f..6370f1f 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -99,8 +99,8 @@ pub fn entropy_rs(y: &Series, target_groups: &Series) -> f64 { let trues = Series::new("", vec![true; target_groups.len()]); let target_groups = target_groups.equal(&trues).unwrap(); - let mut entropy_left = 0.0; - let mut entropy_right = 0.0; + let entropy_left: f64; + let entropy_right: f64; if w_left > 0. { let y_left = y.filter(&target_groups).unwrap(); let probs = calc_probabilities(&y_left); @@ -122,19 +122,22 @@ pub fn entropy_rs(y: &Series, target_groups: &Series) -> f64 { pub fn calc_score( y: &Series, target_groups: &Series, - growth_params: &TreeGrowthParameters, - g: Option<&Series>, - h: Option<&Series>, - incrementing_score: Option, + _growth_params: &TreeGrowthParameters, + _g: Option<&Series>, + _h: Option<&Series>, + _incrementing_score: Option, ) -> f64 { let score = entropy_rs(y, target_groups); score } +#[cfg(test)] mod tests { + use super::*; // test that gini impurity correctly computes values smaller than zero for a couple of vectors + #[test] fn test_gini_impurity() { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; From 070374bdbe0146e025b710d30c7cc31512f838d0 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Fri, 19 Jan 2024 13:08:08 -0600 Subject: [PATCH 24/28] chore: added .envrc to ignores --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 68bc17f..486fe26 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env +.envrc .venv env/ venv/ From 8472ba540a0fbfca26d641a879315a24499def24 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Sat, 20 Jan 2024 07:53:07 -0600 Subject: [PATCH 25/28] feat: added enum SplitScoreMetrics and extended TreeGrowthParameters --- src/decisiontree.rs | 26 +++++++++++++++++++++----- src/scoring.rs | 27 +++++++++++++++++++++------ src/utils.rs | 6 ++++++ 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index 2965577..0a7d6c4 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -3,7 +3,10 @@ use polars::{lazy::dsl::GetOutput, prelude::*}; // use rand_chacha::ChaCha20Rng; use uuid::Uuid; -use crate::{scoring, utils::TreeGrowthParameters}; +use crate::{ + scoring, + utils::{SplitScoreMetrics, TreeGrowthParameters}, +}; #[derive(PartialEq, Debug, Clone)] pub struct SplitScore { @@ -444,6 +447,7 @@ impl DecisionTreeCore { pub fn new(max_depth: usize) -> Self { let growth_params = TreeGrowthParameters { max_depth: Some(max_depth), + split_score_metric: Some(SplitScoreMetrics::Entropy), }; DecisionTreeCore { growth_params, @@ -656,7 +660,10 @@ mod tests { #[test] fn test_calc_leaf_weight_and_split_score() { let y = Series::new("y", &[1, 1, 1]); - let growth_params = TreeGrowthParameters { max_depth: Some(2) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(2), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let (leaf_weight, score) = calc_leaf_weight_and_split_score(&y, &growth_params, None, None, None); assert_eq!(leaf_weight, 1.0); @@ -673,7 +680,10 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 1, 2]); - let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -693,7 +703,10 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 1, 1]); - let growth_params = TreeGrowthParameters { max_depth: Some(2) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(2), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -711,7 +724,10 @@ mod tests { ]) .unwrap(); let y = Series::new("y", &[1, 1, 1, 1]); - let growth_params = TreeGrowthParameters { max_depth: Some(2) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(2), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let tree = grow_tree(&df, &y, &growth_params, None, 0); let predictions = predict_with_tree(df, tree); diff --git a/src/scoring.rs b/src/scoring.rs index 6370f1f..13765e0 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -4,7 +4,7 @@ use polars::prelude::*; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use crate::utils::TreeGrowthParameters; +use crate::utils::{SplitScoreMetrics, TreeGrowthParameters}; // compute gini impurity of an array of discrete values #[pyfunction(name = "gini_impurity")] @@ -231,21 +231,30 @@ mod tests { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); - let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, 0.0); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); - let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, -1.0); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); - let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, -3.0); @@ -255,7 +264,10 @@ mod tests { "target_groups", vec![true, true, true, true, false, false, false, false], ); - let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, 0.0); @@ -265,7 +277,10 @@ mod tests { "target_groups", vec![true, true, true, true, false, false, false, false], ); - let growth_params = TreeGrowthParameters { max_depth: Some(1) }; + let growth_params = TreeGrowthParameters { + max_depth: Some(1), + split_score_metric: Some(SplitScoreMetrics::Entropy), + }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, -1.0); } diff --git a/src/utils.rs b/src/utils.rs index 82b4485..5d69a04 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,9 @@ +pub enum SplitScoreMetrics { + Variance, + Entropy, +} + pub struct TreeGrowthParameters { pub max_depth: Option, + pub split_score_metric: Option, } From 8b5f76bab20851a24c1cf4dab43d4ba4f0595133 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Sat, 20 Jan 2024 08:50:21 -0600 Subject: [PATCH 26/28] feat: added DecisionTreeRegressor --- src/decisiontree.rs | 116 +++++++++++++++++++++++++++++++++++++----- src/lib.rs | 34 ++++++++++++- src/scoring.rs | 121 ++++++++++++++++++++++++++++++++++++++------ src/utils.rs | 15 +++++- 4 files changed, 254 insertions(+), 32 deletions(-) diff --git a/src/decisiontree.rs b/src/decisiontree.rs index 0a7d6c4..d9ebeb2 100644 --- a/src/decisiontree.rs +++ b/src/decisiontree.rs @@ -120,13 +120,15 @@ pub fn check_is_baselevel( pub fn calc_leaf_weight( y: &Series, - _growth_params: &TreeGrowthParameters, + growth_params: &TreeGrowthParameters, _g: Option<&Series>, _h: Option<&Series>, ) -> f64 { - let leaf_weight = y.mean().unwrap(); - - leaf_weight + match growth_params.split_score_metric { + Some(SplitScoreMetrics::NegEntropy) => y.mean().unwrap(), + Some(SplitScoreMetrics::NegVariance) => y.mean().unwrap(), + _ => panic!("Something went wrong. The split_score_metric is not defined."), + } } pub fn calc_leaf_weight_and_split_score( @@ -140,7 +142,12 @@ pub fn calc_leaf_weight_and_split_score( let target_groups: Series = Series::new("target_groups", vec![true; y.len()]); let score = scoring::calc_score(y, &target_groups, growth_params, g, h, incrementing_score); - let score = SplitScore::new("neg_entropy".to_string(), score); + let name = match &growth_params.split_score_metric { + Some(metric) => metric.to_string(), + _ => panic!("Something went wrong. The split_score_metric is not defined."), + }; + + let score = SplitScore::new(name, score); (leaf_weight, score) } @@ -444,10 +451,10 @@ pub struct DecisionTreeCore { } impl DecisionTreeCore { - pub fn new(max_depth: usize) -> Self { + pub fn new(max_depth: usize, split_score_metric: SplitScoreMetrics) -> Self { let growth_params = TreeGrowthParameters { max_depth: Some(max_depth), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(split_score_metric), }; DecisionTreeCore { growth_params, @@ -476,7 +483,7 @@ pub struct DecisionTreeClassifier { impl DecisionTreeClassifier { pub fn new(max_depth: usize) -> Self { DecisionTreeClassifier { - decision_tree_core: DecisionTreeCore::new(max_depth), + decision_tree_core: DecisionTreeCore::new(max_depth, SplitScoreMetrics::NegEntropy), } } @@ -511,6 +518,28 @@ impl DecisionTreeClassifier { } } +pub struct DecisionTreeRegressor { + decision_tree_core: DecisionTreeCore, +} + +impl DecisionTreeRegressor { + pub fn new(max_depth: usize) -> Self { + DecisionTreeRegressor { + decision_tree_core: DecisionTreeCore::new(max_depth, SplitScoreMetrics::NegVariance), + } + } + + pub fn fit(&mut self, x: &DataFrame, y: &Series) { + self.decision_tree_core.fit(x, y); + } + + pub fn predict(&self, x: &DataFrame) -> Series { + let y_pred = Series::new("y", self.decision_tree_core.predict(x)); + + y_pred + } +} + #[cfg(test)] mod tests { // use rand_chacha::ChaCha20Rng; @@ -662,12 +691,12 @@ mod tests { let y = Series::new("y", &[1, 1, 1]); let growth_params = TreeGrowthParameters { max_depth: Some(2), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let (leaf_weight, score) = calc_leaf_weight_and_split_score(&y, &growth_params, None, None, None); assert_eq!(leaf_weight, 1.0); - assert_eq!(score.name, "neg_entropy"); + assert_eq!(score.name, "NegEntropy"); assert_eq!(score.score, 0.0); } @@ -682,7 +711,7 @@ mod tests { let y = Series::new("y", &[1, 1, 2]); let growth_params = TreeGrowthParameters { max_depth: Some(1), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -705,7 +734,7 @@ mod tests { let y = Series::new("y", &[1, 1, 1]); let growth_params = TreeGrowthParameters { max_depth: Some(2), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -726,7 +755,7 @@ mod tests { let y = Series::new("y", &[1, 1, 1, 1]); let growth_params = TreeGrowthParameters { max_depth: Some(2), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let tree = grow_tree(&df, &y, &growth_params, None, 0); @@ -747,7 +776,7 @@ mod tests { .unwrap(); let y = Series::new("y", &[1, 1, 1]); - let mut dtree = DecisionTreeCore::new(2); + let mut dtree = DecisionTreeCore::new(2, SplitScoreMetrics::NegEntropy); dtree.fit(&df, &y); let predictions = dtree.predict(&df); assert_eq!(predictions, Series::new("predictions", &[1.0, 1.0, 1.0])); @@ -851,4 +880,63 @@ mod tests { .unwrap(); assert_eq!(y_proba_sum, Series::new("class_0", &[1.0, 1.0, 1.0, 1.0])); } + + #[test] + fn test_decision_tree_regressor() { + let df = DataFrame::new(vec![ + Series::new("a", &[1, 2, 3]), + Series::new("b", &[1, 2, 3]), + Series::new("c", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[1, 1, 1]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[1, 1, 1])); + } + + #[test] + fn test_decision_tree_regressor_1d() { + let df = DataFrame::new(vec![Series::new("a", &[1, 2, 3])]).unwrap(); + let y = Series::new("y", &[-1, 1, 1]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[-1, 1, 1])); + } + + #[test] + fn test_decision_tree_regressor_2d_case1() { + // is given two columns but only needs one + let df = DataFrame::new(vec![ + Series::new("a", &[1, 1, 1]), + Series::new("b", &[1, 2, 3]), + ]) + .unwrap(); + let y = Series::new("y", &[-1, 1, 1]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[-1, 1, 1])); + } + + #[test] + fn test_decision_tree_regressor_2d_case2() { + // is given two columns and needs both + let df = DataFrame::new(vec![ + Series::new("a", &[-1, 1, -1, 1]), + Series::new("b", &[-1, -1, 1, 1]), + ]) + .unwrap(); + let y = Series::new("y", &[0, 1, 1, 2]); + + let mut dtree = DecisionTreeRegressor::new(2); + dtree.fit(&df, &y); + let predictions = dtree.predict(&df); + assert_eq!(predictions, Series::new("y", &[0, 1, 1, 2])); + } } diff --git a/src/lib.rs b/src/lib.rs index fe296d6..8da91e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod utils; fn random_tree_models(py: Python<'_>, m: &PyModule) -> PyResult<()> { register_scoring_module(py, m)?; m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -62,4 +63,35 @@ impl DecisionTreeClassifier { } } -// TODO: implement DecisionTreeRegressor +#[pyclass] +struct DecisionTreeRegressor { + max_depth: usize, + tree_: Option, +} + +#[pymethods] +impl DecisionTreeRegressor { + #[new] + fn new(max_depth: usize) -> Self { + DecisionTreeRegressor { + max_depth, + tree_: None, + } + } + + fn fit(&mut self, x: PyDataFrame, y: PySeries) -> PyResult<()> { + let mut tree = decisiontree::DecisionTreeRegressor::new(self.max_depth); + let x: DataFrame = x.into(); + let y: Series = y.into(); + tree.fit(&x, &y); + self.tree_ = Some(tree); + Ok(()) + } + + fn predict(&self, x: PyDataFrame) -> PyResult { + let x: DataFrame = x.into(); + let y_pred = self.tree_.as_ref().unwrap().predict(&x); + + Ok(PySeries(y_pred)) + } +} diff --git a/src/scoring.rs b/src/scoring.rs index 13765e0..aa1cb74 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -85,7 +85,7 @@ pub fn calc_neg_entropy_series(ps: &Series) -> f64 { neg_entropy } -pub fn entropy_rs(y: &Series, target_groups: &Series) -> f64 { +pub fn neg_entropy_rs(y: &Series, target_groups: &Series) -> f64 { let msg = "Could not cast to f64"; let w_left: f64 = (*target_groups) .cast(&polars::datatypes::DataType::Float64) @@ -119,17 +119,68 @@ pub fn entropy_rs(y: &Series, target_groups: &Series) -> f64 { score } +pub fn neg_variance_rs(y: &Series, target_groups: &Series) -> f64 { + let msg = "Could not cast to f64"; + let w_left: f64 = (*target_groups) + .cast(&polars::datatypes::DataType::Float64) + .expect(msg) + .sum::() + .unwrap() + / y.len() as f64; + let w_right: f64 = 1.0 - w_left; + + // generate boolean chunked array of target_groups + let trues = Series::new("", vec![true; target_groups.len()]); + let target_groups = target_groups.equal(&trues).unwrap(); + + let variance_left: f64; + let variance_right: f64; + if w_left > 0. { + let y_left = y.filter(&target_groups).unwrap(); + let ddof_left: u8 = (y_left.len() - 1).try_into().unwrap(); + variance_left = y_left + .var_as_series(ddof_left) + .unwrap() + .f64() + .expect("not f64") + .get(0) + .expect("was null"); + } else { + variance_left = 0.0; + } + if w_right > 0. { + let y_right = y.filter(&!target_groups).unwrap(); + let ddof_right: u8 = (y_right.len() - 1).try_into().unwrap(); + variance_right = y_right + .var_as_series(ddof_right) + .unwrap() + .f64() + .expect("not f64") + .get(0) + .expect("was null"); + } else { + variance_right = 0.0; + } + let score = (w_left * variance_left) + (w_right * variance_right); + -score +} + pub fn calc_score( y: &Series, target_groups: &Series, - _growth_params: &TreeGrowthParameters, + growth_params: &TreeGrowthParameters, _g: Option<&Series>, _h: Option<&Series>, _incrementing_score: Option, ) -> f64 { - let score = entropy_rs(y, target_groups); - - score + match growth_params.split_score_metric { + Some(SplitScoreMetrics::NegEntropy) => neg_entropy_rs(y, target_groups), + Some(SplitScoreMetrics::NegVariance) => neg_variance_rs(y, target_groups), + _ => panic!( + "split_score_metric {:?} not supported", + growth_params.split_score_metric + ), + } } #[cfg(test)] @@ -233,7 +284,7 @@ mod tests { let target_groups = Series::new("target_groups", vec![true; 8]); let growth_params = TreeGrowthParameters { max_depth: Some(1), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, 0.0); @@ -243,7 +294,7 @@ mod tests { let target_groups = Series::new("target_groups", vec![true; 8]); let growth_params = TreeGrowthParameters { max_depth: Some(1), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, -1.0); @@ -253,7 +304,7 @@ mod tests { let target_groups = Series::new("target_groups", vec![true; 8]); let growth_params = TreeGrowthParameters { max_depth: Some(1), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, -3.0); @@ -266,7 +317,7 @@ mod tests { ); let growth_params = TreeGrowthParameters { max_depth: Some(1), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, 0.0); @@ -279,7 +330,7 @@ mod tests { ); let growth_params = TreeGrowthParameters { max_depth: Some(1), - split_score_metric: Some(SplitScoreMetrics::Entropy), + split_score_metric: Some(SplitScoreMetrics::NegEntropy), }; let score = calc_score(&s, &target_groups, &growth_params, None, None, None); assert_eq!(score, -1.0); @@ -291,19 +342,19 @@ mod tests { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); - let score = entropy_rs(&s, &target_groups); + let score = neg_entropy_rs(&s, &target_groups); assert_eq!(score, 0.0); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); - let score = entropy_rs(&s, &target_groups); + let score = neg_entropy_rs(&s, &target_groups); assert_eq!(score, -1.0); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); - let score = entropy_rs(&s, &target_groups); + let score = neg_entropy_rs(&s, &target_groups); assert_eq!(score, -3.0); let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; @@ -312,7 +363,7 @@ mod tests { "target_groups", vec![true, true, true, true, false, false, false, false], ); - let score = entropy_rs(&s, &target_groups); + let score = neg_entropy_rs(&s, &target_groups); assert_eq!(score, 0.0); let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; @@ -321,7 +372,47 @@ mod tests { "target_groups", vec![true, true, true, true, false, false, false, false], ); - let score = entropy_rs(&s, &target_groups); + let score = neg_entropy_rs(&s, &target_groups); assert_eq!(score, -1.0); } + + // test neg_variance_rs + #[test] + fn test_neg_variance_rs() { + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -2.); + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let s = Series::new("y", values); + let target_groups = Series::new("target_groups", vec![true; 8]); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -42.); + + let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, 0.0); + + let values = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -1.); + } } diff --git a/src/utils.rs b/src/utils.rs index 5d69a04..d3907ef 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,17 @@ +use std::fmt; + +#[derive(Debug)] pub enum SplitScoreMetrics { - Variance, - Entropy, + NegVariance, + NegEntropy, +} + +impl fmt::Display for SplitScoreMetrics { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + // or, alternatively: + // fmt::Debug::fmt(self, f) + } } pub struct TreeGrowthParameters { From 7e77c486471b88c98f29585e919eb3f288016678 Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Sat, 20 Jan 2024 10:01:31 -0600 Subject: [PATCH 27/28] feat: DecisionTreeRegressor now working in nb --- nbs/decision-tree-rs.ipynb | 29 ++++++++------- src/scoring.rs | 74 +++++++++++++++++++++++++++----------- 2 files changed, 69 insertions(+), 34 deletions(-) diff --git a/nbs/decision-tree-rs.ipynb b/nbs/decision-tree-rs.ipynb index 96ca7cf..1ef5aee 100644 --- a/nbs/decision-tree-rs.ipynb +++ b/nbs/decision-tree-rs.ipynb @@ -167,15 +167,6 @@ "X_plot = pl.from_numpy(X_plot)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "X_plot" - ] - }, { "cell_type": "code", "execution_count": null, @@ -257,7 +248,18 @@ "metadata": {}, "outputs": [], "source": [ - "model = rust.DecisionTreeRegressor(measure_name=\"variance\", max_depth=2)" + "model = rust.DecisionTreeRegressor(max_depth=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X = pl.from_numpy(X)\n", + "y = pl.from_numpy(y).to_series()\n", + "display(X.head(2), y.head(2))" ] }, { @@ -287,7 +289,8 @@ "x0 = np.linspace(X[:, 0].min(), X[:, 0].max(), 100)\n", "x1 = np.linspace(X[:, 1].min(), X[:, 1].max(), 100)\n", "X0, X1 = np.meshgrid(x0, x1)\n", - "X_plot = np.array([X0.ravel(), X1.ravel()]).T" + "X_plot = np.array([X0.ravel(), X1.ravel()]).T\n", + "X_plot = pl.from_numpy(X_plot)" ] }, { @@ -324,7 +327,7 @@ "outputs": [], "source": [ "fig, ax = plt.subplots()\n", - "im = ax.pcolormesh(X0, X1, y_pred.reshape(X0.shape), alpha=0.2)\n", + "im = ax.pcolormesh(X0, X1, y_pred.to_numpy().reshape(X0.shape), alpha=0.2)\n", "fig.colorbar(im, ax=ax)\n", "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, ax=ax, alpha=0.3)\n", "plt.show()" @@ -336,7 +339,7 @@ "metadata": {}, "outputs": [], "source": [ - "y_pred = model.predict(X)\n", + "y_pred = model.predict(X).to_numpy()\n", "\n", "fig, axs = plt.subplots(nrows=2, figsize=(12, 6))\n", "\n", diff --git a/src/scoring.rs b/src/scoring.rs index aa1cb74..172e0ac 100644 --- a/src/scoring.rs +++ b/src/scoring.rs @@ -119,6 +119,22 @@ pub fn neg_entropy_rs(y: &Series, target_groups: &Series) -> f64 { score } +pub fn series_neg_variance(y: Series) -> f64 { + if y.len() == 1 { + return 0.; + } + + let var_series = y.var_as_series(1).unwrap(); + + let variance = var_series + .f64() + .expect("not float64") + .get(0) + .expect("was null"); + + -variance +} + pub fn neg_variance_rs(y: &Series, target_groups: &Series) -> f64 { let msg = "Could not cast to f64"; let w_left: f64 = (*target_groups) @@ -135,34 +151,23 @@ pub fn neg_variance_rs(y: &Series, target_groups: &Series) -> f64 { let variance_left: f64; let variance_right: f64; + if w_left == 1. || w_right == 1. { + return series_neg_variance(y.clone()); + } if w_left > 0. { let y_left = y.filter(&target_groups).unwrap(); - let ddof_left: u8 = (y_left.len() - 1).try_into().unwrap(); - variance_left = y_left - .var_as_series(ddof_left) - .unwrap() - .f64() - .expect("not f64") - .get(0) - .expect("was null"); + variance_left = series_neg_variance(y_left); } else { variance_left = 0.0; } if w_right > 0. { let y_right = y.filter(&!target_groups).unwrap(); - let ddof_right: u8 = (y_right.len() - 1).try_into().unwrap(); - variance_right = y_right - .var_as_series(ddof_right) - .unwrap() - .f64() - .expect("not f64") - .get(0) - .expect("was null"); + variance_right = series_neg_variance(y_right); } else { variance_right = 0.0; } let score = (w_left * variance_left) + (w_right * variance_right); - -score + score } pub fn calc_score( @@ -376,7 +381,25 @@ mod tests { assert_eq!(score, -1.0); } - // test neg_variance_rs + #[test] + fn test_series_neg_variance_rs() { + let s = Series::new("y", vec![1]); + let score = series_neg_variance(s); + assert_eq!(score, 0.0); + + let s = Series::new("y", vec![0, 0, 0, 0, 0, 0, 0, 0]); + let score = series_neg_variance(s); + assert_eq!(score, 0.0); + + let s = Series::new("y", vec![0, 1, 0, 1, 0, 1, 0, 1]); + let score = series_neg_variance(s); + assert_eq!(score, -0.2857142857142857); + + let s = Series::new("y", vec![0, 1, 2, 3, 4, 5, 6, 7]); + let score = series_neg_variance(s); + assert_eq!(score, -6.); + } + #[test] fn test_neg_variance_rs() { let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; @@ -389,13 +412,13 @@ mod tests { let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); let score = neg_variance_rs(&s, &target_groups); - assert_eq!(score, -2.); + assert_eq!(score, -0.2857142857142857); let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; let s = Series::new("y", values); let target_groups = Series::new("target_groups", vec![true; 8]); let score = neg_variance_rs(&s, &target_groups); - assert_eq!(score, -42.); + assert_eq!(score, -6.); let values = vec![0, 0, 0, 0, 0, 0, 0, 0]; let s = Series::new("y", values); @@ -413,6 +436,15 @@ mod tests { vec![true, true, true, true, false, false, false, false], ); let score = neg_variance_rs(&s, &target_groups); - assert_eq!(score, -1.); + assert_eq!(score, -0.3333333333333333); + + let values = vec![0., 1., 0., 1., 0., 1., 0., 1.]; + let s = Series::new("y", values); + let target_groups = Series::new( + "target_groups", + vec![true, true, true, true, false, false, false, false], + ); + let score = neg_variance_rs(&s, &target_groups); + assert_eq!(score, -0.3333333333333333); } } From e65fb45c1a3acdcfc3d6250738c7914d1e3b7e7b Mon Sep 17 00:00:00 2001 From: eschmidt42 <11818904+eschmidt42@users.noreply.github.com> Date: Fri, 29 Mar 2024 16:13:23 +0100 Subject: [PATCH 28/28] chore: something --- nbs/decision-tree-rs.ipynb | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/nbs/decision-tree-rs.ipynb b/nbs/decision-tree-rs.ipynb index 1ef5aee..d22721c 100644 --- a/nbs/decision-tree-rs.ipynb +++ b/nbs/decision-tree-rs.ipynb @@ -126,6 +126,26 @@ "display(X.head(2), y.head(2))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def fit(X: pl.DataFrame, y: pl.Series):\n", + " model = rust.DecisionTreeClassifier(max_depth=4) # measure_name=\"gini\"\n", + " model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%timeit fit(X,y)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -144,6 +164,25 @@ "# dtree.show_tree(model)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def predict(X: pl.DataFrame):\n", + " model.predict(X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%timeit predict(X)" + ] + }, { "cell_type": "code", "execution_count": null,