diff --git a/nbs/core/decision-tree.ipynb b/nbs/core/decision-tree.ipynb index 16e7d9b..4b1e193 100644 --- a/nbs/core/decision-tree.ipynb +++ b/nbs/core/decision-tree.ipynb @@ -63,11 +63,11 @@ "import seaborn as sns\n", "import sklearn.datasets as sk_datasets\n", "\n", - "from random_tree_models.decisiontree import (\n", + "from random_tree_models.models.decisiontree import (\n", " DecisionTreeClassifier,\n", " DecisionTreeRegressor,\n", ")\n", - "from random_tree_models.decisiontree.visualize import show_tree\n", + "from random_tree_models.models.decisiontree.visualize import show_tree\n", "from random_tree_models.scoring import MetricNames" ] }, diff --git a/nbs/core/extra-trees.ipynb b/nbs/core/extra-trees.ipynb index ed44855..77e2a0b 100644 --- a/nbs/core/extra-trees.ipynb +++ b/nbs/core/extra-trees.ipynb @@ -41,10 +41,10 @@ "import seaborn as sns\n", "import sklearn.datasets as sk_datasets\n", "\n", - "import random_tree_models.extratrees as et\n", + "import random_tree_models.models.extratrees as et\n", "from random_tree_models.scoring import MetricNames\n", "from random_tree_models.params import ThresholdSelectionMethod\n", - "from random_tree_models.decisiontree.visualize import show_tree" + "from random_tree_models.models.decisiontree.visualize import show_tree" ] }, { diff --git a/nbs/core/gradient-boosted-trees.ipynb b/nbs/core/gradient-boosted-trees.ipynb index 81bd07e..dfd0384 100644 --- a/nbs/core/gradient-boosted-trees.ipynb +++ b/nbs/core/gradient-boosted-trees.ipynb @@ -71,8 +71,8 @@ "import seaborn as sns\n", "import sklearn.datasets as sk_datasets\n", "\n", - "from random_tree_models.decisiontree.visualize import show_tree\n", - "import random_tree_models.gradientboostedtrees as gbtree\n", + "from random_tree_models.models.decisiontree.visualize import show_tree\n", + "import random_tree_models.models.gradientboostedtrees as gbtree\n", "from random_tree_models.scoring import MetricNames" ] }, diff --git a/nbs/core/isolation-forest.ipynb b/nbs/core/isolation-forest.ipynb index 5ced2af..c19fd32 100644 --- a/nbs/core/isolation-forest.ipynb +++ b/nbs/core/isolation-forest.ipynb @@ -56,8 +56,8 @@ "import seaborn as sns\n", "import sklearn.datasets as sk_datasets\n", "\n", - "from random_tree_models.decisiontree.visualize import show_tree\n", - "import random_tree_models.isolationforest as iforest\n", + "from random_tree_models.models.decisiontree.visualize import show_tree\n", + "import random_tree_models.models.isolationforest as iforest\n", "from random_tree_models.params import ColumnSelectionMethod, ThresholdSelectionMethod" ] }, diff --git a/nbs/core/random-forest.ipynb b/nbs/core/random-forest.ipynb index 9f5448b..099183a 100644 --- a/nbs/core/random-forest.ipynb +++ b/nbs/core/random-forest.ipynb @@ -41,8 +41,8 @@ "import seaborn as sns\n", "import sklearn.datasets as sk_datasets\n", "\n", - "from random_tree_models.decisiontree.visualize import show_tree\n", - "import random_tree_models.randomforest as rf\n", + "from random_tree_models.models.decisiontree.visualize import show_tree\n", + "import random_tree_models.models.randomforest as rf\n", "from random_tree_models.params import MetricNames" ] }, diff --git a/nbs/core/robust-random-cut-forest.ipynb b/nbs/core/robust-random-cut-forest.ipynb index 034230f..16e6a79 100644 --- a/nbs/core/robust-random-cut-forest.ipynb +++ b/nbs/core/robust-random-cut-forest.ipynb @@ -55,8 +55,8 @@ "import seaborn as sns\n", "import sklearn.datasets as sk_datasets\n", "\n", - "from random_tree_models.decisiontree.visualize import show_tree\n", - "import random_tree_models.isolationforest as iforest\n", + "from random_tree_models.models.decisiontree.visualize import show_tree\n", + "import random_tree_models.models.isolationforest as iforest\n", "from random_tree_models.params import ColumnSelectionMethod, ThresholdSelectionMethod" ] }, diff --git a/nbs/core/xgboost.ipynb b/nbs/core/xgboost.ipynb index 4813e46..ec6b05a 100644 --- a/nbs/core/xgboost.ipynb +++ b/nbs/core/xgboost.ipynb @@ -284,8 +284,8 @@ "import sklearn.datasets as sk_datasets\n", "from scipy import stats\n", "\n", - "from random_tree_models.decisiontree.visualize import show_tree\n", - "import random_tree_models.xgboost as xgboost\n", + "from random_tree_models.models.decisiontree.visualize import show_tree\n", + "import random_tree_models.models.xgboost as xgboost\n", "from random_tree_models.params import MetricNames" ] }, diff --git a/nbs/dev/xgboost-profiling-histogramming-yay-or-nay.ipynb b/nbs/dev/xgboost-profiling-histogramming-yay-or-nay.ipynb index 2b3ddf7..cb0645e 100644 --- a/nbs/dev/xgboost-profiling-histogramming-yay-or-nay.ipynb +++ b/nbs/dev/xgboost-profiling-histogramming-yay-or-nay.ipynb @@ -49,7 +49,7 @@ "import sklearn.datasets as sk_datasets\n", "from sklearn import metrics\n", "\n", - "import random_tree_models.xgboost as xgboost" + "import random_tree_models.models.xgboost as xgboost" ] }, { diff --git a/tests/decisiontree/__init__.py b/src/random_tree_models/models/__init__.py similarity index 100% rename from tests/decisiontree/__init__.py rename to src/random_tree_models/models/__init__.py diff --git a/src/random_tree_models/decisiontree/__init__.py b/src/random_tree_models/models/decisiontree/__init__.py similarity index 100% rename from src/random_tree_models/decisiontree/__init__.py rename to src/random_tree_models/models/decisiontree/__init__.py diff --git a/src/random_tree_models/decisiontree/estimators.py b/src/random_tree_models/models/decisiontree/estimators.py similarity index 83% rename from src/random_tree_models/decisiontree/estimators.py rename to src/random_tree_models/models/decisiontree/estimators.py index df267d2..cd5b37f 100644 --- a/src/random_tree_models/decisiontree/estimators.py +++ b/src/random_tree_models/models/decisiontree/estimators.py @@ -2,12 +2,17 @@ import numpy as np import sklearn.base as base +from sklearn.utils import ClassifierTags # type: ignore from sklearn.utils.multiclass import check_classification_targets, type_of_target from sklearn.utils.validation import check_is_fitted, validate_data # type: ignore -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.predict import predict_with_tree -from random_tree_models.decisiontree.train import grow_tree +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.predict import predict_with_tree +from random_tree_models.models.decisiontree.random import ( + get_random_feature_ids, + get_random_sample_ids, +) +from random_tree_models.models.decisiontree.train import grow_tree from random_tree_models.params import ( ColumnSelectionMethod, ColumnSelectionParameters, @@ -55,6 +60,7 @@ def __init__( random_state: int = 42, ensure_all_finite: bool = True, ) -> None: + # scikit-learn requires we store parameters like this below, instead of directly assigning TreeGrowthParameters self.max_depth = max_depth self.measure_name = measure_name self.min_improvement = min_improvement @@ -70,49 +76,43 @@ def __init__( self.ensure_all_finite = ensure_all_finite def _organize_growth_parameters(self): + lam = -abs(self.lam) # doing this for probably a good reason + + threshold_params = ThresholdSelectionParameters( + method=self.threshold_method, + quantile=self.threshold_quantile, + n_thresholds=self.n_thresholds, + random_state=self.random_state, + ) + + column_params = ColumnSelectionParameters( + method=self.column_method, + n_trials=self.n_columns_to_try, + ) + self.growth_params_ = TreeGrowthParameters( max_depth=self.max_depth, min_improvement=self.min_improvement, - lam=-abs(self.lam), - frac_subsamples=float(self.frac_subsamples), - frac_features=float(self.frac_features), - random_state=int(self.random_state), - threshold_params=ThresholdSelectionParameters( - method=self.threshold_method, - quantile=self.threshold_quantile, - n_thresholds=self.n_thresholds, - random_state=int(self.random_state), - ), - column_params=ColumnSelectionParameters( - method=self.column_method, - n_trials=self.n_columns_to_try, - ), + lam=lam, + frac_subsamples=self.frac_subsamples, + frac_features=self.frac_features, + random_state=self.random_state, + threshold_params=threshold_params, + column_params=column_params, ) def _select_samples_and_features( self, X: np.ndarray, y: np.ndarray ) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: "Sub-samples rows and columns from X and y" + if not hasattr(self, "growth_params_"): raise ValueError(f"Try calling `fit` first.") - ix = np.arange(len(X)) rng = np.random.RandomState(self.growth_params_.random_state) - if self.growth_params_.frac_subsamples < 1.0: - n_samples = int(self.growth_params_.frac_subsamples * len(X)) - ix_samples = rng.choice(ix, size=n_samples, replace=False) - else: - ix_samples = ix - - if self.frac_features < 1.0: - n_columns = int(X.shape[1] * self.frac_features) - ix_features = rng.choice( - np.arange(X.shape[1]), - size=n_columns, - replace=False, - ) - else: - ix_features = np.arange(X.shape[1]) + + ix_samples = get_random_sample_ids(X, rng, self.growth_params_.frac_subsamples) + ix_features = get_random_feature_ids(X, rng, self.growth_params_.frac_features) _X = X[ix_samples, :] _X = _X[:, ix_features] @@ -245,16 +245,9 @@ def __init__( ) self.ensure_all_finite = ensure_all_finite - def _more_tags(self) -> T.Dict[str, bool]: - """Describes to scikit-learn parametrize_with_checks the scope of this class - - Reference: https://scikit-learn.org/stable/developers/develop.html#estimator-tags - """ - return {"binary_only": True} - def __sklearn_tags__(self): - # https://scikit-learn.org/stable/developers/develop.html - tags = super().__sklearn_tags__() # type: ignore + # https://scikit-learn.org/stable/developers/develop.html#estimator-tags + tags: ClassifierTags = super().__sklearn_tags__() # type: ignore tags.classifier_tags.multi_class = False return tags diff --git a/src/random_tree_models/decisiontree/node.py b/src/random_tree_models/models/decisiontree/node.py similarity index 93% rename from src/random_tree_models/decisiontree/node.py rename to src/random_tree_models/models/decisiontree/node.py index 9489d14..5068acd 100644 --- a/src/random_tree_models/decisiontree/node.py +++ b/src/random_tree_models/models/decisiontree/node.py @@ -3,7 +3,7 @@ from pydantic import StrictInt, StrictStr from pydantic.dataclasses import dataclass -from random_tree_models.decisiontree.split_objects import SplitScore +from random_tree_models.models.decisiontree.split_objects import SplitScore @dataclass diff --git a/src/random_tree_models/decisiontree/predict.py b/src/random_tree_models/models/decisiontree/predict.py similarity index 96% rename from src/random_tree_models/decisiontree/predict.py rename to src/random_tree_models/models/decisiontree/predict.py index 59fac01..06a7fac 100644 --- a/src/random_tree_models/decisiontree/predict.py +++ b/src/random_tree_models/models/decisiontree/predict.py @@ -1,6 +1,6 @@ import numpy as np -from random_tree_models.decisiontree.node import Node +from random_tree_models.models.decisiontree.node import Node def find_leaf_node(node: Node, x: np.ndarray) -> Node: diff --git a/src/random_tree_models/models/decisiontree/random.py b/src/random_tree_models/models/decisiontree/random.py new file mode 100644 index 0000000..fa9fa55 --- /dev/null +++ b/src/random_tree_models/models/decisiontree/random.py @@ -0,0 +1,28 @@ +import numpy as np + + +def get_random_sample_ids( + X: np.ndarray, rng: np.random.RandomState, frac_subsamples: float +) -> np.ndarray: + ix = np.arange(len(X)) + if frac_subsamples < 1.0: + n_samples = int(frac_subsamples * len(X)) + ix_samples = rng.choice(ix, size=n_samples, replace=False) + else: + ix_samples = ix + return ix_samples + + +def get_random_feature_ids( + X: np.ndarray, rng: np.random.RandomState, frac_features: float +) -> np.ndarray: + if frac_features < 1.0: + n_columns = int(X.shape[1] * frac_features) + ix_features = rng.choice( + np.arange(X.shape[1]), + size=n_columns, + replace=False, + ) + else: + ix_features = np.arange(X.shape[1]) + return ix_features diff --git a/src/random_tree_models/decisiontree/split.py b/src/random_tree_models/models/decisiontree/split.py similarity index 98% rename from src/random_tree_models/decisiontree/split.py rename to src/random_tree_models/models/decisiontree/split.py index 25716a9..a01bf14 100644 --- a/src/random_tree_models/decisiontree/split.py +++ b/src/random_tree_models/models/decisiontree/split.py @@ -3,8 +3,8 @@ import numpy as np import random_tree_models.scoring as scoring -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.split_objects import BestSplit +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.split_objects import BestSplit from random_tree_models.params import ( ColumnSelectionMethod, ColumnSelectionParameters, diff --git a/src/random_tree_models/decisiontree/split_objects.py b/src/random_tree_models/models/decisiontree/split_objects.py similarity index 100% rename from src/random_tree_models/decisiontree/split_objects.py rename to src/random_tree_models/models/decisiontree/split_objects.py diff --git a/src/random_tree_models/decisiontree/train.py b/src/random_tree_models/models/decisiontree/train.py similarity index 96% rename from src/random_tree_models/decisiontree/train.py rename to src/random_tree_models/models/decisiontree/train.py index 1f33cc1..15f982a 100644 --- a/src/random_tree_models/decisiontree/train.py +++ b/src/random_tree_models/models/decisiontree/train.py @@ -5,13 +5,13 @@ import random_tree_models.leafweights as leafweights import random_tree_models.params import random_tree_models.scoring as scoring -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.split import ( +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.split import ( check_if_split_sensible, find_best_split, select_arrays_for_child_node, ) -from random_tree_models.decisiontree.split_objects import SplitScore +from random_tree_models.models.decisiontree.split_objects import SplitScore from random_tree_models.params import MetricNames diff --git a/src/random_tree_models/decisiontree/visualize.py b/src/random_tree_models/models/decisiontree/visualize.py similarity index 89% rename from src/random_tree_models/decisiontree/visualize.py rename to src/random_tree_models/models/decisiontree/visualize.py index 8d4aa60..19717f2 100644 --- a/src/random_tree_models/decisiontree/visualize.py +++ b/src/random_tree_models/models/decisiontree/visualize.py @@ -1,8 +1,8 @@ from rich import print as rprint from rich.tree import Tree -from random_tree_models.decisiontree.estimators import DecisionTreeTemplate -from random_tree_models.decisiontree.node import Node +from random_tree_models.models.decisiontree.estimators import DecisionTreeTemplate +from random_tree_models.models.decisiontree.node import Node def walk_tree( diff --git a/src/random_tree_models/extratrees.py b/src/random_tree_models/models/extratrees.py similarity index 96% rename from src/random_tree_models/extratrees.py rename to src/random_tree_models/models/extratrees.py index 1a67dba..011abad 100644 --- a/src/random_tree_models/extratrees.py +++ b/src/random_tree_models/models/extratrees.py @@ -9,8 +9,11 @@ validate_data, # type: ignore ) -import random_tree_models.decisiontree as dtree import random_tree_models.params as utils +from random_tree_models.models.decisiontree import ( + DecisionTreeClassifier, + DecisionTreeRegressor, +) from random_tree_models.params import MetricNames @@ -103,11 +106,11 @@ def __init__( def fit(self, X: np.ndarray, y: np.ndarray) -> "ExtraTreesRegressor": X, y = validate_data(self, X, y, ensure_all_finite=False) - self.trees_: T.List[dtree.DecisionTreeRegressor] = [] + self.trees_: list[DecisionTreeRegressor] = [] rng = np.random.RandomState(self.random_state) for _ in track(range(self.n_trees), total=self.n_trees, description="tree"): # train decision tree to predict differences - new_tree = dtree.DecisionTreeRegressor( + new_tree = DecisionTreeRegressor( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, @@ -199,11 +202,11 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "ExtraTreesClassifier": raise ValueError("Cannot train with only one class present") self.classes_, y = np.unique(y, return_inverse=True) - self.trees_: T.List[dtree.DecisionTreeClassifier] = [] + self.trees_: list[DecisionTreeClassifier] = [] rng = np.random.RandomState(self.random_state) for _ in track(range(self.n_trees), description="tree", total=self.n_trees): - new_tree = dtree.DecisionTreeClassifier( + new_tree = DecisionTreeClassifier( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, diff --git a/src/random_tree_models/gradientboostedtrees.py b/src/random_tree_models/models/gradientboostedtrees.py similarity index 95% rename from src/random_tree_models/gradientboostedtrees.py rename to src/random_tree_models/models/gradientboostedtrees.py index 9857378..46d7d40 100644 --- a/src/random_tree_models/gradientboostedtrees.py +++ b/src/random_tree_models/models/gradientboostedtrees.py @@ -13,8 +13,10 @@ validate_data, # type: ignore ) -import random_tree_models.decisiontree as dtree -from random_tree_models.params import MetricNames +from random_tree_models.models.decisiontree import ( + DecisionTreeRegressor, +) +from random_tree_models.params import MetricNames, is_greater_zero from random_tree_models.utils import bool_to_float @@ -33,7 +35,7 @@ def __init__( min_improvement: float = 0.0, ensure_all_finite: bool = True, ) -> None: - self.n_trees = n_trees + self.n_trees = is_greater_zero(n_trees) self.measure_name = measure_name self.max_depth = max_depth self.min_improvement = min_improvement @@ -94,7 +96,7 @@ def __init__( def fit(self, X: np.ndarray, y: np.ndarray) -> "GradientBoostedTreesRegressor": X, y = validate_data(self, X, y, ensure_all_finite=False) - self.trees_: T.List[dtree.DecisionTreeRegressor] = [] + self.trees_: list[DecisionTreeRegressor] = [] self.start_estimate_ = np.mean(y) @@ -103,7 +105,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "GradientBoostedTreesRegressor": for _ in track(range(self.n_trees), total=self.n_trees, description="tree"): # train decision tree to predict differences - new_tree = dtree.DecisionTreeRegressor( + new_tree = DecisionTreeRegressor( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, @@ -239,9 +241,8 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "GradientBoostedTreesClassifier": if len(np.unique(y)) == 1: raise ValueError("Cannot train with only one class present") - # self.n_features_in_ = X.shape[1] self.classes_, y = np.unique(y, return_inverse=True) - self.trees_: T.List[dtree.DecisionTreeRegressor] = [] + self.trees_: list[DecisionTreeRegressor] = [] self.gammas_ = [] y = self._bool_to_float(y) @@ -255,7 +256,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "GradientBoostedTreesClassifier": 2 * y / (1 + np.exp(2 * y * yhat)) ) # dloss/dyhat, g in the xgboost paper - new_tree = dtree.DecisionTreeRegressor( + new_tree = DecisionTreeRegressor( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, diff --git a/src/random_tree_models/isolationforest.py b/src/random_tree_models/models/isolationforest.py similarity index 98% rename from src/random_tree_models/isolationforest.py rename to src/random_tree_models/models/isolationforest.py index 4be543d..6ee842c 100644 --- a/src/random_tree_models/isolationforest.py +++ b/src/random_tree_models/models/isolationforest.py @@ -5,12 +5,12 @@ from sklearn import base from sklearn.utils.validation import check_is_fitted, validate_data # type: ignore -from random_tree_models.decisiontree import ( +from random_tree_models.models.decisiontree import ( DecisionTreeTemplate, find_leaf_node, grow_tree, ) -from random_tree_models.decisiontree.node import Node +from random_tree_models.models.decisiontree.node import Node from random_tree_models.params import ( ColumnSelectionMethod, MetricNames, diff --git a/src/random_tree_models/randomforest.py b/src/random_tree_models/models/randomforest.py similarity index 86% rename from src/random_tree_models/randomforest.py rename to src/random_tree_models/models/randomforest.py index dbe0164..f7015b3 100644 --- a/src/random_tree_models/randomforest.py +++ b/src/random_tree_models/models/randomforest.py @@ -9,8 +9,11 @@ validate_data, # type: ignore ) -import random_tree_models.decisiontree as dtree -from random_tree_models.params import MetricNames +from random_tree_models.models.decisiontree import ( + DecisionTreeClassifier, + DecisionTreeRegressor, +) +from random_tree_models.params import MetricNames, is_greater_zero class RandomForestTemplate(base.BaseEstimator): @@ -34,7 +37,7 @@ def __init__( frac_features: float = 1.0, random_state: int = 42, ) -> None: - self.n_trees = n_trees + self.n_trees = is_greater_zero(n_trees) self.measure_name = measure_name self.max_depth = max_depth self.min_improvement = min_improvement @@ -87,10 +90,10 @@ def __init__( def fit(self, X: np.ndarray, y: np.ndarray) -> "RandomForestRegressor": X, y = validate_data(self, X, y, ensure_all_finite=False) - self.trees_: T.List[dtree.DecisionTreeRegressor] = [] + self.trees_: list[DecisionTreeRegressor] = [] rng = np.random.RandomState(self.random_state) for _ in track(range(self.n_trees), total=self.n_trees, description="tree"): - new_tree = dtree.DecisionTreeRegressor( + new_tree = DecisionTreeRegressor( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, @@ -116,10 +119,13 @@ def predict(self, X: np.ndarray, aggregation: str = "mean") -> np.ndarray: ): y[:, i] = tree.predict(X) - if aggregation == "mean": - y = np.mean(y, axis=1) - elif aggregation == "median": - y = np.median(y, axis=1) + match aggregation: + case "mean": + y = np.mean(y, axis=1) + case "median": + y = np.median(y, axis=1) + case _: + raise ValueError(f"{aggregation=} expected to be 'mean' or 'median'") return y @@ -181,11 +187,11 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "RandomForestClassifier": raise ValueError("Cannot train with only one class present") self.classes_, y = np.unique(y, return_inverse=True) - self.trees_: T.List[dtree.DecisionTreeClassifier] = [] + self.trees_: list[DecisionTreeClassifier] = [] rng = np.random.RandomState(self.random_state) for _ in track(range(self.n_trees), description="tree", total=self.n_trees): - new_tree = dtree.DecisionTreeClassifier( + new_tree = DecisionTreeClassifier( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, @@ -211,12 +217,15 @@ def predict_proba(self, X: np.ndarray, aggregation: str = "mean") -> np.ndarray: ): proba[:, i, :] = tree.predict_proba(X) - if aggregation == "mean": - proba = np.mean(proba, axis=1) - proba = proba / np.sum(proba, axis=1)[:, np.newaxis] - elif aggregation == "median": - proba = np.median(proba, axis=1) - proba = proba / np.sum(proba, axis=1)[:, np.newaxis] + match aggregation: + case "mean": + proba = np.mean(proba, axis=1) + proba = proba / np.sum(proba, axis=1)[:, np.newaxis] + case "median": + proba = np.median(proba, axis=1) + proba = proba / np.sum(proba, axis=1)[:, np.newaxis] + case _: + raise ValueError(f"{aggregation=} expected to be 'mean' or 'median'") return proba diff --git a/src/random_tree_models/xgboost.py b/src/random_tree_models/models/xgboost.py similarity index 88% rename from src/random_tree_models/xgboost.py rename to src/random_tree_models/models/xgboost.py index a227f08..2d42d32 100644 --- a/src/random_tree_models/xgboost.py +++ b/src/random_tree_models/models/xgboost.py @@ -26,9 +26,9 @@ validate_data, # type: ignore ) -import random_tree_models.decisiontree as dtree -import random_tree_models.utils as gbt -from random_tree_models.params import MetricNames +from random_tree_models.models.decisiontree import DecisionTreeRegressor +from random_tree_models.params import MetricNames, is_greater_zero +from random_tree_models.utils import vectorize_bool_to_float class XGBoostTemplate(base.BaseEstimator): @@ -52,7 +52,7 @@ def __init__( use_hist: bool = False, n_bins: int = 256, ) -> None: - self.n_trees = n_trees + self.n_trees = is_greater_zero(n_trees) self.measure_name = measure_name self.max_depth = max_depth self.min_improvement = min_improvement @@ -92,8 +92,10 @@ def xgboost_histogrammify_with_h( X: np.ndarray, h: np.ndarray, n_bins: int ) -> T.Tuple[np.ndarray, T.List[np.ndarray]]: """Converts X into a histogram representation using XGBoost paper eq 8 and 9 using 2nd order gradient statistics as weights""" + X_hist = np.zeros_like(X, dtype=int) all_x_bin_edges = [] + for i in range(X.shape[1]): order = np.argsort(X[:, i]) h_ordered = h[order] @@ -122,12 +124,12 @@ def xgboost_histogrammify_with_x_bin_edges( X: np.ndarray, all_x_bin_edges: T.List[np.ndarray] ) -> np.ndarray: """Converts X into a histogram representation using XGBoost paper eq 8 and 9 using 2nd order gradient statistics as weights""" + X_hist = np.zeros_like(X, dtype=int) for i in range(X.shape[1]): - bin_assignments = pd.cut( - X[:, i], bins=all_x_bin_edges[i].tolist(), labels=False, include_lowest=True - ) + bins = all_x_bin_edges[i].tolist() + bin_assignments = pd.cut(X[:, i], bins=bins, labels=False, include_lowest=True) X_hist[:, i] = bin_assignments @@ -143,10 +145,8 @@ class XGBoostRegressor(base.RegressorMixin, XGBoostTemplate): def fit(self, X: np.ndarray, y: np.ndarray) -> "XGBoostRegressor": X, y = validate_data(self, X, y, ensure_all_finite=False) - # X, y = check_X_y(X, y, ensure_all_finite=self.ensure_all_finite) - # self.n_features_in_ = X.shape[1] - self.trees_: T.List[dtree.DecisionTreeRegressor] = [] + self.trees_: list[DecisionTreeRegressor] = [] self.start_estimate_: float = float(np.mean(y)) @@ -161,7 +161,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "XGBoostRegressor": for _ in track(range(self.n_trees), total=self.n_trees, description="tree"): # train decision tree to predict differences - new_tree = dtree.DecisionTreeRegressor( + new_tree = DecisionTreeRegressor( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, @@ -211,6 +211,7 @@ def compute_start_estimate_binomial_loglikelihood(y_float: np.ndarray) -> float: ym = np.mean(y_float) start_estimate = 0.5 * math.log((1 + ym) / (1 - ym)) + return start_estimate @@ -218,13 +219,16 @@ def compute_derivatives_binomial_loglikelihood( y_float: np.ndarray, yhat: np.ndarray ) -> T.Tuple[np.ndarray, np.ndarray]: "loss = - sum log(1+exp(2*y*yhat))" + check_y_float(y_float) + # differences to predict using binomial log-likelihood (yes, the negative of the negative :P) exp_y_yhat = np.exp(2 * y_float * yhat) g = 2 * y_float / (1 + exp_y_yhat) # dloss/dyhat, g in the xgboost paper - h = -( - 4 * y_float**2 * exp_y_yhat / (1 + exp_y_yhat) ** 2 - ) # d^2loss/dyhat^2, h in the xgboost paper + + # d^2loss/dyhat^2, h in the xgboost paper + h = -(4 * y_float**2 * exp_y_yhat / (1 + exp_y_yhat) ** 2) + return g, h @@ -235,17 +239,6 @@ class XGBoostClassifier(base.ClassifierMixin, XGBoostTemplate): https://dl.acm.org/doi/10.1145/2939672.2939785 """ - def _bool_to_float(self, y: np.ndarray) -> np.ndarray: - f = np.vectorize(gbt.bool_to_float) - return f(y) - - def _more_tags(self) -> T.Dict[str, bool]: - """Describes to scikit-learn parametrize_with_checks the scope of this class - - Reference: https://scikit-learn.org/stable/developers/develop.html#estimator-tags - """ - return {"binary_only": True} - def __sklearn_tags__(self): # https://scikit-learn.org/stable/developers/develop.html tags = super().__sklearn_tags__() # type: ignore @@ -267,14 +260,13 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "XGBoostClassifier": if len(np.unique(y)) == 1: raise ValueError("Cannot train with only one class present") - # self.n_features_in_ = X.shape[1] self.classes_, y = np.unique(y, return_inverse=True) - self.trees_: T.List[dtree.DecisionTreeRegressor] = [] + self.trees_: list[DecisionTreeRegressor] = [] self.gammas_ = [] self.all_x_bin_edges_ = [] # convert y from True/False to 1/-1 for binomial log-likelihood - y = self._bool_to_float(y) + y = vectorize_bool_to_float(y) # initial estimate self.start_estimate_ = compute_start_estimate_binomial_loglikelihood(y) @@ -291,7 +283,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "XGBoostClassifier": else: _X = X - new_tree = dtree.DecisionTreeRegressor( + new_tree = DecisionTreeRegressor( measure_name=self.measure_name, max_depth=self.max_depth, min_improvement=self.min_improvement, diff --git a/src/random_tree_models/params.py b/src/random_tree_models/params.py index b261bb8..5b44704 100644 --- a/src/random_tree_models/params.py +++ b/src/random_tree_models/params.py @@ -1,7 +1,7 @@ from enum import StrEnum, auto -from typing import Any +from typing import Annotated, Any -from pydantic import BaseModel, StrictFloat, StrictInt +from pydantic import AfterValidator, BaseModel, StrictFloat, StrictInt class ColumnSelectionMethod(StrEnum): @@ -17,40 +17,54 @@ class ThresholdSelectionMethod(StrEnum): uniform = "uniform" +def is_quantile(quantile: float) -> float: + is_okay = 0.0 < quantile < 1.0 + if not is_okay: + raise ValueError(f"{quantile=} not in (0, 1)") + + is_okay = 1 / quantile % 1 == 0 + if not is_okay: + raise ValueError(f"{quantile=} not a valid quantile") + + return quantile + + +def is_fraction(fraction: float) -> float: + is_okay = 0.0 < fraction <= 1.0 + if not is_okay: + raise ValueError(f"{fraction=} not in (0, 1]") + + return fraction + + +def is_greater_zero(value: int) -> int: + if value <= 0: + raise ValueError(f"{value=} not > 0") + return value + + +def is_greater_equal_zero(value: int | float) -> int | float: + if value < 0: + raise ValueError(f"{value=} not >= 0") + return value + + +QuantileValidator = AfterValidator(is_quantile) +GreaterEqualZeroValidator = AfterValidator(is_greater_equal_zero) +GreaterZeroValidator = AfterValidator(is_greater_zero) + + class ThresholdSelectionParameters(BaseModel): method: ThresholdSelectionMethod - quantile: StrictFloat = 0.1 - random_state: StrictInt = 0 - n_thresholds: StrictInt = 100 + + quantile: Annotated[StrictFloat, QuantileValidator] = 0.1 + + random_state: Annotated[StrictInt, GreaterEqualZeroValidator] = 0 + + n_thresholds: Annotated[StrictInt, GreaterZeroValidator] = 100 num_quantile_steps: StrictInt = -1 def model_post_init(self, context: Any): - # verify method - # expected = ThresholdSelectionMethod.__members__.keys() - # is_okay = self.method in expected - # if not is_okay: - # raise ValueError( - # f"passed value for method ('{self.method}') not one of {expected}" - # ) - - # verify quantile - is_okay = 0.0 < self.quantile < 1.0 - if not is_okay: - raise ValueError(f"{self.quantile=} not in (0, 1)") - is_okay = 1 / self.quantile % 1 == 0 - if not is_okay: - raise ValueError(f"{self.quantile=} not a valid quantile") - - # verify random_state - is_okay = self.random_state >= 0 - if not is_okay: - raise ValueError(f"{self.random_state=} not in [0, inf)") - - # verify n_thresholds valid int - is_okay = self.n_thresholds > 0 - if not is_okay: - raise ValueError(f"{self.n_thresholds=} not > 0") - # set dq self.num_quantile_steps = int(1 / self.quantile) + 1 @@ -60,15 +74,19 @@ class ColumnSelectionParameters(BaseModel): n_trials: StrictInt | None = None +FractionValidator = AfterValidator(is_fraction) + + class TreeGrowthParameters(BaseModel): - max_depth: StrictInt - min_improvement: StrictFloat = 0.0 + max_depth: Annotated[StrictInt, GreaterZeroValidator] + min_improvement: Annotated[StrictFloat, GreaterEqualZeroValidator] = 0.0 # xgboost lambda - multiplied with sum of squares of leaf weights # see Chen et al. 2016 equation 2 lam: StrictFloat = 0.0 - frac_subsamples: StrictFloat = 1.0 - frac_features: StrictFloat = 1.0 - random_state: StrictInt = 0 + + frac_subsamples: Annotated[StrictFloat, FractionValidator] = 1.0 + frac_features: Annotated[StrictFloat, FractionValidator] = 1.0 + random_state: Annotated[StrictInt, GreaterEqualZeroValidator] = 0 threshold_params: ThresholdSelectionParameters = ThresholdSelectionParameters( method=ThresholdSelectionMethod.bruteforce, quantile=0.1, @@ -79,17 +97,6 @@ class TreeGrowthParameters(BaseModel): method=ColumnSelectionMethod.ascending, n_trials=None ) - def model_post_init(self, context: Any): - # verify frac_subsamples - is_okay = 0.0 < self.frac_subsamples <= 1.0 - if not is_okay: - raise ValueError(f"{self.frac_subsamples=} not in (0, 1]") - - # verify frac_features - is_okay = 0.0 < self.frac_features <= 1.0 - if not is_okay: - raise ValueError(f"{self.frac_features=} not in (0, 1]") - class MetricNames(StrEnum): variance = auto() diff --git a/src/random_tree_models/utils.py b/src/random_tree_models/utils.py index 58bc1d3..390eec5 100644 --- a/src/random_tree_models/utils.py +++ b/src/random_tree_models/utils.py @@ -1,5 +1,6 @@ import logging +import numpy as np from rich.logging import RichHandler @@ -28,3 +29,8 @@ def bool_to_float(x: bool) -> float: return -1.0 else: raise ValueError(f"{x=}, expected bool") + + +def vectorize_bool_to_float(y: np.ndarray) -> np.ndarray: + f = np.vectorize(bool_to_float) + return f(y) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/decisiontree/__init__.py b/tests/models/decisiontree/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/decisiontree/conftest.py b/tests/models/decisiontree/conftest.py similarity index 92% rename from tests/decisiontree/conftest.py rename to tests/models/decisiontree/conftest.py index a86f141..b7edf9f 100644 --- a/tests/decisiontree/conftest.py +++ b/tests/models/decisiontree/conftest.py @@ -1,4 +1,4 @@ -from random_tree_models.decisiontree.node import Node +from random_tree_models.models.decisiontree.node import Node # first value in each tuple is the value to test and the second is the flag indicating if this should work BOOL_OPTIONS_NONE_OKAY = [(False, True), (True, True), ("blub", False)] diff --git a/tests/decisiontree/test_estimators.py b/tests/models/decisiontree/test_estimators.py similarity index 96% rename from tests/decisiontree/test_estimators.py rename to tests/models/decisiontree/test_estimators.py index 9f6c185..f725c7d 100644 --- a/tests/decisiontree/test_estimators.py +++ b/tests/models/decisiontree/test_estimators.py @@ -3,12 +3,12 @@ from sklearn.utils.estimator_checks import parametrize_with_checks import random_tree_models.params -from random_tree_models.decisiontree import ( +from random_tree_models.models.decisiontree import ( DecisionTreeClassifier, DecisionTreeRegressor, ) -from random_tree_models.decisiontree.estimators import DecisionTreeTemplate -from random_tree_models.decisiontree.node import Node +from random_tree_models.models.decisiontree.estimators import DecisionTreeTemplate +from random_tree_models.models.decisiontree.node import Node from random_tree_models.params import MetricNames from tests.conftest import expected_failed_checks diff --git a/tests/decisiontree/test_node.py b/tests/models/decisiontree/test_node.py similarity index 94% rename from tests/decisiontree/test_node.py rename to tests/models/decisiontree/test_node.py index 6d92468..085e819 100644 --- a/tests/decisiontree/test_node.py +++ b/tests/models/decisiontree/test_node.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.split_objects import SplitScore +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.split_objects import SplitScore from .conftest import ( BOOL_OPTIONS_NONE_OKAY, diff --git a/tests/decisiontree/test_predict.py b/tests/models/decisiontree/test_predict.py similarity index 89% rename from tests/decisiontree/test_predict.py rename to tests/models/decisiontree/test_predict.py index f0806d7..d7dfbe7 100644 --- a/tests/decisiontree/test_predict.py +++ b/tests/models/decisiontree/test_predict.py @@ -1,8 +1,11 @@ import numpy as np import pytest -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.predict import find_leaf_node, predict_with_tree +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.predict import ( + find_leaf_node, + predict_with_tree, +) @pytest.mark.parametrize( diff --git a/tests/models/decisiontree/test_random.py b/tests/models/decisiontree/test_random.py new file mode 100644 index 0000000..e50603b --- /dev/null +++ b/tests/models/decisiontree/test_random.py @@ -0,0 +1,38 @@ +import numpy as np + +from random_tree_models.models.decisiontree.random import ( + get_random_feature_ids, + get_random_sample_ids, +) + + +def test_get_random_sample_ids(): + X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + rng = np.random.RandomState(42) + + # Test with frac_subsamples = 1.0 + ix_samples_full = get_random_sample_ids(X, rng, 1.0) + assert np.array_equal(ix_samples_full, np.array([0, 1, 2, 3, 4])) + assert len(set(ix_samples_full)) == len(ix_samples_full) + + # Test with frac_subsamples < 1.0 + ix_samples_partial = get_random_sample_ids(X, rng, 0.5) + assert len(ix_samples_partial) == int(0.5 * len(X)) + assert all(i in np.arange(len(X)) for i in ix_samples_partial) + assert len(set(ix_samples_partial)) == len(ix_samples_partial) + + +def test_get_random_feature_ids(): + X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) + rng = np.random.RandomState(42) + + # Test with frac_features = 1.0 + ix_features_full = get_random_feature_ids(X, rng, 1.0) + assert np.array_equal(ix_features_full, np.array([0, 1, 2])) + assert len(set(ix_features_full)) == len(ix_features_full) + + # Test with frac_features < 1.0 + ix_features_partial = get_random_feature_ids(X, rng, 0.5) + assert len(ix_features_partial) == int(X.shape[1] * 0.5) + assert all(i in np.arange(X.shape[1]) for i in ix_features_partial) + assert len(set(ix_features_partial)) == len(ix_features_partial) diff --git a/tests/decisiontree/test_split.py b/tests/models/decisiontree/test_split.py similarity index 99% rename from tests/decisiontree/test_split.py rename to tests/models/decisiontree/test_split.py index 6c2653a..6c3c0a2 100644 --- a/tests/decisiontree/test_split.py +++ b/tests/models/decisiontree/test_split.py @@ -5,8 +5,8 @@ from pydantic import ValidationError from scipy import stats -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.split import ( +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.split import ( BestSplit, check_if_split_sensible, find_best_split, @@ -15,7 +15,7 @@ select_arrays_for_child_node, select_thresholds, ) -from random_tree_models.decisiontree.split_objects import SplitScore +from random_tree_models.models.decisiontree.split_objects import SplitScore from random_tree_models.params import ( ColumnSelectionMethod, ColumnSelectionParameters, diff --git a/tests/decisiontree/test_split_objects.py b/tests/models/decisiontree/test_split_objects.py similarity index 93% rename from tests/decisiontree/test_split_objects.py rename to tests/models/decisiontree/test_split_objects.py index ee0f712..9aecd91 100644 --- a/tests/decisiontree/test_split_objects.py +++ b/tests/models/decisiontree/test_split_objects.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from random_tree_models.decisiontree.split_objects import SplitScore +from random_tree_models.models.decisiontree.split_objects import SplitScore from .conftest import FLOAT_OPTIONS_NONE_OKAY, STR_OPTIONS_NONE_NOT_OKAY diff --git a/tests/decisiontree/test_train.py b/tests/models/decisiontree/test_train.py similarity index 93% rename from tests/decisiontree/test_train.py rename to tests/models/decisiontree/test_train.py index ef5fb5b..debb309 100644 --- a/tests/decisiontree/test_train.py +++ b/tests/models/decisiontree/test_train.py @@ -3,9 +3,9 @@ from dirty_equals import IsApprox from inline_snapshot import snapshot -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.split_objects import SplitScore -from random_tree_models.decisiontree.train import ( +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.split_objects import SplitScore +from random_tree_models.models.decisiontree.train import ( calc_leaf_weight_and_split_score, check_is_baselevel, grow_tree, diff --git a/tests/decisiontree/test_visualize.py b/tests/models/decisiontree/test_visualize.py similarity index 87% rename from tests/decisiontree/test_visualize.py rename to tests/models/decisiontree/test_visualize.py index b77dda6..c285d6a 100644 --- a/tests/decisiontree/test_visualize.py +++ b/tests/models/decisiontree/test_visualize.py @@ -2,9 +2,9 @@ from inline_snapshot import snapshot from pytest import CaptureFixture -from random_tree_models.decisiontree.estimators import DecisionTreeTemplate -from random_tree_models.decisiontree.node import Node -from random_tree_models.decisiontree.visualize import show_tree +from random_tree_models.models.decisiontree.estimators import DecisionTreeTemplate +from random_tree_models.models.decisiontree.node import Node +from random_tree_models.models.decisiontree.visualize import show_tree from random_tree_models.params import MetricNames diff --git a/tests/test_extratrees.py b/tests/models/test_extratrees.py similarity index 96% rename from tests/test_extratrees.py rename to tests/models/test_extratrees.py index c35a9c5..e33401b 100644 --- a/tests/test_extratrees.py +++ b/tests/models/test_extratrees.py @@ -2,8 +2,8 @@ import pytest from sklearn.utils.estimator_checks import parametrize_with_checks -import random_tree_models.extratrees as et -from random_tree_models.decisiontree import ( +import random_tree_models.models.extratrees as et +from random_tree_models.models.decisiontree import ( DecisionTreeClassifier, DecisionTreeRegressor, ) diff --git a/tests/test_gradientboostedtrees.py b/tests/models/test_gradientboostedtrees.py similarity index 95% rename from tests/test_gradientboostedtrees.py rename to tests/models/test_gradientboostedtrees.py index d939b7b..3c3f761 100644 --- a/tests/test_gradientboostedtrees.py +++ b/tests/models/test_gradientboostedtrees.py @@ -2,8 +2,8 @@ import pytest from sklearn.utils.estimator_checks import parametrize_with_checks -import random_tree_models.gradientboostedtrees as gbt -from random_tree_models.decisiontree import ( +import random_tree_models.models.gradientboostedtrees as gbt +from random_tree_models.models.decisiontree import ( DecisionTreeRegressor, ) from tests.conftest import expected_failed_checks diff --git a/tests/test_isolationforest.py b/tests/models/test_isolationforest.py similarity index 97% rename from tests/test_isolationforest.py rename to tests/models/test_isolationforest.py index c127bc0..c3be640 100644 --- a/tests/test_isolationforest.py +++ b/tests/models/test_isolationforest.py @@ -1,6 +1,6 @@ import numpy as np -import random_tree_models.isolationforest as iforest +import random_tree_models.models.isolationforest as iforest from random_tree_models.params import ThresholdSelectionMethod rng = np.random.RandomState(42) diff --git a/tests/test_randomforest.py b/tests/models/test_randomforest.py similarity index 95% rename from tests/test_randomforest.py rename to tests/models/test_randomforest.py index ce377d5..840a96e 100644 --- a/tests/test_randomforest.py +++ b/tests/models/test_randomforest.py @@ -2,8 +2,8 @@ import pytest from sklearn.utils.estimator_checks import parametrize_with_checks -import random_tree_models.randomforest as rf -from random_tree_models.decisiontree import ( +import random_tree_models.models.randomforest as rf +from random_tree_models.models.decisiontree import ( DecisionTreeClassifier, DecisionTreeRegressor, ) diff --git a/tests/test_xgboost.py b/tests/models/test_xgboost.py similarity index 97% rename from tests/test_xgboost.py rename to tests/models/test_xgboost.py index 188a7f2..83a84c1 100644 --- a/tests/test_xgboost.py +++ b/tests/models/test_xgboost.py @@ -2,8 +2,8 @@ import pytest from sklearn.utils.estimator_checks import parametrize_with_checks -import random_tree_models.xgboost as xgboost -from random_tree_models.decisiontree import DecisionTreeRegressor +import random_tree_models.models.xgboost as xgboost +from random_tree_models.models.decisiontree import DecisionTreeRegressor from tests.conftest import expected_failed_checks diff --git a/tests/test_params.py b/tests/test_params.py index 547cdc8..ef530c0 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -7,6 +7,10 @@ ThresholdSelectionMethod, ThresholdSelectionParameters, TreeGrowthParameters, + is_fraction, + is_greater_equal_zero, + is_greater_zero, + is_quantile, ) @@ -214,3 +218,97 @@ def test_frac_features(self, frac_features: float, fail: bool): def test_fail_if_max_depth_missing(self): with pytest.raises(ValidationError): _ = TreeGrowthParameters() # type: ignore + + +@pytest.mark.parametrize( + "q, ok", + [ + (0.1, True), + (0.2, True), + (0.3, True), + (-0.1, False), + (1.1, False), + (0.4, False), + (0.51, False), + ], +) +def test_is_quantile(q: float, ok: bool): + try: + res = is_quantile(q) + except ValueError: + pass + else: + if ok: + assert res == q + else: + raise + + +@pytest.mark.parametrize( + "f, ok", + [ + (0.1, True), + (0.2, True), + (0.3, True), + (0.5, True), + (0.6, True), + (0.8, True), + (0.99, True), + (1.0, True), + (-0.1, False), + (1.1, False), + (0, False), + ], +) +def test_is_fraction(f: float, ok: bool): + try: + res = is_fraction(f) + except ValueError: + pass + else: + if ok: + assert res == f + else: + raise + + +@pytest.mark.parametrize( + "v, ok", + [ + (0.1, True), + (1.0, True), + (-0.1, False), + (0, False), + ], +) +def test_is_greater_zero(v: int, ok: bool): + try: + res = is_greater_zero(v) + except ValueError: + pass + else: + if ok: + assert res == v + else: + raise + + +@pytest.mark.parametrize( + "v, ok", + [ + (0.1, True), + (1.0, True), + (-0.1, False), + (0, True), + ], +) +def test_is_greater_equal_zero(v: float, ok: bool): + try: + res = is_greater_equal_zero(v) + except ValueError: + pass + else: + if ok: + assert res == v + else: + raise diff --git a/tests/test_utils.py b/tests/test_utils.py index 7c852fa..2a11693 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ import logging +import numpy as np import pytest import random_tree_models.utils @@ -36,3 +37,17 @@ def test_bool_to_float(x, exp, is_bad: bool): if is_bad: pytest.fail(f"Passed unexpectedly for non-bool value {x} returning {res}") assert res == exp + + +def test_vectorize_bool_to_float(): + y = np.array([True, False, True, False]) + res = utils.vectorize_bool_to_float(y) + assert np.all(res == np.array([1.0, -1.0, 1.0, -1.0])) + + y = np.array([True, False, True, True]) + res = utils.vectorize_bool_to_float(y) + assert np.all(res == np.array([1.0, -1.0, 1.0, 1.0])) + + y = np.array([False, False, True, False]) + res = utils.vectorize_bool_to_float(y) + assert np.all(res == np.array([-1.0, -1.0, 1.0, -1.0]))