diff --git a/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py b/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py index 83404ab00..48f3924fd 100644 --- a/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py +++ b/Wrappers/Python/cil/optimisation/algorithms/Algorithm.py @@ -29,19 +29,17 @@ class Algorithm: r"""Base class providing minimal infrastructure for iterative algorithms. An iterative algorithm is designed to solve an optimization problem by repeatedly refining a solution. In CIL, we use iterative algorithms to minimize an objective function, often referred to as a loss. The process begins with an initial guess, and with each iteration, the algorithm updates the current solution based on the results of previous iterations (previous iterates). Iterative algorithms typically continue until a stopping criterion is met, indicating that an optimal or sufficiently good solution has been found. In CIL, stopping criteria can be implemented using a callback function (`cil.optimisation.utilities.callbacks`). - + The user is required to implement the :code:`set_up`, :code:`__init__`, :code:`update` and :code:`update_objective` methods. The method :code:`run` is available to run :code:`n` iterations. The method accepts :code:`callbacks`: a list of callables, each of which receive the current Algorithm object (which in turn contains the iteration number and the actual objective value) and can be used to trigger print to screens and other user interactions. The :code:`run` method will stop when the stopping criterion is met or `StopIteration` is raised. - - Parameters - ---------- - update_objective_interval: int, optional, default 1 - The objective (or loss) is calculated and saved every `update_objective_interval`. 1 means every iteration, 2 every 2 iterations and so forth. This is by default 1 and should be increased when evaluating the objective is computationally expensive. """ - def __init__(self, update_objective_interval=1, max_iteration=None, log_file=None): - + def __init__(self, update_objective_interval=None, max_iteration=None, log_file=None): + if update_objective_interval is None: + update_objective_interval = 1 + else: + warn("use `Algorithm.run(update_objective_interval)` instead of `update_objective_interval`", DeprecationWarning, stacklevel=2) self.iteration = -1 self.__max_iteration = 1 if max_iteration is not None: @@ -223,31 +221,30 @@ def update_objective_interval(self): @update_objective_interval.setter def update_objective_interval(self, value): '''sets the update_objective_interval''' - if not isinstance(value, Integral) or value < 0: + if not ((isinstance(value, Integral) and value >= 0) or np.isposinf(value)): raise ValueError('interval must be an integer >= 0') self.__update_objective_interval = value - def run(self, iterations=None, callbacks: Optional[List[Callback]]=None, verbose=1, **kwargs): - r"""run upto :code:`iterations` with callbacks/logging. - + def run(self, iterations: int, update_objective_interval=None, callbacks: Optional[List[Callback]]=None, verbose=1, **kwargs): + """run upto :code:`iterations` with callbacks/logging. + For a demonstration of callbacks see https://github.com/TomographicImaging/CIL-Demos/blob/main/misc/callback_demonstration.ipynb Parameters ----------- - iterations: int, default is None + iterations: int Number of iterations to run. If a positive infinity is passed, the algorithm will run indefinitely until a callback raises `StopIteration`. + update_objective_interval: int, optional, default 1 + The objective (or loss) is calculated and saved every `update_objective_interval`. 1 means every iteration, 2 every 2 iterations and so forth. This is by default 1 and should be increased when evaluating the objective is computationally expensive. callbacks: list of callables, default is Defaults to :code:`[ProgressCallback(verbose)]` List of callables which are passed the current Algorithm object each iteration. Defaults to :code:`[ProgressCallback(verbose)]`. verbose: 0=quiet, 1=info, 2=debug - Passed to the default callback to determine the verbosity of the printed output. + Passed to the default callback to determine the verbosity of the printed output. """ - - if iterations is None: - raise ValueError("`run()` missing number of `iterations`") - + if update_objective_interval is not None: + self.update_objective_interval = update_objective_interval if 'print_interval' in kwargs: - warn("use `TextProgressCallback(miniters)` instead of `run(print_interval)`", - DeprecationWarning, stacklevel=2) + warn("use `TextProgressCallback(miniters)` instead of `run(print_interval)`", DeprecationWarning, stacklevel=2) if np.isposinf(iterations): if callbacks is None: raise ValueError("Infinite iterations require a callback with a stopping criterion that raises `StopIteration`")