Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools", "Cython", "numpy", "wheel"]
requires = ["setuptools", "Cython==3.1.0", "numpy", "wheel"]

[tool.cython-lint]
max-line-length = 127
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy>=1.25.0
matplotlib>=3.7.1
cython>=3.0.0
cython==3.1.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

NAME = "adaXT"
VERSION = "1.5.0"
VERSION = "1.5.1"
DESCRIPTION = "A Python package for tree-based regression and classification"
PROJECT_URLS = {
"Documentation": "https://NiklasPfister.github.io/adaXT/",
Expand Down Expand Up @@ -140,6 +140,7 @@ def run_build():
extensions = cythonize(extensions, **arg_dir)
setup(
name=NAME,
license="BSD-3-clause",
version=VERSION,
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
Expand All @@ -160,7 +161,6 @@ def run_build():
classifiers=[
"Programming Language :: Python :: 3",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
],
extras_require=extras,
Expand Down
17 changes: 9 additions & 8 deletions src/adaXT/decision_tree/_decision_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import sys

cimport numpy as cnp
ctypedef cnp.float64_t DOUBLE_t
ctypedef cnp.int64_t LONG_t
ctypedef cnp.int32_t INT32_T
from libcpp cimport bool


Expand All @@ -16,7 +16,6 @@ from .nodes import DecisionNode

# for c level definitions

cimport cython
from .nodes cimport DecisionNode, Node

from ..utils cimport dsum
Expand All @@ -27,7 +26,7 @@ cdef double EPSILON = np.finfo('double').eps
cdef class refit_object(Node):
cdef public:
list list_idx
bint is_left
bool is_left

def __init__(
self,
Expand All @@ -44,9 +43,7 @@ cdef class refit_object(Node):
def add_idx(self, idx: int) -> None:
self.list_idx.append(idx)


@cython.auto_pickle(True)
cdef class _DecisionTree():
cdef class _DecisionTree:
cdef public:
object criteria
object splitter
Expand Down Expand Up @@ -180,7 +177,7 @@ cdef class _DecisionTree():
cdef void __fit_new_leaf_nodes(self, cnp.ndarray[DOUBLE_t, ndim=2] X,
cnp.ndarray[DOUBLE_t, ndim=2] Y,
cnp.ndarray[DOUBLE_t, ndim=1] sample_weight,
cnp.ndarray[LONG_t, ndim=1] sample_indices):
cnp.ndarray[INT32_T, ndim=1] sample_indices):
cdef:
int idx, n_objs, depth, cur_split_idx
double cur_threshold
Expand Down Expand Up @@ -328,7 +325,7 @@ cdef class _DecisionTree():
cnp.ndarray[DOUBLE_t, ndim=2] X,
cnp.ndarray[DOUBLE_t, ndim=2] Y,
cnp.ndarray[DOUBLE_t, ndim=1] sample_weight,
cnp.ndarray[LONG_t, ndim=1] sample_indices) -> None:
cnp.ndarray[INT32_T, ndim=1] sample_indices) -> None:

if self.root is None:
raise ValueError("The tree has not been trained before trying to\
Expand All @@ -343,6 +340,10 @@ cdef class _DecisionTree():
# Now squash all the DecisionNodes not visited
self.__squash_tree()

# Make sure that predictor_instance points to the same root, if we have
# changed it
self.predictor_instance.root = self.root


# From below here, it is the DepthTreeBuilder
class queue_obj:
Expand Down
2 changes: 1 addition & 1 deletion src/adaXT/predictor/predictor.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cdef class Predictor():
cnp.ndarray X
cnp.ndarray Y
int n_features
Node root
cdef public Node root

cpdef dict predict_leaf(self, double[:, ::1] X)

Expand Down
14 changes: 6 additions & 8 deletions src/adaXT/random_forest/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ def get_sample_indices(
Assumes there has been a previous call to self.__get_sample_indices on the
RandomForest.
"""
indices = np.arange(0, X_n_rows, dtype=np.int32)
if sampling == "resampling":
ret = (
gen.choice(
np.arange(0, X_n_rows),
indices,
size=sampling_args["size"],
replace=sampling_args["replace"],
),
None,
)
elif sampling == "honest_tree":
indices = np.arange(0, X_n_rows)
gen.shuffle(indices)
if sampling_args["replace"]:
resample_size0 = sampling_args["size"]
Expand All @@ -81,7 +81,6 @@ def get_sample_indices(
)
ret = (fit_indices, pred_indices)
elif sampling == "honest_forest":
indices = np.arange(0, X_n_rows)
if sampling_args["replace"]:
resample_size0 = sampling_args["size"]
resample_size1 = sampling_args["size"]
Expand All @@ -103,7 +102,7 @@ def get_sample_indices(
)
ret = (fit_indices, pred_indices)
else:
ret = (np.arange(0, X_n_rows), None)
ret = (indices, None)

if sampling_args["OOB"]:
# Only fitting indices
Expand Down Expand Up @@ -164,7 +163,6 @@ def build_single_tree(
Y=Y,
sample_weight=sample_weight,
sample_indices=prediction_indices)

return tree


Expand Down Expand Up @@ -366,11 +364,11 @@ def __get_sampling_parameter(self, sampling_args: dict | None) -> dict:
sampling_args["split"] = np.min(
[int(0.5 * self.X_n_rows), self.X_n_rows - 1]
)
elif isinstance(sampling_args["size"], float):
elif isinstance(sampling_args["split"], float):
sampling_args["split"] = np.min(
[int(sampling_args["split"] * self.X_n_rows), self.X_n_rows - 1]
)
elif not isinstance(sampling_args["size"], int):
elif not isinstance(sampling_args["split"], (int, np.integer)):
raise ValueError(
"The provided sampling_args['split'] is not an integer or float as required."
)
Expand All @@ -380,7 +378,7 @@ def __get_sampling_parameter(self, sampling_args: dict | None) -> dict:
sampling_args["size"] = int(
sampling_args["size"] * sampling_args["split"]
)
elif not isinstance(sampling_args["size"], int):
elif not isinstance(sampling_args["size"], (np.integer, int)):
raise ValueError(
"The provided sampling_args['size'] is not an integer or float as required."
)
Expand Down