diff --git a/CHANGES.rst b/CHANGES.rst index d9842a8dc..a107cfc07 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -41,6 +41,12 @@ New Features - Added ``__repr__`` methods to ``ImagePSF`` and ``GriddedPSFModel``. [#2134] + - Added a ``shape`` property to ``ImagePSF``. [#2158] + + - ``EPSFBuilder`` now automatically excludes stars that repeatedly + fail fitting and emits warnings with specific failure reasons. + [#2158] + Bug Fixes ^^^^^^^^^ @@ -89,6 +95,21 @@ API Changes - Removed the ``ModelGridPlotMixin`` class. [#2137] + - Removed the ``norm_radius`` keyword from ``EPSFBuilder``. [#2158] + + - Removed the ``build_epsf`` method from ``EPSFBuilder``. Use the + callable interface (``builder(stars)``) instead. [#2158] + + - Removed the deprecated ``FittableImageModel`` and ``EPSFModel`` + classes. Use ``ImagePSF`` instead. [#2158] + + - ``EPSFBuilder`` now returns an ``EPSFBuildResult`` dataclass + containing the ePSF, fitted stars, iteration count, convergence + status, and excluded star diagnostics. Tuple unpacking is still + supported for backward compatibility. [#2158] + + - ``LinkedEPSFStar`` no longer inherits from ``EPSFStars``. [#2158] + 2.3.0 (2025-09-15) ------------------ diff --git a/docs/conf.py b/docs/conf.py index 089d9e546..a7662792b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -74,8 +74,13 @@ extensions += [ 'sphinx_design', + 'sphinx_reredirects', ] +redirects = { + 'user_guide/epsf': 'epsf_building.html', +} + # This is added to the end of RST files - a good place to put # substitutions to be used globally. rst_epilog = """ diff --git a/docs/user_guide/epsf.rst b/docs/user_guide/epsf.rst deleted file mode 100644 index 6378cdcd3..000000000 --- a/docs/user_guide/epsf.rst +++ /dev/null @@ -1,315 +0,0 @@ -.. _build-epsf: - -Building an effective Point Spread Function (ePSF) -================================================== - -The ePSF --------- - -The instrumental PSF is a combination of many factors that are -generally difficult to model. `Anderson and King (2000; PASP 112, -1360) -`_ -showed that accurate stellar photometry and astrometry can be derived -by modeling the net PSF, which they call the effective PSF (ePSF). -The ePSF is an empirical model describing what fraction of a star's -light will land in a particular pixel. The constructed ePSF is -typically oversampled with respect to the detector pixels. - - -Building an ePSF ----------------- - -Photutils provides tools for building an ePSF following the -prescription of `Anderson and King (2000; PASP 112, 1360) -`_ -and subsequent enhancements detailed mainly -in `Anderson (2016; WFC3 ISR 2016-12) -`_. The -process involves iterating between the ePSF itself and the stars used to -build it. - -To begin, we must first define a large (e.g., several hundred) -sample of stars used to build the ePSF. Ideally these stars -should be bright (high S/N) and isolated to prevent contamination -from nearby stars. One may use the star-finding tools in -Photutils (e.g., :class:`~photutils.detection.DAOStarFinder` or -:class:`~photutils.detection.IRAFStarFinder`) to identify an initial -sample of stars. However, the step of creating a good sample of stars -generally requires visual inspection and manual selection to ensure -stars are sufficiently isolated and of good quality (e.g., no cosmic -rays, detector artifacts, etc.). To produce a good ePSF, one should have -a large sample (e.g., several hundred) of stars in order to fully sample -the PSF over the oversampled grid and to help reduce the effects of -noise. Otherwise, the resulting ePSF may have holes or may be noisy. - -Let's start by loading a simulated HST/WFC3 image in the F160W band:: - - >>> from photutils.datasets import load_simulated_hst_star_image - >>> hdu = load_simulated_hst_star_image() # doctest: +REMOTE_DATA - >>> data = hdu.data # doctest: +REMOTE_DATA - -The simulated image does not contain any background or noise, so let's add -those to the image:: - - >>> from photutils.datasets import make_noise_image - >>> data += make_noise_image(data.shape, distribution='gaussian', - ... mean=10.0, stddev=5.0, seed=123) # doctest: +REMOTE_DATA - -Let's show the image: - -.. plot:: - :include-source: - - import matplotlib.pyplot as plt - from astropy.visualization import simple_norm - from photutils.datasets import (load_simulated_hst_star_image, - make_noise_image) - - hdu = load_simulated_hst_star_image() - data = hdu.data - data += make_noise_image(data.shape, distribution='gaussian', mean=10.0, - stddev=5.0, seed=123) - norm = simple_norm(data, 'sqrt', percent=99.0) - plt.imshow(data, norm=norm, origin='lower', cmap='viridis') - -For this example we'll use the :func:`~photutils.detection.find_peaks` -function to identify the stars and their initial positions. We will -not use the centroiding option in -:func:`~photutils.detection.find_peaks` to simulate the effect of -having imperfect initial guesses for the positions of the stars. Here we -set the detection threshold value to 500.0 to select only the brightest -stars:: - - >>> from photutils.detection import find_peaks - >>> peaks_tbl = find_peaks(data, threshold=500.0) # doctest: +REMOTE_DATA - >>> peaks_tbl['peak_value'].info.format = '%.8g' # for consistent table output # doctest: +REMOTE_DATA - >>> print(peaks_tbl) # doctest: +REMOTE_DATA - id x_peak y_peak peak_value - --- ------ ------ ---------- - 1 849 2 1076.7026 - 2 182 4 1709.5671 - 3 324 4 3006.0086 - 4 100 9 1142.9915 - 5 824 9 1302.8604 - ... ... ... ... - 427 751 992 801.23834 - 428 114 994 1595.2804 - 429 299 994 648.18539 - 430 207 998 2810.6503 - 431 691 999 2611.0464 - Length = 431 rows - -Note that the stars are sufficiently separated in the simulated image -that we do not need to exclude any stars due to crowding. In practice -this step will require some manual inspection and selection. - -Next, we need to extract cutouts of the stars using the -:func:`~photutils.psf.extract_stars` function. This function requires -a table of star positions either in pixel or sky coordinates. For -this example we are using the pixel coordinates, which need to be in -table columns called simply ``x`` and ``y``. - -We plan to extract 25 x 25 pixel cutouts of our selected stars, so -let's explicitly exclude stars that are too close to the image -boundaries (because they cannot be extracted):: - - >>> size = 25 - >>> hsize = (size - 1) / 2 - >>> x = peaks_tbl['x_peak'] # doctest: +REMOTE_DATA - >>> y = peaks_tbl['y_peak'] # doctest: +REMOTE_DATA - >>> mask = ((x > hsize) & (x < (data.shape[1] -1 - hsize)) & - ... (y > hsize) & (y < (data.shape[0] -1 - hsize))) # doctest: +REMOTE_DATA - -Now let's create the table of good star positions:: - - >>> from astropy.table import Table - >>> stars_tbl = Table() - >>> stars_tbl['x'] = x[mask] # doctest: +REMOTE_DATA - >>> stars_tbl['y'] = y[mask] # doctest: +REMOTE_DATA - -The star cutouts from which we build the ePSF must have the background -subtracted. Here we'll use the sigma-clipped median value as the -background level. If the background in the image varies across the -image, one should use more sophisticated methods (e.g., -`~photutils.background.Background2D`). - -Let's subtract the background from the image:: - - >>> from astropy.stats import sigma_clipped_stats - >>> mean_val, median_val, std_val = sigma_clipped_stats(data, sigma=2.0) # doctest: +REMOTE_DATA - >>> data -= median_val # doctest: +REMOTE_DATA - -The :func:`~photutils.psf.extract_stars` function requires the input -data as an `~astropy.nddata.NDData` object. An -`~astropy.nddata.NDData` object is easy to create from our data -array:: - - >>> from astropy.nddata import NDData - >>> nddata = NDData(data=data) # doctest: +REMOTE_DATA - -We are now ready to create our star cutouts using the -:func:`~photutils.psf.extract_stars` function. For this simple -example we are extracting stars from a single image using a single -catalog. The :func:`~photutils.psf.extract_stars` can also extract -stars from multiple images using a separate catalog for each image or -a single catalog. When using a single catalog, the star positions -must be in sky coordinates (as `~astropy.coordinates.SkyCoord` -objects) and the `~astropy.nddata.NDData` objects must contain valid -`~astropy.wcs.WCS` objects. In the case of using multiple images -(i.e., dithered images) and a single catalog, the same physical star -will be "linked" across images, meaning it will be constrained to have -the same sky coordinate in each input image. - -Let's extract the 25 x 25 pixel cutouts of our selected stars:: - - >>> from photutils.psf import extract_stars - >>> stars = extract_stars(nddata, stars_tbl, size=25) # doctest: +REMOTE_DATA - -The function returns a `~photutils.psf.EPSFStars` object containing -the cutouts of our selected stars. The function extracted 403 stars, -from which we'll build our ePSF. Let's show the first 25 of them: - -.. doctest-skip:: - - >>> import matplotlib.pyplot as plt - >>> from astropy.visualization import simple_norm - >>> nrows = 5 - >>> ncols = 5 - >>> fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), - ... squeeze=True) - >>> ax = ax.ravel() - >>> for i in range(nrows * ncols): - ... norm = simple_norm(stars[i], 'log', percent=99.0) - ... ax[i].imshow(stars[i], norm=norm, origin='lower', cmap='viridis') - -.. plot:: - - import matplotlib.pyplot as plt - from astropy.nddata import NDData - from astropy.stats import sigma_clipped_stats - from astropy.table import Table - from astropy.visualization import simple_norm - from photutils.datasets import (load_simulated_hst_star_image, - make_noise_image) - from photutils.detection import find_peaks - from photutils.psf import extract_stars - - hdu = load_simulated_hst_star_image() - data = hdu.data - data += make_noise_image(data.shape, distribution='gaussian', mean=10.0, - stddev=5.0, seed=123) - - peaks_tbl = find_peaks(data, threshold=500.0) - - size = 25 - hsize = (size - 1) / 2 - x = peaks_tbl['x_peak'] - y = peaks_tbl['y_peak'] - mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) - & (y > hsize) & (y < (data.shape[0] - 1 - hsize))) - - stars_tbl = Table() - stars_tbl['x'] = x[mask] - stars_tbl['y'] = y[mask] - - mean_val, median_val, std_val = sigma_clipped_stats(data, sigma=2.0) - data -= median_val - - nddata = NDData(data=data) - - stars = extract_stars(nddata, stars_tbl, size=25) - - nrows = 5 - ncols = 5 - fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), - squeeze=True) - ax = ax.ravel() - for i in range(nrows * ncols): - norm = simple_norm(stars[i], 'log', percent=99.0) - ax[i].imshow(stars[i], norm=norm, origin='lower', cmap='viridis') - -With the star cutouts in hand, we are ready to construct the ePSF with -the :class:`~photutils.psf.EPSFBuilder` class. We'll create an ePSF -with an oversampling factor of 4.0. Here we limit the maximum number of -iterations to 3 (to limit its run time), but in practice one should use -about 10 or more iterations. The :class:`~photutils.psf.EPSFBuilder` -class has many other options to control the ePSF build process, -including changing the centering function, the smoothing kernel, and the -centering accuracy. Please see the :class:`~photutils.psf.EPSFBuilder` -documentation for further details. - -We first initialize an :class:`~photutils.psf.EPSFBuilder` instance -with our desired parameters and then input the cutouts of our selected -stars to the instance:: - - >>> from photutils.psf import EPSFBuilder - >>> epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, - ... progress_bar=False) # doctest: +REMOTE_DATA - >>> epsf, fitted_stars = epsf_builder(stars) # doctest: +REMOTE_DATA - -The returned values are the ePSF, as an -:class:`~photutils.psf.EPSFModel` object, and our input stars fitted -with the constructed ePSF, as a new :class:`~photutils.psf.EPSFStars` -object with fitted star positions and fluxes. - -Finally, let's show the constructed ePSF: - -.. doctest-skip:: - - >>> import matplotlib.pyplot as plt - >>> from astropy.visualization import simple_norm - >>> norm = simple_norm(epsf.data, 'log', percent=99.0) - >>> plt.imshow(epsf.data, norm=norm, origin='lower', cmap='viridis') - >>> plt.colorbar() - -.. plot:: - - import matplotlib.pyplot as plt - from astropy.nddata import NDData - from astropy.stats import sigma_clipped_stats - from astropy.table import Table - from astropy.visualization import simple_norm - from photutils.datasets import (load_simulated_hst_star_image, - make_noise_image) - from photutils.detection import find_peaks - from photutils.psf import EPSFBuilder, extract_stars - - hdu = load_simulated_hst_star_image() - data = hdu.data - data += make_noise_image(data.shape, distribution='gaussian', mean=10.0, - stddev=5.0, seed=123) - - peaks_tbl = find_peaks(data, threshold=500.0) - - size = 25 - hsize = (size - 1) / 2 - x = peaks_tbl['x_peak'] - y = peaks_tbl['y_peak'] - mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) - & (y > hsize) & (y < (data.shape[0] - 1 - hsize))) - - stars_tbl = Table() - stars_tbl['x'] = x[mask] - stars_tbl['y'] = y[mask] - - mean_val, median_val, std_val = sigma_clipped_stats(data, sigma=2.0) - data -= median_val - - nddata = NDData(data=data) - - stars = extract_stars(nddata, stars_tbl, size=25) - - epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, - progress_bar=False) - epsf, fitted_stars = epsf_builder(stars) - - norm = simple_norm(epsf.data, 'log', percent=99.0) - plt.imshow(epsf.data, norm=norm, origin='lower', cmap='viridis') - plt.colorbar() - -The :class:`~photutils.psf.EPSFModel` object is a subclass of -:class:`~photutils.psf.FittableImageModel`, thus it can be used -as a PSF model for the :ref:`PSF-fitting machinery in Photutils -` (i.e., `~photutils.psf.PSFPhotometry` or -`~photutils.psf.IterativePSFPhotometry`). diff --git a/docs/user_guide/epsf_building.rst b/docs/user_guide/epsf_building.rst new file mode 100644 index 000000000..c0d31ac17 --- /dev/null +++ b/docs/user_guide/epsf_building.rst @@ -0,0 +1,448 @@ +.. _build-epsf: + +Building an effective Point Spread Function (ePSF) +================================================== + +The ePSF +-------- + +The instrumental PSF is a combination of many factors that are +generally difficult to model. `Anderson and King 2000 (PASP 112, 1360) +`_ +showed that accurate stellar photometry and astrometry can be derived +by modeling the net PSF, which they call the effective PSF (ePSF). The +ePSF is an empirical model describing what fraction of a star's light +will land in a particular pixel. The constructed ePSF is typically +oversampled with respect to the detector pixels. + +The oversampling in the ePSF is crucial because it captures the PSF +pixel phase effect. Since stars can land at fractional pixel positions +on the detector, the PSF appearance varies depending on the star's +position within a pixel. By building an oversampled ePSF, we capture +this phase information across the full pixel-to-pixel variation. +This allows for more accurate PSF modeling and improved photometric +measurements, as the PSF can be interpolated to the exact position of +any star. + + +Building an ePSF +---------------- + +Photutils provides tools for building an ePSF following the +prescription of `Anderson and King 2000 (PASP 112, 1360) +`_ +and subsequent enhancements detailed mainly +in `Anderson 2016 (WFC3 ISR 2016-12) +`_. +The process iteratively refines the ePSF model and star positions: the +current ePSF is fitted to the stars to improve their centers, and then +the ePSF is rebuilt using the improved star positions. + +To begin, we must first define a sample of stars used to build the +ePSF. Ideally these stars should be bright (high S/N) and isolated to +prevent contamination from nearby stars. One may use the star-finding +tools in Photutils (e.g., :class:`~photutils.detection.DAOStarFinder` +or :class:`~photutils.detection.IRAFStarFinder`) to identify an initial +sample of stars. However, the step of creating a good sample of stars +generally requires visual inspection and manual selection to ensure +stars are sufficiently isolated and of good quality (e.g., no cosmic +rays, detector artifacts, etc.). To produce a good ePSF, one should have +a reasonably large sample of stars (e.g., several hundred) in order to +fully sample the PSF over the oversampled grid and to help reduce the +effects of noise. Otherwise, the resulting ePSF may have holes or may be +noisy. + +Let's start by loading a simulated HST/WFC3 image in the F160W band:: + + >>> from photutils.datasets import load_simulated_hst_star_image + >>> hdu = load_simulated_hst_star_image() # doctest: +REMOTE_DATA + >>> data = hdu.data # doctest: +REMOTE_DATA + +The simulated image does not contain any background or noise, so let's +add those to the image:: + + >>> from photutils.datasets import make_noise_image + >>> data += make_noise_image(data.shape, distribution='gaussian', + ... mean=10.0, stddev=5.0, seed=123) # doctest: +REMOTE_DATA + +Let's show the image: + +.. plot:: + :include-source: + + import matplotlib.pyplot as plt + from astropy.visualization import simple_norm + from photutils.datasets import (load_simulated_hst_star_image, + make_noise_image) + + hdu = load_simulated_hst_star_image() + data = hdu.data + data += make_noise_image(data.shape, distribution='gaussian', mean=10.0, + stddev=5.0, seed=123) + norm = simple_norm(data, 'sqrt', percent=99.0) + plt.imshow(data, norm=norm, origin='lower', cmap='viridis') + +For this example we'll use the :func:`~photutils.detection.find_peaks` +function to identify the stars and their initial positions. We will not +use the centroiding option in :func:`~photutils.detection.find_peaks` +to simulate the effect of having imperfect initial guesses for the +positions of the stars. Here we set the detection threshold value to +500.0 to select only the brightest stars:: + + >>> from photutils.detection import find_peaks + >>> peaks_tbl = find_peaks(data, threshold=500.0) # doctest: +REMOTE_DATA + >>> peaks_tbl['peak_value'].info.format = '%.8g' # for consistent table output # doctest: +REMOTE_DATA + >>> print(peaks_tbl) # doctest: +REMOTE_DATA + id x_peak y_peak peak_value + --- ------ ------ ---------- + 1 849 2 1076.7026 + 2 182 4 1709.5671 + 3 324 4 3006.0086 + 4 100 9 1142.9915 + 5 824 9 1302.8604 + ... ... ... ... + 427 751 992 801.23834 + 428 114 994 1595.2804 + 429 299 994 648.18539 + 430 207 998 2810.6503 + 431 691 999 2611.0464 + Length = 431 rows + +Note that the stars are sufficiently separated in the simulated image +that we do not need to exclude any stars due to crowding. In practice +this step will require some manual inspection and selection. + + +Extracting Star Cutouts +----------------------- + +Next, we need to extract cutouts of the stars using the +:func:`~photutils.psf.extract_stars` function. This function requires +a table of star positions either in pixel or sky coordinates. For this +example we are using pixel coordinates, which need to be in table +columns called ``x`` and ``y``. + +We'll extract 25 x 25 pixel cutouts of our selected stars. Let's +explicitly exclude stars that are too close to the image boundaries +(because they cannot be extracted):: + + >>> size = 25 + >>> hsize = (size - 1) / 2 + >>> x = peaks_tbl['x_peak'] # doctest: +REMOTE_DATA + >>> y = peaks_tbl['y_peak'] # doctest: +REMOTE_DATA + >>> mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) & + ... (y > hsize) & (y < (data.shape[0] - 1 - hsize))) # doctest: +REMOTE_DATA + +Now let's create the table of good star positions:: + + >>> from astropy.table import Table + >>> stars_tbl = Table() + >>> stars_tbl['x'] = x[mask] # doctest: +REMOTE_DATA + >>> stars_tbl['y'] = y[mask] # doctest: +REMOTE_DATA + +The star cutouts from which we build the ePSF must have the +background subtracted. Here we'll use the sigma-clipped median value +as the background level. If the background in the image varies +across the image, one should use more sophisticated methods (e.g., +`~photutils.background.Background2D`). + +Let's subtract the background from the image:: + + >>> from astropy.stats import sigma_clipped_stats + >>> mean_val, median_val, std_val = sigma_clipped_stats(data, sigma=2.0) # doctest: +REMOTE_DATA + >>> data -= median_val # doctest: +REMOTE_DATA + +The :func:`~photutils.psf.extract_stars` function requires the input +data as an `~astropy.nddata.NDData` object. An `~astropy.nddata.NDData` +object is easy to create from our data array:: + + >>> from astropy.nddata import NDData + >>> nddata = NDData(data=data) # doctest: +REMOTE_DATA + +We are now ready to create our star cutouts using the +:func:`~photutils.psf.extract_stars` function. For this simple example +we are extracting stars from a single image using a single catalog. The +:func:`~photutils.psf.extract_stars` function can also extract stars +from multiple images using a separate catalog for each image or a single +catalog. When using a single catalog with multiple images, the star +positions must be in sky coordinates (as `~astropy.coordinates.SkyCoord` +objects) and the `~astropy.nddata.NDData` objects must contain valid +`~astropy.wcs.WCS` objects. In the case of using multiple images (i.e., +dithered images) and a single catalog, the same physical star will be +"linked" across images, meaning it will be constrained to have the same +sky coordinate in each input image. + +Let's extract the 25 x 25 pixel cutouts of our selected stars:: + + >>> from photutils.psf import extract_stars + >>> stars = extract_stars(nddata, stars_tbl, size=25) # doctest: +REMOTE_DATA + +The function returns an `~photutils.psf.EPSFStars` object containing the +cutouts of our selected stars. The function extracted 403 stars, from +which we'll build our ePSF. Let's show the first 25 of them: + +.. doctest-skip:: + + >>> import matplotlib.pyplot as plt + >>> from astropy.visualization import simple_norm + >>> nrows = 5 + >>> ncols = 5 + >>> fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), + ... squeeze=True) + >>> ax = ax.ravel() + >>> for i in range(nrows * ncols): + ... norm = simple_norm(stars[i], 'log', percent=99.0) + ... ax[i].imshow(stars[i], norm=norm, origin='lower', cmap='viridis') + +.. plot:: + + import matplotlib.pyplot as plt + from astropy.nddata import NDData + from astropy.stats import sigma_clipped_stats + from astropy.table import Table + from astropy.visualization import simple_norm + from photutils.datasets import (load_simulated_hst_star_image, + make_noise_image) + from photutils.detection import find_peaks + from photutils.psf import extract_stars + + hdu = load_simulated_hst_star_image() + data = hdu.data + data += make_noise_image(data.shape, distribution='gaussian', mean=10.0, + stddev=5.0, seed=123) + + peaks_tbl = find_peaks(data, threshold=500.0) + + size = 25 + hsize = (size - 1) / 2 + x = peaks_tbl['x_peak'] + y = peaks_tbl['y_peak'] + mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) + & (y > hsize) & (y < (data.shape[0] - 1 - hsize))) + + stars_tbl = Table() + stars_tbl['x'] = x[mask] + stars_tbl['y'] = y[mask] + + mean_val, median_val, std_val = sigma_clipped_stats(data, sigma=2.0) + data -= median_val + + nddata = NDData(data=data) + + stars = extract_stars(nddata, stars_tbl, size=25) + + nrows = 5 + ncols = 5 + fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20), + squeeze=True) + ax = ax.ravel() + for i in range(nrows * ncols): + norm = simple_norm(stars[i], 'log', percent=99.0) + ax[i].imshow(stars[i], norm=norm, origin='lower', cmap='viridis') + + +Constructing the ePSF +--------------------- + +With the star cutouts, we are ready to construct the ePSF with the +:class:`~photutils.psf.EPSFBuilder` class. We'll create an ePSF with +an oversampling factor of 4. Here we limit the maximum number of +iterations to 3 (to limit its run time), but in practice one should use +about 10 or more iterations. The :class:`~photutils.psf.EPSFBuilder` +class has many options to control the ePSF build process, including +changing the recentering function, the smoothing kernel, and the +convergence accuracy. Please see the :class:`~photutils.psf.EPSFBuilder` +documentation for further details. + +We first initialize an :class:`~photutils.psf.EPSFBuilder` instance with +our desired parameters and then input the cutouts of our selected stars +to the instance:: + + >>> from photutils.psf import EPSFBuilder + >>> epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, + ... progress_bar=False) # doctest: +REMOTE_DATA + >>> result = epsf_builder(stars) # doctest: +REMOTE_DATA + +The :class:`~photutils.psf.EPSFBuilder` returns an +`~photutils.psf.EPSFBuildResult` object containing the constructed ePSF, +the fitted stars, and detailed information about the build process. This +result object supports tuple unpacking, so both of the following work:: + + >>> # New style: access result attributes + >>> epsf = result.epsf # doctest: +REMOTE_DATA + >>> fitted_stars = result.fitted_stars # doctest: +REMOTE_DATA + + >>> # Old style: tuple unpacking still works + >>> epsf, fitted_stars = epsf_builder(stars) # doctest: +REMOTE_DATA + +The `~photutils.psf.EPSFBuildResult` object provides useful diagnostic +information about the build process:: + + >>> result.converged # doctest: +REMOTE_DATA + np.False_ + >>> result.iterations # doctest: +REMOTE_DATA + 3 + >>> result.n_excluded_stars # doctest: +REMOTE_DATA + 0 + +The returned ``epsf`` is an `~photutils.psf.ImagePSF` object, and +``fitted_stars`` is a new `~photutils.psf.EPSFStars` object with the +updated star positions and fluxes from fitting the final ePSF model. + +Finally, let's show the constructed ePSF: + +.. doctest-skip:: + + >>> import matplotlib.pyplot as plt + >>> from astropy.visualization import simple_norm + >>> norm = simple_norm(epsf.data, 'log', percent=99.0) + >>> plt.imshow(epsf.data, norm=norm, origin='lower', cmap='viridis') + >>> plt.colorbar() + +.. plot:: + + import matplotlib.pyplot as plt + from astropy.nddata import NDData + from astropy.stats import sigma_clipped_stats + from astropy.table import Table + from astropy.visualization import simple_norm + from photutils.datasets import (load_simulated_hst_star_image, + make_noise_image) + from photutils.detection import find_peaks + from photutils.psf import EPSFBuilder, extract_stars + + hdu = load_simulated_hst_star_image() + data = hdu.data + data += make_noise_image(data.shape, distribution='gaussian', mean=10.0, + stddev=5.0, seed=123) + + peaks_tbl = find_peaks(data, threshold=500.0) + + size = 25 + hsize = (size - 1) / 2 + x = peaks_tbl['x_peak'] + y = peaks_tbl['y_peak'] + mask = ((x > hsize) & (x < (data.shape[1] - 1 - hsize)) + & (y > hsize) & (y < (data.shape[0] - 1 - hsize))) + + stars_tbl = Table() + stars_tbl['x'] = x[mask] + stars_tbl['y'] = y[mask] + + mean_val, median_val, std_val = sigma_clipped_stats(data, sigma=2.0) + data -= median_val + + nddata = NDData(data=data) + + stars = extract_stars(nddata, stars_tbl, size=25) + + epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, + progress_bar=False) + epsf, fitted_stars = epsf_builder(stars) + + norm = simple_norm(epsf.data, 'log', percent=99.0) + plt.imshow(epsf.data, norm=norm, origin='lower', cmap='viridis') + plt.colorbar() + +The `~photutils.psf.ImagePSF` object can be used as a PSF model for +the :ref:`PSF-fitting machinery in Photutils +` (i.e., `~photutils.psf.PSFPhotometry` or +`~photutils.psf.IterativePSFPhotometry`). + + +Customizing the ePSF Builder +---------------------------- + +The :class:`~photutils.psf.EPSFBuilder` class provides several options +to customize the ePSF build process. + +Smoothing Kernel +^^^^^^^^^^^^^^^^ + +The ``smoothing_kernel`` parameter controls the smoothing applied to +the ePSF during each iteration. The smoothing helps to reduce noise in +the ePSF, especially when the number of stars is small. The default +is ``'quartic'``, which uses a fourth-degree polynomial kernel. This +kernel was initial developed by Anderson and King for HST data with an +ePSF oversampling factor of 4. It is designed to provide a good balance +between smoothing and preserving the shape of the ePSF. + +You can also use ``'quadratic'`` for a second-degree polynomial kernel, +provide a custom 2D array, or set it to `None` for no smoothing:: + + >>> epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, + ... smoothing_kernel='quadratic', + ... progress_bar=False) # doctest: +REMOTE_DATA + +Using Custom ePSF Fitters +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`~photutils.psf.EPSFBuilder` uses an +`~photutils.psf.EPSFFitter` object to fit the ePSF to the stars during +each iteration. You can customize the fitting process by providing your +own `~photutils.psf.EPSFFitter` instance:: + + >>> from photutils.psf import EPSFFitter + >>> fitter = EPSFFitter(fit_boxsize=7) # doctest: +REMOTE_DATA + >>> epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, + ... fitter=fitter, + ... progress_bar=False) # doctest: +REMOTE_DATA + +The ``fit_boxsize`` parameter specifies the size of the box centered on +each star used for fitting. Using a smaller box can speed up the fitting +process while still capturing the core of the PSF. + +Sigma Clipping +^^^^^^^^^^^^^^ + +The ``sigma_clip`` parameter controls the sigma clipping applied when +stacking the ePSF residuals in each iteration. The default uses sigma +clipping with ``sigma=3.0`` and ``maxiters=10``. You can provide your +own `~astropy.stats.SigmaClip` instance to customize this behavior:: + + >>> from astropy.stats import SigmaClip + >>> sigclip = SigmaClip(sigma=2.5, maxiters=5) # doctest: +REMOTE_DATA + >>> epsf_builder = EPSFBuilder(oversampling=4, maxiters=3, + ... sigma_clip=sigclip, + ... progress_bar=False) # doctest: +REMOTE_DATA + + +Including Weights +----------------- + +If your input `~astropy.nddata.NDData` object contains uncertainty +information, the :func:`~photutils.psf.extract_stars` function will +automatically create weights for each star cutout. These weights are +used during the ePSF fitting process to give more weight to pixels with +lower uncertainties. + +To include weights, provide an ``uncertainty`` attribute in +your `~astropy.nddata.NDData` object. The uncertainty can be +any of the `~astropy.nddata.NDUncertainty` subclasses (e.g., +`~astropy.nddata.StdDevUncertainty`):: + + >>> from astropy.nddata import StdDevUncertainty + >>> uncertainty = StdDevUncertainty(np.sqrt(np.abs(data))) # doctest: +REMOTE_DATA, +SKIP + >>> nddata = NDData(data=data, uncertainty=uncertainty) # doctest: +REMOTE_DATA, +SKIP + + +Linked Stars for Dithered Images +-------------------------------- + +When building an ePSF from multiple dithered images, you can link +stars across images to ensure they are constrained to have the same +sky coordinates. This is done by providing a single catalog with sky +coordinates and multiple `~astropy.nddata.NDData` objects, each with a +valid WCS. + +The :func:`~photutils.psf.extract_stars` function will create +`~photutils.psf.LinkedEPSFStar` objects that link the corresponding star +cutouts from each image. During the ePSF building process, linked stars +are constrained to have the same sky coordinate across all images. + +.. doctest-skip:: + + >>> from astropy.coordinates import SkyCoord + >>> catalog = Table() + >>> catalog['skycoord'] = SkyCoord(ra=[...]*u.deg, dec=[...]*u.deg) + >>> stars = extract_stars([nddata1, nddata2], catalog, size=25) diff --git a/docs/user_guide/index.rst b/docs/user_guide/index.rst index 43fcde3b1..05fe13484 100644 --- a/docs/user_guide/index.rst +++ b/docs/user_guide/index.rst @@ -58,7 +58,7 @@ PSF Photometry :maxdepth: 1 psf.rst - epsf.rst + epsf_building.rst grouping.rst PSF Matching diff --git a/docs/whats_new/3.0.rst b/docs/whats_new/3.0.rst index 5d05033e9..fd3cf67f0 100644 --- a/docs/whats_new/3.0.rst +++ b/docs/whats_new/3.0.rst @@ -115,16 +115,65 @@ corresponding source. This makes it easy to interpret which conditions were encountered during PSF fitting for each source. +Refactored ePSF Building +========================= + +The ePSF building tools have been significantly refactored for improved +robustness, better diagnostics, and a cleaner API. + +New EPSFBuildResult class +------------------------- + +:class:`~photutils.psf.EPSFBuilder` now returns an +:class:`~photutils.psf.EPSFBuildResult` dataclass instead of a plain +tuple. The result object provides structured access to detailed build +diagnostics:: + + >>> from photutils.psf import EPSFBuilder + >>> builder = EPSFBuilder(oversampling=4) + >>> result = builder(stars) + >>> result.epsf # the constructed ePSF (ImagePSF) + >>> result.fitted_stars # stars with updated centers/fluxes + >>> result.iterations # number of iterations performed + >>> result.converged # whether the build converged + >>> result.final_center_accuracy # max center shift in last iteration + +Backward compatibility is maintained. Existing code using tuple +unpacking will continue to work:: + + >>> epsf, stars = builder(stars) + +Improved star exclusion handling +-------------------------------- + +Stars that repeatedly fail fitting are now automatically excluded from +subsequent iterations, with informative warnings indicating the reason +for exclusion (e.g., fit region extends beyond the cutout, or the fit +did not converge). The number of excluded stars and their indices are +reported in the :class:`~photutils.psf.EPSFBuildResult`. + +New ImagePSF shape property +--------------------------- + +:class:`~photutils.psf.ImagePSF` now has a ``shape`` property that +returns the shape of the (oversampled) PSF data array. + + Removed Deprecations ==================== -The following previously deprecated features have been removed: +The following previously deprecated features from the +``background`` package have been removed: * The ``Background2D`` ``edge_method`` keyword argument. * The ``Background2D`` ``background_mesh_masked``, ``background_rms_mesh_masked``, and ``mesh_nmasked`` properties. * The ``BkgZoomInterpolator`` ``grid_mode`` keyword argument. +For the ``psf`` package, the previously deprecated +``FittableImageModel`` and ``EPSFModel`` classes have been removed. Use +:class:`~photutils.psf.ImagePSF` instead. + New Deprecations ================ @@ -145,3 +194,16 @@ be preserved in the default resizing behavior of ``Background2D``. The ``grid_from_epsfs`` helper function is now deprecated. This function creates a ``GriddedPSFModel`` from a list of ePSFs. Instead, use the ``GriddedPSFModel`` class directly. + + +Breaking Changes +================ + +The ``EPSFBuilder.build_epsf()`` method has been removed. Use the +``EPSFBuilder`` callable interface instead (i.e., ``builder(stars)``). + +The ``norm_radius`` keyword has been removed from +:class:`~photutils.psf.EPSFBuilder`. This keyword is no +longer relevant because the ePSF is now built directly as an +:class:`~photutils.psf.ImagePSF`, which does not use a normalization +radius. diff --git a/photutils/psf/__init__.py b/photutils/psf/__init__.py index 2c6b2247f..e0ea4db6b 100644 --- a/photutils/psf/__init__.py +++ b/photutils/psf/__init__.py @@ -4,7 +4,7 @@ photometry. """ -from .epsf import * # noqa: F401, F403 +from .epsf_builder import * # noqa: F401, F403 from .epsf_stars import * # noqa: F401, F403 from .flags import * # noqa: F401, F403 from .functional_models import * # noqa: F401, F403 diff --git a/photutils/psf/_components.py b/photutils/psf/_components.py index 6c3023655..ed027e516 100644 --- a/photutils/psf/_components.py +++ b/photutils/psf/_components.py @@ -1525,7 +1525,7 @@ def assemble_results_table(self, init_params, fit_params, data_shape, Returns ------- results_tbl : `~astropy.table.QTable` - Comprehensive results table containing: + Results table containing: - Source ID and group ID - Initial parameter estimates diff --git a/photutils/psf/epsf.py b/photutils/psf/epsf.py deleted file mode 100644 index c1b628561..000000000 --- a/photutils/psf/epsf.py +++ /dev/null @@ -1,852 +0,0 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst -""" -Tools for building and fitting an effective PSF (ePSF) based on Anderson -and King (2000; PASP 112, 1360) and Anderson (2016; WFC3 ISR 2016-12). -""" - -import copy -import warnings - -import numpy as np -from astropy.modeling.fitting import TRFLSQFitter -from astropy.nddata import NoOverlapError, PartialOverlapError, overlap_slices -from astropy.stats import SigmaClip -from astropy.utils.exceptions import AstropyUserWarning -from scipy.ndimage import convolve - -from photutils.centroids import centroid_com -from photutils.psf.epsf_stars import EPSFStar, EPSFStars, LinkedEPSFStar -from photutils.psf.image_models import ImagePSF, _LegacyEPSFModel -from photutils.psf.utils import _interpolate_missing_data -from photutils.utils._parameters import (SigmaClipSentinelDefault, as_pair, - create_default_sigmaclip) -from photutils.utils._progress_bars import add_progress_bar -from photutils.utils._round import py2intround -from photutils.utils._stats import nanmedian - -__all__ = ['EPSFBuilder', 'EPSFFitter'] - - -SIGMA_CLIP = SigmaClipSentinelDefault(sigma=3.0, maxiters=10) - - -class EPSFFitter: - """ - Class to fit an ePSF model to one or more stars. - - Parameters - ---------- - fitter : `astropy.modeling.fitting.Fitter`, optional - A `~astropy.modeling.fitting.Fitter` object. If `None`, then the - default `~astropy.modeling.fitting.TRFLSQFitter` will be used. - - fit_boxsize : int, tuple of int, or `None`, optional - The size (in pixels) of the box centered on the star to be used - for ePSF fitting. This allows using only a small number of - central pixels of the star (i.e., where the star is brightest) - for fitting. If ``fit_boxsize`` is a scalar then a square box of - size ``fit_boxsize`` will be used. If ``fit_boxsize`` has two - elements, they must be in ``(ny, nx)`` order. ``fit_boxsize`` - must have odd values and be greater than or equal to 3 for both - axes. If `None`, the fitter will use the entire star image. - - **fitter_kwargs : dict, optional - Any additional keyword arguments (except ``x``, ``y``, ``z``, or - ``weights``) to be passed directly to the ``__call__()`` method - of the input ``fitter``. - """ - - def __init__(self, *, fitter=None, fit_boxsize=5, - **fitter_kwargs): - - if fitter is None: - fitter = TRFLSQFitter() - self.fitter = fitter - self.fitter_has_fit_info = hasattr(self.fitter, 'fit_info') - self.fit_boxsize = as_pair('fit_boxsize', fit_boxsize, - lower_bound=(3, 0), check_odd=True) - - # remove any fitter keyword arguments that we need to set - remove_kwargs = ['x', 'y', 'z', 'weights'] - fitter_kwargs = copy.deepcopy(fitter_kwargs) - for kwarg in remove_kwargs: - if kwarg in fitter_kwargs: - del fitter_kwargs[kwarg] - self.fitter_kwargs = fitter_kwargs - - def __call__(self, epsf, stars): - """ - Fit an ePSF model to stars. - - Parameters - ---------- - epsf : `ImagePSF` - An ePSF model to be fitted to the stars. - - stars : `EPSFStars` object - The stars to be fit. The center coordinates for each star - should be as close as possible to actual centers. For stars - than contain weights, a weighted fit of the ePSF to the star - will be performed. - - Returns - ------- - fitted_stars : `EPSFStars` object - The fitted stars. The ePSF-fitted center position and flux - are stored in the ``center`` (and ``cutout_center``) and - ``flux`` attributes. - """ - if len(stars) == 0: - return stars - - if not isinstance(epsf, ImagePSF): - msg = 'The input epsf must be an ImagePSF' - raise TypeError(msg) - - epsf = _LegacyEPSFModel(epsf.data, flux=epsf.flux, x_0=epsf.x_0, - y_0=epsf.y_0, oversampling=epsf.oversampling, - fill_value=epsf.fill_value) - - # make a copy of the input ePSF - epsf = epsf.copy() - - # perform the fit - fitted_stars = [] - for star in stars: - if isinstance(star, EPSFStar): - fitted_star = self._fit_star(epsf, star, self.fitter, - self.fitter_kwargs, - self.fitter_has_fit_info, - self.fit_boxsize) - - elif isinstance(star, LinkedEPSFStar): - fitted_star = [] - for linked_star in star: - fitted_star.append( - self._fit_star(epsf, linked_star, self.fitter, - self.fitter_kwargs, - self.fitter_has_fit_info, - self.fit_boxsize)) - - fitted_star = LinkedEPSFStar(fitted_star) - fitted_star.constrain_centers() - - else: - msg = ('stars must contain only EPSFStar and/or ' - 'LinkedEPSFStar objects') - raise TypeError(msg) - - fitted_stars.append(fitted_star) - - return EPSFStars(fitted_stars) - - def _fit_star(self, epsf, star, fitter, fitter_kwargs, - fitter_has_fit_info, fit_boxsize): - """ - Fit an ePSF model to a single star. - - The input ``epsf`` will usually be modified by the fitting - routine in this function. Make a copy before calling this - function if the original is needed. - """ - if fit_boxsize is not None: - try: - xcenter, ycenter = star.cutout_center - large_slc, _ = overlap_slices(star.shape, fit_boxsize, - (ycenter, xcenter), - mode='strict') - except (PartialOverlapError, NoOverlapError): - warnings.warn(f'The star at ({star.center[0]}, ' - f'{star.center[1]}) cannot be fit because ' - 'its fitting region extends beyond the star ' - 'cutout image.', AstropyUserWarning) - - star = copy.deepcopy(star) - star._fit_error_status = 1 - - return star - - data = star.data[large_slc] - weights = star.weights[large_slc] - - # define the origin of the fitting region - x0 = large_slc[1].start - y0 = large_slc[0].start - else: - # use the entire cutout image - data = star.data - weights = star.weights - - # define the origin of the fitting region - x0 = 0 - y0 = 0 - - # Define positions in the undersampled grid. The fitter will - # evaluate on the defined interpolation grid, currently in the - # range [0, len(undersampled grid)]. - yy, xx = np.indices(data.shape, dtype=float) - xx = xx + x0 - star.cutout_center[0] - yy = yy + y0 - star.cutout_center[1] - - # define the initial guesses for fitted flux and shifts - epsf.flux = star.flux - epsf.x_0 = 0.0 - epsf.y_0 = 0.0 - - try: - fitted_epsf = fitter(model=epsf, x=xx, y=yy, z=data, - weights=weights, **fitter_kwargs) - except TypeError: - # fitter doesn't support weights - fitted_epsf = fitter(model=epsf, x=xx, y=yy, z=data, - **fitter_kwargs) - - fit_error_status = 0 - if fitter_has_fit_info: - fit_info = copy.copy(fitter.fit_info) - - if 'ierr' in fit_info and fit_info['ierr'] not in [1, 2, 3, 4]: - fit_error_status = 2 # fit solution was not found - else: - fit_info = None - - # compute the star's fitted position - x_center = star.cutout_center[0] + fitted_epsf.x_0.value - y_center = star.cutout_center[1] + fitted_epsf.y_0.value - - star = copy.deepcopy(star) - star.cutout_center = (x_center, y_center) - - # set the star's flux to the ePSF-fitted flux - star.flux = fitted_epsf.flux.value - - star._fit_info = fit_info - star._fit_error_status = fit_error_status - - return star - - -class EPSFBuilder: - """ - Class to build an effective PSF (ePSF). - - See `Anderson and King (2000; PASP 112, 1360) - `_ - and `Anderson (2016; WFC3 ISR 2016-12) - `_ - for details. - - Parameters - ---------- - oversampling : int or array_like (int) - The integer oversampling factor(s) of the ePSF relative to the - input ``stars`` along each axis. If ``oversampling`` is a scalar - then it will be used for both axes. If ``oversampling`` has two - elements, they must be in ``(y, x)`` order. - - shape : float, tuple of two floats, or `None`, optional - The shape of the output ePSF. If the ``shape`` is not `None`, it - will be derived from the sizes of the input ``stars`` and the - ePSF oversampling factor. If the size is even along any axis, - it will be made odd by adding one. The output ePSF will always - have odd sizes along both axes to ensure a well-defined central - pixel. - - smoothing_kernel : {'quartic', 'quadratic'}, 2D `~numpy.ndarray`, or `None` - The smoothing kernel to apply to the ePSF. The predefined - ``'quartic'`` and ``'quadratic'`` kernels are derived - from fourth and second degree polynomials, respectively. - Alternatively, a custom 2D array can be input. If `None` then no - smoothing will be performed. - - recentering_func : callable, optional - A callable object (e.g., function or class) that is used to - calculate the centroid of a 2D array. The callable must accept - a 2D `~numpy.ndarray`, have a ``mask`` keyword and optionally - ``error`` and ``oversampling`` keywords. The callable object - must return a tuple of two 1D `~numpy.ndarray` variables, - representing the x and y centroids. - - recentering_maxiters : int, optional - The maximum number of recentering iterations to perform during - each ePSF build iteration. - - fitter : `EPSFFitter` object, optional - A `EPSFFitter` object use to fit the ePSF to stars. If `None`, - then the default `EPSFFitter` will be used. To set custom fitter - options, input a new `EPSFFitter` object. See the `EPSFFitter` - documentation for options. - - maxiters : int, optional - The maximum number of iterations to perform. - - progress_bar : bool, option - Whether to print the progress bar during the build - iterations. The progress bar requires that the `tqdm - `_ optional dependency be installed. - - norm_radius : float, optional - The pixel radius over which the ePSF is normalized. - - recentering_boxsize : float or tuple of two floats, optional - The size (in pixels) of the box used to calculate the centroid - of the ePSF during each build iteration. If a single integer - number is provided, then a square box will be used. If two - values are provided, then they must be in ``(ny, nx)`` order. - ``recentering_boxsize`` must have odd values and be greater than - or equal to 3 for both axes. - - center_accuracy : float, optional - The desired accuracy for the centers of stars. The building - iterations will stop if the centers of all the stars change by - less than ``center_accuracy`` pixels between iterations. All - stars must meet this condition for the loop to exit. - - sigma_clip : `astropy.stats.SigmaClip` instance, optional - A `~astropy.stats.SigmaClip` object that defines the sigma - clipping parameters used to determine which pixels are ignored - when stacking the ePSF residuals in each iteration step. If - `None` then no sigma clipping will be performed. - - Notes - ----- - If your image contains NaN values, you may see better performance if - you have the `bottleneck`_ package installed. - - .. _bottleneck: https://github.com/pydata/bottleneck - """ - - def __init__(self, *, oversampling=4, shape=None, - smoothing_kernel='quartic', recentering_func=centroid_com, - recentering_maxiters=20, fitter=None, maxiters=10, - progress_bar=True, norm_radius=5.5, - recentering_boxsize=(5, 5), center_accuracy=1.0e-3, - sigma_clip=SIGMA_CLIP): - - if oversampling is None: - msg = "'oversampling' must be specified" - raise ValueError(msg) - self.oversampling = as_pair('oversampling', oversampling, - lower_bound=(0, 1)) - self._norm_radius = norm_radius - if shape is not None: - self.shape = as_pair('shape', shape, lower_bound=(0, 1)) - else: - self.shape = shape - - self.recentering_func = recentering_func - self.recentering_maxiters = recentering_maxiters - self.recentering_boxsize = as_pair('recentering_boxsize', - recentering_boxsize, - lower_bound=(3, 0), check_odd=True) - self.smoothing_kernel = smoothing_kernel - - if fitter is None: - fitter = EPSFFitter() - if not isinstance(fitter, EPSFFitter): - msg = 'fitter must be an EPSFFitter instance' - raise TypeError(msg) - self.fitter = fitter - - if center_accuracy <= 0.0: - msg = 'center_accuracy must be a positive number' - raise ValueError(msg) - self.center_accuracy_sq = center_accuracy**2 - - maxiters = int(maxiters) - if maxiters <= 0: - msg = 'maxiters must be a positive number' - raise ValueError(msg) - self.maxiters = maxiters - - self.progress_bar = progress_bar - - if sigma_clip is SIGMA_CLIP: - sigma_clip = create_default_sigmaclip(sigma=SIGMA_CLIP.sigma, - maxiters=SIGMA_CLIP.maxiters) - if not isinstance(sigma_clip, SigmaClip): - msg = 'sigma_clip must be an astropy.stats.SigmaClip instance' - raise TypeError(msg) - self._sigma_clip = sigma_clip - - # store each ePSF build iteration - self._epsf = [] - - def __call__(self, stars): - return self.build_epsf(stars) - - def _create_initial_epsf(self, stars): - """ - Create an initial `_LegacyEPSFModel` object. - - The initial ePSF data are all zeros. - - If ``shape`` is not specified, the shape of the ePSF data array - is determined from the shape of the input ``stars`` and the - oversampling factor. If the size is even along any axis, it will - be made odd by adding one. The output ePSF will always have odd - sizes along both axes to ensure a central pixel. - - Parameters - ---------- - stars : `EPSFStars` object - The stars used to build the ePSF. - - Returns - ------- - epsf : `_LegacyEPSFModel` - The initial ePSF model. - """ - norm_radius = self._norm_radius - oversampling = self.oversampling - shape = self.shape - - # define the ePSF shape - if shape is not None: - shape = as_pair('shape', shape, lower_bound=(0, 1), check_odd=True) - else: - # Stars class should have odd-sized dimensions, and thus we - # get the oversampled shape as oversampling * len + 1; if - # len=25, then newlen=101, for example. - x_shape = (np.ceil(stars._max_shape[1]) * oversampling[1] - + 1).astype(int) - y_shape = (np.ceil(stars._max_shape[0]) * oversampling[0] - + 1).astype(int) - - shape = np.array((y_shape, x_shape)) - - # verify odd sizes of shape - shape = [(i + 1) if i % 2 == 0 else i for i in shape] - - data = np.zeros(shape, dtype=float) - - # ePSF origin should be in the undersampled pixel units, not the - # oversampled grid units. The middle, fractional (as we wish for - # the center of the pixel, so the center should be at (v.5, w.5) - # detector pixels) value is simply the average of the two values - # at the extremes. - xcenter = stars._max_shape[1] / 2.0 - ycenter = stars._max_shape[0] / 2.0 - - return _LegacyEPSFModel(data=data, origin=(xcenter, ycenter), - oversampling=oversampling, - norm_radius=norm_radius) - - def _resample_residual(self, star, epsf): - """ - Compute a normalized residual image in the oversampled ePSF - grid. - - A normalized residual image is calculated by subtracting the - normalized ePSF model from the normalized star at the location - of the star in the undersampled grid. The normalized residual - image is then resampled from the undersampled star grid to the - oversampled ePSF grid. - - Parameters - ---------- - star : `EPSFStar` object - A single star object. - - epsf : `_LegacyEPSFModel` object - The ePSF model. - - Returns - ------- - image : 2D `~numpy.ndarray` - A 2D image containing the resampled residual image. The - image contains NaNs where there is no data. - """ - # Compute the normalized residual by subtracting the ePSF model - # from the normalized star at the location of the star in the - # undersampled grid. - - x = star._xidx_centered - y = star._yidx_centered - - stardata = (star._data_values_normalized - - epsf.evaluate(x=x, y=y, flux=1.0, x_0=0.0, y_0=0.0)) - - x = epsf.oversampling[1] * star._xidx_centered - y = epsf.oversampling[0] * star._yidx_centered - - epsf_xcenter, epsf_ycenter = (int((epsf.data.shape[1] - 1) / 2), - int((epsf.data.shape[0] - 1) / 2)) - xidx = py2intround(x + epsf_xcenter) - yidx = py2intround(y + epsf_ycenter) - - resampled_img = np.full(epsf.shape, np.nan) - - mask = np.logical_and(np.logical_and(xidx >= 0, xidx < epsf.shape[1]), - np.logical_and(yidx >= 0, yidx < epsf.shape[0])) - xidx_ = xidx[mask] - yidx_ = yidx[mask] - - resampled_img[yidx_, xidx_] = stardata[mask] - - return resampled_img - - def _resample_residuals(self, stars, epsf): - """ - Compute normalized residual images for all the input stars. - - Parameters - ---------- - stars : `EPSFStars` object - The stars used to build the ePSF. - - epsf : `_LegacyEPSFModel` object - The ePSF model. - - Returns - ------- - epsf_resid : 3D `~numpy.ndarray` - A 3D cube containing the resampled residual images. - """ - shape = (stars.n_good_stars, epsf.shape[0], epsf.shape[1]) - epsf_resid = np.zeros(shape) - for i, star in enumerate(stars.all_good_stars): - epsf_resid[i, :, :] = self._resample_residual(star, epsf) - - return epsf_resid - - def _smooth_epsf(self, epsf_data): - """ - Smooth the ePSF array by convolving it with a kernel. - - Parameters - ---------- - epsf_data : 2D `~numpy.ndarray` - A 2D array containing the ePSF image. - - Returns - ------- - result : 2D `~numpy.ndarray` - The smoothed (convolved) ePSF data. - """ - if self.smoothing_kernel is None: - return epsf_data - - # do this check first as comparing a ndarray to string causes a warning - if isinstance(self.smoothing_kernel, np.ndarray): - kernel = self.smoothing_kernel - - elif self.smoothing_kernel == 'quartic': - # from Polynomial2D fit with degree=4 to 5x5 array of - # zeros with 1.0 at the center - # Polynomial2D(4, c0_0=0.04163265, c1_0=-0.76326531, - # c2_0=0.99081633, c3_0=-0.4, c4_0=0.05, - # c0_1=-0.76326531, c0_2=0.99081633, c0_3=-0.4, - # c0_4=0.05, c1_1=0.32653061, c1_2=-0.08163265, - # c1_3=0.0, c2_1=-0.08163265, c2_2=0.02040816, - # c3_1=-0.0)> - kernel = np.array( - [[+0.041632, -0.080816, 0.078368, -0.080816, +0.041632], - [-0.080816, -0.019592, 0.200816, -0.019592, -0.080816], - [+0.078368, +0.200816, 0.441632, +0.200816, +0.078368], - [-0.080816, -0.019592, 0.200816, -0.019592, -0.080816], - [+0.041632, -0.080816, 0.078368, -0.080816, +0.041632]]) - - elif self.smoothing_kernel == 'quadratic': - # from Polynomial2D fit with degree=2 to 5x5 array of - # zeros with 1.0 at the center - # Polynomial2D(2, c0_0=-0.07428571, c1_0=0.11428571, - # c2_0=-0.02857143, c0_1=0.11428571, - # c0_2=-0.02857143, c1_1=-0.0) - kernel = np.array( - [[-0.07428311, 0.01142786, 0.03999952, 0.01142786, - -0.07428311], - [+0.01142786, 0.09714283, 0.12571449, 0.09714283, - +0.01142786], - [+0.03999952, 0.12571449, 0.15428215, 0.12571449, - +0.03999952], - [+0.01142786, 0.09714283, 0.12571449, 0.09714283, - +0.01142786], - [-0.07428311, 0.01142786, 0.03999952, 0.01142786, - -0.07428311]]) - - else: - msg = 'Unsupported kernel' - raise TypeError(msg) - - return convolve(epsf_data, kernel) - - def _recenter_epsf(self, epsf, centroid_func=centroid_com, - box_size=(5, 5), maxiters=20, center_accuracy=1.0e-4): - """ - Calculate the center of the ePSF data and shift the data so the - ePSF center is at the center of the ePSF data array. - - Parameters - ---------- - epsf : `_LegacyEPSFModel` object - The ePSF model. - - centroid_func : callable, optional - A callable object (e.g., function or class) that is used - to calculate the centroid of a 2D array. The callable must - accept a 2D `~numpy.ndarray`, have a ``mask`` keyword - and optionally an ``error`` keyword. The callable object - must return a tuple of two 1D `~numpy.ndarray` variables, - representing the x and y centroids. - - box_size : float or tuple of two floats, optional - The size (in pixels) of the box used to calculate the - centroid of the ePSF during each build iteration. If a - single integer number is provided, then a square box will - be used. If two values are provided, then they must be in - ``(ny, nx)`` order. ``box_size`` must have odd values and be - greater than or equal to 3 for both axes. - - maxiters : int, optional - The maximum number of recentering iterations to perform. - - center_accuracy : float, optional - The desired accuracy for the centers of stars. The building - iterations will stop if the center of the ePSF changes by - less than ``center_accuracy`` pixels between iterations. - - Returns - ------- - result : 2D `~numpy.ndarray` - The recentered ePSF data. - """ - epsf_data = epsf._data - - epsf = _LegacyEPSFModel(data=epsf._data, origin=epsf.origin, - oversampling=epsf.oversampling, - norm_radius=epsf._norm_radius, normalize=False) - - xcenter, ycenter = epsf.origin - - y, x = np.indices(epsf._data.shape, dtype=float) - x /= epsf.oversampling[1] - y /= epsf.oversampling[0] - - dx_total, dy_total = 0, 0 - iter_num = 0 - center_accuracy_sq = center_accuracy**2 - center_dist_sq = center_accuracy_sq + 1.0e6 - center_dist_sq_prev = center_dist_sq + 1 - while (iter_num < maxiters and center_dist_sq >= center_accuracy_sq): - iter_num += 1 - - # Anderson & King (2000) recentering function depends - # on specific pixels, and thus does not need a cutout - slices_large, _ = overlap_slices(epsf_data.shape, box_size, - (ycenter * self.oversampling[0], - xcenter * self.oversampling[1])) - epsf_cutout = epsf_data[slices_large] - mask = ~np.isfinite(epsf_cutout) - - # find a new center position - xcenter_new, ycenter_new = centroid_func(epsf_cutout, - mask=mask) - xcenter_new /= self.oversampling[1] - ycenter_new /= self.oversampling[0] - - xcenter_new += slices_large[1].start / self.oversampling[1] - ycenter_new += slices_large[0].start / self.oversampling[0] - - # Calculate the shift; dx = i - x_star so if dx was positively - # incremented then x_star was negatively incremented for a given i. - # We will therefore actually subsequently subtract dx from xcenter - # (or x_star). - dx = xcenter_new - xcenter - dy = ycenter_new - ycenter - - center_dist_sq = dx**2 + dy**2 - - if center_dist_sq >= center_dist_sq_prev: # don't shift - break - center_dist_sq_prev = center_dist_sq - - dx_total += dx - dy_total += dy - - epsf_data = epsf.evaluate(x=x, y=y, flux=1.0, - x_0=xcenter - dx_total, - y_0=ycenter - dy_total) - - return epsf_data - - def _build_epsf_step(self, stars, epsf=None): - """ - A single iteration of improving an ePSF. - - Parameters - ---------- - stars : `EPSFStars` object - The stars used to build the ePSF. - - epsf : `_LegacyEPSFModel` object, optional - The initial ePSF model. If not input, then the ePSF will be - built from scratch. - - Returns - ------- - epsf : `_LegacyEPSFModel` object - The updated ePSF. - """ - if len(stars) < 1: - msg = ('stars must contain at least one EPSFStar or ' - 'LinkedEPSFStar object') - raise ValueError(msg) - - if epsf is None: - # create an initial ePSF (array of zeros) - epsf = self._create_initial_epsf(stars) - else: - # improve the input ePSF - epsf = copy.deepcopy(epsf) - - # compute a 3D stack of 2D residual images - residuals = self._resample_residuals(stars, epsf) - - # compute the sigma-clipped average along the 3D stack - with warnings.catch_warnings(): - warnings.simplefilter('ignore', category=RuntimeWarning) - warnings.simplefilter('ignore', category=AstropyUserWarning) - residuals = self._sigma_clip(residuals, axis=0, masked=False, - return_bounds=False) - residuals = nanmedian(residuals, axis=0) - - # interpolate any missing data (np.nan) - mask = ~np.isfinite(residuals) - if np.any(mask): - residuals = _interpolate_missing_data(residuals, mask, - method='cubic') - - # fill any remaining nans (outer points) with zeros - residuals[~np.isfinite(residuals)] = 0.0 - - # add the residuals to the previous ePSF image - new_epsf = epsf._data + residuals - - # smooth and recenter the ePSF - new_epsf = self._smooth_epsf(new_epsf) - - epsf = _LegacyEPSFModel(data=new_epsf, origin=epsf.origin, - oversampling=epsf.oversampling, - norm_radius=epsf._norm_radius, normalize=False) - - epsf._data = self._recenter_epsf( - epsf, centroid_func=self.recentering_func, - box_size=self.recentering_boxsize, - maxiters=self.recentering_maxiters) - - # Return the new ePSF object, but with undersampled grid pixel - # coordinates. - xcenter = (epsf._data.shape[1] - 1) / 2.0 / epsf.oversampling[1] - ycenter = (epsf._data.shape[0] - 1) / 2.0 / epsf.oversampling[0] - - return _LegacyEPSFModel(data=epsf._data, origin=(xcenter, ycenter), - oversampling=epsf.oversampling, - norm_radius=epsf._norm_radius) - - def build_epsf(self, stars, *, init_epsf=None): - """ - Build iteratively an ePSF from star cutouts. - - Parameters - ---------- - stars : `EPSFStars` object - The stars used to build the ePSF. - - init_epsf : `ImagePSF` object, optional - The initial ePSF model. If not input, then the ePSF will be - built from scratch. - - Returns - ------- - epsf : `ImagePSF` object - The constructed ePSF. - - fitted_stars : `EPSFStars` object - The input stars with updated centers and fluxes derived - from fitting the output ``epsf``. - """ - iter_num = 0 - fit_failed = np.zeros(stars.n_stars, dtype=bool) - epsf = init_epsf - center_dist_sq = self.center_accuracy_sq + 1.0 - centers = stars.cutout_center_flat - - pbar = None - if self.progress_bar: - desc = f'EPSFBuilder ({self.maxiters} maxiters)' - pbar = add_progress_bar(total=self.maxiters, - desc=desc) # pragma: no cover - - if epsf is None: - legacy_epsf = None - else: - legacy_epsf = _LegacyEPSFModel(epsf.data, flux=epsf.flux, - x_0=epsf.x_0, y_0=epsf.y_0, - oversampling=epsf.oversampling, - fill_value=epsf.fill_value) - - while (iter_num < self.maxiters and not np.all(fit_failed) - and np.max(center_dist_sq) >= self.center_accuracy_sq): - - iter_num += 1 - - # build/improve the ePSF - legacy_epsf = self._build_epsf_step(stars, epsf=legacy_epsf) - - # fit the new ePSF to the stars to find improved centers - # we catch fit warnings here -- stars with unsuccessful fits - # are excluded from the ePSF build process - with warnings.catch_warnings(): - message = '.*The fit may be unsuccessful;.*' - warnings.filterwarnings('ignore', message=message, - category=AstropyUserWarning) - - image_psf = ImagePSF(data=legacy_epsf.data, - flux=legacy_epsf.flux, - x_0=legacy_epsf.x_0, - y_0=legacy_epsf.y_0, - oversampling=legacy_epsf.oversampling, - fill_value=legacy_epsf.fill_value) - - stars = self.fitter(image_psf, stars) - - # find all stars where the fit failed - fit_failed = np.array([star._fit_error_status > 0 - for star in stars.all_stars]) - if np.all(fit_failed): - msg = 'The ePSF fitting failed for all stars.' - raise ValueError(msg) - - # permanently exclude fitting any star where the fit fails - # after 3 iterations - if iter_num > 3 and np.any(fit_failed): - idx = fit_failed.nonzero()[0] - for i in idx: # pylint: disable=not-an-iterable - stars.all_stars[i]._excluded_from_fit = True - - # if no star centers have moved by more than pixel accuracy, - # stop the iteration loop early - dx_dy = stars.cutout_center_flat - centers - dx_dy = dx_dy[np.logical_not(fit_failed)] - center_dist_sq = np.sum(dx_dy * dx_dy, axis=1, dtype=np.float64) - centers = stars.cutout_center_flat - - self._epsf.append(legacy_epsf) - - if pbar is not None: - pbar.update() - - if pbar is not None: - if iter_num < self.maxiters: - pbar.write(f'EPSFBuilder converged after {iter_num} ' - f'iterations (of {self.maxiters} maximum ' - 'iterations)') - pbar.close() - - epsf = ImagePSF(data=legacy_epsf.data, flux=legacy_epsf.flux, - x_0=legacy_epsf.x_0, y_0=legacy_epsf.y_0, - oversampling=legacy_epsf.oversampling, - fill_value=legacy_epsf.fill_value) - - return epsf, stars diff --git a/photutils/psf/epsf_builder.py b/photutils/psf/epsf_builder.py new file mode 100644 index 000000000..39e08fc45 --- /dev/null +++ b/photutils/psf/epsf_builder.py @@ -0,0 +1,1779 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +Tools to build and fit an effective PSF (ePSF) based on Anderson and +King 2000 (PASP 112, 1360) and Anderson 2016 (WFC3 ISR 2016-12). +""" + +import copy +import warnings +from dataclasses import dataclass + +import numpy as np +from astropy.modeling.fitting import TRFLSQFitter +from astropy.nddata import NoOverlapError, PartialOverlapError, overlap_slices +from astropy.stats import SigmaClip +from astropy.utils.exceptions import AstropyUserWarning +from scipy.ndimage import convolve + +from photutils.centroids import centroid_com +from photutils.psf.epsf_stars import EPSFStar, EPSFStars, LinkedEPSFStar +from photutils.psf.image_models import ImagePSF +from photutils.psf.utils import _interpolate_missing_data +from photutils.utils._parameters import (SigmaClipSentinelDefault, as_pair, + create_default_sigmaclip) +from photutils.utils._progress_bars import add_progress_bar +from photutils.utils._round import py2intround +from photutils.utils._stats import nanmedian + +__all__ = ['EPSFBuildResult', 'EPSFBuilder', 'EPSFFitter'] + +SIGMA_CLIP = SigmaClipSentinelDefault(sigma=3.0, maxiters=10) + + +class _SmoothingKernel: + """ + Utility class for ePSF smoothing kernel generation and convolution. + + This class encapsulates the creation of smoothing kernels used in + ePSF building and provides consistent smoothing operations. + """ + + # Pre-computed kernels based on polynomial fits + QUARTIC_KERNEL = np.array([ + [+0.041632, -0.080816, 0.078368, -0.080816, +0.041632], + [-0.080816, -0.019592, 0.200816, -0.019592, -0.080816], + [+0.078368, +0.200816, 0.441632, +0.200816, +0.078368], + [-0.080816, -0.019592, 0.200816, -0.019592, -0.080816], + [+0.041632, -0.080816, 0.078368, -0.080816, +0.041632]]) + + QUADRATIC_KERNEL = np.array([ + [-0.07428311, 0.01142786, 0.03999952, 0.01142786, -0.07428311], + [+0.01142786, 0.09714283, 0.12571449, 0.09714283, +0.01142786], + [+0.03999952, 0.12571449, 0.15428215, 0.12571449, +0.03999952], + [+0.01142786, 0.09714283, 0.12571449, 0.09714283, +0.01142786], + [-0.07428311, 0.01142786, 0.03999952, 0.01142786, -0.07428311]]) + + @classmethod + def get_kernel(cls, kernel_type): + """ + Get a smoothing kernel by type. + + Parameters + ---------- + kernel_type : {'quartic', 'quadratic'} or array_like + The type of kernel to retrieve or a custom kernel array. + + Returns + ------- + kernel : 2D `numpy.ndarray` + The smoothing kernel. + + Raises + ------ + TypeError + If `kernel_type` is not supported. + + Notes + ----- + The predefined kernels are derived from polynomial fits: + + - 'quartic': From Polynomial2D fit with degree=4 to 5x5 array of + zeros with 1.0 at the center. Based on fourth degree polynomial. + + - 'quadratic': From Polynomial2D fit with degree=2 to 5x5 + array of zeros with 1.0 at the center. Based on second degree + polynomial. + """ + if isinstance(kernel_type, np.ndarray): + return kernel_type + if kernel_type == 'quartic': + return cls.QUARTIC_KERNEL + if kernel_type == 'quadratic': + return cls.QUADRATIC_KERNEL + + msg = (f'Unsupported kernel type: {kernel_type}. Supported types ' + 'are "quartic", "quadratic", or ndarray.') + raise TypeError(msg) + + @staticmethod + def apply_smoothing(data, kernel_type): + """ + Apply smoothing to data using the specified kernel. + + Parameters + ---------- + data : 2D `numpy.ndarray` + The data to smooth. + + kernel_type : {'quartic', 'quadratic'}, array_like, or `None` + The type of kernel to use for smoothing, or `None` for no + smoothing. + + Returns + ------- + smoothed_data : 2D `numpy.ndarray` + The smoothed data. Returns original data if `kernel_type` is + `None`. + """ + if kernel_type is None: + return data + + kernel = _SmoothingKernel.get_kernel(kernel_type) + return convolve(data, kernel) + + +class _EPSFValidator: + """ + Class to validate ePSF building parameters and data. + + This class centralizes all validation logic with context-aware error + messages. + """ + + @staticmethod + def validate_oversampling(oversampling, context=''): + """ + Validate oversampling parameters. + + Parameters + ---------- + oversampling : int or tuple + The oversampling factor(s). + + context : str, optional + Additional context for error messages. + + Raises + ------ + ValueError + If oversampling is invalid. + """ + if oversampling is None: + msg = "'oversampling' must be specified" + raise ValueError(msg) + + try: + oversampling = as_pair('oversampling', oversampling, + lower_bound=(0, 1)) + except (TypeError, ValueError) as e: + msg = f'Invalid oversampling parameter - {e}' + if context: + msg = f'{context}: {msg}' + raise ValueError(msg) from None + + return oversampling + + @staticmethod + def validate_shape_compatibility(stars, oversampling, shape=None): + """ + Validate that ePSF shape is compatible with star dimensions. + + Performs validation of shape compatibility between requested + ePSF shape and star cutout dimensions, accounting for + oversampling factors and providing detailed diagnostics. + + Parameters + ---------- + stars : EPSFStars + The input stars. + + oversampling : tuple + The oversampling factors (y, x). + + shape : tuple, optional + Requested ePSF shape (height, width). + + Raises + ------ + ValueError + If shape is incompatible with stars and oversampling. + Error messages include suggested minimum shapes and + detailed diagnostic information. + """ + if not stars: + msg = ('Cannot validate shape compatibility with empty star list. ' + 'Please provide at least one star for ePSF building.') + raise ValueError(msg) + + # Collect star dimension statistics + star_heights = [star.shape[0] for star in stars] + star_widths = [star.shape[1] for star in stars] + max_height = max(star_heights) + max_width = max(star_widths) + + # Check for extremely small stars that may cause issues + min_star_size = 3 # minimum reasonable star cutout size + problematic_stars = [] + for i, star in enumerate(stars): + if min(star.shape) < min_star_size: + problematic_stars.append(f'Star {i}: {star.shape}') + + if problematic_stars: + msg = (f"Found {len(problematic_stars)} star(s) with very small " + f"dimensions (< {min_star_size}x{min_star_size}): " + f"{', '.join(problematic_stars)}. Consider using larger " + 'star cutouts for better ePSF quality.') + raise ValueError(msg) + + # Compute minimum required ePSF shape with proper padding + # The +1 ensures odd dimensions for proper centering + min_epsf_height = max_height * oversampling[0] + 1 + min_epsf_width = max_width * oversampling[1] + 1 + + # Validate requested shape if provided + if shape is not None: + shape = np.array(shape) + if shape.ndim != 1 or len(shape) != 2: + msg = 'Shape must be a 2-element sequence' + raise ValueError(msg) + + if shape[0] < min_epsf_height or shape[1] < min_epsf_width: + # Provide detailed diagnostic information + msg = (f'Requested ePSF shape {shape} is incompatible with ' + f'star dimensions and oversampling.\n\n' + f' Oversampling factors: {oversampling}\n' + f' Minimum required ePSF shape: ' + f'({min_epsf_height}, {min_epsf_width})\n' + f'Solution: Use shape >= ' + f'({min_epsf_height}, {min_epsf_width}) ' + f'or reduce oversampling factors.') + raise ValueError(msg) + + # Check for odd dimensions (for proper centering) + if shape[0] % 2 == 0 or shape[1] % 2 == 0: + msg = (f'Requested ePSF shape {shape} has even dimensions. ' + f'Odd dimensions are recommended for proper ePSF ' + f'centering. Consider using ' + f'({shape[0] + shape[0] % 2}, ' + f'{shape[1] + shape[1] % 2}) instead.') + warnings.warn(msg, AstropyUserWarning) + + @staticmethod + def validate_stars(stars, context=''): + """ + Validate EPSFStars object and individual star data. + + Parameters + ---------- + stars : EPSFStars + The stars to validate. + + context : str, optional + Additional context for error messages. + + Raises + ------ + ValueError, TypeError + If stars are invalid. + """ + # Check basic type and structure + if not hasattr(stars, '__len__') or len(stars) == 0: + msg = 'EPSFStars object must contain at least one star' + if context: + msg = f'{context}: {msg}' + raise ValueError(msg) + + # Validate individual stars + invalid_stars = [] + for i, star in enumerate(stars): + try: + # Check for valid data + if not hasattr(star, 'data') or star.data is None: + invalid_stars.append((i, 'missing data')) + continue + + # Check for finite values + if not np.any(np.isfinite(star.data)): + invalid_stars.append((i, 'no finite data values')) + continue + + # Check for reasonable dimensions + if min(star.shape) < 3: + invalid_stars.append((i, f'too small ({star.shape})')) + continue + + # Check for center coordinates + if not hasattr(star, 'cutout_center'): + invalid_stars.append((i, 'missing cutout_center')) + continue + + except (AttributeError, TypeError, ValueError) as e: + invalid_stars.append((i, f'validation error: {e}')) + + if invalid_stars: + error_details = [f'Star {i}: {issue}' + for i, issue in invalid_stars[:5]] + if len(invalid_stars) > 5: + error_details.append(f'... and {len(invalid_stars) - 5} more') + + msg = (f'Found {len(invalid_stars)} invalid stars out of ' + f'{len(stars)} total:\n' + '\n'.join(error_details)) + if context: + msg = f'{context}: {msg}' + raise ValueError(msg) + + @staticmethod + def validate_center_accuracy(center_accuracy): + """ + Validate center accuracy parameter. + + Parameters + ---------- + center_accuracy : float + The center accuracy threshold. + + Raises + ------ + ValueError + If center accuracy is invalid. + """ + if not isinstance(center_accuracy, (int, float)): + msg = (f'center_accuracy must be a number, got ' + f'{type(center_accuracy)}') + raise TypeError(msg) + + if center_accuracy <= 0.0: + msg = ('center_accuracy must be positive, got ' + f'{center_accuracy}. Typical values are 1e-3 to 1e-4.') + raise ValueError(msg) + + if center_accuracy > 1.0: + msg = (f'center_accuracy {center_accuracy} seems unusually large. ' + 'Values > 1.0 may prevent convergence. ' + 'Typical values are 1e-3 to 1e-4.') + warnings.warn(msg, AstropyUserWarning) + + @staticmethod + def validate_maxiters(maxiters): + """ + Validate maximum iterations parameter. + + Parameters + ---------- + maxiters : int + The maximum number of iterations. + + Raises + ------ + ValueError, TypeError + If maxiters is invalid. + """ + if not isinstance(maxiters, int): + msg = f'maxiters must be an integer, got {type(maxiters)}' + raise TypeError(msg) + + if maxiters <= 0: + msg = 'maxiters must be a positive number' + raise ValueError(msg) + + maxiters_warn_threshold = 100 + if maxiters > maxiters_warn_threshold: + msg = (f'maxiters {maxiters} seems unusually large. ' + f'Values > {maxiters_warn_threshold} may indicate ' + 'convergence issues. Consider checking your data and ' + 'parameters.') + warnings.warn(msg, AstropyUserWarning) + + +class _CoordinateTransformer: + """ + Handle coordinate transformations between pixel and oversampled + spaces. + + This class centralizes all coordinate system conversions used in + ePSF building, providing consistent transformations between the + input star coordinate system and the oversampled ePSF coordinate + system. + + Parameters + ---------- + oversampling : tuple of int + The (y, x) oversampling factors for the ePSF. + """ + + def __init__(self, oversampling): + self.oversampling = np.asarray(oversampling) + + def star_to_epsf_coords(self, star_x, star_y, epsf_origin): + """ + Transform star-relative coordinates to ePSF grid coordinates. + + Parameters + ---------- + star_x, star_y : array_like + Star coordinates in undersampled units relative to star + center. + + epsf_origin : tuple + The (x, y) origin of the ePSF in oversampled coordinates. + + Returns + ------- + epsf_x, epsf_y : array_like + Integer coordinates in the oversampled ePSF grid. + """ + # Apply oversampling transformation + x_oversampled = self.oversampling[1] * star_x + y_oversampled = self.oversampling[0] * star_y + + # Add ePSF center offset + epsf_xcenter, epsf_ycenter = epsf_origin + epsf_x = py2intround(x_oversampled + epsf_xcenter).astype(int) + epsf_y = py2intround(y_oversampled + epsf_ycenter).astype(int) + + return epsf_x, epsf_y + + def compute_epsf_shape(self, star_shapes): + """ + Compute the appropriate ePSF shape from input star shapes. + + Parameters + ---------- + star_shapes : list of tuple + List of (height, width) tuples for each star. + + Returns + ------- + epsf_shape : tuple + The (height, width) shape for the oversampled ePSF. + """ + if not star_shapes: + msg = 'Need at least one star to compute ePSF shape' + raise ValueError(msg) + + # Find maximum star dimensions + max_height = max(shape[0] for shape in star_shapes) + max_width = max(shape[1] for shape in star_shapes) + + # Apply oversampling (both are integers, so product is integer) + epsf_height = max_height * self.oversampling[0] + epsf_width = max_width * self.oversampling[1] + + # Ensure odd dimensions for centered origin + if epsf_height % 2 == 0: + epsf_height += 1 + if epsf_width % 2 == 0: + epsf_width += 1 + + return (epsf_height, epsf_width) + + def compute_epsf_origin(self, epsf_shape): + """ + Compute the geometric origin (center) coordinates for an ePSF. + + Parameters + ---------- + epsf_shape : tuple + The (height, width) shape of the ePSF. The shape should have + odd dimensions to ensure a well-defined center. + + Returns + ------- + origin : tuple + The (x, y) origin coordinates in the ePSF coordinate system. + """ + origin_x = (epsf_shape[1] - 1) / 2.0 + origin_y = (epsf_shape[0] - 1) / 2.0 + return (origin_x, origin_y) + + def oversampled_to_undersampled(self, x, y): + """ + Convert oversampled coordinates to undersampled coordinates. + + Parameters + ---------- + x, y : array_like or float + Coordinates in the oversampled grid. + + Returns + ------- + x_under, y_under : array_like or float + Coordinates in the undersampled (original) grid. + """ + return x / self.oversampling[1], y / self.oversampling[0] + + def undersampled_to_oversampled(self, x, y): + """ + Convert undersampled coordinates to oversampled coordinates. + + Parameters + ---------- + x, y : array_like or float + Coordinates in the undersampled (original) grid. + + Returns + ------- + x_over, y_over : array_like or float + Coordinates in the oversampled grid. + """ + return x * self.oversampling[1], y * self.oversampling[0] + + +class _ProgressReporter: + """ + Utility class for managing progress reporting during ePSF building. + + This class encapsulates all progress bar functionality, providing a + clean interface for setting up, updating, and finalizing progress + reporting during the iterative ePSF building process. + + Parameters + ---------- + enabled : bool + Whether progress reporting is enabled. + + maxiters : int + Maximum number of iterations for progress tracking. + + Attributes + ---------- + enabled : bool + Whether progress reporting is active. + + maxiters : int + Maximum iterations for progress bar setup. + + _pbar : progress bar or `None` + The underlying progress bar instance. + """ + + def __init__(self, enabled, maxiters): + """ + Initialize a _ProgressReporter. + + Parameters + ---------- + enabled : bool + Whether progress reporting is enabled. + + maxiters : int + The maximum number of iterations. + """ + self.enabled = enabled + self.maxiters = maxiters + self._pbar = None + + def setup(self): + """ + Initialize the progress bar for ePSF building. + + Sets up the progress bar with appropriate description and + maximum iterations if progress reporting is enabled. + + Returns + ------- + self : _ProgressReporter + Returns `self` for method chaining. + """ + if not self.enabled: + self._pbar = None + return self + + desc = f'EPSFBuilder ({self.maxiters} maxiters)' + self._pbar = add_progress_bar(total=self.maxiters, + desc=desc) + return self + + def update(self): + """ + Update the progress bar by one iteration. + + Only updates if progress reporting is enabled and progress bar + is initialized. + """ + if self._pbar is not None: + self._pbar.update() + + def write_convergence_message(self, iteration): + """ + Write convergence message to progress bar. + + Parameters + ---------- + iteration : int + The iteration number at which convergence occurred. + """ + if self._pbar is not None: + self._pbar.write(f'EPSFBuilder converged after {iteration} ' + f'iterations (of {self.maxiters} maximum ' + 'iterations)') + + def close(self): + """ + Close and finalize the progress bar. + + Should be called when ePSF building is complete, regardless of + convergence status. + """ + if self._pbar is not None: + self._pbar.close() + + +@dataclass +class EPSFBuildResult: + """ + Container for ePSF building results. + + This class provides structured access to the results of the ePSF + building process, including convergence information and diagnostic + data that can help users understand and validate the building + process. + + Attributes + ---------- + epsf : `ImagePSF` object + The final constructed ePSF model. + + fitted_stars : `EPSFStars` object + The input stars with updated centers and fluxes derived from + fitting the final ePSF. + + iterations : int + The number of iterations performed during the building process. + This will be <= maxiters specified in EPSFBuilder. + + converged : bool + Whether the building process converged based on the center + accuracy criterion. `True` if star centers moved less than the + specified accuracy between the final iterations. + + final_center_accuracy : float + The maximum center displacement in the final iteration, in + pixels. This indicates how much the star centers changed in the + last iteration and can be used to assess convergence quality. + + n_excluded_stars : int + The number of individual stars (including those from linked + stars) that were excluded from fitting due to repeated fit + failures. + + excluded_star_indices : list + Indices of stars that were excluded from fitting during the + building process. These correspond to positions in the flattened + star list (stars.all_stars). + + Notes + ----- + This result object maintains backward compatibility by implementing + tuple unpacking, so existing code like: + + epsf, stars = epsf_builder(stars) + + will continue to work unchanged. The additional information is + available as attributes for users who want more detailed results. + + Examples + -------- + >>> from photutils.psf import EPSFBuilder + >>> epsf_builder = EPSFBuilder(oversampling=4) # doctest: +SKIP + >>> result = epsf_builder(stars) # doctest: +SKIP + >>> print(result.iterations) # doctest: +SKIP + >>> print(result.final_center_accuracy) # doctest: +SKIP + >>> print(result.n_excluded_stars) # doctest: +SKIP + """ + + epsf: 'ImagePSF' + fitted_stars: 'EPSFStars' + iterations: int + converged: bool + final_center_accuracy: float + n_excluded_stars: int + excluded_star_indices: list + + def __iter__(self): + """ + Allow tuple unpacking for backward compatibility. + + Returns + ------- + iterator + An iterator that yields (epsf, fitted_stars) for + compatibility with existing code that expects a 2-tuple. + """ + return iter((self.epsf, self.fitted_stars)) + + def __getitem__(self, index): + """ + Allow indexing for backward compatibility. + + Parameters + ---------- + index : int + Index to access (0 for epsf, 1 for fitted_stars). + + Returns + ------- + value + The ePSF (index 0) or fitted stars (index 1). + """ + if index == 0: + return self.epsf + if index == 1: + return self.fitted_stars + + msg = 'EPSFBuildResult index must be 0 (epsf) or 1 (fitted_stars)' + raise IndexError(msg) + + +class EPSFFitter: + """ + Class to fit an ePSF model to one or more stars. + + Parameters + ---------- + fitter : `astropy.modeling.fitting.Fitter`, optional + A `~astropy.modeling.fitting.Fitter` object. If `None`, then the + default `~astropy.modeling.fitting.TRFLSQFitter` will be used. + + fit_boxsize : int, tuple of int, or `None`, optional + The size (in pixels) of the box centered on the star to be used + for ePSF fitting. This allows using only a small number of + central pixels of the star (i.e., where the star is brightest) + for fitting. If ``fit_boxsize`` is a scalar then a square box of + size ``fit_boxsize`` will be used. If ``fit_boxsize`` has two + elements, they must be in ``(ny, nx)`` order. ``fit_boxsize`` + must have odd values and be greater than or equal to 3 for both + axes. If `None`, the fitter will use the entire star image. + + **fitter_kwargs : dict, optional + Any additional keyword arguments (except ``x``, ``y``, ``z``, or + ``weights``) to be passed directly to the ``__call__()`` method + of the input ``fitter``. + """ + + def __init__(self, *, fitter=None, fit_boxsize=5, **fitter_kwargs): + + if fitter is None: + fitter = TRFLSQFitter() + self.fitter = fitter + self.fitter_has_fit_info = hasattr(self.fitter, 'fit_info') + if fit_boxsize is not None: + self.fit_boxsize = as_pair('fit_boxsize', fit_boxsize, + lower_bound=(3, 0), check_odd=True) + else: + self.fit_boxsize = None + + # Remove any fitter keyword arguments that we need to set + remove_kwargs = ['x', 'y', 'z', 'weights'] + fitter_kwargs = copy.deepcopy(fitter_kwargs) + for kwarg in remove_kwargs: + if kwarg in fitter_kwargs: + del fitter_kwargs[kwarg] + self.fitter_kwargs = fitter_kwargs + + def __call__(self, epsf, stars): + """ + Fit an ePSF model to stars. + + Parameters + ---------- + epsf : `ImagePSF` + An ePSF model to be fitted to the stars. + + stars : `EPSFStars` object + The stars to be fit. The center coordinates for each star + should be as close as possible to actual centers. For stars + than contain weights, a weighted fit of the ePSF to the star + will be performed. + + Returns + ------- + fitted_stars : `EPSFStars` object + The fitted stars. The ePSF-fitted center position and flux + are stored in the ``center`` (and ``cutout_center``) and + ``flux`` attributes. + """ + if len(stars) == 0: + return stars + + if not isinstance(epsf, ImagePSF): + msg = 'The input epsf must be an ImagePSF' + raise TypeError(msg) + + # Perform the fit + fitted_stars = [] + for star in stars: + if isinstance(star, EPSFStar): + # Skip fitting stars that have been excluded; return + # directly since no modification is needed + if star._excluded_from_fit: + fitted_star = star + else: + fitted_star = self._fit_star(epsf, star, self.fitter, + self.fitter_kwargs, + self.fitter_has_fit_info, + self.fit_boxsize) + + elif isinstance(star, LinkedEPSFStar): + fitted_star = [] + for linked_star in star: + # Skip fitting stars that have been excluded; return + # directly since no modification is needed + if linked_star._excluded_from_fit: + fitted_star.append(linked_star) + else: + fitted_star.append( + self._fit_star(epsf, linked_star, self.fitter, + self.fitter_kwargs, + self.fitter_has_fit_info, + self.fit_boxsize)) + + fitted_star = LinkedEPSFStar(fitted_star) + fitted_star.constrain_centers() + + else: + msg = ('stars must contain only EPSFStar and/or ' + 'LinkedEPSFStar objects') + raise TypeError(msg) + + fitted_stars.append(fitted_star) + + return EPSFStars(fitted_stars) + + def _fit_star(self, epsf, star, fitter, fitter_kwargs, + fitter_has_fit_info, fit_boxsize): + """ + Fit an ePSF model to a single star. + """ + if fit_boxsize is not None: + try: + xcenter, ycenter = star.cutout_center + large_slc, _ = overlap_slices(star.shape, fit_boxsize, + (ycenter, xcenter), + mode='strict') + except (PartialOverlapError, NoOverlapError): + star = copy.copy(star) + star._fit_error_status = 1 + + return star + + data = star.data[large_slc] + weights = star.weights[large_slc] + + # Define the origin of the fitting region + x0 = large_slc[1].start + y0 = large_slc[0].start + else: + # Use the entire cutout image + data = star.data + weights = star.weights + + # Define the origin of the fitting region + x0 = 0 + y0 = 0 + + # Define positions in the undersampled grid. The fitter will + # evaluate on the defined interpolation grid, currently in the + # range [0, len(undersampled grid)]. + yy, xx = np.indices(data.shape, dtype=float) + xx = xx + x0 - star.cutout_center[0] + yy = yy + y0 - star.cutout_center[1] + + # Define the initial guesses for fitted flux and shifts + epsf.flux = star.flux + epsf.x_0 = 0.0 + epsf.y_0 = 0.0 + + try: + fitted_epsf = fitter(model=epsf, x=xx, y=yy, z=data, + weights=weights, **fitter_kwargs) + except TypeError: + # Handle case where the fitter does not support weights + fitted_epsf = fitter(model=epsf, x=xx, y=yy, z=data, + **fitter_kwargs) + + fit_error_status = 0 + if fitter_has_fit_info: + fit_info = copy.copy(fitter.fit_info) + + if 'ierr' in fit_info and fit_info['ierr'] not in [1, 2, 3, 4]: + fit_error_status = 2 # fit solution was not found + else: + fit_info = None + + # Compute the star's fitted position + x_center = star.cutout_center[0] + fitted_epsf.x_0.value + y_center = star.cutout_center[1] + fitted_epsf.y_0.value + + star = copy.copy(star) + star.cutout_center = (x_center, y_center) + + # Set the star's flux to the ePSF-fitted flux + star.flux = fitted_epsf.flux.value + + star._fit_info = fit_info + star._fit_error_status = fit_error_status + + return star + + +class EPSFBuilder: + """ + Class to build an effective PSF (ePSF). + + See `Anderson and King 2000 (PASP 112, 1360) + `_ + and `Anderson 2016 (WFC3 ISR 2016-12) + `_ + for details. + + Parameters + ---------- + oversampling : int or array_like (int) + The integer oversampling factor(s) of the output ePSF relative + to the input ``stars`` along each axis. If ``oversampling`` is a + scalar then it will be used for both axes. If ``oversampling`` + has two elements, they must be in ``(y, x)`` order. + + shape : float, tuple of two floats, or `None`, optional + The (ny, nx) shape of the output ePSF. If the input shape is + even along any axis, it will be made odd by adding one. If the + ``shape`` is `None`, it will be derived from the sizes of the + input ``stars`` and the ePSF ``oversampling`` factor. The output + ePSF will always have odd sizes along both axes to ensure a + well-defined central pixel. + + smoothing_kernel : {'quartic', 'quadratic'}, 2D `~numpy.ndarray`, or `None` + The smoothing kernel to apply to the ePSF during each iteration + step. The predefined ``'quartic'`` and ``'quadratic'`` kernels + are derived from fourth and second degree polynomials, + respectively. Alternatively, a custom 2D array can be input. If + `None` then no smoothing will be performed. + + sigma_clip : `astropy.stats.SigmaClip` instance, optional + A `~astropy.stats.SigmaClip` object that defines the sigma + clipping parameters used to determine which pixels are ignored + when stacking the ePSF residuals in each iteration step. If + `None` then no sigma clipping will be performed. + + recentering_func : callable, optional + A callable object that is used to calculate the centroid of a + 2D array. The callable must accept a 2D `~numpy.ndarray`, have + a ``mask`` keyword and optionally an ``error`` keyword. The + callable object must return a tuple of (x, y) centroids. + + recentering_boxsize : float or tuple of two floats, optional + The size (in pixels) of the box used to calculate the centroid + of the ePSF during each build iteration. If a single integer + number is provided, then a square box will be used. If two + values are provided, then they must be in ``(ny, nx)`` order. + ``recentering_boxsize`` must have odd values and be greater than + or equal to 3 for both axes. + + recentering_maxiters : int, optional + The maximum number of recentering iterations to perform during + each ePSF build iteration. + + center_accuracy : float, optional + The desired accuracy for the centers of stars. The building + iterations will stop if the centers of all the stars change by + less than ``center_accuracy`` pixels between iterations. All + stars must meet this condition for the building iterations to + stop. + + fitter : `EPSFFitter` object, optional + A `EPSFFitter` object use to fit the ePSF to stars. If `None`, + then the default `EPSFFitter` will be used. To set custom fitter + options, input a new `EPSFFitter` object. See the `EPSFFitter` + documentation for options. + + maxiters : int, optional + The maximum number of ePSF building iterations to perform. + + progress_bar : bool, option + Whether to print the progress bar during the build + iterations. The progress bar requires that the `tqdm + `_ optional dependency be installed. + + Notes + ----- + If your image contains NaN values, you may see better performance if + you have the `bottleneck`_ package installed. + + .. _bottleneck: https://github.com/pydata/bottleneck + """ + + def __init__(self, *, oversampling=4, shape=None, + smoothing_kernel='quartic', sigma_clip=SIGMA_CLIP, + recentering_func=centroid_com, recentering_boxsize=(5, 5), + recentering_maxiters=20, center_accuracy=1.0e-3, + fitter=None, maxiters=10, progress_bar=True): + + # Validate and store oversampling using the validator + self.oversampling = _EPSFValidator.validate_oversampling( + oversampling, 'EPSFBuilder initialization') + + # Initialize coordinate transformer for consistent transformations + self.coord_transformer = _CoordinateTransformer(self.oversampling) + + if shape is not None: + self.shape = as_pair('shape', shape, lower_bound=(0, 1)) + else: + self.shape = shape + + self.recentering_func = recentering_func + self.recentering_maxiters = recentering_maxiters + self.recentering_boxsize = as_pair('recentering_boxsize', + recentering_boxsize, + lower_bound=(3, 0), check_odd=True) + self.smoothing_kernel = smoothing_kernel + + if fitter is None: + fitter = EPSFFitter() + if not isinstance(fitter, EPSFFitter): + msg = 'fitter must be an EPSFFitter instance' + raise TypeError(msg) + self.fitter = fitter + + # Validate center accuracy using the validator + _EPSFValidator.validate_center_accuracy(center_accuracy) + self.center_accuracy_sq = center_accuracy**2 + + # Validate maxiters using the validator + _EPSFValidator.validate_maxiters(maxiters) + self.maxiters = maxiters + + self.progress_bar = progress_bar + + if sigma_clip is SIGMA_CLIP: + sigma_clip = create_default_sigmaclip(sigma=SIGMA_CLIP.sigma, + maxiters=SIGMA_CLIP.maxiters) + if not isinstance(sigma_clip, SigmaClip): + msg = 'sigma_clip must be an astropy.stats.SigmaClip instance' + raise TypeError(msg) + self._sigma_clip = sigma_clip + + # store each ePSF build iteration + self._epsf = [] + + def __call__(self, stars): + """ + Build an ePSF from input stars. + + Parameters + ---------- + stars : `EPSFStars` + The stars used to build the ePSF. + + Returns + ------- + result : `EPSFBuildResult` + The result of the ePSF building process. + """ + return self._build_epsf(stars) + + def _create_initial_epsf(self, stars): + """ + Create an initial `ImagePSF` object with zero data. + + This method initializes the ePSF building process by creating a + blank ImagePSF model with the appropriate size and coordinate + system. The initial ePSF data are all zeros and will be + populated through the iterative building process. + + Shape Determination Algorithm + ----------------------------- + 1. If shape is explicitly provided, use it (ensuring odd + dimensions) + + 2. Otherwise, determine shape from input stars and oversampling: + - Take the maximum star cutout dimensions + - Apply oversampling factor: new_size = old_size * oversampling + - Ensure resulting dimensions are odd (add 1 if even) + + This ensures that oversampled arrays have a well-defined center + pixel, which is crucial for PSF modeling and fitting. + + Coordinate System Setup + ----------------------- + The method establishes the coordinate system for the ImagePSF. + The origin is set to the geometric center of the data array, + which ensures that the PSF center aligns with the array center. + The coordinate system is consistent with the expectations of the + ImagePSF class and allows for straightforward mapping between + star-relative coordinates and ePSF grid coordinates during the + building process. + + Parameters + ---------- + stars : `EPSFStars` object + The stars used to build the ePSF. The method uses + stars._max_shape to ensure the ePSF is large enough to + contain all stars. + + Returns + ------- + epsf : `ImagePSF` object + The initial ePSF model with: + - data: Zero-filled array of appropriate dimensions + - origin: Set to the array center in (x, y) order + - oversampling: Copied from the EPSFBuilder configuration + - fill_value: Set to 0.0 for regions outside the PSF + + Notes + ----- + The initial ePSF has zero flux and data values. These will be + populated through the iterative building process as residuals + from individual stars are combined. + + The method ensures that: + - Array dimensions are always odd (ensuring a center pixel) + - The coordinate system is properly established + - All necessary attributes are set for downstream processing + + Examples + -------- + For stars with maximum shape (25, 25) and oversampling=4: + - x_shape = 25 * 4 = 100 (even), add 1 -> 101 + - y_shape = 25 * 4 = 100 (even), add 1 -> 101 + - Final shape: (101, 101) + - Origin: (50.0, 50.0) + + For stars with maximum shape (25, 25) and oversampling=3: + - x_shape = 25 * 3 = 75 (already odd) + - y_shape = 25 * 3 = 75 (already odd) + - Final shape: (75, 75) + - Origin: (37.0, 37.0) + """ + oversampling = self.oversampling + shape = self.shape + + # Define the ePSF shape using coordinate transformer + if shape is not None: + shape = as_pair('shape', shape, lower_bound=(0, 1), check_odd=True) + else: + # Use coordinate transformer to compute shape from star + # dimensions + star_shapes = [star.shape for star in stars] + shape = self.coord_transformer.compute_epsf_shape(star_shapes) + + # Initialize with zeros + data = np.zeros(shape, dtype=float) + + # Use coordinate transformer to compute origin + origin_xy = self.coord_transformer.compute_epsf_origin(shape) + + return ImagePSF(data=data, origin=origin_xy, oversampling=oversampling, + fill_value=0.0) + + def _resample_residual(self, star, epsf, out_image=None): + """ + Compute a normalized residual image in the oversampled ePSF + grid. + + A normalized residual image is calculated by subtracting the + normalized ePSF model from the normalized star at the location + of the star in the undersampled grid. The normalized residual + image is then resampled from the undersampled star grid to the + oversampled ePSF grid. + + Parameters + ---------- + star : `EPSFStar` object + A single star object. + + epsf : `ImagePSF` object + The ePSF model. + + out_image : 2D `~numpy.ndarray`, optional + A 2D array to hold the resampled residual image. If `None`, + a new array will be created. + + Returns + ------- + image : 2D `~numpy.ndarray` + A 2D image containing the resampled residual image. The + image contains NaNs where there is no data. + """ + # Compute the normalized residual by subtracting the ePSF model + # from the normalized star at the location of the star in the + # undersampled grid. + xidx_centered, yidx_centered = star._xyidx_centered + stardata = (star._data_values_normalized + - epsf.evaluate(x=xidx_centered, + y=yidx_centered, + flux=1.0, x_0=0.0, y_0=0.0)) + + # Use coordinate transformer to map to the oversampled ePSF grid + xidx, yidx = self.coord_transformer.star_to_epsf_coords( + xidx_centered, yidx_centered, epsf.origin) + + epsf_shape = epsf.data.shape + if out_image is None: + out_image = np.full(epsf_shape, np.nan) + + mask = np.logical_and(np.logical_and(xidx >= 0, xidx < epsf_shape[1]), + np.logical_and(yidx >= 0, yidx < epsf_shape[0])) + xidx_ = xidx[mask] + yidx_ = yidx[mask] + + out_image[yidx_, xidx_] = stardata[mask] + + return out_image + + def _resample_residuals(self, stars, epsf): + """ + Compute normalized residual images for all the input stars. + + Optimized to minimize memory allocations. + + Parameters + ---------- + stars : `EPSFStars` object + The stars used to build the ePSF. + + epsf : `ImagePSF` object + The ePSF model. + + Returns + ------- + epsf_resid : 3D `~numpy.ndarray` + A 3D cube containing the resampled residual images. + """ + epsf_shape = epsf.data.shape + n_good_stars = stars.n_good_stars + + if n_good_stars == 0: + # Return empty array with correct shape + return np.zeros((0, epsf_shape[0], epsf_shape[1])) + + # Pre-allocate with NaN (default for missing data) + shape = (n_good_stars, epsf_shape[0], epsf_shape[1]) + epsf_resid = np.full(shape, np.nan) + + # Loop over stars and compute residuals directly into the + # pre-allocated array + for i, star in enumerate(stars.all_good_stars): + self._resample_residual(star, epsf, out_image=epsf_resid[i]) + + return epsf_resid + + def _smooth_epsf(self, epsf_data): + """ + Smooth the ePSF array by convolving it with a kernel. + + Parameters + ---------- + epsf_data : 2D `~numpy.ndarray` + A 2D array containing the ePSF image. + + Returns + ------- + result : 2D `~numpy.ndarray` + The smoothed (convolved) ePSF data. + """ + return _SmoothingKernel.apply_smoothing(epsf_data, + self.smoothing_kernel) + + def _normalize_epsf(self, epsf_data): + """ + Normalize the ePSF data so that the sum of the array values + equals the product of the oversampling factors. + + The normalization accounts for oversampling. For proper + normalization with flux=1.0, the sum of the ePSF data array + should equal the product of the oversampling factors. + + Parameters + ---------- + epsf_data : 2D `~numpy.ndarray` + A 2D array containing the ePSF image. + + Returns + ------- + result : 2D `~numpy.ndarray` + The normalized ePSF data. + + Notes + ----- + For an oversampled PSF image, the sum of array values should + equal the product of the oversampling factors (e.g., for + oversampling=(4, 4), sum should be 16.0). This ensures that the + ImagePSF model with flux=1.0 represents a properly normalized + PSF. + """ + oversampling_product = np.prod(self.oversampling) + current_sum = np.sum(epsf_data) + + if current_sum == 0: + msg = 'Cannot normalize ePSF: data sum is zero' + raise ValueError(msg) + + return epsf_data * (oversampling_product / current_sum) + + def _recenter_epsf(self, epsf, centroid_func=None, box_size=None, + maxiters=None, center_accuracy=None): + """ + Recenter the ePSF data by shifting to the array center. + + This method uses iterative centroiding to find the center of + the ePSF and applies sub-pixel shifts using interpolation. + This provides accurate centering even when the PSF is offset + by fractional pixels. + + Algorithm Overview + ------------------ + 1. Find the centroid of the ePSF using the centroid function + 2. Calculate the sub-pixel shift needed to center the PSF + 3. Apply the shift using spline interpolation via epsf.evaluate() + 4. Iterate until convergence or max iterations reached + + Parameters + ---------- + epsf : `ImagePSF` object + The ePSF model containing the data to be recentered. + + centroid_func : callable, optional + A callable object (e.g., function or class) that is used + to calculate the centroid of a 2D array. The callable must + accept a 2D `~numpy.ndarray`, have a ``mask`` keyword + and optionally an ``error`` keyword. The callable object + must return a tuple of two 1D `~numpy.ndarray` variables, + representing the x and y centroids. If `None`, uses the + builder's configured recentering_func. + + box_size : float or tuple of two floats, optional + The size (in pixels) of the box used to calculate the + centroid of the ePSF during each iteration. If a single + integer number is provided, then a square box will be used. + If two values are provided, then they must be in ``(ny, + nx)`` order. ``box_size`` must have odd values and be + greater than or equal to 3 for both axes. If `None`, uses + the builder's configured recentering_boxsize. + + maxiters : int, optional + The maximum number of recentering iterations to perform. If + `None`, uses the builder's configured recentering_maxiters . + + center_accuracy : float, optional + The desired accuracy for the center position. The centering + iterations will stop if the center of the ePSF changes by + less than ``center_accuracy`` pixels between iterations. If + `None`, uses 1.0e-4. + + Returns + ------- + result : 2D `~numpy.ndarray` + The recentered ePSF data array with the same shape as input. + + Notes + ----- + This method uses spline interpolation to apply sub-pixel shifts, + which preserves the PSF shape more accurately than integer + pixel shifting. The interpolation is done using the ImagePSF's + evaluate method. + """ + # Use instance defaults if not specified + if centroid_func is None: + centroid_func = self.recentering_func + if box_size is None: + box_size = self.recentering_boxsize + if maxiters is None: + maxiters = self.recentering_maxiters + if center_accuracy is None: + center_accuracy = 1.0e-4 + + # The center of the ePSF in oversampled pixel coordinates. + # This is where we want the PSF center to be. + xcenter, ycenter = self.coord_transformer.compute_epsf_origin( + epsf.data.shape) + + # Create coordinate grids in undersampled units for evaluate() + y, x = np.indices(epsf.data.shape, dtype=float) + x, y = self.coord_transformer.oversampled_to_undersampled(x, y) + + # The origin in undersampled units (for use with evaluate) + x_origin, y_origin = ( + self.coord_transformer.oversampled_to_undersampled(xcenter, + ycenter)) + + dx_total, dy_total = 0.0, 0.0 + iter_num = 0 + center_accuracy_sq = center_accuracy ** 2 + center_dist_sq = center_accuracy_sq + 1.0e6 + center_dist_sq_prev = center_dist_sq + 1 + + epsf_data = epsf.data + while (iter_num < maxiters and center_dist_sq >= center_accuracy_sq): + iter_num += 1 + + # Get a cutout around the expected center for centroiding + slices_large, _ = overlap_slices( + epsf_data.shape, box_size, + (ycenter, xcenter)) + epsf_cutout = epsf_data[slices_large] + mask = ~np.isfinite(epsf_cutout) + + # Find the centroid in the cutout (in oversampled pixel coords) + xcenter_new, ycenter_new = centroid_func(epsf_cutout, mask=mask) + + # Convert cutout coordinates to full array coordinates + xcenter_new += slices_large[1].start + ycenter_new += slices_large[0].start + + # Calculate the shift in oversampled pixels + dx = xcenter_new - xcenter + dy = ycenter_new - ycenter + + center_dist_sq = dx ** 2 + dy ** 2 + + if center_dist_sq >= center_dist_sq_prev: + # Shift is getting larger, stop iterating + break + center_dist_sq_prev = center_dist_sq + + # Accumulate total shift in undersampled units + dx_under, dy_under = ( + self.coord_transformer.oversampled_to_undersampled(dx, dy)) + dx_total += dx_under + dy_total += dy_under + + # Apply the shift using evaluate (uses spline + # interpolation). The shift is applied by moving the origin. + epsf_data = epsf.evaluate(x=x, y=y, flux=1.0, + x_0=x_origin - dx_total, + y_0=y_origin - dy_total) + + return epsf_data + + def _build_epsf_step(self, stars, epsf=None): + """ + A single iteration of improving an ePSF. + + Parameters + ---------- + stars : `EPSFStars` object + The stars used to build the ePSF. + + epsf : `ImagePSF` object, optional + The initial ePSF model. If not input, then the ePSF will be + built from scratch. + + Returns + ------- + epsf : `ImagePSF` object + The updated ePSF. + """ + if epsf is None: + # Create an initial ePSF (array of zeros) + epsf = self._create_initial_epsf(stars) + + # Compute a 3D stack of 2D residual images + residuals = self._resample_residuals(stars, epsf) + + # Compute the sigma-clipped median along the 3D stack + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=RuntimeWarning) + warnings.simplefilter('ignore', category=AstropyUserWarning) + residuals = self._sigma_clip(residuals, axis=0, masked=False, + return_bounds=False) + residuals = nanmedian(residuals, axis=0) + + # Interpolate any missing data (np.nan values) in the residual + # image + mask = ~np.isfinite(residuals) + if np.any(mask): + residuals = _interpolate_missing_data(residuals, mask, + method='cubic') + + # Add the residuals to the previous ePSF image + new_epsf = epsf.data + residuals + + # Smooth the ePSF + smoothed_data = self._smooth_epsf(new_epsf) + + # Recenter the ePSF + # Create an intermediate ePSF for recentering operations. + # Use the current epsf's origin if it exists, otherwise compute + # center. + temp_epsf = ImagePSF(data=smoothed_data, + origin=epsf.origin, + oversampling=self.oversampling, + fill_value=0.0) + + # Apply recentering to the smoothed data + recentered_data = self._recenter_epsf(temp_epsf) + + # Normalize the ePSF data + normalized_data = self._normalize_epsf(recentered_data) + + return ImagePSF(data=normalized_data, + oversampling=self.oversampling, + fill_value=0.0) + + def _check_convergence(self, stars, centers, fit_failed): + """ + Check if the ePSF building has converged. + + Convergence is determined by checking the movement of star + centers between iterations. The method calculates the squared + distance of center movements for successfully fitted stars and + applies enhanced convergence criteria that consider both the + maximum movement and the overall stability of the star centers. + This provides a more robust convergence detection mechanism that + is less sensitive to outliers and provides better diagnostic + information on the quality of convergence. + + Parameters + ---------- + stars : `EPSFStars` object + The stars used to build the ePSF. + + centers : `~numpy.ndarray` + Previous star center positions. + + fit_failed : `~numpy.ndarray` + Boolean array tracking failed fits. + + Returns + ------- + converged : bool + `True` if convergence criteria are met. + + center_dist_sq : `~numpy.ndarray` + Squared distances of center movements. + + new_centers : `~numpy.ndarray` + Updated star center positions. + """ + # Calculate center movements for successfully fitted stars only + new_centers = stars.cutout_center_flat + dx_dy = new_centers - centers + + # Filter out failed fits for convergence calculation + good_stars = np.logical_not(fit_failed) + + if not np.any(good_stars): + # No good stars - cannot determine convergence + # Return high values to prevent false convergence + return False, np.array([self.center_accuracy_sq * 10]), new_centers + + dx_dy_good = dx_dy[good_stars] + center_dist_sq = np.sum(dx_dy_good * dx_dy_good, axis=1, + dtype=np.float64) + + # Enhanced convergence criteria + max_movement = np.max(center_dist_sq) + + # Primary convergence check + primary_converged = max_movement < self.center_accuracy_sq + + # Secondary check: ensure most stars are stable + # 80% of stars must be stable + stable_fraction_threshold = 0.8 + stable_fraction = (np.sum(center_dist_sq < self.center_accuracy_sq) + / len(center_dist_sq)) + stability_converged = stable_fraction > stable_fraction_threshold + + # Combined convergence: both criteria must be met for robust + # results + converged = primary_converged and stability_converged + + return converged, center_dist_sq, new_centers + + def _process_iteration(self, stars, epsf, iter_num): + """ + Process a single iteration of ePSF building. + + Parameters + ---------- + stars : `EPSFStars` object + The stars used to build the ePSF. + + epsf : `ImagePSF` object + Current ePSF model. + + iter_num : int + Current iteration number. + + Returns + ------- + epsf : `ImagePSF` object + Updated ePSF model. + + stars : `EPSFStars` object + Updated stars with new fitted centers. + + fit_failed : `~numpy.ndarray` + Boolean array tracking failed fits. + """ + # Build/improve the ePSF + epsf = self._build_epsf_step(stars, epsf=epsf) + + # Fit the new ePSF to the stars to find improved centers + with warnings.catch_warnings(): + message = '.*The fit may be unsuccessful;.*' + warnings.filterwarnings('ignore', message=message, + category=AstropyUserWarning) + + stars = self.fitter(epsf, stars) + + # Reset ePSF flux to 1.0 after fitting (fitting modifies the + # flux) + epsf.flux = 1.0 + + # Find all stars where the fit failed + fit_failed = np.array([star._fit_error_status > 0 + for star in stars.all_stars]) + + if np.all(fit_failed): + msg = 'The ePSF fitting failed for all stars.' + raise ValueError(msg) + + # Permanently exclude fitting any star where the fit fails + # after 3 iterations + if iter_num > 3 and np.any(fit_failed): + for i in fit_failed.nonzero()[0]: + star = stars.all_stars[i] + # Only warn for stars being newly excluded + if not star._excluded_from_fit: + if star._fit_error_status == 1: + reason = ('its fitting region extends beyond the ' + 'star cutout image') + else: # _fit_error_status == 2 + reason = 'the fit did not converge' + warnings.warn(f'The star at ({star.center[0]}, ' + f'{star.center[1]}) has been excluded ' + f'from ePSF fitting because {reason}.', + AstropyUserWarning) + star._excluded_from_fit = True + + # Store the ePSF from this iteration + self._epsf.append(epsf) + + return epsf, stars, fit_failed + + def _finalize_build(self, epsf, stars, progress_reporter, iter_num, + converged, final_center_accuracy, + excluded_star_indices): + """ + Finalize the ePSF building process and create result object. + + Parameters + ---------- + epsf : `ImagePSF` object + Final ePSF model. + + stars : `EPSFStars` object + Final fitted stars. + + progress_reporter : `_ProgressReporter` + Progress reporter instance for handling completion messages. + + iter_num : int + Number of completed iterations. + + converged : bool + Whether the building process converged. + + final_center_accuracy : float + Final center accuracy achieved. + + excluded_star_indices : list + Indices of excluded stars. + + Returns + ------- + result : `EPSFBuildResult` + Structured result containing ePSF, stars, and build + diagnostics. + """ + # Handle progress reporting completion + if iter_num < self.maxiters: + progress_reporter.write_convergence_message(iter_num) + progress_reporter.close() + + # Create structured result + return EPSFBuildResult( + epsf=epsf, + fitted_stars=stars, + iterations=iter_num, + converged=converged, + final_center_accuracy=final_center_accuracy, + n_excluded_stars=len(excluded_star_indices), + excluded_star_indices=excluded_star_indices, + ) + + def _build_epsf(self, stars, *, epsf=None): + """ + Build iteratively an ePSF from star cutouts. + + Parameters + ---------- + stars : `EPSFStars` object + The stars used to build the ePSF. + + epsf : `ImagePSF` object, optional + The initial ePSF model. If not input, then the ePSF will be + built from scratch. + + Returns + ------- + result : `EPSFBuildResult` or tuple + The ePSF building results. Returns an `EPSFBuildResult` object + with detailed information about the building process. For + backward compatibility, the result can be unpacked as a tuple: + ``(epsf, fitted_stars) = epsf_builder(stars)``. + + Notes + ----- + The structured result object contains: + - epsf: The final constructed ePSF + - fitted_stars: Stars with updated centers/fluxes + - iterations: Number of iterations performed + - converged: Whether convergence was achieved + - final_center_accuracy: Final center movement accuracy + - n_excluded_stars: Number of stars excluded due to fit failures + - excluded_star_indices: Indices of excluded stars + """ + _EPSFValidator.validate_stars(stars, 'ePSF building') + _EPSFValidator.validate_shape_compatibility(stars, self.oversampling, + self.shape) + + # Initialize variables for building process + fit_failed = np.zeros(stars.n_stars, dtype=bool) + centers = stars.cutout_center_flat + + # Setup progress tracking + progress_reporter = _ProgressReporter(self.progress_bar, + self.maxiters).setup() + + # Initialize iteration variables and tracking + iter_num = 0 + converged = False + center_dist_sq = np.array([self.center_accuracy_sq + 1.0]) + excluded_star_indices = [] + + # Main iteration loop + while (iter_num < self.maxiters and not np.all(fit_failed) + and not converged): + + iter_num += 1 + + # Process one iteration + epsf, stars, fit_failed = self._process_iteration( + stars, epsf, iter_num) + + # Track newly excluded stars + if iter_num > 3 and np.any(fit_failed): + new_excluded = fit_failed.nonzero()[0] + for idx in new_excluded: + if idx not in excluded_star_indices: + excluded_star_indices.append(idx) + + # Check convergence based on center movements + converged, center_dist_sq, centers = self._check_convergence( + stars, centers, fit_failed) + + # Update progress bar + progress_reporter.update() + + # Calculate the final center accuracy + final_converged = converged + final_center_accuracy = np.max(center_dist_sq) ** 0.5 + + # Finalize and return structured results + return self._finalize_build(epsf, stars, progress_reporter, + iter_num, final_converged, + final_center_accuracy, + excluded_star_indices) diff --git a/photutils/psf/epsf_stars.py b/photutils/psf/epsf_stars.py index c72df1f1f..464f3495c 100644 --- a/photutils/psf/epsf_stars.py +++ b/photutils/psf/epsf_stars.py @@ -14,7 +14,6 @@ from astropy.utils.exceptions import AstropyUserWarning from photutils.aperture import BoundingBox -from photutils.psf.image_models import _LegacyEPSFModel from photutils.psf.utils import _interpolate_missing_data from photutils.utils._parameters import as_pair @@ -39,6 +38,10 @@ class EPSFStar: input cutout ``data`` array. If `None`, then the center of the input cutout ``data`` array will be used. + flux : float or `None`, optional + The flux of the star. If `None`, then the flux will be estimated + from the input ``data``. + origin : tuple of two int, optional The ``(x, y)`` index of the origin (bottom-left corner) pixel of the input cutout array with respect to the original array @@ -61,43 +64,101 @@ class EPSFStar: An optional identification number or label for the star. """ - def __init__(self, data, *, weights=None, cutout_center=None, + def __init__(self, data, *, weights=None, cutout_center=None, flux=None, origin=(0, 0), wcs_large=None, id_label=None): self._data = np.asanyarray(data) + + # Validate data dimensionality and shape + if self._data.ndim != 2: + msg = f'Input data must be 2-dimensional, got {self._data.ndim}D' + raise ValueError(msg) + if self._data.size == 0: + msg = 'Input data cannot be empty' + raise ValueError(msg) + self.shape = self._data.shape + # Validate and process weights if weights is not None: - if weights.shape != data.shape: - msg = ('weights must have the same shape as the input ' - 'data array') + weights = np.asanyarray(weights) + if weights.shape != self._data.shape: + msg = (f'Weights shape {weights.shape} must match data shape ' + f'{self._data.shape}') raise ValueError(msg) - self.weights = np.asanyarray(weights, dtype=float).copy() + + # Check for valid weight values + if not np.all(np.isfinite(weights)): + warnings.warn('Non-finite weight values detected. These will ' + 'be set to zero.', AstropyUserWarning) + weights = np.where(np.isfinite(weights), weights, 0.0) + + # Copy to avoid modifying the input weights + self.weights = weights.astype(float, copy=True) else: self.weights = np.ones_like(self._data, dtype=float) + # Create initial mask from weights self.mask = (self.weights <= 0.0) - # mask out invalid image data - invalid_data = np.logical_not(np.isfinite(self._data)) + # Mask out invalid image data and provide informative warning + invalid_data = ~np.isfinite(self._data) if np.any(invalid_data): self.weights[invalid_data] = 0.0 self.mask[invalid_data] = True + warnings.warn('Input data array contains invalid data that ' + 'will be masked.', AstropyUserWarning) + + # Validate origin + origin = np.asarray(origin) + if origin.shape != (2,): + msg = f'Origin must have exactly 2 elements, got {len(origin)}' + raise ValueError(msg) + if not np.all(np.isfinite(origin)): + msg = 'Origin coordinates must be finite' + raise ValueError(msg) + self.origin = origin.astype(int) - self._cutout_center = cutout_center - self.origin = np.asarray(origin) self.wcs_large = wcs_large self.id_label = id_label - self.flux = self.estimate_flux() + # Set cutout_center (triggers validation via setter) + self.cutout_center = cutout_center + + if flux is not None: + self.flux = float(flux) + self._has_all_zero_data = False # Unknown for explicit flux + else: + # Check if completely masked before attempting flux estimation + if np.all(self.mask): + msg = ('Star cutout is completely masked; no valid data ' + 'available') + raise ValueError(msg) + + # Check if all unmasked data values are exactly zero + # Store flag for later warning (to avoid duplicate warnings) + unmasked_data = self._data[~self.mask] + self._has_all_zero_data = bool(np.all(unmasked_data == 0.0)) + + # Warn if all data is zero + if self._has_all_zero_data: + warnings.warn('All unmasked data values in star cutout ' + 'are zero', AstropyUserWarning) + + # Estimate flux + self.flux = self.estimate_flux() + + # Note: We allow flux <= 0 for real sources that may have + # negative net flux due to background subtraction or similar + # effects self._excluded_from_fit = False + self._fit_error_status = 0 # 0: no error, >0: error during fitting self._fitinfo = None def __array__(self): """ - Array representation of the mask data array (e.g., for - matplotlib). + Array representation of the data array (e.g., for matplotlib). """ return self._data @@ -113,6 +174,10 @@ def cutout_center(self): """ A `~numpy.ndarray` of the ``(x, y)`` position of the star's center with respect to the input cutout ``data`` array. + + Initially set to the geometric center of the cutout, this value + is updated during ePSF building iterations to reflect the fitted + center position as the star is aligned with the ePSF model. """ return self._cutout_center @@ -120,10 +185,31 @@ def cutout_center(self): def cutout_center(self, value): if value is None: value = ((self.shape[1] - 1) / 2.0, (self.shape[0] - 1) / 2.0) - elif len(value) != 2: - msg = ('The "cutout_center" attribute must have two elements ' - 'in (x, y) form.') - raise ValueError(msg) + else: + # Convert to array-like for validation + value = np.asarray(value) + + # Validate shape + if value.shape != (2,): + msg = ('cutout_center must have exactly two elements in ' + f'(x, y) form, got shape {value.shape}') + raise ValueError(msg) + + # Validate finite values + if not np.all(np.isfinite(value)): + msg = 'All cutout_center coordinates must be finite' + raise ValueError(msg) + + # Validate bounds (should be within the cutout image) + x, y = value + if not (0 <= x < self.shape[1]): + warnings.warn(f'cutout_center x-coordinate {x} is outside ' + f'the cutout bounds [0, {self.shape[1]})', + AstropyUserWarning) + if not (0 <= y < self.shape[0]): + warnings.warn(f'cutout_center y-coordinate {y} is outside ' + f'the cutout bounds [0, {self.shape[0]})', + AstropyUserWarning) self._cutout_center = np.asarray(value) @@ -164,20 +250,16 @@ def estimate_flux(self): Returns ------- flux : float - The estimated star's flux. + The estimated star's flux. If there is no valid data in the + cutout, `numpy.nan` will be returned. """ - if np.any(self.mask): - data_interp = _interpolate_missing_data(self.data, method='cubic', - mask=self.mask) - data_interp = _interpolate_missing_data(data_interp, - method='nearest', - mask=self.mask) - flux = np.sum(data_interp, dtype=float) - - else: - flux = np.sum(self.data, dtype=float) + if not np.any(self.mask): + return float(np.sum(self.data)) - return flux + # Interpolate missing data to estimate total flux + data_interp = _interpolate_missing_data(self.data, mask=self.mask, + method='cubic') + return float(np.sum(data_interp)) def register_epsf(self, epsf): """ @@ -193,17 +275,11 @@ def register_epsf(self, epsf): data : `~numpy.ndarray` A 2D array of the registered/scaled ePSF. """ - legacy_epsf = _LegacyEPSFModel(epsf.data, flux=epsf.flux, x_0=epsf.x_0, - y_0=epsf.y_0, - oversampling=epsf.oversampling, - fill_value=epsf.fill_value) - + # evaluate the input ePSF on the star cutout grid yy, xx = np.indices(self.shape, dtype=float) - xx = xx - self.cutout_center[0] - yy = yy - self.cutout_center[1] - - return self.flux * legacy_epsf.evaluate(xx, yy, flux=1.0, x_0=0.0, - y_0=0.0) + return epsf.evaluate(xx, yy, flux=self.flux, + x_0=self.cutout_center[0], + y_0=self.cutout_center[1]) def compute_residual_image(self, epsf): """ @@ -222,53 +298,21 @@ def compute_residual_image(self, epsf): """ return self.data - self.register_epsf(epsf) - @lazyproperty - def _xy_idx(self): - """ - 1D arrays of x and y indices of unmasked pixels in the cutout - reference frame. - """ - yidx, xidx = np.indices(self._data.shape) - return xidx[~self.mask].ravel(), yidx[~self.mask].ravel() - - @lazyproperty - def _xidx(self): - """ - 1D arrays of x indices of unmasked pixels in the cutout - reference frame. - """ - return self._xy_idx[0] - - @lazyproperty - def _yidx(self): - """ - 1D arrays of y indices of unmasked pixels in the cutout - reference frame. - """ - return self._xy_idx[1] - @property - def _xidx_centered(self): - """ - 1D array of x indices of unmasked pixels, with respect to the - star center, in the cutout reference frame. - """ - return self._xy_idx[0] - self.cutout_center[0] - - @property - def _yidx_centered(self): - """ - 1D array of y indices of unmasked pixels, with respect to the - star center, in the cutout reference frame. + def _xyidx_centered(self): """ - return self._xy_idx[1] - self.cutout_center[1] + 1D arrays of x and y indices of unmasked pixels, with respect + to the star center, in the cutout reference frame. - @lazyproperty - def _data_values(self): - """ - 1D array of unmasked cutout data values. + Returns + ------- + x_centered, y_centered : tuple of `~numpy.ndarray` + The x and y indices centered on the star position. """ - return self.data[~self.mask].ravel() + yidx, xidx = np.indices(self._data.shape) + x_centered = xidx[~self.mask].ravel() - self.cutout_center[0] + y_centered = yidx[~self.mask].ravel() - self.cutout_center[1] + return x_centered, y_centered @lazyproperty def _data_values_normalized(self): @@ -276,14 +320,7 @@ def _data_values_normalized(self): 1D array of unmasked cutout data values, normalized by the star's total flux. """ - return self._data_values / self.flux - - @lazyproperty - def _weight_values(self): - """ - 1D array of unmasked weight values. - """ - return self.weights[~self.mask].ravel() + return self.data[~self.mask].ravel() / self.flux class EPSFStars: @@ -307,45 +344,56 @@ def __init__(self, stars_list): raise TypeError(msg) def __len__(self): + """ + Return the number of stars in this container. + """ return len(self._data) def __getitem__(self, index): + """ + Return a new EPSFStars instance containing the indexed star(s). + """ return self.__class__(self._data[index]) def __delitem__(self, index): + """ + Delete the star at the given index. + """ del self._data[index] def __iter__(self): + """ + Iterate over the stars in this container. + """ yield from self._data - # explicit set/getstate to avoid infinite recursion - # from pickler using __getattr__ def __getstate__(self): + """ + Return state for pickling (avoids __getattr__ recursion). + """ return self.__dict__ def __setstate__(self, d): + """ + Restore state from pickling. + """ self.__dict__ = d def __getattr__(self, attr): - if attr in ['cutout_center', 'center', 'flux', - '_excluded_from_fit']: - result = np.array([getattr(star, attr) for star in self._data]) - else: - result = [getattr(star, attr) for star in self._data] + """ + Delegate attribute access to the underlying star list. + + This allows accessing star attributes (like ``cutout_center``, + ``center``, ``flux``) directly on the EPSFStars container, + returning an array of values from all contained stars. + """ + result = [getattr(star, attr) for star in self._data] + if attr in ['cutout_center', 'center', 'flux', '_excluded_from_fit']: + result = np.array(result) if len(self._data) == 1: result = result[0] return result - def _getattr_flat(self, attr): - values = [] - for item in self._data: - if isinstance(item, LinkedEPSFStar): - values.extend(getattr(item, attr)) - else: - values.append(getattr(item, attr)) - - return np.array(values) - @property def cutout_center_flat(self): """ @@ -356,7 +404,7 @@ def cutout_center_flat(self): Note that when `EPSFStars` contains any `LinkedEPSFStar`, the ``cutout_center`` attribute will be a nested 3D array. """ - return self._getattr_flat('cutout_center') + return np.array([star.cutout_center for star in self.all_stars]) @property def center_flat(self): @@ -369,7 +417,7 @@ def center_flat(self): Note that when `EPSFStars` contains any `LinkedEPSFStar`, the ``center`` attribute will be a nested 3D array. """ - return self._getattr_flat('center') + return np.array([star.center for star in self.all_stars]) @lazyproperty def all_stars(self): @@ -384,7 +432,6 @@ def all_stars(self): stars.extend(item.all_stars) else: stars.append(item) - return stars @property @@ -399,7 +446,6 @@ def all_good_stars(self): if star._excluded_from_fit: continue stars.append(star) - return stars @lazyproperty @@ -432,17 +478,8 @@ def n_good_stars(self): """ return len(self.all_good_stars) - @lazyproperty - def _max_shape(self): - """ - The maximum x and y shapes of all the `EPSFStar` objects - (including linked stars). - """ - return np.max([star.shape for star in self.all_stars], - axis=0) - -class LinkedEPSFStar(EPSFStars): +class LinkedEPSFStar: """ A class to hold a list of `EPSFStar` objects for linked stars. @@ -450,6 +487,10 @@ class LinkedEPSFStar(EPSFStars): represent the same physical star. When building the ePSF, linked stars are constrained to have the same sky coordinates. + Note that unlike `EPSFStars` (which is a collection of potentially + unrelated stars), `LinkedEPSFStar` represents a single logical star + observed in multiple images. + Parameters ---------- stars_list : list of `EPSFStar` objects @@ -468,7 +509,126 @@ def __init__(self, stars_list): 'attribute') raise ValueError(msg) - super().__init__(stars_list) + self._data = list(stars_list) + + def __len__(self): + """ + Return the number of EPSFStar objects in this linked star. + """ + return len(self._data) + + def __getitem__(self, index): + """ + Return the EPSFStar at the given index. + """ + return self._data[index] + + def __iter__(self): + """ + Iterate over the EPSFStar objects in this linked star. + """ + yield from self._data + + def __getattr__(self, attr): + """ + Delegate attribute access to the underlying star list. + + This provides access to common star attributes like cutout_center, + center, flux, etc. as arrays when accessed on the LinkedEPSFStar. + """ + if attr.startswith('_'): + msg = f"'{type(self).__name__}' object has no attribute '{attr}'" + raise AttributeError(msg) + result = [getattr(star, attr) for star in self._data] + if attr in ('cutout_center', 'center', 'flux', '_excluded_from_fit'): + result = np.array(result) + if len(self._data) == 1: + result = result[0] + return result + + def __getstate__(self): + """ + Return state for pickling (avoids __getattr__ recursion). + """ + return self.__dict__ + + def __setstate__(self, d): + """ + Restore state from pickling. + """ + self.__dict__ = d + + @property + def all_stars(self): + """ + A flat list of all `EPSFStar` objects in this linked star. + + Since LinkedEPSFStar only contains EPSFStar objects (not nested + LinkedEPSFStar), this is simply the internal list. + """ + return self._data + + @property + def cutout_center_flat(self): + """ + A `~numpy.ndarray` of the ``(x, y)`` position of all the stars' + centers with respect to the input cutout ``data`` array, as a + 2D array (``n_all_stars`` x 2). + """ + return np.array([star.cutout_center for star in self._data]) + + @property + def center_flat(self): + """ + A `~numpy.ndarray` of the ``(x, y)`` position of all the stars' + centers with respect to the original (large) image (not the + cutout image) as a 2D array (``n_all_stars`` x 2). + """ + return np.array([star.center for star in self._data]) + + @property + def n_stars(self): + """ + The number of `EPSFStar` objects in this linked star. + + For LinkedEPSFStar this is the same as n_all_stars since there + is no nesting. + """ + return len(self._data) + + @property + def n_all_stars(self): + """ + The total number of `EPSFStar` objects in this linked star. + + For LinkedEPSFStar this is the same as n_stars since there + is no nesting. + """ + return len(self._data) + + @property + def n_good_stars(self): + """ + The number of `EPSFStar` objects that have not been excluded + from fitting. + """ + return len(self.all_good_stars) + + @property + def all_good_stars(self): + """ + A list of all `EPSFStar` objects that have not been excluded + from fitting. + """ + return [star for star in self._data if not star._excluded_from_fit] + + @property + def all_excluded(self): + """ + Whether all `EPSFStar` objects in this linked star have been + excluded from fitting during the ePSF build process. + """ + return all(star._excluded_from_fit for star in self._data) def constrain_centers(self): """ @@ -484,43 +644,243 @@ def constrain_centers(self): if len(self._data) < 2: # no linked stars return - idx = np.logical_not(self._excluded_from_fit).nonzero()[0] - if idx.shape == (0,): # pylint: disable=no-member + if self.all_excluded: warnings.warn('Cannot constrain centers of linked stars because ' - 'all the stars have been excluded during the ePSF ' + 'they have all been excluded during the ePSF ' 'build process.', AstropyUserWarning) return - good_stars = [self._data[i] - for i in idx] # pylint: disable=not-an-iterable + # Convert pixel coordinates to sky coordinates + # Note: each star may have a different WCS, so we cannot + # vectorize + good_stars = self.all_good_stars + sky_coords = np.array([ + star.wcs_large.pixel_to_world_values(*star.center) + for star in good_stars]) - coords = [] - for star in good_stars: - wcs = star.wcs_large - xposition = star.center[0] - yposition = star.center[1] - coords.append(wcs.pixel_to_world_values(xposition, yposition)) - - # compute mean cartesian coordinates - lon, lat = np.transpose(coords) - lon *= np.pi / 180.0 - lat *= np.pi / 180.0 - x_mean = np.mean(np.cos(lat) * np.cos(lon)) - y_mean = np.mean(np.cos(lat) * np.sin(lon)) - z_mean = np.mean(np.sin(lat)) - - # convert mean cartesian coordinates back to spherical - hypot = np.hypot(x_mean, y_mean) - mean_lon = np.arctan2(y_mean, x_mean) - mean_lat = np.arctan2(z_mean, hypot) - mean_lon *= 180.0 / np.pi - mean_lat *= 180.0 / np.pi - - # convert mean sky coordinates back to center pixel coordinates - # for each star + # Compute mean sky coordinate using spherical averaging + mean_lon, mean_lat = _compute_mean_sky_coordinate(sky_coords) + + # Convert mean sky coordinate back to pixel coordinates for each + # star for star in good_stars: - center = star.wcs_large.world_to_pixel_values(mean_lon, mean_lat) - star.cutout_center = np.array(center) - star.origin + pixel_center = star.wcs_large.world_to_pixel_values( + mean_lon, mean_lat) + star.cutout_center = np.asarray(pixel_center) - star.origin + + +def _compute_mean_sky_coordinate(sky_coords): + """ + Compute the mean sky coordinate using spherical trigonometry. + + This method properly handles coordinate system singularities by + converting to Cartesian coordinates for averaging, then converting + back to spherical coordinates. + + Parameters + ---------- + sky_coords : array-like, shape (N, 2) + Array of sky coordinates in degrees, where each row contains + (longitude, latitude). + + Returns + ------- + mean_lon, mean_lat : float + Mean longitude and latitude in degrees. + """ + lon, lat = sky_coords.T + lon_rad = np.deg2rad(lon) + lat_rad = np.deg2rad(lat) + + # Convert to Cartesian coordinates for averaging + x_cart = np.cos(lat_rad) * np.cos(lon_rad) + y_cart = np.cos(lat_rad) * np.sin(lon_rad) + z_cart = np.sin(lat_rad) + + # Compute mean Cartesian coordinates + mean_x = np.mean(x_cart) + mean_y = np.mean(y_cart) + mean_z = np.mean(z_cart) + + # Convert mean Cartesian coordinates back to spherical + hypot = np.hypot(mean_x, mean_y) + mean_lon = np.rad2deg(np.arctan2(mean_y, mean_x)) + mean_lat = np.rad2deg(np.arctan2(mean_z, hypot)) + + return mean_lon, mean_lat + + +def _normalize_data_input(data): + """ + Normalize the input data to a list of NDData objects. + + Parameters + ---------- + data : `~astropy.nddata.NDData` or list of `~astropy.nddata.NDData` + The input data to normalize. + + Returns + ------- + data : list of `~astropy.nddata.NDData` + The normalized list of NDData objects. + + Raises + ------ + TypeError + If the input data is not an NDData object or list of NDData + objects. + """ + if isinstance(data, NDData): + return [data] + if isinstance(data, list): + return data + msg = 'data must be a single NDData object or list of NDData objects' + raise TypeError(msg) + + +def _normalize_catalog_input(catalogs): + """ + Normalize the input catalogs to a list of Table objects. + + Parameters + ---------- + catalogs : `~astropy.table.Table` or list of `~astropy.table.Table` + The input catalogs to normalize. + + Returns + ------- + catalogs : list of `~astropy.table.Table` + The normalized list of Table objects. + + Raises + ------ + TypeError + If the input catalogs is not a Table object or list of Table + objects. + """ + if isinstance(catalogs, Table): + return [catalogs] + if isinstance(catalogs, list): + return catalogs + msg = 'catalogs must be a single Table object or list of Table objects' + raise TypeError(msg) + + +def _validate_nddata_list(data): + """ + Validate that a list contains only valid NDData objects. + + Parameters + ---------- + data : list of `~astropy.nddata.NDData` + The list of NDData objects to validate. + + Raises + ------ + TypeError + If any element is not an NDData object. + ValueError + If any NDData object has no data array or non-2D data. + """ + for i, img in enumerate(data): + if not isinstance(img, NDData): + msg = (f'All data elements must be NDData objects. ' + f'Element {i} is {type(img)}') + raise TypeError(msg) + if img.data.ndim != 2: + msg = (f'All NDData objects must contain 2D data. ' + f'Object at index {i} has {img.data.ndim}D data') + raise ValueError(msg) + + +def _validate_catalog_list(catalogs): + """ + Validate that a list contains only valid Table objects. + + Parameters + ---------- + catalogs : list of `~astropy.table.Table` + The list of Table objects to validate. + + Raises + ------ + TypeError + If any element is not a Table object. + """ + for i, cat in enumerate(catalogs): + if not isinstance(cat, Table): + msg = (f'All catalog elements must be Table objects. ' + f'Element {i} is {type(cat)}') + raise TypeError(msg) + if len(cat) == 0: + warnings.warn(f'Catalog at index {i} is empty. No stars will ' + 'be extracted from this catalog.', + AstropyUserWarning) + + +def _validate_coordinate_consistency(data, catalogs): + """ + Validate coordinate system consistency between data and catalogs. + + This function ensures that the necessary coordinate information + (either pixel coordinates or WCS for sky coordinates) is available + to extract stars. + + Parameters + ---------- + data : list of `~astropy.nddata.NDData` + The list of NDData objects. + + catalogs : list of `~astropy.table.Table` + The list of Table catalogs. + + Raises + ------ + ValueError + If the coordinate information is inconsistent or missing. + """ + if len(catalogs) == 1 and len(data) > 1: + # Single catalog with multiple images requires skycoord and WCS + if 'skycoord' not in catalogs[0].colnames: + msg = ('When inputting a single catalog with multiple NDData ' + 'objects, the catalog must have a "skycoord" column.') + raise ValueError(msg) + + if any(img.wcs is None for img in data): + msg = ('When inputting a single catalog with multiple NDData ' + 'objects, each NDData object must have a wcs attribute.') + raise ValueError(msg) + else: + # Multiple catalogs (or single catalog with single image) + for i, cat in enumerate(catalogs): + has_xy = 'x' in cat.colnames and 'y' in cat.colnames + has_skycoord = 'skycoord' in cat.colnames + + if not has_xy and not has_skycoord: + msg = (f'Catalog at index {i} must have either ' + '"x" and "y" columns or a "skycoord" column.') + raise ValueError(msg) + + # If only skycoord is available, ensure WCS is present + if has_skycoord and not has_xy: + data_idx = i if len(data) == len(catalogs) else 0 + if (data_idx < len(data) + and data[data_idx].wcs is None): + msg = (f'When catalog at index {i} contains only skycoord ' + f'positions, the corresponding NDData object must ' + 'have a wcs attribute.') + raise ValueError(msg) + + if any(img.wcs is None for img in data): + msg = ('When inputting catalog(s) with only skycoord ' + 'positions, each NDData object must have a ' + 'wcs attribute.') + raise ValueError(msg) + + if len(data) != len(catalogs): + msg = ('When inputting multiple catalogs, the number of ' + 'catalogs must match the number of input images.') + raise ValueError(msg) def extract_stars(data, catalogs, *, size=(11, 11)): @@ -583,99 +943,139 @@ def extract_stars(data, catalogs, *, size=(11, 11)): stars : `EPSFStars` instance A `EPSFStars` instance containing the extracted stars. """ - if isinstance(data, NDData): - data = [data] + data = _normalize_data_input(data) + catalogs = _normalize_catalog_input(catalogs) + _validate_nddata_list(data) + _validate_catalog_list(catalogs) + _validate_coordinate_consistency(data, catalogs) + size = as_pair('size', size, lower_bound=(3, 0), check_odd=True) - if isinstance(catalogs, Table): - catalogs = [catalogs] + if len(catalogs) == 1: # may include linked stars + stars_out, overlap_fail_count = _extract_linked_stars( + data, catalogs[0], size) + else: # no linked stars + stars_out, overlap_fail_count = _extract_unlinked_stars( + data, catalogs, size) - for img in data: - if not isinstance(img, NDData): - msg = 'data must be a single NDData or list of NDData objects' - raise TypeError(msg) + if overlap_fail_count > 0: + warnings.warn(f'{overlap_fail_count} star(s) were not extracted ' + 'because their cutout region extended beyond the ' + 'input image.', AstropyUserWarning) - for cat in catalogs: - if not isinstance(cat, Table): - msg = 'catalogs must be a single Table or list of Table objects' - raise TypeError(msg) + return EPSFStars(stars_out) - if len(catalogs) == 1 and len(data) > 1: - if 'skycoord' not in catalogs[0].colnames: - msg = ('When inputting a single catalog with multiple NDData ' - 'objects, the catalog must have a "skycoord" column.') - raise ValueError(msg) - if any(img.wcs is None for img in data): - msg = ('When inputting a single catalog with multiple NDData ' - 'objects, each NDData object must have a wcs attribute.') - raise ValueError(msg) - else: - for cat in catalogs: - if 'x' not in cat.colnames or 'y' not in cat.colnames: - if 'skycoord' not in cat.colnames: - msg = ('When inputting multiple catalogs, each one ' - 'must have a "x" and "y" column or a ' - '"skycoord" column.') - raise ValueError(msg) +def _extract_linked_stars(data, catalog, size): + """ + Extract stars that may be linked across multiple images. - if any(img.wcs is None for img in data): - msg = ('When inputting catalog(s) with only skycoord ' - 'positions, each NDData object must have a ' - 'wcs attribute.') - raise ValueError(msg) + Parameters + ---------- + data : list of `~astropy.nddata.NDData` + A list of `~astropy.nddata.NDData` objects containing + the 2D images from which to extract the stars. Each + `~astropy.nddata.NDData` object must have a valid ``wcs`` + attribute. - if len(data) != len(catalogs): - msg = ('When inputting multiple catalogs, the number of ' - 'catalogs must match the number of input images.') - raise ValueError(msg) + catalog : `~astropy.table.Table` + A single catalog of sources to be extracted from the input + ``data``. The center of each source must be defined in + sky coordinates (in a ``skycoord`` column containing a + `~astropy.coordinates.SkyCoord` object). - size = as_pair('size', size, lower_bound=(3, 0), check_odd=True) + size : int or array_like (int) + The extraction box size along each axis. If ``size`` is a scalar + then a square box of size ``size`` will be used. If ``size`` has + two elements, they must be in ``(ny, nx)`` order. - if len(catalogs) == 1: # may included linked stars - use_xy = True - if len(data) > 1: - use_xy = False # linked stars require skycoord positions - - # stars is a list of lists, one list of stars in each image - stars = [_extract_stars(img, catalogs[0], size=size, use_xy=use_xy) - for img in data] - - # transpose the list of lists, to associate linked stars - stars = list(map(list, zip(*stars, strict=True))) - - # remove 'None' stars (i.e., no or partial overlap in one or - # more images) and handle the case of only one "linked" star - stars_out = [] - n_input = len(catalogs[0]) * len(data) - n_extracted = 0 - for star in stars: - good_stars = [i for i in star if i is not None] - n_extracted += len(good_stars) - if not good_stars: - continue # no overlap in any image - - if len(good_stars) == 1: - good_stars = good_stars[0] # only one star, cannot be linked - else: - good_stars = LinkedEPSFStar(good_stars) + Returns + ------- + stars : list of `EPSFStar` or `LinkedEPSFStar` objects + A list of `EPSFStar` and/or `LinkedEPSFStar` instances + containing the extracted stars. Stars that are linked across + multiple images will be represented as a single `LinkedEPSFStar` + instance containing the corresponding `EPSFStar` instances from + each image. Failed extractions are represented as `None`. + + overlap_fail_count : int + The number of stars that failed extraction because their cutout + region extended beyond the input image. + """ + # Use pixel coords only for single image + use_xy = len(data) == 1 - stars_out.append(good_stars) - else: # no linked stars - stars_out = [] - for img, cat in zip(data, catalogs, strict=True): - stars_out.extend(_extract_stars(img, cat, size=size, use_xy=True)) - - n_input = len(stars_out) - stars_out = [star for star in stars_out if star is not None] - n_extracted = len(stars_out) - - n_excluded = n_input - n_extracted - if n_excluded > 0: - warnings.warn(f'{n_excluded} star(s) were not extracted because ' - 'their cutout region extended beyond the input image.', - AstropyUserWarning) + # Extract stars from each image + results = [_extract_stars(img, catalog, size=size, use_xy=use_xy) + for img in data] + stars = [r[0] for r in results] + overlap_fail_count = sum(r[1] for r in results) - return EPSFStars(stars_out) + # Transpose to associate linked stars across images + stars = list(map(list, zip(*stars, strict=True))) + + # Process each potential linked star group + stars_out = [] + for star_group in stars: + good_stars = [star for star in star_group if star is not None] + + if not good_stars: + continue # No valid stars in any image + + if len(good_stars) == 1: + # Single star, not linked + stars_out.append(good_stars[0]) + else: + # Multiple stars - create linked star + stars_out.append(LinkedEPSFStar(good_stars)) + + return stars_out, overlap_fail_count + + +def _extract_unlinked_stars(data, catalogs, size): + """ + Extract stars from individual catalogs (no linking). + + Parameters + ---------- + data : list of `~astropy.nddata.NDData` + A list of `~astropy.nddata.NDData` objects containing + the 2D images from which to extract the stars. + + catalogs : list of `~astropy.table.Table` + A list of catalogs of sources to be extracted from the + input ``data``. Each catalog corresponds to the list of + `~astropy.nddata.NDData` objects input in ``data`` (i.e., a + separate source catalog for each 2D image). The center of each + source can be defined either in pixel coordinates (in ``x`` and + ``y`` columns) or sky coordinates (in a ``skycoord`` column + containing a `~astropy.coordinates.SkyCoord`. + + size : int or array_like (int) + The extraction box size along each axis. If ``size`` is a scalar + then a square box of size ``size`` will be used. If ``size`` has + two elements, they must be in ``(ny, nx)`` order. + + Returns + ------- + stars : list of `EPSFStar` objects + A list of `EPSFStar` instances containing the extracted stars. + Failed extractions are represented as `None`. + + overlap_fail_count : int + The number of stars that failed extraction because their cutout + region extended beyond the input image. + """ + stars_out = [] + total_overlap_fail_count = 0 + for img, cat in zip(data, catalogs, strict=True): + extracted, overlap_fail_count = _extract_stars( + img, cat, size=size, use_xy=True) + stars_out.extend(extracted) + total_overlap_fail_count += overlap_fail_count + + # Filter out None values + return ([star for star in stars_out if star is not None], + total_overlap_fail_count) def _extract_stars(data, catalog, *, size=(11, 11), use_xy=True): @@ -717,57 +1117,205 @@ def _extract_stars(data, catalog, *, size=(11, 11), use_xy=True): ------- stars : list of `EPSFStar` objects A list of `EPSFStar` instances containing the extracted stars. - """ - size = as_pair('size', size, lower_bound=(3, 0), check_odd=True) + Failed extractions are represented as `None`. + overlap_fail_count : int + The number of stars that failed extraction because their cutout + region extended beyond the input image. + """ colnames = catalog.colnames if ('x' not in colnames or 'y' not in colnames) or not use_xy: xcenters, ycenters = data.wcs.world_to_pixel(catalog['skycoord']) + # Convert to numpy arrays if not already + xcenters = np.asarray(xcenters, dtype=float) + ycenters = np.asarray(ycenters, dtype=float) else: - xcenters = catalog['x'].data.astype(float) - ycenters = catalog['y'].data.astype(float) + # Avoid unnecessary copying by getting data directly + xcenters = np.asarray(catalog['x'], dtype=float) + ycenters = np.asarray(catalog['y'], dtype=float) if 'id' in colnames: ids = catalog['id'] else: ids = np.arange(len(catalog), dtype=int) + 1 - if data.uncertainty is None: - weights = np.ones_like(data.data) - elif data.uncertainty.uncertainty_type == 'weights': - weights = np.asanyarray(data.uncertainty.array, dtype=float) - else: - # other uncertainties are converted to the inverse standard - # deviation as the weight; ignore divide-by-zero RuntimeWarning - with warnings.catch_warnings(): - warnings.simplefilter('ignore', RuntimeWarning) - weights = data.uncertainty.represent_as(StdDevUncertainty) - weights = 1.0 / weights.array - if np.any(~np.isfinite(weights)): - warnings.warn('One or more weight values is not finite. Please ' - 'check the input uncertainty values in the input ' - 'NDData object.', AstropyUserWarning) + fluxes = catalog['flux'] if 'flux' in colnames else None - if data.mask is not None: - weights[data.mask] = 0.0 + # Prepare uncertainty handling - defer weight array creation + # until we know which cutouts we need + uncertainty_info = _prepare_uncertainty_info(data) + data_mask = data.mask # Cache mask reference stars = [] - for xcenter, ycenter, obj_id in zip(xcenters, ycenters, ids, strict=True): + nonfinite_weights_count = 0 + overlap_fail_count = 0 + flux_failures = [] # Collect flux estimation failures + all_zero_stars = [] # Collect stars with all-zero data + for i, (xcenter, ycenter) in enumerate(zip(xcenters, ycenters, + strict=True)): try: large_slc, _ = overlap_slices(data.data.shape, size, (ycenter, xcenter), mode='strict') - data_cutout = data.data[large_slc] - weights_cutout = weights[large_slc] except (PartialOverlapError, NoOverlapError): stars.append(None) + overlap_fail_count += 1 continue + # Extract data cutout + data_cutout = data.data[large_slc].copy() # Explicit copy for safety + + # Create weights cutout only for this specific region + weights_cutout, has_nonfinite = _create_weights_cutout( + uncertainty_info, data_mask, large_slc) + if has_nonfinite: + nonfinite_weights_count += 1 + origin = (large_slc[1].start, large_slc[0].start) cutout_center = (xcenter - origin[0], ycenter - origin[1]) - star = EPSFStar(data_cutout, weights=weights_cutout, - cutout_center=cutout_center, origin=origin, - wcs_large=data.wcs, id_label=obj_id) + flux = fluxes[i] if fluxes is not None else None + + try: + # Suppress all-zero warning in EPSFStar (we emit our own below) + with warnings.catch_warnings(): + msg = 'All unmasked data values in star cutout are zero' + warnings.filterwarnings('ignore', message=msg, + category=AstropyUserWarning) + star = EPSFStar(data_cutout, weights=weights_cutout, + cutout_center=cutout_center, origin=origin, + wcs_large=data.wcs, id_label=ids[i], flux=flux) + stars.append(star) + + # Track stars with all-zero data + if hasattr(star, '_has_all_zero_data') and star._has_all_zero_data: + all_zero_stars.append((xcenter, ycenter)) + except ValueError as exc: + # Collect flux estimation failures; emit warnings later + flux_failures.append((xcenter, ycenter, exc)) + stars.append(None) + + # Emit consolidated warning for non-finite weights + if nonfinite_weights_count > 0: + warnings.warn(f'{nonfinite_weights_count} star cutout(s) had ' + 'non-finite weight values which were set to zero. ' + 'Please check the input uncertainty values in the ' + 'NDData object.', AstropyUserWarning) + + # Emit individual flux estimation failure warnings. These may be a + # consequence of having all non-finite weights (data then becomes + # completely masked), so we emit them after the non-finite weights + # warning. + for xcenter, ycenter, exc in flux_failures: + warnings.warn(f'Failed to create EPSFStar for object at ' + f'({xcenter:.1f}, {ycenter:.1f}): {exc}', + AstropyUserWarning) + + # Emit warnings for stars with all-zero data + for xcenter, ycenter in all_zero_stars: + warnings.warn(f'Star at ({xcenter:.1f}, {ycenter:.1f}) has all ' + 'unmasked data values equal to zero', + AstropyUserWarning) + + return stars, overlap_fail_count + + +def _prepare_uncertainty_info(data): + """ + Prepare uncertainty information for efficient weight computation. + + This function analyzes the input NDData's uncertainty and returns + a dictionary with information needed to compute weights for cutout + regions without creating the full weight array. + + Parameters + ---------- + data : `~astropy.nddata.NDData` + The NDData object containing the data and possibly uncertainty. + + Returns + ------- + info : dict + A dictionary with keys: + - 'type' : str + One of 'none', 'weights', or 'uncertainty'. + - 'array' : `~numpy.ndarray` (only if type='weights') + The weight array from the input data. + - 'uncertainty' : `~astropy.nddata.NDUncertainty` (only if + type='uncertainty') + The uncertainty object for on-the-fly conversion to weights. + """ + if data.uncertainty is None: + return {'type': 'none'} - stars.append(star) + if data.uncertainty.uncertainty_type == 'weights': + return { + 'type': 'weights', + 'array': data.uncertainty.array, + } - return stars + # For other uncertainties, prepare the conversion + return { + 'type': 'uncertainty', + 'uncertainty': data.uncertainty, + } + + +def _create_weights_cutout(uncertainty_info, data_mask, slices): + """ + Create a weights cutout for a specific region. + + This avoids creating the full weights array when only a small cutout + is needed, improving memory efficiency. + + Parameters + ---------- + uncertainty_info : dict + Dictionary containing uncertainty information. + + data_mask : `~numpy.ndarray` or None + Mask array for the data. + + slices : tuple of slice + Slices defining the cutout region. + + Returns + ------- + weights_cutout : `~numpy.ndarray` + The weights array for the cutout region. + + has_nonfinite : bool + True if non-finite weights were found and set to zero. + """ + cutout_shape = (slices[0].stop - slices[0].start, + slices[1].stop - slices[1].start) + + if uncertainty_info['type'] == 'none': + weights_cutout = np.ones(cutout_shape, dtype=float) + elif uncertainty_info['type'] == 'weights': + weights_cutout = np.asarray( + uncertainty_info['array'][slices], dtype=float) + else: + # Convert uncertainty to weights for this cutout only + uncertainty_cutout = uncertainty_info['uncertainty'].array[slices] + with warnings.catch_warnings(): + warnings.simplefilter('ignore', RuntimeWarning) + # Convert to standard deviation representation if needed + if hasattr(uncertainty_info['uncertainty'], 'represent_as'): + uncertainty_cutout = ( + uncertainty_info['uncertainty'] + .represent_as(StdDevUncertainty).array[slices]) + # First compute weights, then check for non-finite values + weights_cutout = 1.0 / uncertainty_cutout + + # Check for non-finite weights and track if found + has_nonfinite = not np.all(np.isfinite(weights_cutout)) + if has_nonfinite: + # Set non-finite weights to 0 + weights_cutout = np.where(np.isfinite(weights_cutout), + weights_cutout, 0.0) + + # Apply mask if present + if data_mask is not None: + mask_cutout = data_mask[slices] + weights_cutout[mask_cutout] = 0.0 + + return weights_cutout, has_nonfinite diff --git a/photutils/psf/image_models.py b/photutils/psf/image_models.py index 1e042e76d..44a401140 100644 --- a/photutils/psf/image_models.py +++ b/photutils/psf/image_models.py @@ -4,18 +4,15 @@ """ import copy -import warnings import numpy as np from astropy.modeling import Fittable2DModel, Parameter -from astropy.utils.decorators import deprecated, lazyproperty -from astropy.utils.exceptions import AstropyUserWarning +from astropy.utils.decorators import lazyproperty from scipy.interpolate import RectBivariateSpline -from photutils.aperture import CircularAperture from photutils.utils._parameters import as_pair -__all__ = ['EPSFModel', 'FittableImageModel', 'ImagePSF'] +__all__ = ['ImagePSF'] class ImagePSF(Fittable2DModel): @@ -223,6 +220,18 @@ def deepcopy(self): """ return copy.deepcopy(self) + @property + def shape(self): + """ + The shape of the (oversampled) PSF data array. + + Returns + ------- + shape : tuple + The shape of the (oversampled) PSF data array. + """ + return self.data.shape + @property def origin(self): """ @@ -363,1023 +372,3 @@ def evaluate(self, x, y, flux, x_0, y_0): evaluated_model[invalid] = self.fill_value return evaluated_model - - -@deprecated('2.0.0', alternative='`ImagePSF`') -class FittableImageModel(Fittable2DModel): - r""" - A fittable image model allowing for intensity scaling and - translations. - - This class takes 2D image data and computes the values of - the model at arbitrary locations, including fractional pixel - positions, within the image using spline interpolation provided by - :py:class:`~scipy.interpolate.RectBivariateSpline`. - - The fittable model provided by this class has three model - parameters: an image intensity scaling factor (``flux``) which - is applied to (normalized) image, and two positional parameters - (``x_0`` and ``y_0``) indicating the location of a feature in the - coordinate grid on which the model is to be evaluated. - - Parameters - ---------- - data : 2D `~numpy.ndarray` - Array containing the 2D image. - - flux : float, optional - Intensity scaling factor for image data. If ``flux`` is `None`, - then the normalization constant will be computed so that the - total flux of the model's image data is 1.0. - - x_0, y_0 : float, optional - Position of a feature in the image in the output coordinate grid - on which the model is evaluated. - - normalize : bool, optional - Indicates whether or not the model should be build on normalized - input image data. If true, then the normalization constant (*N*) - is computed so that - - .. math:: - - N \cdot C \cdot \sum\limits_{i,j} D_{i,j} = 1, - - where *N* is the normalization constant, *C* is correction - factor given by the parameter ``normalization_correction``, and - :math:`D_{i,j}` are the elements of the input image ``data`` - array. - - normalization_correction : float, optional - A strictly positive number that represents correction that needs - to be applied to model's data normalization (see *C* in the - equation in the comments to ``normalize`` for more details). - A possible application for this parameter is to account for - aperture correction. Assuming model's data represent a PSF to be - fitted to some target star, we set ``normalization_correction`` - to the aperture correction that needs to be applied to the - model. That is, ``normalization_correction`` in this case should - be set to the ratio between the total flux of the PSF (including - flux outside model's data) to the flux of model's data. Then, - best fitted value of the ``flux`` model parameter will represent - an aperture-corrected flux of the target star. In the case of - aperture correction, ``normalization_correction`` should be a - value larger than one, as the total flux, including regions - outside the aperture, should be larger than the flux inside the - aperture, and thus the correction is applied as an inversely - multiplied factor. - - origin : tuple, None, optional - A reference point in the input image ``data`` array. When origin - is `None`, origin will be set at the middle of the image array. - - If ``origin`` represents the location of a feature (e.g., the - position of an intensity peak) in the input ``data``, then - model parameters ``x_0`` and ``y_0`` show the location of this - peak in an another target image to which this model was fitted. - Fundamentally, it is the coordinate in the model's image data - that should map to coordinate (``x_0``, ``y_0``) of the output - coordinate system on which the model is evaluated. - - Alternatively, when ``origin`` is set to ``(0, 0)``, then model - parameters ``x_0`` and ``y_0`` are shifts by which model's image - should be translated in order to match a target image. - - oversampling : int or array_like (int) - The integer oversampling factor(s) of the ePSF relative to the - input ``stars`` along each axis. If ``oversampling`` is a scalar - then it will be used for both axes. If ``oversampling`` has two - elements, they must be in ``(y, x)`` order. - - fill_value : float, optional - The value to be returned by the `evaluate` or - ``astropy.modeling.Model.__call__`` methods when evaluation is - performed outside the definition domain of the model. - - **kwargs : dict, optional - Additional optional keyword arguments to be passed directly to - the `compute_interpolator` method. See `compute_interpolator` - for more details. - """ - - flux = Parameter(description='Intensity scaling factor for image data.', - default=1.0) - x_0 = Parameter(description='X-position of a feature in the image in ' - 'the output coordinate grid on which the model is ' - 'evaluated.', default=0.0) - y_0 = Parameter(description='Y-position of a feature in the image in ' - 'the output coordinate grid on which the model is ' - 'evaluated.', default=0.0) - - def __init__(self, data, *, flux=flux.default, x_0=x_0.default, - y_0=y_0.default, normalize=False, - normalization_correction=1.0, origin=None, oversampling=1, - fill_value=0.0, **kwargs): - - self._fill_value = fill_value - self._img_norm = None - self._normalization_status = 0 if normalize else 2 - self._store_interpolator_kwargs(**kwargs) - self._oversampling = as_pair('oversampling', oversampling, - lower_bound=(0, 1)) - - if normalization_correction <= 0: - msg = 'normalization_correction must be strictly positive' - raise ValueError(msg) - self._normalization_correction = normalization_correction - self._normalization_constant = 1.0 / self._normalization_correction - - self._data = np.array(data, copy=True, dtype=float) - - if not np.all(np.isfinite(self._data)): - msg = 'All elements of input data must be finite' - raise ValueError(msg) - - # set input image related parameters: - self._ny, self._nx = self._data.shape - self._shape = self._data.shape - if self._data.size < 1: - msg = 'Image data array cannot be zero-sized' - raise ValueError(msg) - - # set the origin of the coordinate system in image's pixel grid: - self.origin = origin - - flux = self._initial_norm(flux, normalize) - - super().__init__(flux, x_0, y_0) - - # initialize interpolator: - self.compute_interpolator(**kwargs) - - def _initial_norm(self, flux, normalize): - - if flux is None: - if self._img_norm is None: - self._img_norm = self._compute_raw_image_norm() - flux = self._img_norm - - self._compute_normalization(normalize) - - return flux - - def _compute_raw_image_norm(self): - """ - Helper function that computes the uncorrected inverse - normalization factor of input image data. This quantity is - computed as the *sum of all pixel values*. - - .. note:: - This function is intended to be overridden in a subclass if - one desires to change the way the normalization factor is - computed. - """ - return np.sum(self._data, dtype=float) - - def _compute_normalization(self, normalize=True): - r""" - Helper function that computes (corrected) normalization factor - of the original image data. - - This quantity is computed as the inverse "raw image norm" - (or total "flux" of model's image) corrected by the - ``normalization_correction``: - - .. math:: - - N = 1/(\Phi * C), - - where :math:`\Phi` is the "total flux" of model's image as - computed by `_compute_raw_image_norm` and *C* is the - normalization correction factor. :math:`\Phi` is computed only - once if it has not been previously computed. Otherwise, the - existing (stored) value of :math:`\Phi` is not modified as - :py:class:`FittableImageModel` does not allow image data to be - modified after the object is created. - - .. note:: - Normally, this function should not be called by the - end-user. It is intended to be overridden in a subclass if - one desires to change the way the normalization factor is - computed. - """ - self._normalization_constant = 1.0 / self._normalization_correction - - if normalize: - # compute normalization constant so that - # N*C*sum(data) = 1: - if self._img_norm is None: - self._img_norm = self._compute_raw_image_norm() - - if self._img_norm != 0.0 and np.isfinite(self._img_norm): - self._normalization_constant /= self._img_norm - self._normalization_status = 0 - - else: - self._normalization_constant = 1.0 - self._normalization_status = 1 - warnings.warn('Overflow encountered while computing ' - 'normalization constant. Normalization ' - 'constant will be set to 1.', AstropyUserWarning) - - else: - self._normalization_status = 2 - - @property - def oversampling(self): - """ - The factor by which the stored image is oversampled. - - An input to this model is multiplied by this factor to yield the - index into the stored image. - """ - return self._oversampling - - @property - def data(self): - """ - Get original image data. - """ - return self._data - - @property - def normalized_data(self): - """ - Get normalized and/or intensity-corrected image data. - """ - return self._normalization_constant * self._data - - @property - def normalization_constant(self): - """ - Get normalization constant. - """ - return self._normalization_constant - - @property - def normalization_status(self): - """ - Get normalization status. - - Possible status values are: - - * 0: **Performed**. Model has been successfully normalized at - user's request. - * 1: **Failed**. Attempt to normalize has failed. - * 2: **NotRequested**. User did not request model to be normalized. - """ - return self._normalization_status - - @property - def normalization_correction(self): - """ - Set/Get flux correction factor. - - .. note:: - When setting correction factor, model's flux will be - adjusted accordingly such that if this model was a good fit - to some target image before, then it will remain a good fit - after correction factor change. - """ - return self._normalization_correction - - @normalization_correction.setter - def normalization_correction(self, normalization_correction): - old_cf = self._normalization_correction - self._normalization_correction = normalization_correction - self._compute_normalization(normalize=self._normalization_status != 2) - - # adjust model's flux so that if this model was a good fit to - # some target image, then it will remain a good fit after - # correction factor change: - self.flux *= normalization_correction / old_cf - - @property - def shape(self): - """ - A tuple of dimensions of the data array in numpy style (ny, nx). - """ - return self._shape - - @property - def nx(self): - """ - Number of columns in the data array. - """ - return self._nx - - @property - def ny(self): - """ - Number of rows in the data array. - """ - return self._ny - - @property - def origin(self): - """ - A tuple of ``x`` and ``y`` coordinates of the origin of the - coordinate system in terms of pixels of model's image. - - When setting the coordinate system origin, a tuple of two - integers or floats may be used. If origin is set to `None`, the - origin of the coordinate system will be set to the middle of the - data array (``(npix-1)/2.0``). - - .. warning:: - Modifying ``origin`` will not adjust (modify) model's - parameters ``x_0`` and ``y_0``. - """ - return (self._x_origin, self._y_origin) - - @origin.setter - def origin(self, origin): - if origin is None: - self._x_origin = (self._nx - 1) / 2.0 - self._y_origin = (self._ny - 1) / 2.0 - elif hasattr(origin, '__iter__') and len(origin) == 2: - self._x_origin, self._y_origin = origin - else: - msg = ('Parameter "origin" must be either None or an iterable ' - 'with two elements') - raise TypeError(msg) - - @property - def x_origin(self): - """ - X-coordinate of the origin of the coordinate system. - """ - return self._x_origin - - @property - def y_origin(self): - """ - Y-coordinate of the origin of the coordinate system. - """ - return self._y_origin - - @property - def fill_value(self): - """ - Fill value to be returned for coordinates outside the domain of - definition of the interpolator. - - If ``fill_value`` is `None`, then values outside the domain of - definition are the ones returned by the interpolator. - """ - return self._fill_value - - @fill_value.setter - def fill_value(self, fill_value): - self._fill_value = fill_value - - def _store_interpolator_kwargs(self, **kwargs): - """ - Store interpolator keyword arguments. - - This function should be called in a subclass whenever model's - interpolator is (re)computed. - """ - self._interpolator_kwargs = copy.deepcopy(kwargs) - - @property - def interpolator_kwargs(self): - """ - Get current interpolator's arguments used when interpolator was - created. - """ - return self._interpolator_kwargs - - def compute_interpolator(self, **kwargs): - """ - Compute/define the interpolating spline. - - This function can be overridden in a subclass to define custom - interpolators. - - Parameters - ---------- - **kwargs : dict, optional - Additional optional keyword arguments: - - * **degree** : int, tuple, optional - Degree of the interpolating spline. A tuple can be used - to provide different degrees for the X- and Y-axes. - Default value is degree=3. - - * **s** : float, optional - Non-negative smoothing factor. Default - value s=0 corresponds to interpolation. See - :py:class:`~scipy.interpolate.RectBivariateSpline` for - more details. - - Notes - ----- - * When subclassing :py:class:`FittableImageModel` for the - purpose of overriding :py:func:`compute_interpolator`, the - :py:func:`evaluate` may need be to overridden depending - on the behavior of the new interpolator. In addition, for - improved future compatibility, make sure that the overriding - method stores keyword arguments ``kwargs`` by calling - ``_store_interpolator_kwargs`` method. - - * Use caution when modifying interpolator's degree or smoothness - in a computationally intensive part of the code as it may - decrease code performance due to the need to recompute - interpolator. - """ - if 'degree' in kwargs: - degree = kwargs['degree'] - if hasattr(degree, '__iter__') and len(degree) == 2: - degx = int(degree[0]) - degy = int(degree[1]) - else: - degx = int(degree) - degy = int(degree) - if degx < 0 or degy < 0: - msg = 'Interpolator degree must be a non-negative integer' - raise ValueError(msg) - else: - degx = 3 - degy = 3 - - smoothness = kwargs.get('s', 0) - - x = np.arange(self._nx, dtype=float) - y = np.arange(self._ny, dtype=float) - self.interpolator = RectBivariateSpline( - x, y, self._data.T, kx=degx, ky=degy, s=smoothness, - ) - - self._store_interpolator_kwargs(**kwargs) - - def evaluate(self, x, y, flux, x_0, y_0, *, use_oversampling=True): - """ - Calculate the value of the image model at the input coordinates - for the given model parameters. - - Parameters - ---------- - x, y : float or array_like - The x and y coordinates at which to evaluate the model. - - flux : float - The total flux of the source. - - x_0, y_0 : float - The x and y positions of the feature in the image in the - output coordinate grid on which the model is evaluated. - - use_oversampling : bool, optional - Whether to use the oversampling factor to calculate the - model pixel indices. The default is `True`, which means the - input indices will be multiplied by this factor. - - Returns - ------- - evaluated_model : `~numpy.ndarray` - The evaluated model. - """ - if use_oversampling: - xi = self._oversampling[1] * (np.asarray(x) - x_0) - yi = self._oversampling[0] * (np.asarray(y) - y_0) - else: - xi = np.asarray(x) - x_0 - yi = np.asarray(y) - y_0 - - xi = xi.astype(float) - yi = yi.astype(float) - xi += self._x_origin - yi += self._y_origin - - f = flux * self._normalization_constant - evaluated_model = f * self.interpolator.ev(xi, yi) - - if self._fill_value is not None: - # find indices of pixels that are outside the input pixel grid and - # set these pixels to the 'fill_value': - invalid = (((xi < 0) | (xi > self._nx - 1)) - | ((yi < 0) | (yi > self._ny - 1))) - evaluated_model[invalid] = self._fill_value - - return evaluated_model - - -class _LegacyEPSFModel(Fittable2DModel): - """ - A class that models an effective PSF (ePSF). - - This class will be removed when the deprecated EPSFModel is removed, - which will require the EPSFBuilder class to be - rewritten/refactored/replaced. - - The EPSFModel is normalized such that the sum of the PSF over the - (undersampled) pixels within the input ``norm_radius`` is 1.0. - This means that when the EPSF is fit to stars, the resulting flux - corresponds to aperture photometry within a circular aperture of - radius ``norm_radius``. - - While this class is a subclass of `FittableImageModel`, it is very - similar. The primary differences/motivation are a few additional - parameters necessary specifically for ePSFs. - - Parameters - ---------- - data : 2D `~numpy.ndarray` - Array containing the 2D image. - - flux : float, optional - Intensity scaling factor for image data. - - x_0, y_0 : float, optional - Position of a feature in the image in the output coordinate grid - on which the model is evaluated. - - normalize : bool, optional - Indicates whether or not the model should be build on normalized - input image data. - - normalization_correction : float, optional - A strictly positive number that represents correction that needs - to be applied to model's data normalization. - - origin : tuple, None, optional - A reference point in the input image ``data`` array. When origin - is `None`, origin will be set at the middle of the image array. - - oversampling : int or array_like (int) - The integer oversampling factor(s) of the ePSF relative to the - input ``stars`` along each axis. If ``oversampling`` is a scalar - then it will be used for both axes. If ``oversampling`` has two - elements, they must be in ``(y, x)`` order. - - fill_value : float, optional - The value to be returned when evaluation is performed outside - the domain of the model. - - norm_radius : float, optional - The radius inside which the ePSF is normalized by the sum over - undersampled integer pixel values inside a circular aperture. - - **kwargs : dict, optional - Additional optional keyword arguments to be passed directly to - the `compute_interpolator` method. See `compute_interpolator` - for more details. - """ - - flux = Parameter(description='Intensity scaling factor for image data.', - default=1.0) - x_0 = Parameter(description='X-position of a feature in the image in ' - 'the output coordinate grid on which the model is ' - 'evaluated.', default=0.0) - y_0 = Parameter(description='Y-position of a feature in the image in ' - 'the output coordinate grid on which the model is ' - 'evaluated.', default=0.0) - - def __init__(self, data, *, flux=flux.default, x_0=x_0.default, - y_0=y_0.default, normalize=False, - normalization_correction=1.0, origin=None, oversampling=1, - fill_value=0.0, norm_radius=5.5, **kwargs): - - self._norm_radius = norm_radius - self._fill_value = fill_value - self._img_norm = None - self._normalization_status = 0 if normalize else 2 - self._store_interpolator_kwargs(**kwargs) - self._oversampling = as_pair('oversampling', oversampling, - lower_bound=(0, 1)) - - if normalization_correction <= 0: - msg = 'normalization_correction must be strictly positive' - raise ValueError(msg) - self._normalization_correction = normalization_correction - self._normalization_constant = 1.0 / self._normalization_correction - - self._data = np.array(data, copy=True, dtype=float) - - if not np.all(np.isfinite(self._data)): - msg = 'All elements of input data must be finite' - raise ValueError(msg) - - # set input image related parameters: - self._ny, self._nx = self._data.shape - self._shape = self._data.shape - if self._data.size < 1: - msg = 'Image data array cannot be zero-sized' - raise ValueError(msg) - - # set the origin of the coordinate system in image's pixel grid: - self.origin = origin - - flux = self._initial_norm(flux, normalize) - - super().__init__(flux, x_0, y_0) - - # initialize interpolator: - self.compute_interpolator(**kwargs) - - def _initial_norm(self, flux, normalize): - if flux is None: - if self._img_norm is None: - self._img_norm = self._compute_raw_image_norm() - flux = self._img_norm - - if normalize: - self._compute_normalization() - else: - self._img_norm = self._compute_raw_image_norm() - - return flux - - def _compute_raw_image_norm(self): - """ - Compute the normalization of input image data as the flux within - a given radius. - """ - xypos = (self._nx / 2.0, self._ny / 2.0) - # How to generalize "radius" if oversampling is - # different along x/y axes (ellipse?) - radius = self._norm_radius * self.oversampling[0] - aper = CircularAperture(xypos, r=radius) - flux, _ = aper.do_photometry(self._data, method='exact') - return flux[0] / np.prod(self.oversampling) - - def _compute_normalization(self, normalize=True): - """ - Helper function that computes (corrected) normalization factor - of the original image data. - - For the ePSF this is defined as the sum over the inner N - (default=5.5) pixels of the non-oversampled image. Will - renormalize the data to the value calculated. - """ - if normalize: - if self._img_norm is None: - if np.sum(self._data) == 0: - self._img_norm = 1 - else: - self._img_norm = self._compute_raw_image_norm() - - if self._img_norm != 0.0 and np.isfinite(self._img_norm): - self._data /= (self._img_norm * self._normalization_correction) - self._normalization_status = 0 - else: - self._normalization_status = 1 - self._img_norm = 1 - warnings.warn('Overflow encountered while computing ' - 'normalization constant. Normalization ' - 'constant will be set to 1.', AstropyUserWarning) - else: - self._normalization_status = 2 - - @property - def normalized_data(self): - """ - Overloaded dummy function that also returns self._data, as the - normalization occurs within _compute_normalization in EPSFModel, - and as such self._data will sum, accounting for - under/oversampled pixels, to 1/self._normalization_correction. - """ - return self._data - - @property - def oversampling(self): - """ - The factor by which the stored image is oversampled. - - An input to this model is multiplied by this factor to yield the - index into the stored image. - """ - return self._oversampling - - @property - def data(self): - """ - Get original image data. - """ - return self._data - - @property - def normalization_constant(self): - """ - Get normalization constant. - """ - return self._normalization_constant - - @property - def normalization_status(self): - """ - Get normalization status. - - Possible status values are: - - * 0: **Performed**. Model has been successfully normalized at - user's request. - * 1: **Failed**. Attempt to normalize has failed. - * 2: **NotRequested**. User did not request model to be normalized. - """ - return self._normalization_status - - @property - def normalization_correction(self): - """ - Set/Get flux correction factor. - - .. note:: - When setting correction factor, model's flux will be - adjusted accordingly such that if this model was a good fit - to some target image before, then it will remain a good fit - after correction factor change. - """ - return self._normalization_correction - - @normalization_correction.setter - def normalization_correction(self, normalization_correction): - old_cf = self._normalization_correction - self._normalization_correction = normalization_correction - self._compute_normalization(normalize=self._normalization_status != 2) - - # adjust model's flux so that if this model was a good fit to - # some target image, then it will remain a good fit after - # correction factor change: - self.flux *= normalization_correction / old_cf - - @property - def shape(self): - """ - A tuple of dimensions of the data array in numpy style (ny, nx). - """ - return self._shape - - @property - def nx(self): - """ - Number of columns in the data array. - """ - return self._nx - - @property - def ny(self): - """ - Number of rows in the data array. - """ - return self._ny - - @property - def origin(self): - """ - A tuple of ``x`` and ``y`` coordinates of the origin of the - coordinate system in terms of pixels of model's image. - - When setting the coordinate system origin, a tuple of two - integers or floats may be used. If origin is set to `None`, the - origin of the coordinate system will be set to the middle of the - data array (``(npix-1)/2.0``). - - .. warning:: - Modifying ``origin`` will not adjust (modify) model's - parameters ``x_0`` and ``y_0``. - """ - return (self._x_origin, self._y_origin) - - @origin.setter - def origin(self, origin): - if origin is None: - self._x_origin = (self._nx - 1) / 2.0 / self.oversampling[1] - self._y_origin = (self._ny - 1) / 2.0 / self.oversampling[0] - elif (hasattr(origin, '__iter__') and len(origin) == 2): - self._x_origin, self._y_origin = origin - else: - msg = ('Parameter "origin" must be either None or an iterable ' - 'with two elements') - raise TypeError(msg) - - @property - def x_origin(self): - """ - X-coordinate of the origin of the coordinate system. - """ - return self._x_origin - - @property - def y_origin(self): - """ - Y-coordinate of the origin of the coordinate system. - """ - return self._y_origin - - @property - def fill_value(self): - """ - Fill value to be returned for coordinates outside the domain of - definition of the interpolator. - - If ``fill_value`` is `None`, then values outside the domain of - definition are the ones returned by the interpolator. - """ - return self._fill_value - - @fill_value.setter - def fill_value(self, fill_value): - self._fill_value = fill_value - - def _store_interpolator_kwargs(self, **kwargs): - """ - Store interpolator keyword arguments. - - This function should be called in a subclass whenever model's - interpolator is (re)computed. - """ - self._interpolator_kwargs = copy.deepcopy(kwargs) - - @property - def interpolator_kwargs(self): - """ - Get current interpolator's arguments used when interpolator was - created. - """ - return self._interpolator_kwargs - - def compute_interpolator(self, **kwargs): - """ - Compute/define the interpolating spline. - - This function can be overridden in a subclass to define custom - interpolators. - - Parameters - ---------- - **kwargs : dict, optional - Additional optional keyword arguments: - - * **degree** : int, tuple, optional - Degree of the interpolating spline. A tuple can be used - to provide different degrees for the X- and Y-axes. - Default value is degree=3. - - * **s** : float, optional - Non-negative smoothing factor. Default - value s=0 corresponds to interpolation. See - :py:class:`~scipy.interpolate.RectBivariateSpline` for - more details. - - Notes - ----- - * When subclassing :py:class:`FittableImageModel` for the - purpose of overriding :py:func:`compute_interpolator`, the - :py:func:`evaluate` may need to be overridden depending - on the behavior of the new interpolator. In addition, for - improved future compatibility, make sure that the overriding - method stores keyword arguments ``kwargs`` by calling - ``_store_interpolator_kwargs`` method. - - * Use caution when modifying interpolator's degree or smoothness - in a computationally intensive part of the code as it may - decrease code performance due to the need to recompute - interpolator. - """ - if 'degree' in kwargs: - degree = kwargs['degree'] - if hasattr(degree, '__iter__') and len(degree) == 2: - degx = int(degree[0]) - degy = int(degree[1]) - else: - degx = int(degree) - degy = int(degree) - if degx < 0 or degy < 0: - msg = 'Interpolator degree must be a non-negative integer' - raise ValueError(msg) - else: - degx = 3 - degy = 3 - - smoothness = kwargs.get('s', 0) - - # Interpolator must be set to interpolate on the undersampled - # pixel grid, going from 0 to len(undersampled_grid) - x = np.arange(self._nx, dtype=float) / self.oversampling[1] - y = np.arange(self._ny, dtype=float) / self.oversampling[0] - self.interpolator = RectBivariateSpline( - x, y, self._data.T, kx=degx, ky=degy, s=smoothness, - ) - - self._store_interpolator_kwargs(**kwargs) - - def evaluate(self, x, y, flux, x_0, y_0): - """ - Calculate the value of the image model at the input coordinates - for the given model parameters. - - Parameters - ---------- - x, y : float or array_like - The x and y coordinates at which to evaluate the model. - - flux : float - The total flux of the source. - - x_0, y_0 : float - The x and y positions of the feature in the image in the - output coordinate grid on which the model is evaluated. - - Returns - ------- - evaluated_model : `~numpy.ndarray` - The evaluated model. - """ - xi = np.asarray(x) - x_0 + self._x_origin - yi = np.asarray(y) - y_0 + self._y_origin - - evaluated_model = flux * self.interpolator.ev(xi, yi) - - if self._fill_value is not None: - # find indices of pixels that are outside the input pixel - # grid and set these pixels to the 'fill_value': - invalid = (((xi < 0) | (xi > (self._nx - 1) - / self.oversampling[1])) - | ((yi < 0) | (yi > (self._ny - 1) - / self.oversampling[0]))) - evaluated_model[invalid] = self._fill_value - - return evaluated_model - - -@deprecated('2.0.0', alternative='`ImagePSF`') -class EPSFModel(_LegacyEPSFModel): - """ - A class that models an effective PSF (ePSF). - - The EPSFModel is normalized such that the sum of the PSF over the - (undersampled) pixels within the input ``norm_radius`` is 1.0. - This means that when the EPSF is fit to stars, the resulting flux - corresponds to aperture photometry within a circular aperture of - radius ``norm_radius``. - - While this class is a subclass of `FittableImageModel`, it is very - similar. The primary differences/motivation are a few additional - parameters necessary specifically for ePSFs. - - Parameters - ---------- - data : 2D `~numpy.ndarray` - Array containing the 2D image. - - flux : float, optional - Intensity scaling factor for image data. - - x_0, y_0 : float, optional - Position of a feature in the image in the output coordinate grid - on which the model is evaluated. - - normalize : bool, optional - Indicates whether or not the model should be build on normalized - input image data. - - normalization_correction : float, optional - A strictly positive number that represents correction that needs - to be applied to model's data normalization. - - origin : tuple, None, optional - A reference point in the input image ``data`` array. When origin - is `None`, origin will be set at the middle of the image array. - - oversampling : int or array_like (int) - The integer oversampling factor(s) of the ePSF relative to the - input ``stars`` along each axis. If ``oversampling`` is a scalar - then it will be used for both axes. If ``oversampling`` has two - elements, they must be in ``(y, x)`` order. - - fill_value : float, optional - The value to be returned when evaluation is performed outside - the domain of the model. - - norm_radius : float, optional - The radius inside which the ePSF is normalized by the sum over - undersampled integer pixel values inside a circular aperture. - - **kwargs : dict, optional - Additional optional keyword arguments to be passed directly to - the "compute_interpolator" method. See "compute_interpolator" - for more details. - """ - - flux = Parameter(description='Intensity scaling factor for image data.', - default=1.0) - x_0 = Parameter(description='X-position of a feature in the image in ' - 'the output coordinate grid on which the model is ' - 'evaluated.', default=0.0) - y_0 = Parameter(description='Y-position of a feature in the image in ' - 'the output coordinate grid on which the model is ' - 'evaluated.', default=0.0) - - def __init__(self, data, *, flux=flux.default, x_0=x_0.default, - y_0=y_0.default, normalize=True, normalization_correction=1.0, - origin=None, oversampling=1, fill_value=0.0, - norm_radius=5.5, **kwargs): - - super().__init__(data=data, flux=flux, x_0=x_0, y_0=y_0, - normalize=normalize, - normalization_correction=normalization_correction, - origin=origin, oversampling=oversampling, - fill_value=fill_value, norm_radius=norm_radius, - **kwargs) diff --git a/photutils/psf/iterative.py b/photutils/psf/iterative.py index 1573cd0e2..e2cd3d7a4 100644 --- a/photutils/psf/iterative.py +++ b/photutils/psf/iterative.py @@ -93,8 +93,10 @@ class for more information. fitter_maxiters : int, optional The maximum number of iterations in which the ``fitter`` is - called for each source. The value can be increased if the fit is - not converging for sources. + called for each source. The value can be increased if the fit + is not converging for sources. This parameter is passed to the + ``fitter`` if it supports the ``maxiter`` parameter and ignored + otherwise. xy_bounds : `None`, float, or 2-tuple of float, optional The maximum distance in pixels that a fitted source can be from diff --git a/photutils/psf/photometry.py b/photutils/psf/photometry.py index 7a6e4ec75..2c194dd37 100644 --- a/photutils/psf/photometry.py +++ b/photutils/psf/photometry.py @@ -279,8 +279,10 @@ class for more information. fitter_maxiters : int, optional The maximum number of iterations in which the ``fitter`` is - called for each source. The value can be increased if the fit is - not converging for sources. + called for each source. The value can be increased if the fit + is not converging for sources. This parameter is passed to the + ``fitter`` if it supports the ``maxiter`` parameter and ignored + otherwise. xy_bounds : `None`, float, or 2-tuple of float, optional The maximum distance in pixels that a fitted source can be from diff --git a/photutils/psf/tests/test_epsf.py b/photutils/psf/tests/test_epsf.py deleted file mode 100644 index f56a0750d..000000000 --- a/photutils/psf/tests/test_epsf.py +++ /dev/null @@ -1,238 +0,0 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst -""" -Tests for the epsf module. -""" - -import itertools - -import numpy as np -import pytest -from astropy.modeling.fitting import TRFLSQFitter -from astropy.nddata import (InverseVariance, NDData, StdDevUncertainty, - VarianceUncertainty) -from astropy.stats import SigmaClip -from astropy.table import Table -from astropy.utils.exceptions import AstropyUserWarning -from numpy.testing import assert_allclose - -from photutils.datasets import make_model_image -from photutils.psf import CircularGaussianPRF, make_psf_model_image -from photutils.psf.epsf import EPSFBuilder, EPSFFitter -from photutils.psf.epsf_stars import EPSFStars, extract_stars - - -@pytest.fixture -def epsf_test_data(): - """ - Create a simulated image for testing. - """ - fwhm = 2.7 - psf_model = CircularGaussianPRF(flux=1, fwhm=fwhm) - model_shape = (9, 9) - n_sources = 100 - shape = (750, 750) - data, true_params = make_psf_model_image(shape, psf_model, n_sources, - model_shape=model_shape, - flux=(500, 700), - min_separation=25, - border_size=25, seed=0) - - nddata = NDData(data) - init_stars = Table() - init_stars['x'] = true_params['x_0'].astype(int) - init_stars['y'] = true_params['y_0'].astype(int) - - return { - 'fwhm': fwhm, - 'data': data, - 'nddata': nddata, - 'init_stars': init_stars, - } - - -class TestEPSFBuild: - - def test_extract_stars(self, epsf_test_data): - size = 25 - stars = extract_stars(epsf_test_data['nddata'], - epsf_test_data['init_stars'], - size=size) - - assert len(stars) == len(epsf_test_data['init_stars']) - assert isinstance(stars, EPSFStars) - assert isinstance(stars[0], EPSFStars) - assert stars[0].data.shape == (size, size) - - def test_extract_stars_uncertainties(self, epsf_test_data): - rng = np.random.default_rng(0) - shape = epsf_test_data['nddata'].data.shape - error = np.abs(rng.normal(loc=0, scale=1, size=shape)) - uncertainty1 = StdDevUncertainty(error) - uncertainty2 = uncertainty1.represent_as(VarianceUncertainty) - uncertainty3 = uncertainty1.represent_as(InverseVariance) - ndd1 = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty1) - ndd2 = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty2) - ndd3 = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty3) - - size = 25 - match = 'were not extracted because their cutout region extended' - ndd_inputs = (ndd1, ndd2, ndd3) - - outputs = [extract_stars(ndd_input, epsf_test_data['init_stars'], - size=size) for ndd_input in ndd_inputs] - - for stars in outputs: - assert len(stars) == len(epsf_test_data['init_stars']) - assert isinstance(stars, EPSFStars) - assert isinstance(stars[0], EPSFStars) - assert stars[0].data.shape == (size, size) - assert stars[0].weights.shape == (size, size) - - assert_allclose(outputs[0].weights, outputs[1].weights) - assert_allclose(outputs[0].weights, outputs[2].weights) - - uncertainty = StdDevUncertainty(np.zeros(shape)) - ndd = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty) - - match = 'One or more weight values is not finite' - with pytest.warns(AstropyUserWarning, match=match): - stars = extract_stars(ndd, epsf_test_data['init_stars'][0:3], - size=size) - - @pytest.mark.parametrize('shape', [(25, 25), (19, 25), (25, 19)]) - def test_epsf_build(self, epsf_test_data, shape): - """ - This is an end-to-end test of EPSFBuilder on a simulated image. - """ - oversampling = 2 - stars = extract_stars(epsf_test_data['nddata'], - epsf_test_data['init_stars'][:10], - size=shape) - epsf_builder = EPSFBuilder(oversampling=oversampling, maxiters=5, - progress_bar=False, norm_radius=10, - recentering_maxiters=5) - epsf, fitted_stars = epsf_builder(stars) - - ref_size = np.array(shape) * oversampling + 1 - assert epsf.data.shape == tuple(ref_size) - - # Verify basic EPSF properties - assert len(fitted_stars) == 10 - assert epsf.data.sum() > 2 # Check it has reasonable total flux - assert epsf.data.max() > 0.01 # Should have a peak - - # Check that the center region has higher values than edges - center_y, center_x = np.array(ref_size) // 2 - center_val = epsf.data[center_y, center_x] - edge_val = epsf.data[0, 0] - assert center_val > edge_val # Center should be brighter than edge - - # Test that residual computation works (basic functionality test) - resid_star = fitted_stars[0].compute_residual_image(epsf) - assert isinstance(resid_star, np.ndarray) - assert resid_star.shape == fitted_stars[0].data.shape - - def test_epsf_fitting_bounds(self, epsf_test_data): - size = 25 - oversampling = 4 - stars = extract_stars(epsf_test_data['nddata'], - epsf_test_data['init_stars'], - size=size) - - epsf_builder = EPSFBuilder(oversampling=oversampling, maxiters=8, - progress_bar=True, norm_radius=25, - recentering_maxiters=5, - fitter=EPSFFitter(fit_boxsize=31), - smoothing_kernel='quadratic') - - # With a boxsize larger than the cutout we expect the fitting to - # fail for all stars, due to star._fit_error_status - match1 = 'The ePSF fitting failed for all stars' - match2 = r'The star at .* cannot be fit because its fitting region ' - with (pytest.raises(ValueError, match=match1), - pytest.warns(AstropyUserWarning, match=match2)): - epsf_builder(stars) - - def test_epsf_build_invalid_fitter(self): - """ - Test that the input fitter is an EPSFFitter instance. - """ - match = 'fitter must be an EPSFFitter instance' - with pytest.raises(TypeError, match=match): - EPSFBuilder(fitter=EPSFFitter, maxiters=3) - - with pytest.raises(TypeError, match=match): - EPSFBuilder(fitter=TRFLSQFitter(), maxiters=3) - - with pytest.raises(TypeError, match=match): - EPSFBuilder(fitter=TRFLSQFitter, maxiters=3) - - -def test_epsfbuilder_inputs(): - # invalid inputs - match = "'oversampling' must be specified" - with pytest.raises(ValueError, match=match): - EPSFBuilder(oversampling=None) - match = 'oversampling must be > 0' - with pytest.raises(ValueError, match=match): - EPSFBuilder(oversampling=-1) - match = 'maxiters must be a positive number' - with pytest.raises(ValueError, match=match): - EPSFBuilder(maxiters=-1) - match = 'oversampling must be > 0' - with pytest.raises(ValueError, match=match): - EPSFBuilder(oversampling=[-1, 4]) - - # valid inputs - EPSFBuilder(oversampling=6) - EPSFBuilder(oversampling=[4, 6]) - - # invalid inputs - for sigma_clip in [None, [], 'a']: - match = 'sigma_clip must be an astropy.stats.SigmaClip instance' - with pytest.raises(TypeError, match=match): - EPSFBuilder(sigma_clip=sigma_clip) - - # valid inputs - EPSFBuilder(sigma_clip=SigmaClip(sigma=2.5, cenfunc='mean', maxiters=2)) - - -@pytest.mark.parametrize('oversamp', [3, 4]) -def test_epsf_build_oversampling(oversamp): - offsets = (np.arange(oversamp) * 1.0 / oversamp - 0.5 + 1.0 - / (2.0 * oversamp)) - xydithers = np.array(list(itertools.product(offsets, offsets))) - xdithers = np.transpose(xydithers)[0] - ydithers = np.transpose(xydithers)[1] - - nstars = oversamp**2 - fwhm = 7.0 - sources = Table() - offset = 50 - size = oversamp * offset + offset - y, x = np.mgrid[0:oversamp, 0:oversamp] * offset + offset - sources['x_0'] = x.ravel() + xdithers - sources['y_0'] = y.ravel() + ydithers - sources['fwhm'] = np.full((nstars,), fwhm) - - psf_model = CircularGaussianPRF(fwhm=fwhm) - shape = (size, size) - data = make_model_image(shape, psf_model, sources) - nddata = NDData(data=data) - stars_tbl = Table() - stars_tbl['x'] = sources['x_0'] - stars_tbl['y'] = sources['y_0'] - stars = extract_stars(nddata, stars_tbl, size=25) - epsf_builder = EPSFBuilder(oversampling=oversamp, maxiters=15, - progress_bar=False, recentering_maxiters=20) - epsf, _ = epsf_builder(stars) - - # input PSF shape - size = epsf.data.shape[0] - cen = (size - 1) / 2 - fwhm2 = oversamp * fwhm - m = CircularGaussianPRF(flux=1, x_0=cen, y_0=cen, fwhm=fwhm2) - yy, xx = np.mgrid[0:size, 0:size] - psf = m(xx, yy) - - assert_allclose(epsf.data, psf * epsf.data.sum(), atol=2.5e-4) diff --git a/photutils/psf/tests/test_epsf_builder.py b/photutils/psf/tests/test_epsf_builder.py new file mode 100644 index 000000000..dff289ff2 --- /dev/null +++ b/photutils/psf/tests/test_epsf_builder.py @@ -0,0 +1,1751 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +Tests for the epsf_builder module. +""" + +import itertools +import warnings +from unittest.mock import patch + +import numpy as np +import pytest +from astropy.modeling.fitting import TRFLSQFitter +from astropy.nddata import NDData +from astropy.table import Table +from astropy.utils.exceptions import AstropyUserWarning +from numpy.testing import assert_allclose + +from photutils.centroids import (centroid_1dg, centroid_2dg, centroid_com, + centroid_quadratic) +from photutils.datasets import make_model_image +from photutils.psf import (CircularGaussianPRF, EPSFBuilder, EPSFBuildResult, + EPSFFitter, EPSFStar, EPSFStars, ImagePSF, + extract_stars, make_psf_model_image) +from photutils.psf.epsf_builder import (_CoordinateTransformer, _EPSFValidator, + _ProgressReporter, _SmoothingKernel) +from photutils.psf.epsf_stars import LinkedEPSFStar +from photutils.utils._optional_deps import HAS_TQDM + + +@pytest.fixture +def epsf_test_data(): + """ + Create a simulated image for testing. + """ + fwhm = 2.7 + psf_model = CircularGaussianPRF(flux=1, fwhm=fwhm) + model_shape = (9, 9) + n_sources = 100 + shape = (750, 750) + data, true_params = make_psf_model_image(shape, psf_model, n_sources, + model_shape=model_shape, + flux=(500, 700), + min_separation=25, + border_size=25, seed=0) + + nddata = NDData(data) + init_stars = Table() + init_stars['x'] = true_params['x_0'] + init_stars['y'] = true_params['y_0'] + + return { + 'fwhm': fwhm, + 'data': data, + 'nddata': nddata, + 'init_stars': init_stars, + } + + +@pytest.fixture +def epsf_fitter_data(epsf_test_data): + """ + Create extracted stars and an ePSF for testing EPSFFitter. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:4], size=11) + builder = EPSFBuilder(oversampling=1, maxiters=2, progress_bar=False) + epsf, _ = builder(stars) + return {'stars': stars, 'epsf': epsf} + + +class TestSmoothingKernel: + """ + Tests for the _SmoothingKernel class. + """ + + @pytest.mark.parametrize('kernel_type', ['quartic', 'quadratic']) + def test_get_kernel(self, kernel_type): + """ + Test quartic kernel retrieval. + """ + kernel = _SmoothingKernel.get_kernel(kernel_type) + assert isinstance(kernel, np.ndarray) + assert kernel.shape == (5, 5) + if kernel_type == 'quartic': + expected_sum = _SmoothingKernel.QUARTIC_KERNEL.sum() + else: + expected_sum = _SmoothingKernel.QUADRATIC_KERNEL.sum() + assert np.isclose(kernel.sum(), expected_sum) + + def test_get_kernel_custom_array(self): + """ + Test custom array kernel retrieval. + """ + custom_kernel = np.ones((3, 3)) / 9.0 + kernel = _SmoothingKernel.get_kernel(custom_kernel) + assert isinstance(kernel, np.ndarray) + assert kernel.shape == (3, 3) + assert np.allclose(kernel, custom_kernel) + + def test_get_kernel_invalid_type(self): + """ + Test invalid kernel type raises TypeError. + """ + with pytest.raises(TypeError, match='Unsupported kernel type'): + _SmoothingKernel.get_kernel('invalid') + + @pytest.mark.parametrize('kernel_type', ['quartic', 'quadratic']) + def test_apply_smoothing(self, kernel_type): + """ + Test smoothing with quartic kernel. + """ + data = np.ones((10, 10)) + smoothed = _SmoothingKernel.apply_smoothing(data, kernel_type) + assert isinstance(smoothed, np.ndarray) + assert smoothed.shape == data.shape + assert_allclose(smoothed.sum(), data.sum()) + + def test_apply_smoothing_custom_kernel(self): + """ + Test smoothing with custom kernel. + """ + data = np.ones((10, 10)) + kernel = np.array([[0, 0.1, 0], [0.1, 0.6, 0.1], [0, 0.1, 0]]) + smoothed = _SmoothingKernel.apply_smoothing(data, kernel) + assert isinstance(smoothed, np.ndarray) + assert smoothed.shape == data.shape + assert_allclose(smoothed.sum(), data.sum()) + + def test_apply_smoothing_none(self): + """ + Test smoothing with None returns original data. + """ + data = np.ones((10, 10)) + result = _SmoothingKernel.apply_smoothing(data, None) + assert result is data # Should return same object + + +class TestEPSFValidator: + """ + Tests for the _EPSFValidator class. + """ + + def test_validate_oversampling_valid(self): + """ + Test valid oversampling validation. + """ + result = _EPSFValidator.validate_oversampling(2) + assert np.array_equal(result, (2, 2)) + + result = _EPSFValidator.validate_oversampling((3, 4)) + assert np.array_equal(result, (3, 4)) + + def test_validate_oversampling_none(self): + """ + Test validate_oversampling with None input. + """ + match = "'oversampling' must be specified" + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_oversampling(None) + + def test_validate_oversampling_invalid_exception(self): + """ + Test oversampling validation with invalid input. + """ + # Test with invalid input that should raise exception from + # as_pair + match = 'Invalid oversampling parameter' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_oversampling('invalid') + + def test_validate_oversampling_invalid_exception_with_context(self): + """ + Test oversampling validation with context and invalid input. + """ + msg = 'test_context: Invalid oversampling parameter' + with pytest.raises(ValueError, match=msg): + _EPSFValidator.validate_oversampling('invalid', + context='test_context') + + def test_validate_oversampling_zero_values(self): + """ + Test oversampling validation with zero values. + """ + with pytest.raises(ValueError, match='oversampling must be > 0'): + _EPSFValidator.validate_oversampling((0, 2)) + + msg = ('test_context: Invalid oversampling parameter - ' + 'oversampling must be > 0') + with pytest.raises(ValueError, match=msg): + _EPSFValidator.validate_oversampling((0, 2), + context='test_context') + + def test_validate_oversampling_as_pair_exception_with_context(self): + """ + Test oversampling validation when as_pair raises exception. + """ + # Use a tuple with wrong number of elements to trigger as_pair error + msg = 'test_ctx: Invalid oversampling parameter' + with pytest.raises(ValueError, match=msg): + _EPSFValidator.validate_oversampling((1, 2, 3), + context='test_ctx') + + def test_validate_shape_compatibility(self, epsf_test_data): + """ + Test shape compatibility validation. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + # Should not raise an exception for compatible shapes + _EPSFValidator.validate_shape_compatibility(stars, (1, 1)) + + def test_validate_shape_compatibility_custom_shape(self, epsf_test_data): + """ + Test shape compatibility with custom shape. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + # Test with specific shape + _EPSFValidator.validate_shape_compatibility(stars, (1, 1), + shape=(21, 21)) + + def test_validate_shape_compatibility_empty_stars(self): + """ + Test shape compatibility with empty star list. + """ + empty_stars = EPSFStars([]) + match = 'Cannot validate shape compatibility with empty star list' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_shape_compatibility(empty_stars, (1, 1)) + + def test_validate_shape_compatibility_small_stars(self): + """ + Test shape compatibility with very small star cutouts. + """ + # Create very small star (2x2 pixels) + small_data = np.ones((2, 2)) + small_star = EPSFStar(small_data, cutout_center=(1, 1)) + small_stars = EPSFStars([small_star]) + + match = r'Found .* star.*with very small dimensions' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_shape_compatibility(small_stars, (1, 1)) + + def test_validate_shape_compatibility_invalid_shape_type(self): + """ + Test shape compatibility with invalid shape type. + """ + data = np.ones((5, 5)) + star = EPSFStar(data, cutout_center=(2, 2)) + stars = EPSFStars([star]) + + match = 'Shape must be a 2-element sequence' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_shape_compatibility(stars, (1, 1), + shape=(10, 10, 10)) + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_shape_compatibility(stars, (1, 1), + shape='invalid') + + def test_validate_shape_compatibility_incompatible_shape(self): + """ + Test shape compatibility with incompatible shape. + """ + data = np.ones((5, 5)) + star = EPSFStar(data, cutout_center=(2, 2)) + stars = EPSFStars([star]) + + # Request shape that's too small + match = r'Requested ePSF shape .* is incompatible' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_shape_compatibility(stars, (2, 2), + shape=(5, 5)) + + def test_validate_shape_compatibility_even_dimensions_warning(self): + """ + Test shape compatibility with even dimensions warning. + """ + data = np.ones((5, 5)) + star = EPSFStar(data, cutout_center=(2, 2)) + stars = EPSFStars([star]) + + # Test even dimensions trigger warning + match = 'ePSF shape .* has even dimensions' + with pytest.warns(UserWarning, match=match): + _EPSFValidator.validate_shape_compatibility(stars, (1, 1), + shape=(20, 20)) + + def test_validate_stars_empty_list(self): + """ + Test validate_stars with empty star list. + """ + empty_stars = EPSFStars([]) + match = 'EPSFStars object must contain at least one star' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(empty_stars) + + match = 'test_context: EPSFStars object must contain at least one star' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(empty_stars, context='test_context') + + def test_validate_stars_non_finite_data(self): + """ + Test validate_stars with non-finite data. + """ + # Create star with all NaN data - need to provide explicit flux + # since flux estimation would fail with all NaN data + data = np.full((5, 5), np.nan) + match = 'Input data array contains invalid data that will be masked' + with pytest.warns(AstropyUserWarning, match=match): + star = EPSFStar(data, cutout_center=(2, 2), flux=1.0) + + match = r'Found [\s\S]* invalid stars [\s\S]* no finite data values' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars([star]) + + def test_validate_stars_too_small(self): + """ + Test validate_stars with very small stars. + """ + # Create very small star (2x2 pixels) + data = np.ones((2, 2)) + star = EPSFStar(data, cutout_center=(1, 1)) + + match = r'Found [\s\S]* invalid stars [\s\S]* too small' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars([star]) + + def test_validate_stars_missing_cutout_center(self): + """ + Test validate_stars with star missing cutout_center. + """ + # Create mock star without cutout_center + class MockStar: + def __init__(self): + self.data = np.ones((5, 5)) + self.shape = (5, 5) + + mock_stars = [MockStar()] + + match = r'Found .* invalid stars [\s\S]* missing cutout_center' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(mock_stars) + + def test_validate_stars_validation_error(self): + """ + Test validate_stars with validation error during processing. + """ + # Create mock star that raises error during validation + class MockStar: + def __init__(self): + self.data = np.ones((5, 5)) + + @property + def shape(self): + msg = 'Test error' + raise ValueError(msg) + + mock_stars = [MockStar()] + + match = r'Found .* invalid stars [\s\S]* validation error' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(mock_stars) + + def test_validate_stars_multiple_invalid(self): + """ + Test validate_stars with multiple invalid stars. + """ + # Create multiple mock stars with different issues + class MockStar1: + def __init__(self): + self.data = None + + class MockStar2: + def __init__(self): + self.data = np.ones((2, 2)) # Too small + self.shape = (2, 2) + + mock_stars = [MockStar1(), MockStar2()] + + match = r'Found 2 invalid stars [\s\S]* too small' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(mock_stars) + + def test_validate_stars_more_than_5_invalid(self): + """ + Test validate_stars with more than 5 invalid stars. + """ + # Create 7 mock stars with missing data + class MockStar: + def __init__(self): + self.data = None + + mock_stars = [MockStar() for _ in range(7)] + + match = r'Found 7 invalid stars [\s\S]* missing data' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(mock_stars) + + def test_validate_stars_context_with_invalid(self): + """ + Test validate_stars with context and invalid stars. + """ + class MockStar: + def __init__(self): + self.data = None + + mock_stars = [MockStar()] + + match = 'my_context: Found 1 invalid stars' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_stars(mock_stars, context='my_context') + + def test_validate_stars_valid(self): + """ + Test validate_stars with valid stars. + """ + # Create valid stars + data1 = np.ones((5, 5)) + data2 = np.ones((6, 6)) + star1 = EPSFStar(data1, cutout_center=(2, 2)) + star2 = EPSFStar(data2, cutout_center=(3, 3)) + + # Should not raise any exception + _EPSFValidator.validate_stars([star1, star2]) + + def test_validate_center_accuracy_valid(self): + """ + Test validate_center_accuracy with valid inputs. + """ + # Test valid values + _EPSFValidator.validate_center_accuracy(0.001) + _EPSFValidator.validate_center_accuracy(0.01) + _EPSFValidator.validate_center_accuracy(0.1) + _EPSFValidator.validate_center_accuracy(1.0) + + def test_validate_center_accuracy_invalid_type(self): + """ + Test validate_center_accuracy with invalid type. + """ + match = 'center_accuracy must be a number' + with pytest.raises(TypeError, match=match): + _EPSFValidator.validate_center_accuracy('0.001') + + def test_validate_center_accuracy_non_positive(self): + """ + Test validate_center_accuracy with non-positive values. + """ + match = 'center_accuracy must be positive' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_center_accuracy(0.0) + + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_center_accuracy(-0.001) + + def test_validate_center_accuracy_too_large(self): + """ + Test validate_center_accuracy with values too large. + """ + match = r'center_accuracy .* seems unusually large' + with pytest.warns(AstropyUserWarning, match=match): + _EPSFValidator.validate_center_accuracy(1.1) + + def test_validate_maxiters_valid(self): + """ + Test validate_maxiters with valid inputs. + """ + # Test valid values (these should not raise or warn) + _EPSFValidator.validate_maxiters(1) + _EPSFValidator.validate_maxiters(10) + _EPSFValidator.validate_maxiters(100) + + def test_validate_maxiters_invalid_type(self): + """ + Test validate_maxiters with invalid type. + """ + match = 'maxiters must be an integer' + with pytest.raises(TypeError, match=match): + _EPSFValidator.validate_maxiters(10.5) + + with pytest.raises(TypeError, match=match): + _EPSFValidator.validate_maxiters('10') + + def test_validate_maxiters_non_positive(self): + """ + Test validate_maxiters with non-positive values. + """ + match = 'maxiters must be a positive number' + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_maxiters(0) + + with pytest.raises(ValueError, match=match): + _EPSFValidator.validate_maxiters(-5) + + def test_validate_maxiters_too_large(self): + """ + Test validate_maxiters with values too large triggers warning. + """ + match = r'maxiters .* seems unusually large' + with pytest.warns(AstropyUserWarning, match=match): + _EPSFValidator.validate_maxiters(101) + + +class TestCoordinateTransformer: + """ + Tests for the _CoordinateTransformer class. + """ + + @pytest.mark.parametrize('oversampling', [(2, 2), (3, 4), (5, 1)]) + def test_basic(self, oversampling): + """ + Test basic coordinate transformation. + """ + # Create transformer + transformer = _CoordinateTransformer(oversampling=oversampling) + assert np.array_equal(transformer.oversampling, oversampling) + assert transformer.oversampling[0] == oversampling[0] + assert transformer.oversampling[1] == oversampling[1] + + def test_empty_star_shapes(self): + """ + Test compute_epsf_shape with empty star_shapes list. + """ + transformer = _CoordinateTransformer(oversampling=(2, 2)) + match = 'Need at least one star to compute ePSF shape' + with pytest.raises(ValueError, match=match): + transformer.compute_epsf_shape([]) + + def test_oversampled_to_undersampled(self): + """ + Test oversampled_to_undersampled conversion. + """ + transformer = _CoordinateTransformer(oversampling=(4, 2)) + x_under, y_under = transformer.oversampled_to_undersampled(8.0, 16.0) + assert x_under == 4.0 # 8 / 2 + assert y_under == 4.0 # 16 / 4 + + def test_undersampled_to_oversampled(self): + """ + Test undersampled_to_oversampled conversion. + """ + transformer = _CoordinateTransformer(oversampling=(4, 2)) + x_over, y_over = transformer.undersampled_to_oversampled(4.0, 4.0) + assert x_over == 8.0 # 4 * 2 + assert y_over == 16.0 # 4 * 4 + + def test_star_to_epsf_coords(self): + """ + Test star_to_epsf_coords method of _CoordinateTransformer. + """ + transformer = _CoordinateTransformer(oversampling=(2, 2)) + + # Test coordinate transformation + star_x = np.array([0.0, 1.0, 2.0]) + star_y = np.array([0.0, 1.0, 2.0]) + epsf_origin = (10.0, 10.0) + + epsf_x, epsf_y = transformer.star_to_epsf_coords( + star_x, star_y, epsf_origin) + + # Check output shape + assert epsf_x.shape == star_x.shape + assert epsf_y.shape == star_y.shape + + # Check values: with oversampling=2 and origin=(10, 10), + # the formula computes round(oversampling * star_x + origin_x) + expected_x = np.array([10, 12, 14]) + expected_y = np.array([10, 12, 14]) + assert np.array_equal(epsf_x, expected_x) + assert np.array_equal(epsf_y, expected_y) + + def test_compute_epsf_origin(self): + """ + Test compute_epsf_origin method of _CoordinateTransformer. + """ + transformer = _CoordinateTransformer(oversampling=(2, 2)) + + # Test with odd shape + origin = transformer.compute_epsf_origin((11, 11)) + assert origin == (5.0, 5.0) + + # Test with different shape + origin = transformer.compute_epsf_origin((21, 31)) + assert origin == (15.0, 10.0) + + +class TestProgressReporter: + """ + Tests for the _ProgressReporter class. + """ + + def test_progress_reporter(self): + """ + Test basic functionality of _ProgressReporter. + """ + # Test with enabled=True + reporter = _ProgressReporter(enabled=True, maxiters=10) + assert reporter.enabled is True + assert reporter.maxiters == 10 + reporter.setup() + reporter.update() + reporter.write_convergence_message(5) + reporter.close() + + reporter = _ProgressReporter(enabled=False, maxiters=5) + assert reporter.enabled is False + + +class TestEPSFBuildResult: + """ + Tests for the EPSFBuildResult class. + """ + + def test_creation(self): + """ + Test EPSFBuildResult creation. + """ + # Create a simple PSF model for testing + data = np.ones((5, 5)) + psf = ImagePSF(data) + + # Create stars list (can be empty for this test) + stars = [] + + result = EPSFBuildResult( + epsf=psf, + fitted_stars=stars, + iterations=5, + converged=True, + final_center_accuracy=0.01, + n_excluded_stars=0, + excluded_star_indices=[], + ) + assert result.epsf is psf + assert result.fitted_stars == stars + assert result.iterations == 5 + assert result.converged is True + + def test_with_data(self, epsf_test_data): + """ + Test EPSFBuildResult with actual data. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + builder = EPSFBuilder(oversampling=1, maxiters=2, progress_bar=False) + epsf, fitted_stars = builder(stars) + + result = EPSFBuildResult( + epsf=epsf, + fitted_stars=fitted_stars, + iterations=2, + converged=False, + final_center_accuracy=0.1, + n_excluded_stars=0, + excluded_star_indices=[], + ) + assert result.epsf is not None + assert result.epsf.data.shape == (11, 11) + assert result.fitted_stars is not None + assert len(result.fitted_stars) == len(stars) + + def test_getitem_invalid_index(self): + """ + Test EPSFBuildResult.__getitem__ with invalid index. + """ + data = np.ones((5, 5)) + psf = ImagePSF(data) + stars = EPSFStars([]) + + result = EPSFBuildResult( + epsf=psf, + fitted_stars=stars, + iterations=5, + converged=True, + final_center_accuracy=0.01, + n_excluded_stars=0, + excluded_star_indices=[], + ) + + # Valid indices + assert result[0] is psf + assert result[1] is stars + + # Invalid index + match = 'EPSFBuildResult index must be 0' + with pytest.raises(IndexError, match=match): + result[2] + + with pytest.raises(IndexError, match=match): + result[-1] + + def test_iteration(self): + """ + Test EPSFBuildResult iteration (tuple unpacking). + """ + data = np.ones((5, 5)) + psf = ImagePSF(data) + stars = EPSFStars([]) + + result = EPSFBuildResult( + epsf=psf, + fitted_stars=stars, + iterations=5, + converged=True, + final_center_accuracy=0.01, + n_excluded_stars=0, + excluded_star_indices=[], + ) + + # Test tuple unpacking via iteration + epsf_out, stars_out = result + assert epsf_out is psf + assert stars_out is stars + + # Test list conversion + result_list = list(result) + assert len(result_list) == 2 + assert result_list[0] is psf + assert result_list[1] is stars + + def test_attributes(self, epsf_test_data): + """ + Test EPSFBuildResult has all expected attributes. + """ + builder = EPSFBuilder(oversampling=1, maxiters=3, progress_bar=False) + + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + result = builder(stars) + + # Check all attributes exist + assert result.epsf is not None + assert result.fitted_stars is not None + assert isinstance(result.iterations, int) + assert isinstance(result.converged, (bool, np.bool_)) + assert isinstance(result.final_center_accuracy, (float, np.floating)) + assert isinstance(result.n_excluded_stars, int) + assert isinstance(result.excluded_star_indices, list) + + +class TestEPSFFitter: + """ + Tests for the EPSFFitter class. + """ + + def test_fit_stars(self, epsf_fitter_data): + """ + Test EPSFFitter __call__ method. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + fitter = EPSFFitter() + fitted_stars = fitter(epsf, stars) + assert fitted_stars is not None + assert len(fitted_stars) == len(stars) + + def test_empty_stars(self): + """ + Test EPSFFitter with empty stars. + """ + data = np.ones((11, 11)) + epsf = ImagePSF(data) + fitter = EPSFFitter() + + empty_stars = EPSFStars([]) + result = fitter(epsf, empty_stars) + assert len(result) == 0 + + def test_invalid_epsf_type(self): + """ + Test EPSFFitter with invalid epsf type. + """ + data = np.ones((5, 5)) + star = EPSFStar(data, cutout_center=(2, 2)) + stars = EPSFStars([star]) + fitter = EPSFFitter() + + match = 'The input epsf must be an ImagePSF' + with pytest.raises(TypeError, match=match): + fitter('not_an_epsf', stars) + + def test_fit_boxsize_none(self, epsf_fitter_data): + """ + Test EPSFFitter with fit_boxsize=None. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + # Test fitter with fit_boxsize=None (use entire star image) + fitter = EPSFFitter(fit_boxsize=None) + fitted_stars = fitter(epsf, stars) + assert len(fitted_stars) == len(stars) + + def test_invalid_star_type(self, epsf_fitter_data): + """ + Test EPSFFitter with invalid star type. + """ + epsf = epsf_fitter_data['epsf'] + + # Create mock invalid star type + class InvalidStar: + pass + + # Create an EPSFStars-like object with invalid star + invalid_stars = [InvalidStar()] + + fitter = EPSFFitter() + match = 'stars must contain only EPSFStar' + with pytest.raises(TypeError, match=match): + fitter(epsf, invalid_stars) + + def test_fit_info_ierr(self, epsf_fitter_data): + """ + Test EPSFFitter handling of fit_info with ierr. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + # Test fitter - the fit_info handling is automatic + fitter = EPSFFitter() + assert fitter.fitter_has_fit_info is True + + fitted_stars = fitter(epsf, stars) + # Check that fit_error_status is set + for star in fitted_stars.all_stars: + assert hasattr(star, '_fit_error_status') + + def test_fitter_without_fit_info(self, epsf_fitter_data): + """ + Test EPSFFitter with a fitter that doesn't have fit_info. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + # Create a mock fitter without fit_info attribute + class MockFitter: + def __call__( + self, model, x, y, z, weights=None, **kwargs, # noqa: ARG002 + ): + return model + + mock_fitter = MockFitter() + fitter = EPSFFitter(fitter=mock_fitter) + + # Verify that fitter_has_fit_info is False + assert fitter.fitter_has_fit_info is False + + # Fit the stars + fitted_stars = fitter(epsf, stars) + assert len(fitted_stars) == len(stars) + + # Check that _fit_info is None for stars fit without fit_info + for star in fitted_stars.all_stars: + assert star._fit_info is None + + def test_weights_not_supported(self, epsf_fitter_data): + """ + Test EPSFFitter when fitter raises TypeError for weights. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + # Create a fitter that raises TypeError when weights is passed + class NoWeightsFitter: + def __init__(self): + self.fit_info = {'ierr': 1} + + def __call__(self, model, *_args, **kwargs): + if 'weights' in kwargs: + msg = 'weights not supported' + raise TypeError(msg) + return model + + no_weights_fitter = NoWeightsFitter() + fitter = EPSFFitter(fitter=no_weights_fitter) + + # Fit the stars - should handle TypeError gracefully + fitted_stars = fitter(epsf, stars) + assert len(fitted_stars) == len(stars) + + def test_invalid_ierr(self, epsf_fitter_data): + """ + Test EPSFFitter when fitter returns invalid ierr value. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + # Create a fitter that returns invalid ierr (not in [1, 2, 3, 4]) + class BadIerrFitter: + def __init__(self): + self.fit_info = {'ierr': 0} # Invalid ierr value + + def __call__(self, model, *_args, **_kwargs): + return model + + bad_ierr_fitter = BadIerrFitter() + fitter = EPSFFitter(fitter=bad_ierr_fitter) + + # Fit the stars - should set fit_error_status = 2 + fitted_stars = fitter(epsf, stars) + assert len(fitted_stars) == len(stars) + + # Check that fit_error_status is set to 2 for all stars + for star in fitted_stars.all_stars: + assert star._fit_error_status == 2 + + def test_removes_fitter_kwargs(self): + """ + Test that EPSFFitter removes reserved kwargs. + """ + # Pass kwargs that should be removed + fitter = EPSFFitter(x=1, y=2, z=3, weights=4, calc_uncertainties=False) + + # These should be removed from fitter_kwargs + assert 'x' not in fitter.fitter_kwargs + assert 'y' not in fitter.fitter_kwargs + assert 'z' not in fitter.fitter_kwargs + assert 'weights' not in fitter.fitter_kwargs + # Other kwargs should be preserved + assert fitter.fitter_kwargs.get('calc_uncertainties') is False + + def test_with_linked_star_mock_wcs(self, epsf_fitter_data): + """ + Test EPSFFitter with LinkedEPSFStar using mock WCS. + """ + stars = epsf_fitter_data['stars'] + epsf = epsf_fitter_data['epsf'] + + # Create mock WCS that returns identity transform + class MockWCS: + def pixel_to_world_values(self, x, y): + return x, y + + def world_to_pixel_values(self, ra, dec): + return ra, dec + + mock_wcs = MockWCS() + + # Create EPSFStar objects with mock WCS + linked_stars_list = [] + for i in range(2): + star_data = stars.all_stars[i].data.copy() + center = stars.all_stars[i].cutout_center + # Use origin that places star in a reasonable position + origin = (0, 0) + star = EPSFStar(star_data, cutout_center=center, + origin=origin, wcs_large=mock_wcs) + linked_stars_list.append(star) + + # Create LinkedEPSFStar + linked_star = LinkedEPSFStar(linked_stars_list) + + # Create EPSFStars with the LinkedEPSFStar + stars_with_linked = EPSFStars([linked_star]) + + # Fit the linked stars + fitter = EPSFFitter() + fitted_stars = fitter(epsf, stars_with_linked) + + assert len(fitted_stars) == 1 + # fitted_stars is an EPSFStars; the first item wraps LinkedEPSFStar + assert len(fitted_stars.all_stars) == 2 # 2 stars in the linked star + + +class TestEPSFBuilder: + """ + Tests for the EPSFBuilder class. + """ + + @pytest.mark.parametrize('extract_shape', [(25, 25), (19, 25), (25, 19)]) + def test_build(self, epsf_test_data, extract_shape): + """ + Test EPSFBuilder build process on a simulated image. + """ + oversampling = 2 + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:10], + size=extract_shape) + epsf_builder = EPSFBuilder(oversampling=oversampling, maxiters=5, + progress_bar=False, + recentering_maxiters=5) + epsf, fitted_stars = epsf_builder(stars) + + # Verify EPSF properties with default settings + assert isinstance(epsf, ImagePSF) + assert epsf.x_0 == 0.0 + assert epsf.y_0 == 0.0 + assert epsf.flux == 1.0 + + # Shape is star_shape * oversampling, then ensure odd + ref_size = np.array(extract_shape) * oversampling + ref_size = np.where(ref_size % 2 == 0, ref_size + 1, ref_size) + assert epsf.data.shape == tuple(ref_size) + + # Verify basic EPSF properties + assert len(fitted_stars) == 10 + # ePSF should sum to ~oversamp^2 for properly normalized + # oversampled PSF + expected_sum = oversampling ** 2 + assert 0.9 * expected_sum < epsf.data.sum() < 1.1 * expected_sum + assert epsf.data.max() > 0.01 # Should have a peak + + # Check that the center region has higher values than edges + center_y, center_x = np.array(ref_size) // 2 + center_val = epsf.data[center_y, center_x] + edge_val = epsf.data[0, 0] + assert center_val > edge_val # Center should be brighter than edge + + # Test that residual computation works (basic functionality test) + resid_star = fitted_stars[0].compute_residual_image(epsf) + assert isinstance(resid_star, np.ndarray) + assert resid_star.shape == fitted_stars[0].data.shape + + def test_invalid_inputs(self): + """ + Test EPSFBuilder with various invalid inputs. + """ + match = "'oversampling' must be specified" + with pytest.raises(ValueError, match=match): + EPSFBuilder(oversampling=None) + + match = 'oversampling must be > 0' + with pytest.raises(ValueError, match=match): + EPSFBuilder(oversampling=-1) + + match = 'maxiters must be a positive number' + with pytest.raises(ValueError, match=match): + EPSFBuilder(maxiters=-1) + + match = 'oversampling must be > 0' + with pytest.raises(ValueError, match=match): + EPSFBuilder(oversampling=[-1, 4]) + + for sigma_clip in [None, [], 'a']: + match = 'sigma_clip must be an astropy.stats.SigmaClip instance' + with pytest.raises(TypeError, match=match): + EPSFBuilder(sigma_clip=sigma_clip) + + def test_fitter_options(self): + """ + Test EPSFBuilder with different EPSFFitter options. + """ + # Test with default EPSFFitter + builder1 = EPSFBuilder(maxiters=3) + assert isinstance(builder1.fitter, EPSFFitter) + # Default fit_boxsize is 5 + np.testing.assert_array_equal(builder1.fitter.fit_boxsize, (5, 5)) + + # Test with EPSFFitter instance + epsf_fitter = EPSFFitter() + builder2 = EPSFBuilder(fitter=epsf_fitter, maxiters=3) + assert isinstance(builder2.fitter, EPSFFitter) + + # Test with custom fit_boxsize + epsf_fitter = EPSFFitter(fit_boxsize=7) + builder2 = EPSFBuilder(fitter=epsf_fitter, maxiters=3) + np.testing.assert_array_equal(builder2.fitter.fit_boxsize, (7, 7)) + + # Test with tuple fit_boxsize + epsf_fitter = EPSFFitter(fit_boxsize=(5, 7)) + builder3 = EPSFBuilder(fitter=epsf_fitter, maxiters=3) + np.testing.assert_array_equal(builder3.fitter.fit_boxsize, (5, 7)) + + # Test with None fit_boxsize + epsf_fitter = EPSFFitter(fit_boxsize=None) + builder4 = EPSFBuilder(fitter=epsf_fitter, maxiters=3) + assert builder4.fitter.fit_boxsize is None + + # Test with invalid fitter type (should fail) + with pytest.raises(TypeError, + match='fitter must be an EPSFFitter instance'): + EPSFBuilder(fitter='invalid_fitter', maxiters=3) + + # Test with astropy fitter directly (should fail) + with pytest.raises(TypeError, + match='fitter must be an EPSFFitter instance'): + EPSFBuilder(fitter=TRFLSQFitter(), maxiters=3) + + def test_fitting_bounds(self, epsf_test_data): + """ + Test EPSFBuilder with EPSFFitter that has fit_boxsize larger + than star cutouts. + """ + size = 25 + oversampling = 4 + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'], + size=size) + + # Create EPSFFitter with fit_boxsize larger than cutout + epsf_fitter = EPSFFitter(fit_boxsize=31) + + epsf_builder = EPSFBuilder(oversampling=oversampling, maxiters=8, + progress_bar=True, + recentering_maxiters=5, + fitter=epsf_fitter, + smoothing_kernel='quadratic') + + # With a boxsize larger than the cutout we expect the fitting to + # fail for all stars. The ValueError is raised before any star + # can be excluded (exclusion only happens after iter > 3). + match = 'The ePSF fitting failed for all stars' + with pytest.raises(ValueError, match=match): + epsf_builder(stars) + + @pytest.mark.parametrize(('oversamp', 'star_size', 'expected_shape'), [ + # oversampling=1: shape should be odd (add 1 to even product) + (1, 25, (25, 25)), # 25*1 = 25 (odd) -> 25 + (1, 24, (25, 25)), # 24*1 = 24 (even) -> 25 + (1, 26, (27, 27)), # 26*1 = 26 (even) -> 27 + # oversampling=2: product is even, add 1 + (2, 25, (51, 51)), # 25*2 = 50 (even) -> 51 + (2, 24, (49, 49)), # 24*2 = 48 (even) -> 49 + # oversampling=3: product is odd for odd star size + (3, 25, (75, 75)), # 25*3 = 75 (odd) -> 75 + (3, 24, (73, 73)), # 24*3 = 72 (even) -> 73 + # oversampling=4: product is even, add 1 + (4, 25, (101, 101)), # 25*4 = 100 (even) -> 101 + (4, 24, (97, 97)), # 24*4 = 96 (even) -> 97 + # oversampling=5: product is odd for odd star size + (5, 25, (125, 125)), # 25*5 = 125 (odd) -> 125 + (5, 24, (121, 121)), # 24*5 = 120 (even) -> 121 + ]) + def test_shape_calculation(self, oversamp, star_size, expected_shape): + """ + Test that the ePSF shape is correctly calculated for various + oversampling factors. + + The ePSF shape should be: + - star_size * oversampling for each dimension + - Then ensure odd dimensions (add 1 if even) + """ + # Test the shape calculation directly via _CoordinateTransformer + transformer = _CoordinateTransformer(oversampling=(oversamp, oversamp)) + star_shapes = [(star_size, star_size)] + computed_shape = transformer.compute_epsf_shape(star_shapes) + + assert computed_shape == expected_shape, ( + f'For oversamp={oversamp}, star_size={star_size}: ' + f'expected {expected_shape}, got {computed_shape}' + ) + + @pytest.mark.parametrize('kernel_type', ['quadratic', 'quartic', + 'custom']) + def test_smoothing_kernel(self, epsf_test_data, kernel_type): + """ + Test EPSFBuilder with smoothing kernel. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:3], size=11) + + if kernel_type == 'custom': + kernel = np.ones((3, 3)) / 9.0 + else: + kernel = kernel_type + + builder = EPSFBuilder( + smoothing_kernel=kernel, + maxiters=1, + progress_bar=False, + ) + + epsf, _ = builder(stars) + assert epsf is not None + assert epsf.data.shape == (45, 45) + + @pytest.mark.parametrize('centering_func', [centroid_com, + centroid_1dg, + centroid_2dg, + centroid_quadratic, + ]) + def test_recentering(self, epsf_test_data, centering_func): + """ + Test EPSFBuilder with different recentering function. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:4], size=11) + + # Setting oversampling=1 is required for centroid_quadratic to + # work because its default fit_boxsize=5 is in native pixels and + # we cannot adjust it here + builder = EPSFBuilder( + oversampling=1, + recentering_func=centering_func, + maxiters=5, + progress_bar=False, + ) + + epsf, _ = builder(stars) + assert epsf is not None + assert epsf.data.shape == (11, 11) + + @pytest.mark.parametrize('shape', [(25, 25), (19, 25), (25, 19)]) + def test_shape_parameters(self, epsf_test_data, shape): + """ + Test EPSFBuilder with explicit shape parameters. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:3], size=11) + + # Test with explicit shape + builder = EPSFBuilder( + shape=shape, + oversampling=1, + maxiters=1, + progress_bar=False, + ) + + epsf, _ = builder(stars) + assert epsf is not None + assert epsf.data.shape == shape + + def test_check_convergence_no_good_stars(self): + """ + Test EPSFBuilder._check_convergence with no good stars. + """ + builder = EPSFBuilder(maxiters=1, progress_bar=False) + + # Create stars and mark all as fit_failed + data = np.ones((5, 5)) + star = EPSFStar(data, cutout_center=(2, 2)) + stars = EPSFStars([star]) + + centers = np.array([[2.0, 2.0]]) + fit_failed = np.array([True]) # All stars failed + + converged, center_dist_sq, _ = builder._check_convergence( + stars, centers, fit_failed) + + # Should return False (not converged) when no good stars + assert converged is False + # center_dist_sq should be high to prevent false convergence + assert center_dist_sq[0] > builder.center_accuracy_sq + + def test_resample_residuals_no_good_stars(self, epsf_test_data): + """ + Test EPSFBuilder._resample_residuals with no good stars. + """ + builder = EPSFBuilder(maxiters=1, progress_bar=False) + + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:2], size=11) + + # Create an initial ePSF + epsf = builder._create_initial_epsf(stars) + + # Mark all stars as excluded + for star in stars.all_stars: + star._excluded_from_fit = True + + # Now resample residuals should handle no good stars + result = builder._resample_residuals(stars, epsf) + assert result.shape[0] == 0 # No good stars + + def test_resample_residual_output(self, epsf_test_data): + """ + Test EPSFBuilder._resample_residual creates output image if None + is passed. + """ + builder = EPSFBuilder(maxiters=1, progress_bar=False) + + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:2], size=11) + + # Create an initial ePSF + epsf = builder._create_initial_epsf(stars) + + # Call _resample_residual without out_image (should create one) + star = stars.all_stars[0] + result = builder._resample_residual(star, epsf, out_image=None) + + assert result is not None + assert result.shape == epsf.data.shape + + def test_build_step_with_epsf(self, epsf_test_data): + """ + Test EPSFBuilder._build_epsf_step with existing ePSF. + """ + builder = EPSFBuilder(maxiters=1, progress_bar=False) + + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + # Create an initial ePSF + epsf = builder._create_initial_epsf(stars) + + # Now build with existing ePSF + improved_epsf = builder._build_epsf_step(stars, epsf=epsf) + + assert improved_epsf is not None + assert improved_epsf.data.shape == epsf.data.shape + + def test_star_exclusion(self, epsf_test_data): + """ + Test that stars are excluded after repeated fit failures. + + Here, we modify the first star's position such that star is + centered near the corner of extracted cutout image. This will + cause the fitting to fail for that star because its fitting + region extends beyond the cutout boundaries, and it should be + excluded from subsequent iterations. + """ + tbl = epsf_test_data['init_stars'][:5].copy() + tbl['x'][0] = 465 + tbl['y'][0] = 30 + stars = extract_stars(epsf_test_data['nddata'], tbl, size=11) + + builder = EPSFBuilder(oversampling=1, maxiters=5, progress_bar=False) + match = ('has been excluded from ePSF fitting because its fitting ' + 'region extends') + with pytest.warns(AstropyUserWarning, match=match): + result = builder(stars) + + assert result.n_excluded_stars == 1 + assert result.excluded_star_indices == [0] + assert result.epsf is not None + assert result.epsf.data.shape == (11, 11) + assert result.fitted_stars.n_good_stars == 4 + assert result.fitted_stars.n_all_stars == 5 + + def test_star_exclusion_single_warning(self, epsf_test_data): + """ + Test that only a single warning is emitted per excluded star. + + When a star repeatedly fails fitting across iterations, the + warning should only be emitted when the star is actually + excluded (after more than 3 iterations of failure). + """ + tbl = epsf_test_data['init_stars'][:5].copy() + tbl['x'][0] = 465 + tbl['y'][0] = 30 + stars = extract_stars(epsf_test_data['nddata'], tbl, size=11) + + builder = EPSFBuilder(oversampling=1, maxiters=6, progress_bar=False) + + # Capture all warnings + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter('always') + builder(stars) + + # Filter for the specific warning about exclusion + fit_warnings = [w for w in warning_list + if 'has been excluded from ePSF fitting' in + str(w.message)] + + # Should only have 1 warning despite multiple iterations + assert len(fit_warnings) == 1 + + def test_excluded_star_no_copy(self, epsf_test_data): + """ + Test that excluded stars are returned without copying. + + When a star is excluded from fitting, the fitter should return + the same star object directly, not a copy. This is more + efficient than creating unnecessary copies. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:3], size=11) + + # Mark one star as excluded + original_star = stars.all_stars[0] + original_star._excluded_from_fit = True + + # Create an ePSF for fitting + builder = EPSFBuilder(oversampling=1, maxiters=1, progress_bar=False) + epsf = builder._create_initial_epsf(stars) + + # Fit the stars + fitter = EPSFFitter() + fitted_stars = fitter(epsf, stars) + + # The excluded star should be the exact same object (identity) + assert fitted_stars.all_stars[0] is original_star + + def test_process_iteration_with_fit_failures(self, epsf_test_data): + """ + Test _process_iteration marks stars excluded after iter > 3. + + This test covers both types of fit failures: + 1. Fitting region extends beyond cutout (status=1) + 2. Fit did not converge due to invalid ierr (status=2) + """ + # Create stars with one positioned near corner to cause overlap + # error + tbl = epsf_test_data['init_stars'][:5].copy() + tbl['x'][0] = 465 # Position near corner to cause overlap error + tbl['y'][0] = 30 + stars = extract_stars(epsf_test_data['nddata'], tbl, size=11) + + # Build initial ePSF. This will fit the stars and move their + # centers. Star 0 will have its center moved near the edge of + # the cutout, which will cause overlap errors in subsequent + # iterations. + builder_init = EPSFBuilder(oversampling=1, maxiters=2, + progress_bar=False) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + epsf, fitted_stars = builder_init(stars) + + # Create a fitter that returns invalid ierr for the first call + # only. + # Star 0 will fail due to overlap error (status=1) before the + # fitter is called (because its center moved near edge). + # Star 1 is the first to reach the fitter, and will fail with + # invalid ierr (status=2). + # Subsequent stars will get valid ierr. + class FirstCallFailingFitter: + def __init__(self): + self.call_count = 0 + self.fit_info = {'ierr': 1} # Valid by default + + def __call__(self, model, *_args, **_kwargs): + self.call_count += 1 + # Fail only on the first fitter call (which is star 1, + # since star 0 fails with overlap error before reaching + # fitter) + if self.call_count == 1: + self.fit_info = {'ierr': 0} # Invalid ierr + else: + self.fit_info = {'ierr': 1} # Valid ierr + return model + + failing_fitter = FirstCallFailingFitter() + epsf_fitter = EPSFFitter(fitter=failing_fitter) + builder = EPSFBuilder(oversampling=1, maxiters=1, progress_bar=False, + fitter=epsf_fitter) + + # Process iteration with iter_num > 3 to trigger exclusion. Use + # fitted_stars (which has moved centers) to trigger overlap error. + # Capture warnings to verify both types are emitted. + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter('always') + _, stars_new, fit_failed = builder._process_iteration( + fitted_stars, epsf, iter_num=4) + + # Check that stars 0 and 1 failed + assert fit_failed[0] # Overlap error (status=1) + assert fit_failed[1] # Invalid ierr (status=2) + + # Verify both stars are marked for exclusion + assert stars_new.all_stars[0]._excluded_from_fit + assert stars_new.all_stars[1]._excluded_from_fit + + # Verify correct error status for each failure type + assert stars_new.all_stars[0]._fit_error_status == 1 # Overlap error + assert stars_new.all_stars[1]._fit_error_status == 2 # Fit failure + + # Verify both warning types were emitted + warning_messages = [str(w.message) for w in warning_list] + overlap_warnings = [m for m in warning_messages + if 'fitting region extends beyond' in m] + converge_warnings = [m for m in warning_messages + if 'fit did not converge' in m] + assert len(overlap_warnings) == 1 + assert len(converge_warnings) == 1 + + def test_star_exclusion_fit_failure(self, epsf_test_data): + """ + Test that stars are excluded with appropriate message when fit + does not converge (ierr error). + + This tests exclusion due to fit failure (status=2), as opposed + to the fitting region extending beyond the cutout (status=1). + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + n_stars = len(stars.all_stars) + + # Create a fitter that fails for the first star (invalid ierr) + # but succeeds for others. Add small offsets to x_0/y_0 to + # prevent early convergence, ensuring we reach iteration > 3. + class PartialFailingFitter: + def __init__(self): + self.call_count = 0 + self.fit_info = {'ierr': 1} # Valid by default + + def __call__(self, model, *_args, **_kwargs): + self.call_count += 1 + star_idx = (self.call_count - 1) % n_stars + # Fail only the first star + if star_idx == 0: + self.fit_info = {'ierr': 0} # Invalid ierr + else: + self.fit_info = {'ierr': 1} # Valid ierr + # Add small offset to prevent early convergence + model.x_0 = model.x_0 + 0.01 + model.y_0 = model.y_0 + 0.01 + return model + + failing_fitter = PartialFailingFitter() + epsf_fitter = EPSFFitter(fitter=failing_fitter) + + # Use maxiters=5 so we reach iter > 3 to trigger exclusion + builder = EPSFBuilder(oversampling=1, maxiters=5, progress_bar=False, + fitter=epsf_fitter) + + # Should warn about fit not converging + match = ('has been excluded from ePSF fitting because the fit did ' + 'not converge') + with pytest.warns(AstropyUserWarning, match=match): + result = builder(stars) + + # At least the first star (with ierr=0) should be excluded + assert result.n_excluded_stars >= 1 + assert 0 in result.excluded_star_indices + # Check that the first star has fit_error_status=2 (fit failure) + assert result.fitted_stars.all_stars[0]._fit_error_status == 2 + assert result.fitted_stars.all_stars[0]._excluded_from_fit + + def test_build_tracks_excluded_indices(self, epsf_test_data): + """ + Test that _build_epsf properly tracks excluded star indices. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:10], size=11) + + # Create a fitter that: + # 1. Adds noise to centers to prevent convergence + # 2. Fails first 3 stars on iteration 4+ + n_stars = len(stars.all_stars) + + class NoConvergeFitter: + def __init__(self): + self.star_count = 0 + self.fit_info = {'ierr': 1} + + def __call__(self, model, *_args, **_kwargs): + self.star_count += 1 + iteration = self.star_count // n_stars + 1 + star_idx = (self.star_count - 1) % n_stars + + # On iteration 4+, fail first 3 stars + if iteration > 4 and star_idx < 3: + self.fit_info = {'ierr': 0} # Invalid + else: + self.fit_info = {'ierr': 1} # Valid + + # Add slight offset to x_0 to prevent convergence + model.x_0 = model.x_0 + 0.01 * (iteration % 2) + return model + + fitter_obj = NoConvergeFitter() + epsf_fitter = EPSFFitter(fitter=fitter_obj) + builder = EPSFBuilder(oversampling=1, maxiters=7, progress_bar=False, + fitter=epsf_fitter, center_accuracy=1e-6) + + # Build - this should trigger exclusion tracking + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + result = builder(stars) + + # Check that excluded_star_indices was populated + assert hasattr(result, 'excluded_star_indices') + assert isinstance(result.excluded_star_indices, list) + # We may or may not have excluded stars depending on exact timing + assert result.n_excluded_stars >= 0 + + def test_build_step_origin_is_none_branch(self, epsf_test_data): + """ + Test _build_epsf_step else branch when origin is None. + """ + builder = EPSFBuilder(maxiters=1, progress_bar=False) + + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:3], size=11) + + # Create ePSF and verify origin condition + epsf = builder._create_initial_epsf(stars) + + # Verify the branch condition logic + has_valid_origin = hasattr(epsf, 'origin') and epsf.origin is not None + assert has_valid_origin # Normal case, origin exists + + # The else branch is only reached when origin is None + # This line calculates origin from shape + expected_origin_y = (epsf.data.shape[0] - 1) / 2.0 + expected_origin_x = (epsf.data.shape[1] - 1) / 2.0 + np.testing.assert_allclose(epsf.origin, + (expected_origin_x, expected_origin_y)) + + @pytest.mark.skipif(not HAS_TQDM, reason='tqdm is required') + def test_with_progress_bar(self, epsf_test_data): + """ + Test EPSFBuilder with progress_bar=True. + """ + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + # Build with progress bar enabled and high center_accuracy to + # prevent convergence + builder = EPSFBuilder(oversampling=1, maxiters=3, progress_bar=True, + center_accuracy=1e-10) + result = builder(stars) + assert result.epsf is not None + assert result.epsf.data.shape == (11, 11) + + def test_recenter_shift_increase(self, epsf_test_data): + """ + Test early exit in _recenter_epsf when shift increases. + + Uses mock to force the centroid function to return values + that cause shift to increase on second iteration. + """ + builder = EPSFBuilder(oversampling=1, maxiters=2, progress_bar=False, + recentering_maxiters=10) + + stars = extract_stars(epsf_test_data['nddata'], + epsf_test_data['init_stars'][:5], size=11) + + epsf, _ = builder(stars) + + # Create a mock centroid function that returns oscillating values + # First call: shift by 0.5 pixels + # Second call: shift back by 1.0 (larger shift, triggers break) + call_count = [0] + center = np.array(epsf.data.shape) / 2.0 + + def mock_centroid(data, mask=None): # noqa: ARG001 + call_count[0] += 1 + if call_count[0] == 1: + # First iteration: small shift + return (center[1] + 0.5, center[0] + 0.5) + # Second iteration: shift back (larger distance) + return (center[1] - 0.5, center[0] - 0.5) + + with patch.object(builder, 'recentering_func', mock_centroid): + recentered = builder._recenter_epsf(epsf) + + assert recentered is not None + assert recentered.shape == epsf.data.shape + # The mock should have been called at least twice + assert call_count[0] >= 2 + + def test_very_small_sources(self): + """ + Test EPSFBuilder with very small sources that may cause + numerical issues. + """ + fwhm = 1.5 + psf_model = CircularGaussianPRF(flux=1, fwhm=fwhm) + + shape = (50, 50) + sources = Table() + sources['x_0'] = [25] + sources['y_0'] = [25] + sources['fwhm'] = [fwhm] + + data = make_model_image(shape, psf_model, sources) + nddata = NDData(data=data) + + stars_tbl = Table() + stars_tbl['x'] = sources['x_0'] + stars_tbl['y'] = sources['y_0'] + stars = extract_stars(nddata, stars_tbl, size=11) + + # Should handle numerical edge cases gracefully + builder = EPSFBuilder(oversampling=1, maxiters=5, progress_bar=False) + + epsf, _ = builder(stars) + assert epsf is not None + assert epsf.data.shape == (11, 11) + + @pytest.mark.parametrize('oversamp', [1, 2, 3, 4, 5]) + def test_build_oversampling(self, oversamp): + """ + Test that the ePSF built with oversampling has the expected + shape and properties. + + Sources are placed on a regular grid with exact subpixel offsets + to ensure that the ePSF is properly sampled. The test checks + that the resulting ePSF has the expected shape, that it sums to + the expected value for an oversampled PSF, and that its shape + matches the input PSF model when scaled by the sum of the ePSF + data. + """ + offsets = (np.arange(oversamp) * 1.0 / oversamp - 0.5 + 1.0 + / (2.0 * oversamp)) + xydithers = np.array(list(itertools.product(offsets, offsets))) + xdithers = np.transpose(xydithers)[0] + ydithers = np.transpose(xydithers)[1] + + nstars = oversamp**2 + fwhm = 7.0 + sources = Table() + offset = 50 + size = oversamp * offset + offset + y, x = np.mgrid[0:oversamp, 0:oversamp] * offset + offset + sources['x_0'] = x.ravel() + xdithers + sources['y_0'] = y.ravel() + ydithers + sources['fwhm'] = np.full((nstars,), fwhm) + + psf_model = CircularGaussianPRF(fwhm=fwhm) + shape = (size, size) + data = make_model_image(shape, psf_model, sources) + nddata = NDData(data=data) + stars_tbl = Table() + stars_tbl['x'] = sources['x_0'] + stars_tbl['y'] = sources['y_0'] + star_size = 25 + stars = extract_stars(nddata, stars_tbl, size=star_size) + + epsf_builder = EPSFBuilder(oversampling=oversamp, maxiters=15, + progress_bar=False, recentering_maxiters=20) + epsf, results = epsf_builder(stars) + + # Verify EPSF properties with default settings + assert isinstance(epsf, ImagePSF) + assert epsf.x_0 == 0.0 + assert epsf.y_0 == 0.0 + assert epsf.flux == 1.0 + + # Check expected shape of ePSF data + # The shape should be star_size * oversamp, then ensure odd + # dimensions by adding 1 if even. + expected_dim = star_size * oversamp + if expected_dim % 2 == 0: + expected_dim += 1 + expected_shape = (expected_dim, expected_dim) + assert epsf.data.shape == expected_shape + + # Check expected sum of ePSF data. + # For an oversampled PSF, the sum of the array values should + # equal the product of the oversampling factors (oversamp^2 for + # symmetric oversampling). + expected_sum = oversamp**2 + assert_allclose(epsf.data.sum(), expected_sum, rtol=0.02) + + # Check that the shape of the ePSF matches the input PSF model + # when scaled by the sum of the ePSF data. The input PSF model + # is a circular Gaussian with the specified FWHM, and the ePSF + # should approximate this shape when scaled by the total flux. + + # Calculate the expected PSF shape based on the input model and + # the oversampling factor. The FWHM should be scaled by the + # oversampling factor to match the ePSF sampling. + size = epsf.data.shape[0] + cen = (size - 1) / 2 + fwhm2 = oversamp * fwhm + model = CircularGaussianPRF(flux=1, x_0=cen, y_0=cen, fwhm=fwhm2) + yy, xx = np.mgrid[0:size, 0:size] + psf = model(xx, yy) * oversamp**2 + assert_allclose(epsf.data, psf, atol=2e-4) + + # Check that the fitted centers are close to the true source + # positions + assert_allclose(results.center_flat[:, 0], sources['x_0'], atol=0.005) + assert_allclose(results.center_flat[:, 1], sources['y_0'], atol=0.005) diff --git a/photutils/psf/tests/test_epsf_stars.py b/photutils/psf/tests/test_epsf_stars.py index 85d407838..e8735099b 100644 --- a/photutils/psf/tests/test_epsf_stars.py +++ b/photutils/psf/tests/test_epsf_stars.py @@ -3,36 +3,1149 @@ Tests for the epsf_stars module. """ +import warnings +from multiprocessing.reduction import ForkingPickler + import numpy as np import pytest -from astropy.modeling.models import Moffat2D -from astropy.nddata import NDData +from astropy.coordinates import SkyCoord +from astropy.nddata import (InverseVariance, NDData, StdDevUncertainty, + VarianceUncertainty) from astropy.table import Table -from numpy.testing import assert_allclose +from astropy.utils.exceptions import AstropyUserWarning +from astropy.wcs import WCS +from numpy.testing import assert_allclose, assert_array_equal -from photutils.psf.epsf_stars import EPSFStars, extract_stars +from photutils.psf import make_psf_model_image +from photutils.psf.epsf_stars import (EPSFStar, EPSFStars, LinkedEPSFStar, + _compute_mean_sky_coordinate, + _create_weights_cutout, + _prepare_uncertainty_info, extract_stars) from photutils.psf.functional_models import CircularGaussianPRF from photutils.psf.image_models import ImagePSF -class TestExtractStars: - def setup_class(self): - stars_tbl = Table() - stars_tbl['x'] = [15, 15, 35, 35] - stars_tbl['y'] = [15, 35, 40, 10] - self.stars_tbl = stars_tbl +@pytest.fixture +def epsf_test_data(): + """ + Create a simulated image for testing. + """ + fwhm = 2.7 + psf_model = CircularGaussianPRF(flux=1, fwhm=fwhm) + model_shape = (9, 9) + n_sources = 100 + shape = (750, 750) + data, true_params = make_psf_model_image(shape, psf_model, n_sources, + model_shape=model_shape, + flux=(500, 700), + min_separation=25, + border_size=25, seed=0) + + nddata = NDData(data) + init_stars = Table() + init_stars['x'] = true_params['x_0'].astype(int) + init_stars['y'] = true_params['y_0'].astype(int) + + return { + 'fwhm': fwhm, + 'data': data, + 'nddata': nddata, + 'init_stars': init_stars, + } + + +@pytest.fixture +def simple_wcs(): + """ + Create a simple WCS for testing. + """ + wcs = WCS(naxis=2) + wcs.wcs.crpix = [25, 25] + wcs.wcs.crval = [0, 0] + wcs.wcs.cdelt = [1, 1] + wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN'] + return wcs + + +@pytest.fixture +def simple_data(): + """ + Create simple 50x50 array of ones for testing. + """ + return np.ones((50, 50)) + + +@pytest.fixture +def simple_nddata(simple_data): + """ + Create simple NDData object for testing. + """ + return NDData(simple_data) + + +@pytest.fixture +def simple_table(): + """ + Create simple table with single star at center. + """ + return Table({'x': [25], 'y': [25]}) + + +@pytest.fixture +def stars_table(): + """ + Create table with multiple star positions. + """ + table = Table() + table['x'] = [15, 15, 35, 35] + table['y'] = [15, 35, 40, 10] + return table + + +@pytest.fixture +def stars_data(stars_table): + """ + Create image data with stars using CircularGaussianPRF model. + """ + yy, xx = np.mgrid[0:51, 0:55] + data = np.zeros(xx.shape) + model = CircularGaussianPRF(fwhm=3.5) + for xi, yi in zip(stars_table['x'], stars_table['y'], strict=True): + data += model.evaluate(xx, yy, 100, xi, yi, 3.5) + return data + + +@pytest.fixture +def stars_nddata(stars_data): + """ + Create NDData object with star data. + """ + return NDData(data=stars_data) + + +def test_compute_mean_sky_coordinate(): + """ + Test spherical coordinate averaging. + """ + delta = 0.5 / 3600.0 # 0.5 arcsec in degrees + ra = 10.0 + dec = 30.0 + coords = np.array([ + [ra - delta, dec - delta], + [ra + delta, dec - delta], + [ra - delta, dec + delta], + [ra + delta, dec + delta], + ]) + mean_lon, mean_lat = _compute_mean_sky_coordinate(coords) + assert_allclose(mean_lon, ra) + assert_allclose(mean_lat, dec) + + +def test_compute_mean_sky_coordinate_edge_cases(): + """ + Test mean sky coordinate calculation edge cases. + """ + # Test coordinates near poles + coords = np.array([ + [0.0, 89.0], + [90.0, 89.0], + [180.0, 89.0], + [270.0, 89.0], + ]) + # Mean latitude should be close to 89 - relax tolerance for edge case + _, mean_lat = _compute_mean_sky_coordinate(coords) + assert abs(mean_lat - 89.0) < 1.1 + + # Test with single coordinate + single_coord = np.array([[45.0, 30.0]]) + mean_lon, mean_lat = _compute_mean_sky_coordinate(single_coord) + assert abs(mean_lon - 45.0) < 1e-10 + assert abs(mean_lat - 30.0) < 1e-10 + + +def test_prepare_uncertainty_info(): + """ + Test uncertainty info preparation. + """ + # Test with no uncertainty + data = NDData(np.ones((5, 5))) + info = _prepare_uncertainty_info(data) + assert info['type'] == 'none' + + # Test with weight-like uncertainty by creating custom uncertainty + class WeightsUncertainty(StdDevUncertainty): + @property + def uncertainty_type(self): + return 'weights' + + weights = np.ones((5, 5)) * 2 + data.uncertainty = WeightsUncertainty(weights) + + info = _prepare_uncertainty_info(data) + assert info['type'] == 'weights' + assert_array_equal(info['array'], weights) + + +def test_prepare_uncertainty_info_variants(): + """ + Test uncertainty preparation for different uncertainty types. + """ + # Test standard deviation uncertainty + data = NDData(np.ones((5, 5))) + data.uncertainty = StdDevUncertainty(np.ones((5, 5)) * 0.1) + + info = _prepare_uncertainty_info(data) + assert info['type'] == 'uncertainty' + assert 'uncertainty' in info + + +def test_create_weights_cutout(): + """ + Test weights cutout creation. + """ + # Test with no uncertainty + info = {'type': 'none'} + slices = (slice(1, 4), slice(1, 4)) # 3x3 cutout + mask = None + + weights, has_nonfinite = _create_weights_cutout(info, mask, slices) + assert weights.shape == (3, 3) + assert_array_equal(weights, np.ones((3, 3))) + assert not has_nonfinite + + # Test with mask + full_mask = np.zeros((5, 5), dtype=bool) + full_mask[2, 2] = True # Mask center of cutout + + weights, has_nonfinite = _create_weights_cutout(info, full_mask, slices) + assert weights[1, 1] == 0.0 # Should be masked + assert not has_nonfinite + + +def test_create_weights_cutout_with_uncertainty(): + """ + Test weights cutout creation with uncertainty. + """ + # Create uncertainty info + uncertainty = StdDevUncertainty(np.ones((5, 5)) * 0.1) + info = { + 'type': 'uncertainty', + 'uncertainty': uncertainty, + } + + slices = (slice(1, 4), slice(1, 4)) + mask = None + + weights, has_nonfinite = _create_weights_cutout(info, mask, slices) + assert weights.shape == (3, 3) + # Should be inverse of uncertainty values (1/0.1 = 10) + assert_allclose(weights, np.ones((3, 3)) * 10) + assert not has_nonfinite + + +def test_create_weights_cutout_non_finite_warning(): + """ + Test detection of non-finite weights. + """ + # Create weights with non-finite values + bad_weights = np.ones((5, 5)) + bad_weights[2, 2] = np.inf + + info = { + 'type': 'weights', + 'array': bad_weights, + } + + slices = (slice(1, 4), slice(1, 4)) + mask = None + + # Function should return has_nonfinite=True (warning is now + # emitted by caller) + weights, has_nonfinite = _create_weights_cutout(info, mask, slices) + assert has_nonfinite + # Non-finite value should be set to zero + assert weights[1, 1] == 0.0 + + +class TestEPSFStar: + """ + Tests for EPSFStar class functionality. + """ + + def test_basic_initialization(self): + """ + Test basic EPSFStar initialization. + """ + data = np.ones((11, 11)) + star = EPSFStar(data) + + assert star.data.shape == (11, 11) + assert star.cutout_center is not None + assert star.weights.shape == data.shape + assert star.flux > 0 + + def test_explicit_flux(self): + """ + Test EPSFStar initialization with explicit flux value. + """ + data = np.ones((5, 5)) + explicit_flux = 100.0 + star = EPSFStar(data, flux=explicit_flux) + + # Flux should be the explicitly provided value + assert star.flux == explicit_flux + + def test_invalid_data_input(self): + """ + Test EPSFStar initialization with invalid data. + """ + # Test 1D data + match = 'Input data must be 2-dimensional' + with pytest.raises(ValueError, match=match): + EPSFStar(np.ones(10)) + + # Test 3D data + with pytest.raises(ValueError, match=match): + EPSFStar(np.ones((5, 5, 5))) + + # Test empty data + with pytest.raises(ValueError, match='Input data cannot be empty'): + EPSFStar(np.array([]).reshape(0, 0)) + + def test_weights_validation(self): + """ + Test weight validation in EPSFStar. + """ + data = np.ones((5, 5)) + + # Test mismatched weights shape + wrong_weights = np.ones((3, 3)) + match = 'Weights shape .* must match data shape' + with pytest.raises(ValueError, match=match): + EPSFStar(data, weights=wrong_weights) + + # Test non-finite weights (should generate warning) + bad_weights = np.ones((5, 5)) + bad_weights[2, 2] = np.inf + bad_weights[1, 1] = np.nan + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + star = EPSFStar(data, weights=bad_weights) + assert len(w) == 1 + assert issubclass(w[0].category, AstropyUserWarning) + assert 'Non-finite weight values' in str(w[0].message) + + # Check that non-finite weights were set to zero + assert star.weights[2, 2] == 0.0 + assert star.weights[1, 1] == 0.0 + + def test_invalid_data_handling(self): + """ + Test handling of invalid pixel values. + """ + data = np.ones((5, 5)) + data[1, 1] = np.nan + data[2, 2] = np.inf + data[3, 3] = np.nan + data[4, 4] = np.inf + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + star = EPSFStar(data) + # Should mask invalid pixels + assert star.mask[1, 1] + assert star.mask[2, 2] + assert star.mask[3, 3] + assert star.mask[4, 4] + assert star.weights[1, 1] == 0.0 + assert star.weights[2, 2] == 0.0 + assert star.weights[3, 3] == 0.0 + assert star.weights[4, 4] == 0.0 + # Check that warning was issued about invalid data + assert len(w) > 0 + + def test_cutout_center_validation(self): + """ + Test cutout_center validation. + """ + data = np.ones((5, 5)) + star = EPSFStar(data) + + # Test invalid shape + match = 'cutout_center must have exactly two elements' + with pytest.raises(ValueError, match=match): + star.cutout_center = [1, 2, 3] + + # Test non-finite values + with pytest.raises(ValueError, match='must be finite'): + star.cutout_center = [np.nan, 2.0] + + # Test bounds warnings (should warn but not raise) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + star.cutout_center = [-1, 2] # Outside bounds + assert len(w) >= 1 + # Check that warning mentions coordinates outside bounds + warning_messages = [str(warning.message) for warning in w] + assert any('outside the cutout bounds' in msg + for msg in warning_messages) + + def test_origin_validation(self): + """ + Test origin parameter validation. + """ + data = np.ones((5, 5)) + + # Test invalid origin shape + match = 'Origin must have exactly 2 elements' + with pytest.raises(ValueError, match=match): + EPSFStar(data, origin=[1, 2, 3]) + + # Test non-finite origin + match = 'Origin coordinates must be finite' + with pytest.raises(ValueError, match=match): + EPSFStar(data, origin=[np.inf, 2]) + + def test_estimate_flux_masked_data(self): + """ + Test flux estimation with masked data. + """ + data = np.ones((5, 5)) * 10 + + # Create weights that mask some pixels + weights = np.ones((5, 5)) + weights[1:3, 1:3] = 0 # Mask central 2x2 region + + star = EPSFStar(data, weights=weights) + + # Flux should be estimated via interpolation + assert star.flux > 0 + # Should be close to total flux despite masking + assert star.flux == pytest.approx(250, rel=0.1) # 5*5*10 = 250 + + def test_data_shape_validation(self): + """ + Test EPSFStar validation for various data shapes. + """ + # Test zero-dimension data - this actually triggers "empty" + # error + with pytest.raises(ValueError, match='Input data cannot be empty'): + EPSFStar(np.zeros((0, 5))) + + with pytest.raises(ValueError, match='Input data cannot be empty'): + EPSFStar(np.zeros((5, 0))) + + def test_flux_estimation_failure(self): + """ + Test flux estimation behavior with all masked data. + """ + # Create data with all masked pixels - this should raise + # ValueError because the star cutout is completely masked + data = np.ones((5, 5)) + weights = np.zeros((5, 5)) # All masked data + + # This should raise ValueError because all data is masked + match = 'Star cutout is completely masked; no valid data available' + with pytest.raises(ValueError, match=match): + EPSFStar(data, weights=weights) + + def test_completely_masked_star(self): + """ + Test that completely masked stars are properly rejected. + """ + # Create star data with all weights zero (completely masked) + data = np.ones((7, 7)) * 100.0 + weights = np.zeros((7, 7)) + + # Should raise ValueError with appropriate message + match = 'Star cutout is completely masked; no valid data available' + with pytest.raises(ValueError, match=match): + EPSFStar(data, weights=weights) + + def test_negative_flux_allowed(self): + """ + Test that negative flux is allowed for valid sources. + + Negative flux can occur legitimately with background + oversubtraction or similar effects. + """ + # Create data with negative net flux + data = np.ones((5, 5)) * -10.0 + star = EPSFStar(data, flux=-50.0) + + # Should not raise an error + assert star.flux == -50.0 + + # Also test with estimated flux + star2 = EPSFStar(data) + assert star2.flux == -250.0 # sum of 25 pixels * -10 + + def test_all_zero_data_warning(self): + """ + Test that all-zero data emits a warning when EPSFStar is called + directly, but flag is set for extract_stars to handle. + + All-zero unmasked data is unusual and likely indicates a problem, + but it's not technically invalid, so we allow star creation. + """ + data = np.zeros((5, 5)) + + # EPSFStar should emit warning when called directly + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + star = EPSFStar(data) + # Should have warning about all-zero data + warning_messages = [str(warning.message) for warning in w] + assert any('All unmasked data values' in msg and 'zero' in msg + for msg in warning_messages) + + # Star should be created with flux=0 and flag set + assert star.flux == 0.0 + assert hasattr(star, '_has_all_zero_data') + assert star._has_all_zero_data is True + + def test_array_method(self): + """ + Test the __array__ method. + """ + data = np.random.default_rng(42).random((5, 5)) + star = EPSFStar(data) + + # Test that __array__ returns the data + star_array = star.__array__() + assert_array_equal(star_array, data) + + def test_properties(self): + """ + Test star properties. + """ + data = np.ones((7, 9)) + origin = (10, 20) + star = EPSFStar(data, origin=origin) + + # Test shape property + assert star.shape == (7, 9) + + # Test center property (different from cutout_center) + expected_center = star.cutout_center + np.array(origin) + assert_array_equal(star.center, expected_center) + + # Test slices property + # Implementation uses (origin_y to origin_y+shape[0], + # origin_x to origin_x+shape[1]) + expected_slices = (slice(20, 29), slice(10, 17)) + assert star.slices == expected_slices + + # Test bbox property + bbox = star.bbox + assert bbox.ixmin == 10 + assert bbox.ixmax == 17 + assert bbox.iymin == 20 + assert bbox.iymax == 29 + + def test_flux_estimation_interpolation_fallback(self): + """ + Test flux estimation with interpolation fallbacks. + """ + data = np.ones((5, 5)) * 10 + weights = np.ones((5, 5)) + weights[2, 2] = 0 # Mask center pixel + + star = EPSFStar(data, weights=weights) + + # Should estimate flux using interpolation + # Flux should be close to total despite masked pixel + assert star.flux == pytest.approx(250, rel=0.1) + + def test_register_epsf(self): + """ + Test ePSF registration and scaling. + """ + data = np.ones((11, 11)) + star = EPSFStar(data) + + # Create a simple ePSF model + epsf_data = np.zeros((5, 5)) + epsf_data[2, 2] = 1 # Central peak + epsf = ImagePSF(epsf_data) + + # Register the ePSF + registered = star.register_epsf(epsf) + + assert registered.shape == data.shape + assert isinstance(registered, np.ndarray) + + def test_private_properties(self): + """ + Test private properties. + """ + data = np.random.default_rng(42).random((5, 5)) + weights = np.ones((5, 5)) + weights[1, 1] = 0 # Mask one pixel + star = EPSFStar(data, weights=weights) + + # Test _xyidx_centered + x_centered, y_centered = star._xyidx_centered + assert len(x_centered) == len(y_centered) + assert len(x_centered) == np.sum(~star.mask) + + # Verify centering is correct + yidx, xidx = np.indices(data.shape) + expected_x = xidx[~star.mask].ravel() - star.cutout_center[0] + expected_y = yidx[~star.mask].ravel() - star.cutout_center[1] + assert_array_equal(x_centered, expected_x) + assert_array_equal(y_centered, expected_y) + + # Test normalized data values + expected_values = data[~star.mask].ravel() + normalized = star._data_values_normalized + expected_normalized = expected_values / star.flux + assert_allclose(normalized, expected_normalized) + + def test_flux_estimation_exception_handling(self): + """ + Test flux estimation exception handling when estimate_flux returns + invalid values. + """ + # Test with data that results in zero flux - this is now ALLOWED + # since zero flux is a valid (though not useful) value + data = np.zeros((3, 3)) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', AstropyUserWarning) + star = EPSFStar(data) + assert star.flux == 0.0 # Zero flux is allowed + + # Test that completely invalid (NaN) data is rejected + # (NaN data gets masked, then completely masked raises error) + data_nan = np.full((3, 3), np.nan) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', AstropyUserWarning) + match = 'Star cutout is completely masked' + with pytest.raises(ValueError, match=match): + EPSFStar(data_nan) + + def test_cutout_center_out_of_bounds_y(self): + """ + Test cutout_center validation for y-coordinate out of bounds. + """ + data = np.ones((5, 5)) + star = EPSFStar(data) + + # Test y-coordinate outside bounds + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + star.cutout_center = (2.0, -1.0) # y < 0 + assert len(w) >= 1 + warning_messages = [str(warning.message) for warning in w] + assert any('y-coordinate' in msg and 'outside' in msg + for msg in warning_messages) + + # Test y-coordinate at upper bound + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + star.cutout_center = (2.0, 6.0) # y >= shape[0] + assert len(w) >= 1 + warning_messages = [str(warning.message) for warning in w] + assert any('y-coordinate' in msg and 'outside' in msg + for msg in warning_messages) + + def test_empty_data_validation(self): + """ + Test empty data validation. + """ + data = np.array([[]]) # Empty 2D array + with pytest.raises(ValueError, match='Input data cannot be empty'): + EPSFStar(data) + + def test_residual_image(self): + """ + Test to ensure ``compute_residual_image`` gives correct + residuals. + """ + size = 100 + yy, xx, = np.mgrid[0:size + 1, 0:size + 1] / 4 + gmodel = CircularGaussianPRF().evaluate(xx, yy, 1, 12.5, 12.5, 2.5) + epsf = ImagePSF(gmodel, oversampling=4) + _size = 25 + data = np.zeros((_size, _size)) + _yy, _xx, = np.mgrid[0:_size, 0:_size] + data += epsf.evaluate(x=_xx, y=_yy, flux=16, x_0=12, y_0=12) + tbl = Table() + tbl['x'] = [12] + tbl['y'] = [12] + stars = extract_stars(NDData(data), tbl, size=23) + residual = stars[0].compute_residual_image(epsf) + assert_allclose(np.sum(residual), 0.0) + + +class TestEPSFStars: + """ + Tests for EPSFStars collection class functionality. + """ + + def test_initialization_variants(self): + """ + Test different initialization methods. + """ + data1 = np.ones((5, 5)) + data2 = np.ones((7, 7)) + star1 = EPSFStar(data1) + star2 = EPSFStar(data2) + + # Test single star initialization + stars_single = EPSFStars(star1) + assert len(stars_single) == 1 + + # Test list initialization + stars_list = EPSFStars([star1, star2]) + assert len(stars_list) == 2 + + # Test invalid initialization + with pytest.raises(TypeError, match='stars_list must be a list'): + EPSFStars('invalid') + + def test_indexing_operations(self): + """ + Test indexing and slicing operations. + """ + stars = [EPSFStar(np.ones((5, 5))) for _ in range(3)] + stars_obj = EPSFStars(stars) + + # Test getitem + first = stars_obj[0] + assert isinstance(first, EPSFStars) + assert len(first) == 1 + + # Test delitem + del stars_obj[1] + assert len(stars_obj) == 2 + + # Test iteration + count = 0 + for star in stars_obj: + count += 1 + assert isinstance(star, EPSFStar) + assert count == 2 + + def test_pickle_operations(self): + """ + Test pickle state management. + """ + stars = [EPSFStar(np.ones((5, 5))) for _ in range(2)] + stars_obj = EPSFStars(stars) + + # Test getstate/setstate + state = stars_obj.__getstate__() + new_obj = EPSFStars([]) + new_obj.__setstate__(state) + assert len(new_obj) == len(stars_obj) + + def test_attribute_access(self): + """ + Test dynamic attribute access. + """ + data1 = np.ones((5, 5)) + data2 = np.ones((7, 7)) * 2 + stars = EPSFStars([EPSFStar(data1), EPSFStar(data2)]) - yy, xx = np.mgrid[0:51, 0:55] - self.data = np.zeros(xx.shape) - for (xi, yi) in zip(stars_tbl['x'], stars_tbl['y'], strict=True): - m = Moffat2D(100, xi, yi, 3, 3) - self.data += m(xx, yy) + # Test accessing cutout_center attribute + centers = stars.cutout_center + assert len(centers) == 2 + assert centers.shape == (2, 2) - self.nddata = NDData(data=self.data) + # Test accessing flux attribute + fluxes = stars.flux + assert len(fluxes) == 2 - def test_extract_stars(self): + # Test accessing _excluded_from_fit attribute + excluded = stars._excluded_from_fit + assert len(excluded) == 2 + assert not any(excluded) # Should all be False initially + + def test_flat_attributes(self): + """ + Test flat attribute access methods. + """ + stars = [EPSFStar(np.ones((5, 5))) for _ in range(2)] + stars_obj = EPSFStars(stars) + + # Test cutout_center_flat + centers_flat = stars_obj.cutout_center_flat + assert centers_flat.shape == (2, 2) + + # Test center_flat + centers_flat = stars_obj.center_flat + assert centers_flat.shape == (2, 2) + + def test_star_counting(self): + """ + Test star counting properties. + """ + stars = [EPSFStar(np.ones((5, 5))) for _ in range(3)] + stars_obj = EPSFStars(stars) + + # Test counting properties + assert stars_obj.n_stars == 3 + assert stars_obj.n_all_stars == 3 + assert stars_obj.n_good_stars == 3 + + # Test all_stars and all_good_stars properties + all_stars = stars_obj.all_stars + assert len(all_stars) == 3 + + good_stars = stars_obj.all_good_stars + assert len(good_stars) == 3 + + # Mark one star as excluded + stars[1]._excluded_from_fit = True + assert stars_obj.n_good_stars == 2 + + def test_shape_attribute(self): + """ + Test accessing shape attribute through EPSFStars. + """ + stars = [EPSFStar(np.ones((5, 5))), EPSFStar(np.ones((7, 9)))] + stars_obj = EPSFStars(stars) + + # Access individual star shapes through the container + shapes = stars_obj.shape + assert len(shapes) == 2 + assert shapes[0] == (5, 5) + assert shapes[1] == (7, 9) + + def test_pickleable(self): + """ + Verify that EPSFStars can be successfully pickled/unpickled for + multiprocessing. + """ + # This should not fail + stars = EPSFStars([1]) + ForkingPickler.loads(ForkingPickler.dumps(stars)) + + def test_cutout_center_flat_with_linked_stars(self, simple_wcs): + """ + Test cutout_center_flat property with LinkedEPSFStar objects. + """ + # Create regular stars + star1 = EPSFStar(np.ones((5, 5))) + star2 = EPSFStar(np.ones((7, 7))) + + # Create linked stars + linked_star1 = EPSFStar(np.ones((6, 6)), wcs_large=simple_wcs) + linked_star2 = EPSFStar(np.ones((8, 8)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([linked_star1, linked_star2]) + + # Create EPSFStars collection with mix of regular and linked stars + stars = EPSFStars([star1, linked, star2]) + + # Test cutout_center_flat property + centers_flat = stars.cutout_center_flat + # Should have 4 centers: star1, linked_star1, linked_star2, star2 + assert len(centers_flat) == 4 + assert centers_flat.shape == (4, 2) + + def test_all_stars_with_linked_stars(self, simple_wcs): + """ + Test all_stars property with LinkedEPSFStar objects. + """ + # Create regular stars + star1 = EPSFStar(np.ones((5, 5))) + star2 = EPSFStar(np.ones((7, 7))) + + # Create linked stars + linked_star1 = EPSFStar(np.ones((6, 6)), wcs_large=simple_wcs) + linked_star2 = EPSFStar(np.ones((8, 8)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([linked_star1, linked_star2]) + + # Create EPSFStars collection with mix of regular and linked + # stars + stars = EPSFStars([star1, linked, star2]) + + # Test all_stars property + all_stars_list = stars.all_stars + # Should have 4 stars total: star1, linked_star1, linked_star2, + # star2 + assert len(all_stars_list) == 4 + + # Verify they are all EPSFStar instances + for star in all_stars_list: + assert isinstance(star, EPSFStar) + + +class TestLinkedEPSFStar: + """ + Tests for LinkedEPSFStar functionality. + """ + + def test_initialization_validation(self): + """ + Test LinkedEPSFStar initialization validation. + """ + # Test with non-EPSFStar objects + with pytest.raises(TypeError, match='must contain only EPSFStar'): + LinkedEPSFStar(['not_a_star', 'also_not_a_star']) + + # Test with EPSFStar without WCS + star_no_wcs = EPSFStar(np.ones((5, 5))) + with pytest.raises(ValueError, match='must have a valid wcs_large'): + LinkedEPSFStar([star_no_wcs]) + + def test_constraint_no_good_stars(self, simple_wcs): + """ + Test constraining centers with no good stars. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + + # Mark both as excluded + star1._excluded_from_fit = True + star2._excluded_from_fit = True + + linked = LinkedEPSFStar([star1, star2]) + + # Should warn about no good stars + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + linked.constrain_centers() + assert len(w) >= 1 + warning_messages = [str(warning.message) for warning in w] + assert any('have all been excluded' in msg + for msg in warning_messages) + + def test_constraint_single_star(self, simple_wcs): + """ + Test constraining centers with single star (no-op). + """ + star = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star]) + + # Should do nothing for single star + original_center = star.cutout_center.copy() + linked.constrain_centers() + assert_array_equal(star.cutout_center, original_center) + + def test_all_excluded_property(self, simple_wcs): + """ + Test the all_excluded property. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star1, star2]) + + # Initially, no stars are excluded + assert not linked.all_excluded + + # Exclude one star + star1._excluded_from_fit = True + assert not linked.all_excluded + + # Exclude both stars + star2._excluded_from_fit = True + assert linked.all_excluded + + def test_constrain_centers_with_good_stars(self, simple_wcs): + """ + Test constrain_centers method with good stars. + """ + # Create multiple stars with different positions (within bounds) + star1 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs, + cutout_center=(3.1, 3.1), origin=(20, 20)) + star2 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs, + cutout_center=(2.9, 2.9), origin=(20, 20)) + star3 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs, + cutout_center=(3.0, 3.2), origin=(20, 20)) + + # Make sure none are excluded + star1._excluded_from_fit = False + star2._excluded_from_fit = False + star3._excluded_from_fit = False + + linked = LinkedEPSFStar([star1, star2, star3]) + + # Test constrain_centers (should execute without error) + linked.constrain_centers() + + def test_constrain_centers_with_some_excluded_stars(self, simple_wcs): + """ + Test constrain_centers with some excluded stars. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star3 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + + # Exclude some stars but not all + star1._excluded_from_fit = True # Excluded + star2._excluded_from_fit = False # Good + star3._excluded_from_fit = False # Good + + linked = LinkedEPSFStar([star1, star2, star3]) + + # This should process only the good stars; should not raise + # warnings since there are good stars + linked.constrain_centers() + + def test_constrain_all_excluded(self, simple_wcs): + """ + Test constrain_centers when all stars excluded. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + + # Exclude all stars + star1._excluded_from_fit = True + star2._excluded_from_fit = True + + linked = LinkedEPSFStar([star1, star2]) + + # Should trigger early return and emit warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + linked.constrain_centers() + # Should get warning about no good stars + warning_messages = [str(warning.message) for warning in w] + has_warning = any('Cannot constrain centers' in msg + for msg in warning_messages) + assert has_warning + + def test_len_getitem_iter(self, simple_wcs): + """ + Test __len__, __getitem__, and __iter__ methods. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star1, star2]) + + # Test __len__ + assert len(linked) == 2 + + # Test __getitem__ + assert linked[0] is star1 + assert linked[1] is star2 + + # Test __iter__ + stars_list = list(linked) + assert len(stars_list) == 2 + assert stars_list[0] is star1 + assert stars_list[1] is star2 + + def test_getattr_delegation(self, simple_wcs): + """ + Test __getattr__ delegation for various attributes. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star1, star2]) + + # Test accessing flux attribute (should be array) + fluxes = linked.flux + assert len(fluxes) == 2 + assert fluxes[0] == star1.flux + assert fluxes[1] == star2.flux + + # Test accessing cutout_center (should be array) + centers = linked.cutout_center + assert centers.shape == (2, 2) + + # Test accessing center (should be array) + centers = linked.center + assert centers.shape == (2, 2) + + def test_getattr_single_star(self, simple_wcs): + """ + Test __getattr__ with single star (returns scalar not array). + """ + star = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star]) + + # With single star, should return single value not array + flux = linked.flux + assert flux == star.flux + assert not isinstance(flux, np.ndarray) + + def test_getattr_private_attribute_error(self, simple_wcs): + """ + Test that accessing non-existent private attributes raises + error. + """ + star = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star]) + + # Accessing non-existent private attribute should raise + match = "'LinkedEPSFStar' object has no attribute" + with pytest.raises(AttributeError, match=match): + _ = linked._nonexistent_attribute + + def test_pickle_operations(self, simple_wcs): + """ + Test __getstate__ and __setstate__ for pickling. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star1, star2]) + + # Test getstate/setstate + state = linked.__getstate__() + new_linked = LinkedEPSFStar([EPSFStar(np.ones((3, 3)), + wcs_large=simple_wcs)]) + new_linked.__setstate__(state) + assert len(new_linked) == 2 + + def test_flat_properties(self, simple_wcs): + """ + Test cutout_center_flat and center_flat properties. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs, + origin=(10, 20)) + star2 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs, + origin=(30, 40)) + linked = LinkedEPSFStar([star1, star2]) + + # Test cutout_center_flat + centers_flat = linked.cutout_center_flat + assert centers_flat.shape == (2, 2) + assert_array_equal(centers_flat[0], star1.cutout_center) + assert_array_equal(centers_flat[1], star2.cutout_center) + + # Test center_flat + centers = linked.center_flat + assert centers.shape == (2, 2) + assert_array_equal(centers[0], star1.center) + assert_array_equal(centers[1], star2.center) + + def test_counting_properties(self, simple_wcs): + """ + Test n_stars, n_all_stars, and n_good_stars properties. + """ + star1 = EPSFStar(np.ones((5, 5)), wcs_large=simple_wcs) + star2 = EPSFStar(np.ones((7, 7)), wcs_large=simple_wcs) + star3 = EPSFStar(np.ones((6, 6)), wcs_large=simple_wcs) + linked = LinkedEPSFStar([star1, star2, star3]) + + # Test n_stars and n_all_stars (should be same for + # LinkedEPSFStar) + assert linked.n_stars == 3 + assert linked.n_all_stars == 3 + + # Test n_good_stars + assert linked.n_good_stars == 3 + + # Exclude one star + star2._excluded_from_fit = True + assert linked.n_good_stars == 2 + + +class TestExtractStars: + """ + Tests for extract_stars function. + """ + + def test_extract_stars(self, stars_nddata, stars_table): + """ + Test basic star extraction functionality. + """ size = 11 - stars = extract_stars(self.nddata, self.stars_tbl, size=size) + stars = extract_stars(stars_nddata, stars_table, size=size) assert len(stars) == 4 assert isinstance(stars, EPSFStars) assert isinstance(stars[0], EPSFStars) @@ -41,56 +1154,565 @@ def test_extract_stars(self): assert stars.n_stars == stars.n_good_stars assert stars.center.shape == (len(stars), 2) - def test_extract_stars_inputs(self): - match = 'data must be a single NDData or list of NDData objects' + def test_extract_stars_inputs(self, stars_nddata, stars_table): + """ + Test extract_stars input validation. + """ + match = 'data must be a single NDData object or list of NDData objects' with pytest.raises(TypeError, match=match): - extract_stars(np.ones(3), self.stars_tbl) + extract_stars(np.ones(3), stars_table) - match = 'catalogs must be a single Table or list of Table objects' + match = 'All catalog elements must be Table objects' with pytest.raises(TypeError, match=match): - extract_stars(self.nddata, [(1, 1), (2, 2), (3, 3)]) + extract_stars(stars_nddata, [(1, 1), (2, 2), (3, 3)]) match = 'number of catalogs must match the number of input images' with pytest.raises(ValueError, match=match): - extract_stars(self.nddata, [self.stars_tbl, self.stars_tbl]) + extract_stars(stars_nddata, [stars_table, stars_table]) match = 'the catalog must have a "skycoord" column' with pytest.raises(ValueError, match=match): - extract_stars([self.nddata, self.nddata], self.stars_tbl) + extract_stars([stars_nddata, stars_nddata], stars_table) + def test_empty_catalog(self, simple_nddata): + """ + Test extraction with empty catalog. + """ + empty_table = Table() + empty_table['x'] = [] + empty_table['y'] = [] -def test_epsf_star_residual_image(): - """ - Test to ensure ``compute_residual_image`` gives correct residuals. - """ - size = 100 - yy, xx, = np.mgrid[0:size + 1, 0:size + 1] / 4 - gmodel = CircularGaussianPRF().evaluate(xx, yy, 1, 12.5, 12.5, 2.5) - epsf = ImagePSF(gmodel, oversampling=4) - _size = 25 - data = np.zeros((_size, _size)) - _yy, _xx, = np.mgrid[0:_size, 0:_size] - data += epsf.evaluate(x=_xx, y=_yy, flux=16, x_0=12, y_0=12) - tbl = Table() - tbl['x'] = [12] - tbl['y'] = [12] - stars = extract_stars(NDData(data), tbl, size=23) - residual = stars[0].compute_residual_image(epsf) - # As current EPSFStar instances cannot accept CircularGaussianPRF - # as input, we have to accept some loss of precision from the - # conversion to ePSF, and spline fitting (twice), so assert_allclose - # cannot be more precise than 0.001 currently. - assert_allclose(np.sum(residual), 0.0, atol=1.0e-3, rtol=1e-3) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + stars = extract_stars(simple_nddata, empty_table) + assert len(stars) == 0 + # Should warn about empty catalog + assert len(w) >= 1 + warning_messages = [str(warning.message) for warning in w] + assert any('empty' in msg.lower() for msg in warning_messages) + def test_stars_outside_image(self, simple_nddata): + """ + Test extraction with stars outside image bounds. + """ + table = Table() + table['x'] = [-10, 100] # Outside image bounds + table['y'] = [25, 25] -def test_stars_pickleable(): - """ - Verify that EPSFStars can be successfully pickled/unpickled for use - multiprocessing. - """ - from multiprocessing.reduction import ForkingPickler + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + stars = extract_stars(simple_nddata, table, size=11) + assert len(stars) == 0 + # Should warn about excluded stars + assert len(w) >= 1 + warning_messages = [str(warning.message) for warning in w] + assert any('not extracted' in msg for msg in warning_messages) + + def test_invalid_input_types(self, simple_nddata): + """ + Test extraction with invalid input types. + """ + table = Table() + table['x'] = [25] + table['y'] = [25] + + # Test invalid data type + with pytest.raises(TypeError, match='must be a single NDData object'): + extract_stars('not_nddata', table) + + # Test invalid catalog type + with pytest.raises(TypeError, match='must be a single Table object'): + extract_stars(simple_nddata, 'not_table') + + def test_coordinate_validation(self, simple_nddata): + """ + Test coordinate system validation. + """ + table = Table() + table['x'] = [25] + table['y'] = [25] + + # Test missing skycoord for multiple images + with pytest.raises(ValueError, match='must have a "skycoord" column'): + extract_stars([simple_nddata, simple_nddata], table) + + # Test missing coordinate columns + bad_table = Table() + bad_table['flux'] = [100] # No x, y, or skycoord + + with pytest.raises(ValueError, match='must have either'): + extract_stars(simple_nddata, bad_table) + + def test_data_validation(self, simple_table): + """ + Test data input validation. + """ + # Test invalid data types in list + with pytest.raises(TypeError, match='All data elements must be'): + extract_stars(['not_nddata'], simple_table) + + # Test NDData with no data array + empty_nddata = NDData(np.array([])) # Provide empty array + with pytest.raises(ValueError, match='must contain 2D data'): + extract_stars(empty_nddata, simple_table) + + # Test NDData with wrong dimensions + nddata_1d = NDData(np.ones(50)) + with pytest.raises(ValueError, match='must contain 2D data'): + extract_stars(nddata_1d, simple_table) + + def test_catalog_validation(self, simple_nddata): + """ + Test catalog input validation. + """ + # Test invalid catalog types in list + with pytest.raises(TypeError, match='All catalog elements must be'): + extract_stars(simple_nddata, ['not_table']) + + def test_coordinate_system_validation(self, simple_nddata): + """ + Test coordinate system validation for complex cases. + """ + # Test skycoord-only catalog without WCS + skycoord_table = Table() + skycoord_table['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + with pytest.raises(ValueError, + match='must have a wcs attribute'): + extract_stars(simple_nddata, skycoord_table) + + # Test multiple catalogs with mismatched count + table1 = Table({'x': [25], 'y': [25]}) + table2 = Table({'x': [25], 'y': [25]}) + with pytest.raises(ValueError, + match='number of catalogs must match'): + extract_stars(simple_nddata, [table1, table2]) + + def test_extract_stars_skycoord_and_wcs(self, simple_data, simple_wcs): + """ + Test extract_stars with skycoord input and WCS. + """ + nddata_with_wcs = NDData(simple_data) + nddata_with_wcs.wcs = simple_wcs + + table = Table() + table['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + stars = extract_stars(nddata_with_wcs, table, size=(11, 11)) + + valid_stars = [s for s in stars.all_stars if s is not None] + assert len(valid_stars) >= 1 + + def test_extract_stars_size_validation_coverage(self, simple_nddata): + """ + Test size validation paths in extract_stars. + """ + table = Table({'x': [25], 'y': [25]}) + + # Test various size configurations to hit validation paths. + # This should exercise the as_pair validation. + stars = extract_stars(simple_nddata, table, size=11) + assert len(stars) == 1 + + # Test tuple size + stars = extract_stars(simple_nddata, table, size=(11, 13)) + assert len(stars) == 1 + assert stars[0].data.shape == (11, 13) + + def test_extract_stars_coordinate_conversion_paths(self, + simple_data, + simple_wcs): + """ + Test coordinate conversion paths in extract_stars. + """ + nddata_with_wcs = NDData(simple_data) + nddata_with_wcs.wcs = simple_wcs + + # Test with both x,y and skycoord present (should prefer x,y) + table = Table() + table['x'] = [25.0] + table['y'] = [25.0] + table['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + stars = extract_stars(nddata_with_wcs, table, size=11) + assert len(stars) == 1 + + def test_extract_stars_id_handling(self, simple_nddata): + """ + Test ID handling in extract_stars. + """ + # Test with explicit IDs + table = Table() + table['x'] = [25, 30] + table['y'] = [25, 30] + table['id'] = ['star_a', 'star_b'] + + stars = extract_stars(simple_nddata, table, size=11) + assert len(stars) == 2 + assert stars[0].id_label == 'star_a' + assert stars[1].id_label == 'star_b' + + # Test without IDs (should auto-generate) + table_no_id = Table() + table_no_id['x'] = [25, 30] + table_no_id['y'] = [25, 30] + + stars = extract_stars(simple_nddata, table_no_id, size=11) + assert len(stars) == 2 + assert stars[0].id_label == 1 # Auto-generated starting from 1 + assert stars[1].id_label == 2 + + def test_extract_linked_stars_multiple_images(self, simple_wcs): + """ + Test extracting linked stars from multiple images with single + catalog. + """ + # Create two images with WCS + data1 = np.ones((50, 50)) * 10 + data2 = np.ones((50, 50)) * 20 + nddata1 = NDData(data1) + nddata1.wcs = simple_wcs + nddata2 = NDData(data2) + nddata2.wcs = simple_wcs + + # Create catalog with skycoord at center of image + table = Table() + table['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + # Extract linked stars (suppress warnings to avoid pytest error) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', AstropyUserWarning) + stars = extract_stars([nddata1, nddata2], table, size=11) + + # Should have 1 linked star containing 2 EPSFStar objects + assert len(stars) == 1 + assert isinstance(stars._data[0], LinkedEPSFStar) + assert len(stars._data[0]) == 2 + + def test_extract_unlinked_stars_multiple_catalogs(self): + """ + Test extracting stars with multiple catalogs (no linking). + """ + # Create two images + data1 = np.ones((50, 50)) * 10 + data2 = np.ones((50, 50)) * 20 + nddata1 = NDData(data1) + nddata2 = NDData(data2) + + # Create two catalogs with different stars + table1 = Table({'x': [25], 'y': [25]}) + table2 = Table({'x': [30], 'y': [30]}) + + # Extract stars + stars = extract_stars([nddata1, nddata2], [table1, table2], size=11) + + # Should have 2 separate (not linked) stars + assert len(stars) == 2 + assert all(isinstance(s, EPSFStar) for s in stars._data) + + def test_extract_linked_stars_partial_extraction(self, simple_wcs): + """ + Test linked star extraction where star is valid in one image but + not another (edge case). + """ + # Create two images - second one is smaller so star near edge + # won't be extractable + data1 = np.ones((50, 50)) * 10 + data2 = np.ones((20, 20)) * 20 # Smaller image + nddata1 = NDData(data1) + nddata1.wcs = simple_wcs + nddata2 = NDData(data2) + nddata2.wcs = simple_wcs + + # Create catalog with star at position that's valid in first + # but not second + table = Table() + table['skycoord'] = [SkyCoord(0, 0, unit='deg')] # Center + + with warnings.catch_warnings(record=True): + warnings.simplefilter('always') + stars = extract_stars([nddata1, nddata2], table, size=11) + + # Should have extracted at least 1 star (from first image) + # The second image star is outside bounds so only 1 is extracted + assert len(stars) >= 1 + + def test_extract_stars_flux_estimation_failure(self): + """ + Test that EPSFStar creation failure emits warning for completely + masked stars. + """ + # Create data with explicit zero weights (completely masked) + data = np.ones((50, 50)) * 100.0 + nddata = NDData(data) + # Use zero uncertainty which causes infinite weights, + # which are then set to zero (completely masked) + uncertainty = StdDevUncertainty(np.zeros((50, 50))) + nddata.uncertainty = uncertainty + + table = Table({'x': [25], 'y': [25]}) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + stars = extract_stars(nddata, table, size=11) + # Should warn about failed EPSFStar creation + warning_messages = [str(warning.message) for warning in w] + assert any('Failed to create EPSFStar' in msg + for msg in warning_messages) + # Should NOT have duplicate warnings about completely masked + masked_warnings = [msg for msg in warning_messages + if 'completely masked' in msg] + # Should only have one warning per failed star + assert len(masked_warnings) == 1 + + # No valid stars should be extracted + assert len(stars) == 0 + + def test_extract_stars_completely_masked(self): + """ + Test extract_stars with completely masked cutouts. + """ + # Create data with zeros and zero weights + data = np.zeros((50, 50)) + uncertainty = StdDevUncertainty(np.zeros((50, 50))) + nddata = NDData(data, uncertainty=uncertainty) + + table = Table({'x': [25, 30, 35], 'y': [25, 30, 35]}) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + stars = extract_stars(nddata, table, size=11) + + # Check warnings + warning_messages = [str(warning.message) for warning in w] + + # Should have one warning about non-finite weights + nonfinite_warnings = [msg for msg in warning_messages + if 'non-finite weight values' in msg] + assert len(nonfinite_warnings) == 1 + + # Should have warnings about failed EPSFStar creation + failed_warnings = [msg for msg in warning_messages + if 'Failed to create EPSFStar' in msg] + assert len(failed_warnings) == 3 # One per star + + # Each warning should mention completely masked + for msg in failed_warnings: + assert 'completely masked' in msg + + # No valid stars should be extracted + assert len(stars) == 0 + + def test_extract_stars_nonfinite_weights_warning(self): + """ + Test that non-finite weights in uncertainty emit warning. + """ + data = np.ones((50, 50)) * 100 + nddata = NDData(data) + + # Create uncertainty with non-finite values + uncertainty = np.ones((50, 50)) * 0.1 + uncertainty[20:30, 20:30] = 0 # Will cause 1/0 = inf in weights + nddata.uncertainty = StdDevUncertainty(uncertainty) + + table = Table({'x': [25], 'y': [25]}) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + stars = extract_stars(nddata, table, size=11) + # Should warn about non-finite weights + warning_messages = [str(warning.message) for warning in w] + assert any('non-finite weight values' in msg + for msg in warning_messages) + + # Star should still be extracted (non-finite weights set to 0) + assert len(stars) == 1 + + def test_extract_stars_all_zero_data_warnings(self): + """ + Test that extract_stars emits individual warnings for stars with + all-zero data, including their positions. + """ + # Create an all-zero image + data = np.zeros((50, 50)) + nddata = NDData(data) + + # Create a table with 3 stars + table = Table({'x': [10.5, 25.0, 40.8], 'y': [15.2, 30.0, 35.6]}) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + stars = extract_stars(nddata, table, size=11) + + # Should have 3 warnings about all-zero data + warning_messages = [str(warning.message) for warning in w] + zero_warnings = [ + msg for msg in warning_messages + if 'all unmasked data values equal to zero' in msg] + assert len(zero_warnings) == 3 + + # Check that each warning includes the star position + assert any('10.5' in msg and '15.2' in msg + for msg in zero_warnings) + assert any('25.0' in msg and '30.0' in msg + for msg in zero_warnings) + assert any('40.8' in msg and '35.6' in msg + for msg in zero_warnings) + + # All stars should still be extracted with flux=0 + assert len(stars) == 3 + for star in stars: + assert star.flux == 0.0 + + def test_validate_single_catalog_multiple_images_no_wcs(self): + """ + Test validation error when single catalog with multiple images + but images lack WCS. + """ + # Create two images without WCS + data1 = np.ones((50, 50)) + data2 = np.ones((50, 50)) + nddata1 = NDData(data1) + nddata2 = NDData(data2) + + # Create catalog with skycoord + table = Table() + table['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + # Should raise because images don't have WCS + with pytest.raises(ValueError, match='must have a wcs attribute'): + extract_stars([nddata1, nddata2], table, size=11) + + def test_validate_skycoord_only_catalog_no_wcs(self): + """ + Test validation when catalog has only skycoord but NDData lacks + WCS. + """ + # Create NDData without WCS + nddata = NDData(np.ones((50, 50))) + + # Create catalog with only skycoord (no x, y columns) + table = Table() + table['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + # Should raise because NDData does not have WCS + with pytest.raises(ValueError, match='must have a wcs attribute'): + extract_stars(nddata, table, size=11) + + def test_validate_multiple_catalogs_skycoord_only_no_wcs(self, simple_wcs): + """ + Test validation when catalog has only skycoord and some NDData + objects lack WCS. + + This tests the branch where the corresponding NDData has WCS, + but another NDData in the list does not have WCS. + """ + nddata1 = NDData(np.ones((50, 50))) + # nddata1 intentionally has no WCS + nddata2 = NDData(np.ones((50, 50))) + nddata2.wcs = simple_wcs # Second image has WCS + + # First catalog has x,y (does not need WCS), second has only + # skycoord + table1 = Table({'x': [25], 'y': [25]}) + table2 = Table() + table2['skycoord'] = [SkyCoord(0, 0, unit='deg')] + + # nddata2 has WCS, but nddata1 does not + with pytest.raises(ValueError, + match='each NDData object must have a wcs'): + extract_stars([nddata1, nddata2], [table1, table2], size=11) + + def test_extract_stars_uncertainties(self, epsf_test_data): + """ + Test extract_stars with various uncertainty types. + """ + rng = np.random.default_rng(seed=0) + shape = epsf_test_data['nddata'].data.shape + error = np.abs(rng.normal(loc=0, scale=1, size=shape)) + uncertainty1 = StdDevUncertainty(error) + uncertainty2 = uncertainty1.represent_as(VarianceUncertainty) + uncertainty3 = uncertainty1.represent_as(InverseVariance) + ndd1 = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty1) + ndd2 = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty2) + ndd3 = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty3) + + size = 25 + ndd_inputs = (ndd1, ndd2, ndd3) + + outputs = [extract_stars(ndd_input, epsf_test_data['init_stars'], + size=size) for ndd_input in ndd_inputs] + + for stars in outputs: + assert len(stars) == len(epsf_test_data['init_stars']) + assert isinstance(stars, EPSFStars) + assert isinstance(stars[0], EPSFStars) + assert stars[0].data.shape == (size, size) + assert stars[0].weights.shape == (size, size) + + assert_allclose(outputs[0].weights, outputs[1].weights) + assert_allclose(outputs[0].weights, outputs[2].weights) + + def test_extract_stars_nonfinite_weights(self, epsf_test_data): + """ + Test extract_stars with sparse zero uncertainty values that create + non-finite weights at specific locations. The stars should still + be extracted successfully, with only the expected warning about + non-finite weights being set to zero. + """ + shape = epsf_test_data['nddata'].data.shape + init = epsf_test_data['init_stars'] + + # Create an uncertainty array with mostly valid (non-zero) values, + # but include some zero uncertainty values at specific locations + # within the star cutout regions to trigger non-finite weights + uncertainty_data = np.ones(shape) + for i in range(min(3, len(init))): + x_pix = int(init['x'][i]) + y_pix = int(init['y'][i]) + # Set a small region around the star center to zero uncertainty + uncertainty_data[y_pix - 2:y_pix + 3, x_pix - 2:x_pix + 3] = 0.0 + + uncertainty = StdDevUncertainty(uncertainty_data) + ndd = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty) + size = 25 + + # Should only get the non-finite weights warning; stars should + # still be extracted successfully + match = 'non-finite weight values' + with pytest.warns(AstropyUserWarning, match=match): + stars = extract_stars(ndd, init[0:3], size=size) + + # All 3 stars should be successfully extracted + assert len(stars) == 3 + for i in range(3): + assert stars[i] is not None + assert stars[i].data.shape == (size, size) + + def test_extract_stars_all_zero_uncertainty(self, epsf_test_data): + """ + Test extract_stars with all-zero uncertainty values. + + When all uncertainty values are zero, all weights become infinite + and are then set to zero, resulting in fully-masked cutouts. This + causes flux estimation to fail because there is no valid data. + """ + shape = epsf_test_data['nddata'].data.shape + + uncertainty = StdDevUncertainty(np.zeros(shape)) + ndd = NDData(epsf_test_data['nddata'].data, uncertainty=uncertainty) + size = 25 + + # With all-zero uncertainty, stars will fail with completely + # masked errors because all weights are set to zero (fully + # masked data). + match1 = 'Star cutout is completely masked' + match2 = 'non-finite weight values' + with (pytest.warns(AstropyUserWarning, match=match1), + pytest.warns(AstropyUserWarning, match=match2)): + stars = extract_stars(ndd, + epsf_test_data['init_stars'][0:3], + size=size) - # Doesn't need to actually contain anything useful - stars = EPSFStars([1]) - # This should not blow up - ForkingPickler.loads(ForkingPickler.dumps(stars)) + # All stars should fail (None) because they are completely masked + assert len(stars) == 0 diff --git a/photutils/psf/tests/test_image_models.py b/photutils/psf/tests/test_image_models.py index f1518b514..506a116ee 100644 --- a/photutils/psf/tests/test_image_models.py +++ b/photutils/psf/tests/test_image_models.py @@ -5,18 +5,9 @@ import numpy as np import pytest -from astropy.modeling.models import Gaussian2D -from astropy.utils.exceptions import AstropyDeprecationWarning from numpy.testing import assert_allclose, assert_equal -from photutils.psf import (CircularGaussianPSF, EPSFModel, FittableImageModel, - ImagePSF) - - -@pytest.fixture(name='gmodel_old') -def fixture_gmodel_old(): - # remove when FittableImageModel is removed - return Gaussian2D(x_stddev=3, y_stddev=3) +from photutils.psf import CircularGaussianPSF, ImagePSF @pytest.fixture(name='gaussian_psf') @@ -206,186 +197,3 @@ def test_str(self, image_psf): assert key in model_str for param in image_psf.param_names: assert param in model_str - - -class TestFittableImageModel: - """ - Tests for FittableImageModel. - """ - - def test_fittable_image_model(self, gmodel_old): - yy, xx = np.mgrid[-2:3, -2:3] - with pytest.warns(AstropyDeprecationWarning): - model_nonorm = FittableImageModel(gmodel_old(xx, yy)) - - assert_allclose(model_nonorm(0, 0), gmodel_old(0, 0)) - assert_allclose(model_nonorm(1, 1), gmodel_old(1, 1)) - assert_allclose(model_nonorm(-2, 1), gmodel_old(-2, 1)) - - # subpixel should *not* match, but be reasonably close - # in this case good to ~0.1% seems to be fine - assert_allclose(model_nonorm(0.5, 0.5), gmodel_old(0.5, 0.5), - rtol=.001) - assert_allclose(model_nonorm(-0.5, 1.75), gmodel_old(-0.5, 1.75), - rtol=.001) - - with pytest.warns(AstropyDeprecationWarning): - model_norm = FittableImageModel(gmodel_old(xx, yy), normalize=True) - assert not np.allclose(model_norm(0, 0), gmodel_old(0, 0)) - assert_allclose(np.sum(model_norm(xx, yy)), 1) - - with pytest.warns(AstropyDeprecationWarning): - model_norm2 = FittableImageModel(gmodel_old(xx, yy), - normalize=True, - normalization_correction=2) - assert not np.allclose(model_norm2(0, 0), gmodel_old(0, 0)) - assert_allclose(model_norm(0, 0), model_norm2(0, 0) * 2) - assert_allclose(np.sum(model_norm2(xx, yy)), 0.5) - - def test_fittable_image_model_oversampling(self, gmodel_old): - oversamp = 3 # oversampling factor - yy, xx = np.mgrid[-3:3.00001:(1 / oversamp), -3:3.00001:(1 / oversamp)] - - im = gmodel_old(xx, yy) - assert im.shape[0] > 7 - - with pytest.warns(AstropyDeprecationWarning): - model_oversampled = FittableImageModel(im, oversampling=oversamp) - assert_allclose(model_oversampled(0, 0), gmodel_old(0, 0)) - assert_allclose(model_oversampled(1, 1), gmodel_old(1, 1)) - assert_allclose(model_oversampled(-2, 1), gmodel_old(-2, 1)) - assert_allclose(model_oversampled(0.5, 0.5), gmodel_old(0.5, 0.5), - rtol=.001) - assert_allclose(model_oversampled(-0.5, 1.75), - gmodel_old(-0.5, 1.75), rtol=.001) - - # without oversampling the same tests should fail except for at - # the origin - with pytest.warns(AstropyDeprecationWarning): - model_wrongsampled = FittableImageModel(im) - assert_allclose(model_wrongsampled(0, 0), gmodel_old(0, 0)) - assert not np.allclose(model_wrongsampled(1, 1), gmodel_old(1, 1)) - assert not np.allclose(model_wrongsampled(-2, 1), - gmodel_old(-2, 1)) - assert not np.allclose(model_wrongsampled(0.5, 0.5), - gmodel_old(0.5, 0.5), rtol=.001) - assert not np.allclose(model_wrongsampled(-0.5, 1.75), - gmodel_old(-0.5, 1.75), rtol=.001) - - def test_centering_oversampled(self, gmodel_old): - oversamp = 3 - yy, xx = np.mgrid[-3:3.00001:(1 / oversamp), -3:3.00001:(1 / oversamp)] - - with pytest.warns(AstropyDeprecationWarning): - model_oversampled = FittableImageModel(gmodel_old(xx, yy), - oversampling=oversamp) - - valcen = gmodel_old(0, 0) - val36 = gmodel_old(0.66, 0.66) - - assert_allclose(valcen, model_oversampled(0, 0)) - assert_allclose(val36, model_oversampled(0.66, 0.66), rtol=1.0e-6) - - model_oversampled.x_0 = 2.5 - model_oversampled.y_0 = -3.5 - - assert_allclose(valcen, model_oversampled(2.5, -3.5)) - assert_allclose(val36, model_oversampled(2.5 + 0.66, -3.5 + 0.66), - rtol=1.0e-6) - - def test_oversampling_inputs(self): - data = np.arange(30).reshape(5, 6) - for oversampling in [4, (3, 3), (3, 4)]: - with pytest.warns(AstropyDeprecationWarning): - fim = FittableImageModel(data, oversampling=oversampling) - if not hasattr(oversampling, '__len__'): - _oversamp = float(oversampling) - else: - _oversamp = tuple(float(o) for o in oversampling) - assert np.all(fim._oversampling == _oversamp) - - match = 'oversampling must be > 0' - for oversampling in [-1, [-2, 4]]: - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - FittableImageModel(data, oversampling=oversampling) - - match = 'oversampling must have 1 or 2 elements' - oversampling = (1, 4, 8) - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - FittableImageModel(data, oversampling=oversampling) - - match = 'oversampling must be 1D' - for oversampling in [((1, 2), (3, 4)), np.ones((2, 2, 2))]: - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - FittableImageModel(data, oversampling=oversampling) - - match = 'oversampling must have integer values' - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - FittableImageModel(data, oversampling=2.1) - - match = 'oversampling must be a finite value' - for oversampling in [np.nan, (1, np.inf)]: - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - FittableImageModel(data, oversampling=oversampling) - - -def test_epsfmodel_inputs(): - data = np.array([[], []]) - match = 'Image data array cannot be zero-sized' - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data) - - data = np.ones((5, 5), dtype=float) - data[2, 2] = np.inf - match = 'must be finite' - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data) - - data[2, 2] = np.nan - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, flux=None) - - data[2, 2] = 1 - match = 'oversampling must be > 0' - for oversampling in [-1, [-2, 4]]: - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, oversampling=oversampling) - - match = 'oversampling must have 1 or 2 elements' - oversampling = (1, 4, 8) - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, oversampling=oversampling) - - match = 'oversampling must have integer values' - oversampling = 2.1 - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, oversampling=oversampling) - - match = 'oversampling must be 1D' - for oversampling in [((1, 2), (3, 4)), np.ones((2, 2, 2))]: - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, oversampling=oversampling) - - match = 'oversampling must be a finite value' - for oversampling in [np.nan, (1, np.inf)]: - with (pytest.raises(ValueError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, oversampling=oversampling) - - origin = (1, 2, 3) - match = 'Parameter "origin" must be either None or an iterable with' - with (pytest.raises(TypeError, match=match), - pytest.warns(AstropyDeprecationWarning)): - EPSFModel(data, origin=origin) diff --git a/photutils/psf/tests/test_utils.py b/photutils/psf/tests/test_utils.py index 36a32ef36..b8cfbcf5b 100644 --- a/photutils/psf/tests/test_utils.py +++ b/photutils/psf/tests/test_utils.py @@ -150,6 +150,55 @@ def test_interpolate_missing_data(): _interpolate_missing_data(data, mask, method='invalid') +def test_interpolate_missing_data_edge_pixels(): + """ + Test that edge pixels are always filled with cubic interpolation. + """ + data = np.arange(100, dtype=float).reshape(10, 10) + mask = np.zeros_like(data, dtype=bool) + + # Mask corner and edge pixels where cubic interpolation typically + # fails + mask[0, 0] = True # corner + mask[0, 5] = True # top edge + mask[9, 9] = True # corner + mask[5, 9] = True # right edge + + data_int = _interpolate_missing_data(data, mask, method='cubic') + + # All masked pixels should be filled (no NaN values) + assert np.all(np.isfinite(data_int)) + assert not np.any(np.isnan(data_int[mask])) + + +def test_interpolate_missing_data_no_mask(): + """ + Test that data is returned unchanged when no pixels are masked. + """ + data = np.arange(100, dtype=float).reshape(10, 10) + mask = np.zeros_like(data, dtype=bool) + + data_int = _interpolate_missing_data(data, mask, method='cubic') + assert np.array_equal(data, data_int) + + +def test_interpolate_missing_data_all_masked(): + """ + Test that all-masked data returns NaN array. + """ + data = np.arange(100, dtype=float).reshape(10, 10) + mask = np.ones_like(data, dtype=bool) # All pixels masked + + data_int = _interpolate_missing_data(data, mask, method='cubic') + + # All values should be NaN when all data is masked + assert np.all(np.isnan(data_int)) + + # Same for nearest-neighbor method + data_int = _interpolate_missing_data(data, mask, method='nearest') + assert np.all(np.isnan(data_int)) + + def test_validate_psf_model(): model = np.arange(10) diff --git a/photutils/psf/utils.py b/photutils/psf/utils.py index 3a1dc7cd4..b186aaeb3 100644 --- a/photutils/psf/utils.py +++ b/photutils/psf/utils.py @@ -363,7 +363,9 @@ def _interpolate_missing_data(data, mask, method='cubic'): The method of used to interpolate the missing data: * ``'cubic'``: Masked data are interpolated using 2D cubic - splines. This is the default. + splines. If any masked pixels cannot be interpolated using + cubic interpolation (e.g., at the edges), they will be filled + using nearest-neighbor interpolation as a fallback. * ``'nearest'``: Masked data are interpolated using nearest-neighbor interpolation. @@ -371,7 +373,9 @@ def _interpolate_missing_data(data, mask, method='cubic'): Returns ------- data_interp : 2D `~numpy.ndarray` - The interpolated 2D image. + The interpolated 2D image. All masked pixels are guaranteed + to be filled if there are any valid (unmasked) pixels. If all + pixels are masked, the returned array will contain NaN values. """ data_interp = np.copy(data) @@ -383,6 +387,14 @@ def _interpolate_missing_data(data, mask, method='cubic'): msg = 'mask and data must have the same shape' raise ValueError(msg) + if not np.any(mask): + return data_interp + + # Check if all pixels are masked - cannot interpolate + if np.all(mask): + data_interp[:] = np.nan + return data_interp + # initialize the interpolator y, x = np.indices(data_interp.shape) xy = np.dstack((x[~mask].ravel(), y[~mask].ravel()))[0] @@ -400,6 +412,19 @@ def _interpolate_missing_data(data, mask, method='cubic'): xy_missing = np.dstack((x[mask].ravel(), y[mask].ravel()))[0] data_interp[mask] = interpol(xy_missing) + # For cubic interpolation, some edge pixels may not be interpolated + # (NaN values). Use nearest-neighbor interpolation as a fallback. + if method == 'cubic': + remaining_mask = ~np.isfinite(data_interp) + if np.any(remaining_mask): + xy_valid = np.dstack((x[~remaining_mask].ravel(), + y[~remaining_mask].ravel()))[0] + z_valid = data_interp[~remaining_mask].ravel() + interpol_nn = interpolate.NearestNDInterpolator(xy_valid, z_valid) + xy_remaining = np.dstack((x[remaining_mask].ravel(), + y[remaining_mask].ravel()))[0] + data_interp[remaining_mask] = interpol_nn(xy_remaining) + return data_interp diff --git a/pyproject.toml b/pyproject.toml index dc10d87e1..8186f3960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,8 @@ docs = [ 'photutils[all]', 'sphinx >= 8.2', # keep in sync with docs/conf.py 'sphinx-astropy[confv2] >= 1.9.1', - 'sphinx_design', + 'sphinx_design >= 0.6', + 'sphinx-reredirects >= 1.1', ] dev = [ 'photutils[docs,test]',