# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Tabular models.Tabular models of any dimension can be created using `tabular_model`.For convenience `Tabular1D` and `Tabular2D` are provided.Examples-------->>> table = np.array([[ 3., 0., 0.],... [ 0., 2., 0.],... [ 0., 0., 0.]])>>> points = ([1, 2, 3], [1, 2, 3])>>> t2 = Tabular2D(points, lookup_table=table, bounds_error=False,... fill_value=None, method='nearest')"""# pylint: disable=invalid-nameimportnumpyasnpfromastropyimportunitsasufrom.coreimportModeltry:fromscipy.interpolateimportinterpnhas_scipy=TrueexceptImportError:has_scipy=False__all__=["tabular_model","Tabular1D","Tabular2D"]__doctest_requires__={"tabular_model":["scipy"]}class_Tabular(Model):""" Returns an interpolated lookup table value. Parameters ---------- points : tuple of ndarray of float, optional The points defining the regular grid in n dimensions. ndarray must have shapes (m1, ), ..., (mn, ), lookup_table : array-like The data on a regular grid in n dimensions. Must have shapes (m1, ..., mn, ...) method : str, optional The method of interpolation to perform. Supported are "linear" and "nearest", and "splinef2d". "splinef2d" is only supported for 2-dimensional data. Default is "linear". bounds_error : bool, optional If True, when interpolated values are requested outside of the domain of the input data, a ValueError is raised. If False, then ``fill_value`` is used. fill_value : float or `~astropy.units.Quantity`, optional If provided, the value to use for points outside of the interpolation domain. If None, values outside the domain are extrapolated. Extrapolation is not supported by method "splinef2d". If Quantity is given, it will be converted to the unit of ``lookup_table``, if applicable. Returns ------- value : ndarray Interpolated values at input coordinates. Raises ------ ImportError Scipy is not installed. Notes ----- Uses `scipy.interpolate.interpn`. """linear=Falsefittable=Falsestandard_broadcasting=False_is_dynamic=True_id=0def__init__(self,points=None,lookup_table=None,method="linear",bounds_error=True,fill_value=np.nan,**kwargs,):n_models=kwargs.get("n_models",1)ifn_models>1:raiseNotImplementedError("Only n_models=1 is supported.")super().__init__(**kwargs)self.outputs=("y",)iflookup_tableisNone:raiseValueError("Must provide a lookup table.")ifnotisinstance(lookup_table,u.Quantity):lookup_table=np.asarray(lookup_table)ifself.lookup_table.ndim!=lookup_table.ndim:raiseValueError("lookup_table should be an array with "f"{self.lookup_table.ndim} dimensions.")ifpointsisNone:points=tuple(np.arange(x,dtype=float)forxinlookup_table.shape)else:iflookup_table.ndim==1andnotisinstance(points,tuple):points=(points,)npts=len(points)ifnpts!=lookup_table.ndim:raiseValueError("Expected grid points in "f"{lookup_table.ndim} directions, got {npts}.")if(npts>1andisinstance(points[0],u.Quantity)andlen({getattr(p,"unit",None)forpinpoints})>1):raiseValueError("points must all have the same unit.")ifisinstance(fill_value,u.Quantity):ifnotisinstance(lookup_table,u.Quantity):raiseValueError(f"fill value is in {fill_value.unit} but expected to be unitless.")fill_value=fill_value.to(lookup_table.unit).valueself.points=pointsself.lookup_table=lookup_tableself.bounds_error=bounds_errorself.method=methodself.fill_value=fill_valuedef__repr__(self):return(f"<{self.__class__.__name__}(points={self.points}, "f"lookup_table={self.lookup_table})>")def__str__(self):default_keywords=[("Model",self.__class__.__name__),("Name",self.name),("N_inputs",self.n_inputs),("N_outputs",self.n_outputs),("Parameters",""),(" points",self.points),(" lookup_table",self.lookup_table),(" method",self.method),(" fill_value",self.fill_value),(" bounds_error",self.bounds_error),]parts=[f"{keyword}: {value}"forkeyword,valueindefault_keywordsifvalueisnotNone]return"\n".join(parts)@propertydefinput_units(self):pts=self.points[0]ifnotisinstance(pts,u.Quantity):returnNonereturn{x:pts.unitforxinself.inputs}@propertydefreturn_units(self):ifnotisinstance(self.lookup_table,u.Quantity):returnNonereturn{self.outputs[0]:self.lookup_table.unit}@propertydefbounding_box(self):""" Tuple defining the default ``bounding_box`` limits, ``(points_low, points_high)``. Examples -------- >>> from astropy.modeling.models import Tabular1D, Tabular2D >>> t1 = Tabular1D(points=[1, 2, 3], lookup_table=[10, 20, 30]) >>> t1.bounding_box ModelBoundingBox( intervals={ x: Interval(lower=1, upper=3) } model=Tabular1D(inputs=('x',)) order='C' ) >>> t2 = Tabular2D(points=[[1, 2, 3], [2, 3, 4]], ... lookup_table=[[10, 20, 30], [20, 30, 40]]) >>> t2.bounding_box ModelBoundingBox( intervals={ x: Interval(lower=1, upper=3) y: Interval(lower=2, upper=4) } model=Tabular2D(inputs=('x', 'y')) order='C' ) """bbox=[(min(p),max(p))forpinself.points][::-1]iflen(bbox)==1:bbox=bbox[0]returnbboxdefevaluate(self,*inputs):""" Return the interpolated values at the input coordinates. Parameters ---------- inputs : list of scalar or list of ndarray Input coordinates. The number of inputs must be equal to the dimensions of the lookup table. """inputs=np.broadcast_arrays(*inputs)shape=inputs[0].shapeinputs=[inp.flatten()forinpininputs[:self.n_inputs]]inputs=np.array(inputs).Tifnothas_scipy:# pragma: no coverraiseImportError("Tabular model requires scipy.")result=interpn(self.points,self.lookup_table,inputs,method=self.method,bounds_error=self.bounds_error,fill_value=self.fill_value,)# return_units not respected when points has no unitsifisinstance(self.lookup_table,u.Quantity)andnotisinstance(self.points[0],u.Quantity):result=result*self.lookup_table.unitifself.n_outputs==1:result=result.reshape(shape)else:result=[r.reshape(shape)forrinresult]returnresult@propertydefinverse(self):ifself.n_inputs==1:# If the wavelength array is descending instead of ascending, both# points and lookup_table need to be reversed in the inverse transform# for scipy.interpolate to work properlyifnp.all(np.diff(self.lookup_table)>0):# ascending casepoints=self.lookup_tablelookup_table=self.points[0]elifnp.all(np.diff(self.lookup_table)<0):# descending case, reverse orderpoints=self.lookup_table[::-1]lookup_table=self.points[0][::-1]else:# equal-valued or double-valued lookup_tableraiseNotImplementedErrorreturnTabular1D(points=points,lookup_table=lookup_table,method=self.method,bounds_error=self.bounds_error,fill_value=self.fill_value,)raiseNotImplementedError("An analytical inverse transform has not been implemented for this model.")
[docs]deftabular_model(dim,name=None):""" Make a ``Tabular`` model where ``n_inputs`` is based on the dimension of the lookup_table. This model has to be further initialized and when evaluated returns the interpolated values. Parameters ---------- dim : int Dimensions of the lookup table. name : str Name for the class. Examples -------- >>> table = np.array([[3., 0., 0.], ... [0., 2., 0.], ... [0., 0., 0.]]) >>> tab = tabular_model(2, name='Tabular2D') >>> print(tab) <class 'astropy.modeling.tabular.Tabular2D'> Name: Tabular2D N_inputs: 2 N_outputs: 1 >>> points = ([1, 2, 3], [1, 2, 3]) Setting fill_value to None, allows extrapolation. >>> m = tab(points, lookup_table=table, name='my_table', ... bounds_error=False, fill_value=None, method='nearest') >>> xinterp = [0, 1, 1.5, 2.72, 3.14] >>> m(xinterp, xinterp) # doctest: +FLOAT_CMP array([3., 3., 3., 0., 0.]) """ifdim<1:raiseValueError("Lookup table must have at least one dimension.")table=np.zeros([2]*dim)members={"lookup_table":table,"n_inputs":dim,"n_outputs":1}ifdim==1:members["_separable"]=Trueelse:members["_separable"]=FalseifnameisNone:model_id=_Tabular._id_Tabular._id+=1name=f"Tabular{model_id}"model_class=type(str(name),(_Tabular,),members)model_class.__module__="astropy.modeling.tabular"returnmodel_class
Tabular1D=tabular_model(1,name="Tabular1D")Tabular2D=tabular_model(2,name="Tabular2D")_tab_docs=""" method : str, optional The method of interpolation to perform. Supported are "linear" and "nearest", and "splinef2d". "splinef2d" is only supported for 2-dimensional data. Default is "linear". bounds_error : bool, optional If True, when interpolated values are requested outside of the domain of the input data, a ValueError is raised. If False, then ``fill_value`` is used. fill_value : float, optional If provided, the value to use for points outside of the interpolation domain. If None, values outside the domain are extrapolated. Extrapolation is not supported by method "splinef2d". Returns ------- value : ndarray Interpolated values at input coordinates. Raises ------ ImportError Scipy is not installed. Notes ----- Uses `scipy.interpolate.interpn`."""Tabular1D.__doc__=(""" Tabular model in 1D. Returns an interpolated lookup table value. Parameters ---------- points : array-like of float of ndim=1. The points defining the regular grid in n dimensions. lookup_table : array-like, of ndim=1. The data in one dimensions."""+_tab_docs)Tabular2D.__doc__=(""" Tabular model in 2D. Returns an interpolated lookup table value. Parameters ---------- points : tuple of ndarray of float, optional The points defining the regular grid in n dimensions. ndarray with shapes (m1, m2). lookup_table : array-like The data on a regular grid in 2 dimensions. Shape (m1, m2)."""+_tab_docs)