Source code for turbustat.statistics.wavelets.wavelet_transform
# Licensed under an MIT open source license - see LICENSE
from __future__ import print_function, absolute_import, division
import numpy as np
import warnings
import astropy.units as u
import statsmodels.api as sm
from warnings import warn
from astropy.utils.console import ProgressBar
from astropy.convolution import convolve_fft
# Use updated kernel name.
try:
from astropy.convolution import RickerWavelet2DKernel
except ImportError:
from astropy.convolution import MexicanHat2DKernel
RickerWavelet2DKernel = MexicanHat2DKernel
try:
from pyfftw.interfaces.numpy_fft import fftn, ifftn
PYFFTW_FLAG = True
except ImportError:
PYFFTW_FLAG = False
from ..base_statistic import BaseStatisticMixIn
from ...io import common_types, twod_types
from ..fitting_utils import check_fit_limits, residual_bootstrap
from ..lm_seg import Lm_Seg
[docs]class Wavelet(BaseStatisticMixIn):
'''
Compute the wavelet transform of a 2D array.
Parameters
----------
array : %(dtypes)s
2D data.
header : FITS header, optional
Header for the array.
scales : numpy.ndarray or list
The scales where the transform is calculated.
num : int, optional
Number of scales to compute the transform at.
distance : `~astropy.units.Quantity`, optional
Physical distance to the region in the data.
'''
__doc__ %= {"dtypes": " or ".join(common_types + twod_types)}
def __init__(self, data, header=None, scales=None, num=50,
distance=None):
self.input_data_header(data, header)
# NOTE: can't use nan_interpolating from astropy
# until the normalization for sum to zeros kernels is fixed!!!
isnan = np.isnan(self.data)
if isnan.any():
self.data = self.data.copy()
self.data[isnan] = 0.
if distance is not None:
self.distance = distance
if scales is None:
a_min = round((5. / 3.), 3) # Smallest scale given by paper
a_max = min(self.data.shape) / 2.
# Log spaces scales up to half of the smallest size of the array
scales = np.logspace(np.log10(a_min), np.log10(a_max), num) * u.pix
self.scales = scales
@property
def scales(self):
'''
Wavelet scales.
'''
return self._scales
@scales.setter
def scales(self, values):
if not isinstance(values, u.Quantity):
raise TypeError("scales must be given as a "
"astropy.units.Quantity.")
# Now make sure that we can convert into pixels before setting.
try:
pix_scal = self._to_pixel(values)
except Exception as e:
raise e
# The radius should be larger than a pixel
if np.any(pix_scal.value < 1):
raise ValueError("One of the chosen lags is smaller than one "
"pixel."
" Ensure that all lag values are larger than one "
"pixel.")
# Catch floating point issues in comparing to half the image shape
half_comp = (np.floor(pix_scal.value) - min(self.data.shape) / 2.)
if np.any(half_comp > 1e-10):
raise ValueError("At least one of the lags is larger than half of"
" the image size. Remove these lags from the "
"array.")
self._scales = values
[docs] def compute_transform(self, show_progress=True, scale_normalization=True,
keep_convolved_arrays=False, convolve_kwargs={},
use_pyfftw=False, threads=1, pyfftw_kwargs={}):
'''
Compute the wavelet transform at each scale.
Parameters
----------
show_progress : bool, optional
Show a progress bar during the creation of the covariance matrix.
scale_normalization: bool, optional
Compute the transform with the correct scale-invariant
normalization.
keep_convolved_arrays: bool, optional
Keep the image convolved at all wavelet scales. For large images,
this can require a large amount memory. Default is False.
convolve_kwargs : dict, optional
Passed to `~astropy.convolution.convolve_fft`.
use_pyfftw : bool, optional
Enable to use pyfftw, if it is installed.
threads : int, optional
Number of threads to use in FFT when using pyfftw.
pyfftw_kwargs : Passed to
See `here <http://hgomersall.github.io/pyFFTW/pyfftw/builders/builders.html>`_
for a list of accepted kwargs.
'''
if use_pyfftw:
if PYFFTW_FLAG:
use_fftn = fftn
use_ifftn = ifftn
else:
warn("pyfftw not installed. Using numpy.fft functions.")
use_fftn = np.fft.fftn
use_ifftn = np.fft.ifftn
else:
use_fftn = np.fft.fftn
use_ifftn = np.fft.ifftn
n0, m0 = self.data.shape
A = len(self.scales)
if keep_convolved_arrays:
self._Wf = np.zeros((A, n0, m0), dtype=np.float)
else:
self._Wf = None
self._values = np.empty_like(self.scales.value)
self._stddev = np.empty_like(self.scales.value)
factor = 2
if not scale_normalization:
factor = 4
Warning("Transform values are only reliable with the proper scale"
" normalization. When disabled, the slope of the transform"
" CANNOT be used for physical interpretation.")
pix_scales = self._to_pixel(self.scales).value
if show_progress:
bar = ProgressBar(len(pix_scales))
for i, an in enumerate(pix_scales):
psi = RickerWavelet2DKernel(an)
conv_arr = \
convolve_fft(self.data, psi, normalize_kernel=False,
fftn=use_fftn, ifftn=use_ifftn,
nan_treatment='fill',
preserve_nan=True,
**convolve_kwargs).real * \
an**factor
if keep_convolved_arrays:
self._Wf[i] = conv_arr
self._values[i] = (conv_arr[conv_arr > 0]).mean()
# The standard deviation should take into account the number of
# kernel elements at that scale.
kern_area = np.ceil(0.5 * np.pi * np.log(2) * an**2).astype(int)
nindep = np.sqrt(np.isfinite(conv_arr).sum() // kern_area)
self._stddev[i] = (conv_arr[conv_arr > 0]).std() / nindep
if show_progress:
bar.update(i + 1)
@property
def Wf(self):
'''
The wavelet transforms of the image. Each plane is the transform at
different wavelet sizes.
'''
if self._Wf is None:
warn("`keep_convolved_arrays` was disabled in "
"`compute_transform`.")
return self._Wf
@property
def values(self):
'''
The 1-dimensional wavelet transform.
'''
return self._values
@property
def stddev(self):
'''
Standard deviation of the 1-dimensional wavelet transform.
'''
return self._stddev
[docs] def fit_transform(self, xlow=None, xhigh=None, brk=None, min_fits_pts=3,
weighted_fit=False, bootstrap=False,
bootstrap_kwargs={}, **fit_kwargs):
'''
Perform a fit to the transform in log-log space.
Parameters
----------
xlow : `~astropy.units.Quantity`, optional
Lower scale value to consider in the fit.
xhigh : `~astropy.units.Quantity`, optional
Upper scale value to consider in the fit.
brk : `~astropy.units.Quantity`, optional
Give an initial guess for a break point. This enables fitting
with a `turbustat.statistics.Lm_Seg`.
min_fits_pts : int, optional
Minimum number of points required above or below the fitted break
for it to be considered a valid fit. Only used when a segmented
line is fit, i.e. when a value for `brk` is given.
weighted_fit: bool, optional
Use the `~Wavelet.stddev` to perform a weighted fit.
bootstrap : bool, optional
Bootstrap using the model residuals to estimate the standard
errors.
bootstrap_kwargs : dict, optional
Pass keyword arguments to `~turbustat.statistics.fitting_utils.residual_bootstrap`.
fit_kwargs : Passed to `turbustat.statistics.Lm_Seg.fit_model`
'''
pix_scales = self._to_pixel(self.scales)
x = np.log10(pix_scales.value)
y = np.log10(self.values)
if weighted_fit:
y_err = 0.434 * self.stddev / self.values
y_err[y_err == 0.] = np.NaN
weights = y_err**-2
else:
weights = None
if xlow is not None:
xlow = self._to_pixel(xlow)
lower_limit = x >= np.log10(xlow.value)
else:
lower_limit = \
np.ones_like(self.values, dtype=bool)
xlow = pix_scales.min() * 0.99
if xhigh is not None:
xhigh = self._to_pixel(xhigh)
upper_limit = x <= np.log10(xhigh.value)
else:
upper_limit = \
np.ones_like(self.values, dtype=bool)
xhigh = pix_scales.max() * 1.01
self._fit_range = [xlow, xhigh]
within_limits = np.logical_and(lower_limit, upper_limit)
y = y[within_limits]
x = x[within_limits]
if weighted_fit:
weights = weights[within_limits]
if brk is not None:
# Try fitting a segmented model
pix_brk = self._to_pixel(brk)
if pix_brk < xlow or pix_brk > xhigh:
raise ValueError("brk must be within xlow and xhigh.")
model = Lm_Seg(x, y, np.log10(pix_brk.value), weights=weights)
fit_kwargs['cov_type'] = 'HC3'
model.fit_model(**fit_kwargs)
self.fit = model.fit
if model.params.size == 5:
# Check to make sure this leaves enough to fit to.
if sum(x < model.brk) < min_fits_pts:
warnings.warn("Not enough points to fit to." +
" Ignoring break.")
self._brk = None
else:
good_pts = x.copy() < model.brk
x = x[good_pts]
y = y[good_pts]
self._brk = 10**model.brk / u.pix
self._slope = model.slopes
if bootstrap:
stderrs = residual_bootstrap(model.fit,
**bootstrap_kwargs)
self._slope_err = stderrs[1:-1]
self._brk_err = np.log(10) * self.brk.value * \
stderrs[-1] / u.pix
else:
self._slope_err = model.slope_errs
self._brk_err = np.log(10) * self.brk.value * \
model.brk_err / u.pix
self.fit = model.fit
else:
self._brk = None
# Break fit failed, revert to normal model
warnings.warn("Model with break failed, reverting to model\
without break.")
else:
self._brk = None
# Revert to model without break if none is given, or if the segmented
# model failed.
if self.brk is None:
x = sm.add_constant(x)
if weighted_fit:
model = sm.WLS(y, x, missing='drop', weights=weights)
else:
model = sm.OLS(y, x, missing='drop')
self.fit = model.fit(cov_type='HC3')
self._slope = self.fit.params[1]
if bootstrap:
stderrs = residual_bootstrap(self.fit,
**bootstrap_kwargs)
self._slope_err = stderrs[1]
else:
self._slope_err = self.fit.bse[1]
self._model = model
self._bootstrap_flag = bootstrap
@property
def slope(self):
'''
Fitted slope.
'''
return self._slope
@property
def slope_err(self):
'''
Standard error on the fitted slope.
'''
return self._slope_err
@property
def brk(self):
'''
Break point in the segmented linear model.
'''
return self._brk
@property
def brk_err(self):
'''
1-sigma on the break point in the segmented linear model.
'''
return self._brk_err
@property
def fit_range(self):
'''
Range of scales used in the fit.
'''
return self._fit_range
[docs] def fitted_model(self, xvals):
'''
Computes the fitted power-law in log-log space using the
given x values.
Parameters
----------
xvals : `~numpy.ndarray`
Values of log(lags) to compute the model at (base 10 log).
Returns
-------
model_values : `~numpy.ndarray`
Values of the model at the given values.
'''
if isinstance(self._model, Lm_Seg):
return self._model.model(xvals)
else:
return self.fit.params[0] + self.fit.params[1] * xvals
[docs] def plot_transform(self, save_name=None, xunit=u.pix,
color='r', symbol='o', fit_color='k',
label=None, show_residual=True):
'''
Plot the transform and the fit.
Parameters
----------
save_name : str, optional
Save name for the figure. Enables saving the plot.
xunit : `~astropy.units.Unit`, optional
Choose the angular unit to convert to when ang_units is enabled.
color : {str, RGB tuple}, optional
Color to plot the wavelet curve.
symbol : str, optional
Symbol to use for the data.
fit_color : {str, RGB tuple}, optional
Color of the 1D fit.
label : str, optional
Label to later be used in a legend.
show_residual : bool, optional
Plot the fit residuals.
'''
import matplotlib.pyplot as plt
if fit_color is None:
fit_color = color
# Check for already existing subplots
fig = plt.gcf()
axes = plt.gcf().get_axes()
if len(axes) == 0:
if show_residual:
ax = plt.subplot2grid((4, 1), (0, 0), colspan=1, rowspan=3)
ax_r = plt.subplot2grid((4, 1), (3, 0), colspan=1,
rowspan=1,
sharex=ax)
else:
ax = plt.subplot(111)
elif len(axes) == 1:
ax = axes[0]
else:
ax = axes[0]
ax_r = axes[1]
ax.set_xscale("log")
ax.set_yscale("log")
pix_scales = self._to_pixel(self.scales)
scales = self._spatial_unit_conversion(pix_scales, xunit).value
# Check for NaNs
fin_vals = np.logical_or(np.isfinite(self.values),
np.isfinite(self.stddev))
ax.errorbar(scales[fin_vals], self.values[fin_vals],
yerr=self.stddev[fin_vals],
fmt=symbol + "-", color=color,
label=label,
markersize=5, alpha=0.5, capsize=10,
elinewidth=3)
# Plot the fit within the fitting range.
low_lim = \
self._spatial_unit_conversion(self._fit_range[0], xunit).value
high_lim = \
self._spatial_unit_conversion(self._fit_range[1], xunit).value
ax.loglog(scales, 10**self.fitted_model(np.log10(pix_scales.value)),
'--', color=fit_color,
linewidth=3)
ax.axvline(low_lim, color=color, alpha=0.5, linestyle='-')
ax.axvline(high_lim, color=color, alpha=0.5, linestyle='-')
ax.grid()
ax.set_ylabel(r"$T_g$")
if show_residual:
resids = self.values - \
10**self.fitted_model(np.log10(pix_scales.value))
ax_r.errorbar(scales, resids, yerr=self.stddev[fin_vals],
fmt=symbol + "-", color=color, label=label,
markersize=5, alpha=0.5, capsize=10,
elinewidth=3)
ax_r.axvline(low_lim, color=color, alpha=0.5, linestyle='-')
ax_r.axvline(high_lim, color=color, alpha=0.5, linestyle='-')
ax_r.axhline(0., color=fit_color, linestyle='--')
ax_r.grid()
ax_r.set_ylabel("Residuals")
ax_r.set_xlabel("Scales ({})".format(xunit))
plt.setp(ax.get_xticklabels(), visible=False)
else:
ax.set_xlabel("Scales ({})".format(xunit))
plt.tight_layout()
fig.subplots_adjust(hspace=0.1)
if save_name is not None:
plt.savefig(save_name)
plt.close()
else:
plt.show()
[docs] def run(self, show_progress=True, verbose=False, xunit=u.pix,
convolve_kwargs={},
use_pyfftw=False, threads=1,
pyfftw_kwargs={}, scale_normalization=True,
xlow=None, xhigh=None, brk=None, fit_kwargs={},
save_name=None, **plot_kwargs):
'''
Compute the Wavelet transform.
Parameters
----------
show_progress : bool, optional
Show a progress bar during the creation of the covariance matrix.
verbose : bool, optional
Plot wavelet transform.
xunit : u.Unit, optional
Choose the unit to convert to when ang_units is enabled.
convolve_kwargs : dict, optional
Passed to `~astropy.convolution.convolve_fft`.
scale_normalization: bool, optional
Compute the transform with the correct scale-invariant
normalization.
use_pyfftw : bool, optional
Enable to use pyfftw, if it is installed.
threads : int, optional
Number of threads to use in FFT when using pyfftw.
pyfftw_kwargs : Passed to
See `here <http://hgomersall.github.io/pyFFTW/pyfftw/builders/builders.html>`_
for a list of accepted kwargs.
scale_normalization: bool, optional
Multiply the wavelet transform by the correct normalization
factor.
xlow : `~astropy.units.Quantity`, optional
Lower scale value to consider in the fit.
xhigh : `~astropy.units.Quantity`, optional
Upper scale value to consider in the fit.
brk : `~astropy.units.Quantity`, optional
Give an initial guess for a break point. This enables fitting
with a `turbustat.statistics.Lm_Seg`.
fit_kwargs : dict, optional
Passed to `~Wavelet.fit_transform`
save_name : str,optional
Save the figure when a file name is given.
plot_kwargs : Passed to `~Wavelet.plot_transform`.
'''
self.compute_transform(scale_normalization=scale_normalization,
convolve_kwargs=convolve_kwargs,
use_pyfftw=use_pyfftw, threads=threads,
pyfftw_kwargs=pyfftw_kwargs,
show_progress=show_progress)
self.fit_transform(xlow=xlow, xhigh=xhigh, brk=brk, **fit_kwargs)
if verbose:
print(self.fit.summary())
if self._bootstrap_flag:
print("Bootstrapping used to find stderrs! "
"Errors may not equal those shown above.")
self.plot_transform(save_name=save_name, xunit=xunit,
**plot_kwargs)
return self
[docs]class Wavelet_Distance(object):
'''
Compute the distance between the two cubes using the Wavelet transform.
We fit a linear model to the two wavelet transforms. The distance is the
t-statistic of the interaction term describing the difference in the
slopes.
Parameters
----------
dataset1 : %(dtypes)s
2D image.
dataset2 : %(dtypes)s
2D image.
scales : numpy.ndarray or list
The scales where the transform is calculated.
num : int
Number of scales to calculate the transform at.
xlow : `astropy.units.Quantity`, optional
The lower lag fitting limit. An array with 2 elements can be passed to
give separate lower limits for the datasets.
xhigh : `astropy.units.Quantity`, optional
The upper lag fitting limit. See `xlow` above.
fit_kwargs : dict, optional
Passed to `~turbustat.statistics.Wavelet.run`.
fit_kwargs2 : dict, optional
Passed to `~turbustat.statistics.Wavelet.run` for `dataset2`. When
`None` is given, `fit_kwargs` is used for `dataset2`.
'''
__doc__ %= {"dtypes": " or ".join(common_types + twod_types)}
def __init__(self, dataset1, dataset2,
scales=None, num=50, xlow=None, xhigh=None,
fit_kwargs={}, fit_kwargs2=None):
super(Wavelet_Distance, self).__init__()
xlow, xhigh = check_fit_limits(xlow, xhigh)
# if fiducial_model is None:
if isinstance(dataset1, Wavelet):
self.wt1 = dataset1
needs_run = False
if not hasattr(self.wt1, '_slope'):
warn("Wavelet class passed as `dataset1` does not have a "
"fitted slope. Computing Wavelet transform.")
needs_run = True
else:
self.wt1 = Wavelet(dataset1, scales=scales)
needs_run = True
if needs_run:
self.wt1.run(xlow=xlow[0], xhigh=xhigh[0], **fit_kwargs)
if fit_kwargs2 is None:
fit_kwargs2 = fit_kwargs
if isinstance(dataset2, Wavelet):
self.wt2 = dataset2
needs_run = False
if not hasattr(self.wt2, '_slope'):
warn("Wavelet class passed as `dataset2` does not have a "
"fitted slope. Computing Wavelet transform.")
needs_run = True
else:
self.wt2 = Wavelet(dataset2, scales=scales)
needs_run = True
if needs_run:
self.wt2.run(xlow=xlow[1], xhigh=xhigh[1], **fit_kwargs2)
[docs] def distance_metric(self, verbose=False, xunit=u.pix,
save_name=None, plot_kwargs1={},
plot_kwargs2={}):
'''
Implements the distance metric for 2 wavelet transforms.
We fit the linear portion of the transform to represent the powerlaw
Parameters
----------
verbose : bool, optional
Enables plotting.
xunit : `~astropy.units.Unit`, optional
Unit of the x-axis in the plot in pixel, angular, or
physical units.
save_name : str, optional
Name of the save file. Enables saving the figure.
plot_kwargs1 : dict, optional
Pass kwargs to `~turbustat.statistics.Wavelet.plot_transform` for
`dataset1`.
plot_kwargs2 : dict, optional
Pass kwargs to `~turbustat.statistics.Wavelet.plot_transform` for
`dataset2`.
'''
# Construct t-statistic
self.distance = \
np.abs((self.wt1.slope - self.wt2.slope) /
np.sqrt(self.wt1.slope_err**2 +
self.wt2.slope_err**2))
if verbose:
print(self.wt1.fit.summary())
print(self.wt2.fit.summary())
import matplotlib.pyplot as plt
defaults1 = {'color': 'b', 'symbol': 'D', 'label': '1'}
defaults2 = {'color': 'g', 'symbol': 'o', 'label': '2'}
for key in defaults1:
if key not in plot_kwargs1:
plot_kwargs1[key] = defaults1[key]
for key in defaults2:
if key not in plot_kwargs2:
plot_kwargs2[key] = defaults2[key]
if 'xunit' in plot_kwargs1:
del plot_kwargs1['xunit']
if 'xunit' in plot_kwargs2:
del plot_kwargs2['xunit']
self.wt1.plot_transform(xunit=xunit,
**plot_kwargs1)
self.wt2.plot_transform(xunit=xunit,
**plot_kwargs2)
axes = plt.gcf().get_axes()
axes[0].legend(loc='best', frameon=True)
if save_name is not None:
plt.savefig(save_name)
plt.close()
else:
plt.show()
return self