Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ All notable changes to this project will be documented in this file.
- [#1508](https://github.com/pints-team/pints/pull/1508) The methods `OptimisationController.max_unchanged_iterations` and `set_max_unchanged_iterations` are deprecated, in favour of `function_tolerance` and `set_function_tolerance` respectively.
### Removed
### Fixed
- [#1729](https://github.com/pints-team/pints/pull/1729) The `rhat` method now raises an error if only 1 chain is passed in.
- [#1713](https://github.com/pints-team/pints/pull/1713) Fixed Numpy 2.4.1 compatibility issues.
- [#1690](https://github.com/pints-team/pints/pull/1690) Fixed bug in optimisation controller if population size left at `None`.

Expand Down
22 changes: 11 additions & 11 deletions pints/_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,28 +201,28 @@
"""
if not (chains.ndim == 2 or chains.ndim == 3):
raise ValueError(
'Dimension of chains is %d. ' % chains.ndim
+ 'Method computes Rhat for one '
'or multiple parameters and therefore only accepts 2 or 3 '
'dimensional arrays.')
f'Dimension of chains is {chains.ndim}. This method computes Rhat'
' for one or multiple parameters and therefore only accepts 2 or 3'
' dimensional arrays.')
if warm_up > 1 or warm_up < 0:
raise ValueError(
'`warm_up` is set to %f. `warm_up` only takes values in [0,1].' %
warm_up)
f'`warm_up` is set to {warm_up}. `warm_up` only takes values in'
' [0,1].')
if chains.shape[0] < 2:
raise ValueError('Number of chains needs to be 2 or higher')

Check failure on line 212 in pints/_diagnostics.py

View workflow job for this annotation

GitHub Actions / Coverage

Number of chains needs to be 2 or higher

Check failure on line 212 in pints/_diagnostics.py

View workflow job for this annotation

GitHub Actions / Python unit tests (3.8)

Number of chains needs to be 2 or higher

Check failure on line 212 in pints/_diagnostics.py

View workflow job for this annotation

GitHub Actions / OS unit tests (ubuntu-22.04)

Number of chains needs to be 2 or higher

Check failure on line 212 in pints/_diagnostics.py

View workflow job for this annotation

GitHub Actions / Python unit tests (3.13)

Number of chains needs to be 2 or higher

Check failure on line 212 in pints/_diagnostics.py

View workflow job for this annotation

GitHub Actions / Python unit tests (3.9)

Number of chains needs to be 2 or higher

# Get number of samples
n = chains.shape[1]
if n < 2:
raise ValueError(
'Number of samples per chain after warm-up and chain splitting is '
f'{n}. Method needs at least 2 samples per chain.')

# Exclude warm-up
chains = chains[:, int(n * warm_up):]
n = chains.shape[1]

# Split chains in half
n = n // 2 # new length of chains
if n < 1:
raise ValueError(
'Number of samples per chain after warm-up and chain splitting is '
'%d. Method needs at least 1 sample per chain.' % n)
chains = np.vstack([chains[:, :n], chains[:, -n:]])

# Compute mean within-chain variance
Expand Down
59 changes: 27 additions & 32 deletions pints/tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
# copyright notice and full license details.
#
import unittest
import pints
import warnings

import numpy as np

import pints
import pints._diagnostics


Expand Down Expand Up @@ -58,7 +61,7 @@ def test_effective_sample_size(self):
# matrix with two columns of samples
x = np.transpose(np.array([[1.0, 1.1, 1.4, 1.3, 1.3],
[1.0, 2.0, 3.0, 4.0, 5.0]]))
y = pints._diagnostics.effective_sample_size(x)
y = pints.effective_sample_size(x)
self.assertAlmostEqual(y[0], 1.439232, 6)
self.assertAlmostEqual(y[1], 1.315789, 6)

Expand Down Expand Up @@ -91,7 +94,7 @@ def test_rhat(self):
chains = np.array([[1.0, 1.1, 1.4, 1.3],
[1.0, 2.0, 3.0, 4.0]])
self.assertAlmostEqual(
pints._diagnostics.rhat(chains), 2.3303847470550716, 6)
pints.rhat(chains), 2.3303847470550716, 6)

# Test Rhat computation for two parameters, chains.shape=(3, 4, 2)
chains = np.array([
Expand All @@ -114,7 +117,7 @@ def test_rhat(self):
[0.89531238, 0.63207977]
]])

y = pints._diagnostics.rhat(chains)
y = pints.rhat(chains)
d = np.array(y) - np.array([0.84735944450487122, 1.1712652416950846])
self.assertLess(np.linalg.norm(d), 0.01)

Expand All @@ -124,40 +127,32 @@ def test_bad_rhat_inputs(self):

# Pass chain of dimension 1
chains = np.empty(shape=1)
message = (
'Dimension of chains is 1. '
+ 'Method computes Rhat for one '
'or multiple parameters and therefore only accepts 2 or 3 '
'dimensional arrays.')
self.assertRaisesRegex(
ValueError, message[0], pints.rhat, chains)
ValueError, 'only accepts 2 or 3 dimensional', pints.rhat, chains)

# Pass chain of dimension 4
chains = np.empty(shape=(1, 1, 1, 1))
message = (
'Dimension of chains is 4. '
+ 'Method computes Rhat for one '
'or multiple parameters and therefore only accepts 2 or 3 '
'dimensional arrays.')
self.assertRaisesRegex(
ValueError, message[0], pints.rhat, chains)
ValueError, 'only accepts 2 or 3 dimensional', pints.rhat, chains)

# Pass only a single chain
chains = np.empty(shape=(1, 5))
self.assertRaisesRegex(
ValueError, '2 or higher', pints.rhat, chains)

# Pass only a single sample
chains = np.empty(shape=(5, 1))
self.assertRaisesRegex(
ValueError, 'at least 2 samples', pints.rhat, chains)

# Pass bad warm-up arguments
chains = np.empty(shape=(2, 4))

# warm-up greater than 100%
warm_up = 1.1
message = (
'`warm_up` is set to 1.1. `warm_up` only takes values in [0,1].')
# warm-up greater than 100% or negative
self.assertRaisesRegex(
ValueError, message[0], pints.rhat, chains, warm_up)

# Negative warm-up
warm_up = -0.1
message = (
'`warm_up` is set to -0.1. `warm_up` only takes values in [0,1].')
ValueError, r'takes values in \[0,1\]', pints.rhat, chains, 1.1)
self.assertRaisesRegex(
ValueError, message[0], pints.rhat, chains, warm_up)
ValueError, r'takes values in \[0,1\]', pints.rhat, chains, -0.1)

# Pass chains with too little samples (n<4)
chains = np.empty(shape=(1, 4))
Expand All @@ -168,8 +163,7 @@ def test_bad_rhat_inputs(self):
self.assertRaisesRegex(
ValueError, message[0], pints.rhat, chains, warm_up)

def test_rhat_all_params(self):
# Tests that rhat_all works
def test_rhat_deprecated_alias(self):

x = np.array([[[-1.10580535, 2.26589882],
[0.35604827, 1.03523364],
Expand All @@ -184,9 +178,10 @@ def test_rhat_all_params(self):
[0.92272047, -1.49997615],
[0.89531238, 0.63207977]]])

y = pints._diagnostics.rhat_all_params(x)
d = np.array(y) - np.array([0.84735944450487122, 1.1712652416950846])
self.assertLess(np.linalg.norm(d), 0.01)
with warnings.catch_warnings(record=True) as w:
z = pints.rhat_all_params(x)
self.assertIn('deprecated', str(w[-1].message))
self.assertEqual(list(pints.rhat(x)), list(z))


if __name__ == '__main__':
Expand Down
11 changes: 4 additions & 7 deletions pints/tests/test_mcmc_summary.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#!/usr/bin/env python3
#
# Tests the basic methods of the adaptive covariance base class.
#
##
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
Expand All @@ -13,10 +11,8 @@
import pints.toy as toy


class TestAdaptiveCovarianceMC(unittest.TestCase):
"""
Tests the basic methods of the adaptive covariance MCMC routine.
"""
class TestMCMCSummary(unittest.TestCase):
""" Tests the MCMCSummary class. """

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -151,6 +147,7 @@ def test_ess_per_second(self):

def test_named_parameters(self):
# tests that parameter names are used when values supplied

parameters = ['rrrr', 'kkkk', 'ssss']
results = pints.MCMCSummary(
self.chains, parameter_names=parameters)
Expand Down
Loading