Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4cafda8
add fair energy score
sallen12 Jul 9, 2025
921f0c3
add fair variogram score
sallen12 Jul 9, 2025
025f104
add fair gufuncs for energy and variogram scores
sallen12 Jul 9, 2025
72fbf0f
add akr estimators for energy score in numba
sallen12 Jul 9, 2025
aa1e22d
fix bug in fair gks
sallen12 Jul 9, 2025
63d294b
update documentation of fair kernel scores
sallen12 Jul 9, 2025
ddda075
add akr estimators for the Gaussian kernel score
sallen12 Jul 10, 2025
0c5a595
estimator bug fix in weighted crps tests
sallen12 Jul 10, 2025
9562a6a
fix bug in fair kernel scores
sallen12 Jul 24, 2025
67b8882
fix bugs in variogram scores
sallen12 Jul 24, 2025
0b41c15
set p = 0.5 as default in variogram scores
sallen12 Jul 24, 2025
a70cecd
fix bug in ow and vr energy score gufuncs
sallen12 Jul 24, 2025
b96a551
reduce tolerance of vres vs es test
sallen12 Jul 24, 2025
6b6c734
add roll function to backends
sallen12 Jul 27, 2025
389af17
add crps estimators for non-numba backends
sallen12 Jul 27, 2025
e0dbe4b
add test of equality of crps estimators
sallen12 Jul 27, 2025
e931f72
add akr estimators for energy score for non-numba backends
sallen12 Jul 27, 2025
f127713
remove TODO in energy score test
sallen12 Jul 27, 2025
1f7d131
fix bug in torch crps estimator test
sallen12 Jul 27, 2025
5e1c009
add akr estimators for gaussian kernel scores for non-numba backends
sallen12 Jul 27, 2025
c356489
update default crps estimator
sallen12 Jul 27, 2025
31b7148
Merge branch 'main' into mv_scores
sallen12 Aug 4, 2025
abaf2b3
change tolerance of vertically rescaled variogram score test
sallen12 Aug 4, 2025
ea31940
add crps int estimator for non-numba backends
sallen12 Aug 5, 2025
d96247b
add tests for energy score estimators
sallen12 Sep 1, 2025
5475c0e
add tests for fair variogram score
sallen12 Sep 1, 2025
dbd5164
add tests for kernel score estimators
sallen12 Sep 1, 2025
edaeee0
Merge branch 'main' into mv_scores
sallen12 Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/crps_estimators.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,26 @@ $$

Some examples are given below.

More generally, for any positive definite kernel $k$, we have that

$$ \mathbb{E} S_{k}(F_{M}, y) = \mathbb{E} S_{k}(F, y) + \frac{1}{2M} \left( \mathbb{E} k(x_{1}, x_{1}) - \mathbb{E} k(x_{1}, x_{2}) \right). $$

Using this, a fair version of the kernel score is

$$
S_{k}^{f}(F_{M}, y) = \frac{1}{M(M - 1)} \sum_{i=1}^{M-1} \sum_{j=i+1}^{M} k(x_{i}, x_{j}) + \frac{1}{2} k(y, y) - \frac{1}{M} \sum_{i=1}^{M} k(x_{i}, y).
$$

The first term on the right-hand side is simply the mean of $k(x_i, x_j)$ over all $i,j = 1, \dots, M$ such that $i \neq j$.

For kernels defined in terms of conditionally negative definite functions, as described above, we have that

$$
\mathbb{E} k(x_{1}, x_{1}) - \mathbb{E} k(x_{1}, x_{2}) = 2 \mathbb{E} \rho(x_1, x_{0}) - \mathbb{E} \left[ \rho(x_1, x_0) + \rho(x_2, x_{0}) - \rho(x_1, x_2) \right] = \mathbb{E} \rho(x_1, x_2),
$$

showing that we recover the previous results in this case.


#### CRPS

Expand Down
24 changes: 4 additions & 20 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def crps_ensemble(
m_axis: int = -1,
*,
sorted_ensemble: bool = False,
estimator: str = "pwm",
estimator: str = "qd",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the Continuous Ranked Probability Score (CRPS) for a finite ensemble.
Expand Down Expand Up @@ -130,7 +130,7 @@ def twcrps_ensemble(
m_axis: int = -1,
*,
v_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None,
estimator: str = "pwm",
estimator: str = "qd",
sorted_ensemble: bool = False,
backend: "Backend" = None,
) -> "Array":
Expand Down Expand Up @@ -231,7 +231,6 @@ def owcrps_ensemble(
m_axis: int = -1,
*,
w_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None,
estimator: tp.Literal["nrg"] = "nrg",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the outcome-weighted CRPS (owCRPS) for a finite ensemble.
Expand Down Expand Up @@ -308,11 +307,6 @@ def owcrps_ensemble(
B = backends.active if backend is None else backends[backend]
obs, fct = map(B.asarray, (obs, fct))

if estimator != "nrg":
raise ValueError(
"Only the energy form of the estimator is available "
"for the outcome-weighted CRPS."
)
if m_axis != -1:
fct = B.moveaxis(fct, m_axis, -1)

Expand All @@ -325,9 +319,7 @@ def w_func(x):
obs_weights, fct_weights = map(B.asarray, (obs_weights, fct_weights))

if backend == "numba":
return crps.estimator_gufuncs["ow" + estimator](
obs, fct, obs_weights, fct_weights
)
return crps.estimator_gufuncs["ownrg"](obs, fct, obs_weights, fct_weights)

return crps.ow_ensemble(obs, fct, obs_weights, fct_weights, backend=backend)

Expand All @@ -341,7 +333,6 @@ def vrcrps_ensemble(
m_axis: int = -1,
*,
w_func: tp.Callable[["ArrayLike"], "ArrayLike"] = None,
estimator: tp.Literal["nrg"] = "nrg",
backend: "Backend" = None,
) -> "Array":
r"""Estimate the vertically re-scaled CRPS (vrCRPS) for a finite ensemble.
Expand Down Expand Up @@ -415,11 +406,6 @@ def vrcrps_ensemble(
B = backends.active if backend is None else backends[backend]
obs, fct = map(B.asarray, (obs, fct))

if estimator != "nrg":
raise ValueError(
"Only the energy form of the estimator is available "
"for the outcome-weighted CRPS."
)
if m_axis != -1:
fct = B.moveaxis(fct, m_axis, -1)

Expand All @@ -432,9 +418,7 @@ def w_func(x):
obs_weights, fct_weights = map(B.asarray, (obs_weights, fct_weights))

if backend == "numba":
return crps.estimator_gufuncs["vr" + estimator](
obs, fct, obs_weights, fct_weights
)
return crps.estimator_gufuncs["vrnrg"](obs, fct, obs_weights, fct_weights)

return crps.vr_ensemble(obs, fct, obs_weights, fct_weights, backend=backend)

Expand Down
27 changes: 20 additions & 7 deletions scoringrules/_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def es_ensemble(
m_axis: int = -2,
v_axis: int = -1,
*,
estimator: str = "nrg",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Energy Score for a finite multivariate ensemble.
Expand All @@ -39,6 +40,8 @@ def es_ensemble(
v_axis : int
The axis corresponding to the variables dimension on the forecasts array (or the observations
array with an extra dimension on `m_axis`). Defaults to -1.
estimator : str
The energy score estimator to be used.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -63,9 +66,14 @@ def es_ensemble(
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)

if backend == "numba":
return energy._energy_score_gufunc(obs, fct)
if estimator not in energy.estimator_gufuncs:
raise ValueError(
f"{estimator} is not a valid estimator. "
f"Must be one of {energy.estimator_gufuncs.keys()}"
)
return energy.estimator_gufuncs[estimator](obs, fct)

return energy.nrg(obs, fct, backend=backend)
return energy.es(obs, fct, estimator=estimator, backend=backend)


def twes_ensemble(
Expand All @@ -76,6 +84,7 @@ def twes_ensemble(
m_axis: int = -2,
v_axis: int = -1,
*,
estimator: str = "nrg",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Threshold-Weighted Energy Score (twES) for a finite multivariate ensemble.
Expand Down Expand Up @@ -105,6 +114,8 @@ def twes_ensemble(
The axis corresponding to the ensemble dimension. Defaults to -2.
v_axis : int or tuple of int
The axis corresponding to the variables dimension. Defaults to -1.
estimator : str
The energy score estimator to be used.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -114,7 +125,9 @@ def twes_ensemble(
The computed Threshold-Weighted Energy Score.
"""
obs, fct = map(v_func, (obs, fct))
return es_ensemble(obs, fct, m_axis=m_axis, v_axis=v_axis, backend=backend)
return es_ensemble(
obs, fct, m_axis=m_axis, v_axis=v_axis, estimator=estimator, backend=backend
)


def owes_ensemble(
Expand Down Expand Up @@ -173,9 +186,9 @@ def owes_ensemble(
obs_weights = B.apply_along_axis(w_func, obs, -1)

if B.name == "numba":
return energy._owenergy_score_gufunc(obs, fct, obs_weights, fct_weights)
return energy.estimator_gufuncs["ownrg"](obs, fct, obs_weights, fct_weights)

return energy.ownrg(obs, fct, obs_weights, fct_weights, backend=backend)
return energy.owes(obs, fct, obs_weights, fct_weights, backend=backend)


def vres_ensemble(
Expand Down Expand Up @@ -235,6 +248,6 @@ def vres_ensemble(
obs_weights = B.apply_along_axis(w_func, obs, -1)

if backend == "numba":
return energy._vrenergy_score_gufunc(obs, fct, obs_weights, fct_weights)
return energy.estimator_gufuncs["vrnrg"](obs, fct, obs_weights, fct_weights)

return energy.vrnrg(obs, fct, obs_weights, fct_weights, backend=backend)
return energy.vres(obs, fct, obs_weights, fct_weights, backend=backend)
31 changes: 12 additions & 19 deletions scoringrules/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,17 @@ def gksuv_ensemble(
B = backends.active if backend is None else backends[backend]
obs, fct = map(B.asarray, (obs, fct))

if m_axis != -1:
fct = B.moveaxis(fct, m_axis, -1)

if backend == "numba":
if estimator not in kernels.estimator_gufuncs:
raise ValueError(
f"{estimator} is not a valid estimator. "
f"Must be one of {kernels.estimator_gufuncs.keys()}"
)
else:
if estimator not in ["fair", "nrg"]:
raise ValueError(
f"{estimator} is not a valid estimator. "
f"Must be one of ['fair', 'nrg']"
)

if m_axis != -1:
fct = B.moveaxis(fct, m_axis, -1)

if backend == "numba":
return kernels.estimator_gufuncs[estimator](obs, fct)
else:
return kernels.estimator_gufuncs[estimator](obs, fct)

return kernels.ensemble_uv(obs, fct, estimator, backend=backend)

Expand Down Expand Up @@ -386,14 +379,14 @@ def gksmv_ensemble(
backend = backend if backend is not None else backends._active
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)

if estimator not in kernels.estimator_gufuncs_mv:
raise ValueError(
f"{estimator} is not a valid estimator. "
f"Must be one of {kernels.estimator_gufuncs_mv.keys()}"
)

if backend == "numba":
return kernels.estimator_gufuncs_mv[estimator](obs, fct)
if estimator not in kernels.estimator_gufuncs_mv:
raise ValueError(
f"{estimator} is not a valid estimator. "
f"Must be one of {kernels.estimator_gufuncs_mv.keys()}"
)
else:
return kernels.estimator_gufuncs_mv[estimator](obs, fct)

return kernels.ensemble_mv(obs, fct, estimator, backend=backend)

Expand Down
25 changes: 18 additions & 7 deletions scoringrules/_variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def vs_ensemble(
m_axis: int = -2,
v_axis: int = -1,
*,
p: float = 1.0,
p: float = 0.5,
estimator: str = "nrg",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Variogram Score for a finite multivariate ensemble.
Expand All @@ -42,6 +43,8 @@ def vs_ensemble(
The axis corresponding to the ensemble dimension. Defaults to -2.
v_axis : int
The axis corresponding to the variables dimension. Defaults to -1.
estimator : str
The variogram score estimator to be used.
backend: str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand Down Expand Up @@ -69,9 +72,12 @@ def vs_ensemble(
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)

if backend == "numba":
return variogram._variogram_score_gufunc(obs, fct, p)
if estimator == "nrg":
return variogram._variogram_score_nrg_gufunc(obs, fct, p)
elif estimator == "fair":
return variogram._variogram_score_fair_gufunc(obs, fct, p)

return variogram.vs(obs, fct, p, backend=backend)
return variogram.vs(obs, fct, p, estimator=estimator, backend=backend)


def twvs_ensemble(
Expand All @@ -82,7 +88,8 @@ def twvs_ensemble(
m_axis: int = -2,
v_axis: int = -1,
*,
p: float = 1.0,
p: float = 0.5,
estimator: str = "nrg",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Threshold-Weighted Variogram Score (twVS) for a finite multivariate ensemble.
Expand Down Expand Up @@ -111,6 +118,8 @@ def twvs_ensemble(
The axis corresponding to the ensemble dimension. Defaults to -2.
v_axis : int
The axis corresponding to the variables dimension. Defaults to -1.
estimator : str
The variogram score estimator to be used.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -137,7 +146,9 @@ def twvs_ensemble(
array([5.94996894, 4.72029765, 6.08947229])
"""
obs, fct = map(v_func, (obs, fct))
return vs_ensemble(obs, fct, m_axis, v_axis, p=p, backend=backend)
return vs_ensemble(
obs, fct, m_axis, v_axis, p=p, estimator=estimator, backend=backend
)


def owvs_ensemble(
Expand All @@ -148,7 +159,7 @@ def owvs_ensemble(
m_axis: int = -2,
v_axis: int = -1,
*,
p: float = 1.0,
p: float = 0.5,
backend: "Backend" = None,
) -> "Array":
r"""Compute the Outcome-Weighted Variogram Score (owVS) for a finite multivariate ensemble.
Expand Down Expand Up @@ -223,7 +234,7 @@ def vrvs_ensemble(
m_axis: int = -2,
v_axis: int = -1,
*,
p: float = 1.0,
p: float = 0.5,
backend: "Backend" = None,
) -> "Array":
r"""Compute the Vertically Re-scaled Variogram Score (vrVS) for a finite multivariate ensemble.
Expand Down
4 changes: 4 additions & 0 deletions scoringrules/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ def size(self, x: "Array") -> int:
def indices(self, x: "Array") -> int:
"""Return an array representing the indices of a grid."""

@abc.abstractmethod
def roll(self, x: "Array", shift: int = 1, axis: int = -1) -> int:
"""Roll elements of an array along a given axis."""

@abc.abstractmethod
def inv(self, x: "Array") -> "Array":
"""Return the inverse of a matrix."""
Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def size(self, x: "Array") -> int:
def indices(self, dimensions: tuple) -> "Array":
return jnp.indices(dimensions)

def roll(self, x: "Array", shift: int = 1, axis: int = -1) -> "Array":
return jnp.roll(x, shift=shift, axis=axis)

def inv(self, x: "Array") -> "Array":
return jnp.linalg.inv(x)

Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def size(self, x: "NDArray") -> int:
def indices(self, dimensions: tuple) -> "NDArray":
return np.indices(dimensions)

def roll(self, x: "NDArray", shift: int = 1, axis: int = -1) -> "NDArray":
return np.roll(x, shift=shift, axis=axis)

def inv(self, x: "NDArray") -> "NDArray":
return np.linalg.inv(x)

Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def indices(self, dimensions: tuple) -> "Tensor":
indices = tf.stack(index_grids)
return indices

def roll(self, x: "Tensor", shift: int = 1, axis: int = -1) -> "Tensor":
return tf.roll(x, shift=shift, axis=axis)

def inv(self, x: "Tensor") -> "Tensor":
return tf.linalg.inv(x)

Expand Down
3 changes: 3 additions & 0 deletions scoringrules/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def indices(self, dimensions: tuple) -> "Tensor":
indices = torch.stack(index_grids)
return indices

def roll(self, x: "Tensor", shift: int = 1, axis: int = -1) -> "Tensor":
return torch.roll(x, shifts=shift, dims=axis)

def inv(self, x: "Tensor") -> "Tensor":
return torch.linalg.inv(x)

Expand Down
Loading
Loading