Skip to content

Commit 1930efd

Browse files
authored
Merge pull request #1946 from abhisrkckl/downhill
Change `StepProblem` and `MaxIterReached` into warnings
2 parents 3c72474 + da031b0 commit 1930efd

File tree

7 files changed

+26
-40
lines changed

7 files changed

+26
-40
lines changed

CHANGELOG-unreleased.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ the released changes.
99

1010
## Unreleased
1111
### Changed
12+
- Change `StepProblem` and `MaxIterReached` into warnings
1213
### Added
1314
- Anderson-Darling test for normal data with fixed mean/variance
1415
- KS test to check if the whitened residuals are unit-normal distributed

src/pint/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class DegeneracyWarning(UserWarning):
2727
pass
2828

2929

30-
class ConvergenceFailure(ValueError):
30+
class ConvergenceFailure(UserWarning):
3131
pass
3232

3333

src/pint/fitter.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@
6060

6161
import contextlib
6262
import copy
63+
from functools import cached_property
6364
from typing import Dict, List, Literal, Optional, Tuple, Union
6465
from warnings import warn
65-
from functools import cached_property
6666

6767
import astropy.units as u
6868
import numpy as np
@@ -72,21 +72,16 @@
7272
from numdifftools import Hessian
7373

7474
import pint
75-
from pint.models.timing_model import TimingModel
7675
from pint.exceptions import (
7776
ConvergenceFailure,
7877
CorrelatedErrors,
7978
DegeneracyWarning,
79+
InvalidModelParameters,
8080
MaxiterReached,
8181
StepProblem,
8282
)
83-
from pint.models.parameter import (
84-
AngleParameter,
85-
InvalidModelParameters,
86-
Parameter,
87-
boolParameter,
88-
strParameter,
89-
)
83+
from pint.models.parameter import AngleParameter, Parameter, boolParameter, strParameter
84+
from pint.models.timing_model import TimingModel
9085
from pint.pint_matrix import (
9186
CorrelationMatrix,
9287
CovarianceMatrix,
@@ -944,7 +939,7 @@ def _fit_toas(
944939
maxiter=20,
945940
required_chi2_decrease=1e-2,
946941
max_chi2_increase=1e-2,
947-
min_lambda=1e-3,
942+
min_lambda=1e-4,
948943
debug=False,
949944
) -> bool:
950945
"""Downhill fit implementation for fitting the timing model parameters.
@@ -955,12 +950,11 @@ def _fit_toas(
955950
# setup
956951
self.model.validate()
957952
self.model.validate_toas(self.toas)
953+
958954
current_state = self.create_state()
959955
best_state = current_state
960956
self.converged = False
961-
# algorithm
962957
exception = None
963-
964958
for i in range(maxiter):
965959
step = current_state.step
966960
lambda_ = 1
@@ -973,14 +967,10 @@ def _fit_toas(
973967
best_state = new_state
974968
if chi2_decrease < -max_chi2_increase:
975969
raise InvalidModelParameters(
976-
f"chi2 increased from {current_state.chi2} to {new_state.chi2} "
977-
f"when trying to take a step with lambda {lambda_}"
970+
f"chi2 increased from {current_state.chi2} to {new_state.chi2} when trying to take a step with lambda {lambda_}"
978971
)
979972
log.trace(
980-
f"Iteration {i}: "
981-
f"Updating state, chi2 goes down by {chi2_decrease} "
982-
f"from {current_state.chi2} "
983-
f"to {new_state.chi2}"
973+
f"Iteration {i}: Updating state, chi2 goes down by {chi2_decrease} from {current_state.chi2} to {new_state.chi2}"
984974
)
985975
exception = None
986976
current_state = new_state
@@ -989,13 +979,9 @@ def _fit_toas(
989979
# This could be an exception evaluating new_state.chi2 or an increase in value
990980
# If bad parameter values escape, look in ModelState.resids for the except
991981
# that should catch them
992-
lambda_ /= 2
982+
lambda_ /= 1.5
993983
log.trace(f"Iteration {i}: Shortening step to {lambda_}: {e}")
994984
if lambda_ < min_lambda:
995-
log.warning(
996-
f"Unable to improve chi2 even with very small steps, stopping "
997-
f"but keeping best state, message was: {e}"
998-
)
999985
exception = e
1000986
break
1001987
if (
@@ -1042,11 +1028,11 @@ def _fit_toas(
10421028
self.update_model(self.current_state.chi2)
10431029

10441030
if exception is not None:
1045-
raise StepProblem(
1046-
"Unable to improve chi2 even with very small steps"
1047-
) from exception
1031+
warn("Unable to improve chi2 even with very small steps", StepProblem)
1032+
return False
1033+
10481034
if not self.converged:
1049-
raise MaxiterReached(f"Convergence not detected after {maxiter} steps.")
1035+
warn(f"Convergence not detected after {maxiter} steps.", MaxiterReached)
10501036

10511037
return self.converged
10521038

src/pint/models/parameter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from uncertainties import ufloat
3434

3535
from pint import pint_units
36-
from pint.exceptions import InvalidModelParameters
3736
from pint.models import priors
3837
from pint.observatory import get_observatory
3938
from pint.pulsar_mjd import (

src/pint/models/stand_alone_psr_binaries/DD_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from pint import Tsun
8-
from pint.models.parameter import InvalidModelParameters
8+
from pint.exceptions import InvalidModelParameters
99

1010
from .binary_generic import PSR_BINARY
1111

src/pint/models/stand_alone_psr_binaries/ELL1_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import astropy.units as u
55
import numpy as np
66

7-
from pint.models.parameter import InvalidModelParameters
7+
from pint.exceptions import InvalidModelParameters
88

99
from .binary_generic import PSR_BINARY
1010

tests/test_downhill_fitter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,15 @@ def test_wls_two_step(model_eccentric_toas):
177177

178178
f = pint.fitter.DownhillWLSFitter(toas, model_wrong)
179179
f.model.free_params = ["ECC"]
180-
with pytest.raises(pint.fitter.MaxiterReached):
180+
with pytest.warns(pint.fitter.MaxiterReached):
181181
f.fit_toas(maxiter=2)
182182
assert not f.converged
183183

184184
f2 = pint.fitter.DownhillWLSFitter(toas, model_wrong)
185185
f2.model.free_params = ["ECC"]
186-
with pytest.raises(pint.fitter.MaxiterReached):
186+
with pytest.warns(pint.fitter.MaxiterReached):
187187
f2.fit_toas(maxiter=1)
188-
with pytest.raises(pint.fitter.MaxiterReached):
188+
with pytest.warns(pint.fitter.MaxiterReached):
189189
f2.fit_toas(maxiter=1)
190190
assert np.abs(f.model.ECC.value - f2.model.ECC.value) < 1e-12
191191

@@ -198,14 +198,14 @@ def test_gls_two_step(model_eccentric_toas_ecorr, full_cov):
198198

199199
f = pint.fitter.DownhillGLSFitter(toas, model_wrong)
200200
f.model.free_params = ["ECC"]
201-
with pytest.raises(pint.fitter.MaxiterReached):
201+
with pytest.warns(pint.fitter.MaxiterReached):
202202
f.fit_toas(maxiter=2, full_cov=full_cov)
203203
assert not f.converged
204204
f2 = pint.fitter.DownhillGLSFitter(toas, model_wrong)
205205
f2.model.free_params = ["ECC"]
206-
with pytest.raises(pint.fitter.MaxiterReached):
206+
with pytest.warns(pint.fitter.MaxiterReached):
207207
f2.fit_toas(maxiter=1, full_cov=full_cov)
208-
with pytest.raises(pint.fitter.MaxiterReached):
208+
with pytest.warns(pint.fitter.MaxiterReached):
209209
f2.fit_toas(maxiter=1, full_cov=full_cov)
210210
assert np.abs(f.model.ECC.value - f2.model.ECC.value) < 1e-12
211211

@@ -218,14 +218,14 @@ def test_wb_two_step(model_eccentric_toas_wb, full_cov):
218218

219219
f = pint.fitter.WidebandDownhillFitter(toas, model_wrong)
220220
f.model.free_params = ["ECC"]
221-
with pytest.raises(pint.fitter.MaxiterReached):
221+
with pytest.warns(pint.fitter.MaxiterReached):
222222
f.fit_toas(maxiter=2, full_cov=full_cov)
223223
assert not f.converged
224224
f2 = pint.fitter.WidebandDownhillFitter(toas, model_wrong)
225225
f2.model.free_params = ["ECC"]
226-
with pytest.raises(pint.fitter.MaxiterReached):
226+
with pytest.warns(pint.fitter.MaxiterReached):
227227
f2.fit_toas(maxiter=1, full_cov=full_cov)
228-
with pytest.raises(pint.fitter.MaxiterReached):
228+
with pytest.warns(pint.fitter.MaxiterReached):
229229
f2.fit_toas(maxiter=1, full_cov=full_cov)
230230
# FIXME: The full_cov version differs at the 1e-10 level for some reason, is it a failure really?
231231
assert np.abs(f.model.ECC.value - f2.model.ECC.value) < 1e-9

0 commit comments

Comments
 (0)