Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/penguins_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ models:
solver: adam
learning_rate: adaptive

# Random baseline for comparison
- name: random
params:
strategy: stratified # Predicts based on class distribution

training:
repetitions: 10 # Train 10 times with different random seeds
random_seed: 42 # Base seed for reproducibility
Expand Down
5 changes: 5 additions & 0 deletions configs/penguins_multilabel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ models:
max_iter: 1000
class_weight: balanced

# Random baseline for comparison
- name: random
params:
strategy: stratified # Predicts based on class distribution

# Training configuration
training:
repetitions: 5 # Train 5 times with different seeds
Expand Down
5 changes: 5 additions & 0 deletions configs/possum_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ models:
params:
fit_intercept: true

# Random baseline for comparison
- name: random
params:
strategy: mean # Predicts the mean of training targets

# Training configuration
training:
repetitions: 10 # Train each model 10 times with different seeds
Expand Down
5 changes: 5 additions & 0 deletions configs/possum_multilabel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ models:
- name: linear
params: {}

# Random baseline for comparison
- name: random
params:
strategy: mean # Predicts the mean of training targets

# Training configuration
training:
repetitions: 5 # Train 5 times with different seeds
Expand Down
23 changes: 23 additions & 0 deletions ecosci/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- xgboost : gradient boosting trees (requires `xgboost` package)
- logistic : logistic regression (classification baseline)
- linear : linear regression (regression only)
- random : random baseline (uses sklearn DummyClassifier/DummyRegressor)

All model hyperparameters come from `models[].params` in the YAML and are passed
through to the underlying scikit-learn/xgboost classes. This keeps the code
Expand Down Expand Up @@ -186,6 +187,28 @@ def get_model(
return MultiOutputRegressor(base_model)
return base_model

if name.lower() == "random":
from sklearn.dummy import DummyClassifier, DummyRegressor

# DummyClassifier/Regressor serve as random baselines
# Strategies for classification: "most_frequent", "prior", "stratified", "uniform"
# Strategies for regression: "mean", "median", "quantile", "constant"
if problem_type == "classification":
strategy = params.get("strategy", "stratified")
base_model = DummyClassifier(
strategy=strategy,
random_state=params.get("random_state", 0),
**{k: v for k, v in params.items() if k not in ["strategy", "random_state"]},
)
return ModelZoo.wrap_for_multioutput(base_model, problem_type, n_outputs)
else:
strategy = params.get("strategy", "mean")
base_model = DummyRegressor(
strategy=strategy,
**{k: v for k, v in params.items() if k not in ["strategy", "random_state"]},
)
return ModelZoo.wrap_for_multioutput(base_model, problem_type, n_outputs)

raise ValueError(f"Unknown model name: {name}")

@staticmethod
Expand Down
Loading