Coverage for src/scipy_dataclassfitparams/_make_fit.py: 100%
83 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
1"""This module contains the function which wraps scipy for the fitting
2procedure.
3"""
5import numpy as _numpy
6import scipy.optimize as _optimize # type: ignore [import]
7import scipy.version as _scipy_version # type: ignore [import]
9from dataclasses import dataclass as _dataclass
10from numpy import floating as _floating
11from numpy.typing import NDArray as _NDArray
12from typing import Any as _Any, Callable as _Callable, Dict as _Dict, \
13 Generic as _Generic, Literal as _Literal, Mapping as _Mapping, \
14 Optional as _Optional, overload as _overload, Sequence as _Sequence, \
15 Type as _Type, TypeVar as _TypeVar, Union as _Union
17from . import _spec_registry
20_TFitParams = _TypeVar("_TFitParams")
21"""The type of the fit parameter definition class."""
24_FitFunc = _Callable[[_Any, _TFitParams], _NDArray[_floating]]
25"""The type of the fit function."""
28_Method = _Literal["lm", "trf", "dogbox"] # noqa: F821
29"""The fitting method."""
32_NaNPolicy = _Literal["raise", "omit"] # noqa: F821, F722
33"""The NaN policy values."""
36@_dataclass(frozen=True)
37class CovMatrix:
38 """Encapsulate the covariance matrix for the fit parameters."""
40 __slots__ = ("fields", "cov", "_mapping")
42 fields: _Sequence[str]
43 """The fields corresponding to the rows and columns."""
45 cov: _NDArray[_floating]
46 """The covariance matrix."""
48 _mapping: _Mapping[str, int]
49 """The mapping of the field names to their index. Basically the
50 inverse of `fields`."""
52 # Must be defined because "_mapping" cannot be a init=False field.
53 def __init__(self, fields: _Sequence[str], cov: _NDArray[_floating]):
54 """Construct a new `CovMatrix`.
56 Args:
57 fields (Sequence[str]): The fields corresponding to the rows
58 and columns of the covariance matrix.
59 cov (NDArray[floating]): The covariance matrix.
60 """
61 object.__setattr__(self, "fields", fields)
62 object.__setattr__(self, "cov", cov)
63 m = {x: i for i, x in enumerate(fields)}
64 object.__setattr__(self, "_mapping", m)
66 def get_index(self, field: str) -> int:
67 """Map a field name to the corresponding index.
69 Args:
70 field (str): The field name.
72 Raises:
73 KeyError: If `field` is not a valid field name.
75 Returns:
76 int: The index in the covariance matrix belonging to
77 `field`.
78 """
79 return self._mapping[field]
81 def __getitem__(
82 self,
83 key: _Any
84 ) -> _NDArray[_floating]: # pragma: no cover
85 """Indexes the covariance matrix directly. This is equivalent
86 to `self.cov[key]`.
88 Args:
89 key (Any): The indexing object.
91 Returns:
92 NDArray[floating]: The result of indexing the cov matrix.
93 """
94 return self.cov[key]
97@_dataclass(frozen=True)
98class FitResult(_Generic[_TFitParams]):
99 """Encapsulates the simple fit result."""
101 __slots__ = ("opt_instance", "cov", "sigma", "p0")
103 opt_instance: _TFitParams
104 """The optimizing instance found via the fit."""
106 cov: CovMatrix
107 """The covariance matrix."""
109 sigma: _Union[_NDArray[_floating], None]
110 """The sigma values."""
112 p0: _Union[_TFitParams, None]
113 """The initial value used for the fit or None, if none was
114 provided."""
117@_dataclass(frozen=True)
118class FullFitResult(FitResult[_TFitParams]):
119 """Encapsulates the extended fit result. See also the doc of the
120 return value of `scipy.optimize.curve_fit` with full_output=True for
121 more info.
122 """
124 __slots__ = ("info_dict", "mesg", "ierr")
126 info_dict: _Dict[str, _Any]
127 """The info dict for the full result."""
129 mesg: str
130 """The message containing information about the solution."""
132 ierr: int
133 """An integer flag showing the status."""
136def _check_has_nan_policy_arg() -> bool:
137 """Determine if `scipy.optimize.curve_fit` has the 'nan_policy'
138 argument.
140 Returns:
141 bool: Whether `scipy.optimize.curve_fit` has the 'nan_policy'
142 argument.
143 """
144 maj, min, *_ = _scipy_version.short_version.split('.')
145 imaj = int(maj)
146 imin = int(min)
147 # scipy version must be above or equal to 1.11.
148 return (imaj > 1) or ((imaj == 1) and (imin >= 11))
151_CURVE_FIT_HAS_NAN_POLICY = _check_has_nan_policy_arg()
152"""Whether `scipy.optimize.curve_fit` has the 'nan_policy' argument."""
155# There seemst to be no way to remove the nan_policy argument from
156# the signature if it is not supported.
158@_overload
159def make_fit(
160 fitparams_dc: _Type[_TFitParams],
161 xdata: _NDArray[_floating],
162 ydata: _NDArray[_floating],
163 f: _FitFunc[_TFitParams],
164 *,
165 p0: _Optional[_TFitParams] = None,
166 sigma: _Optional[_NDArray[_floating]] = None,
167 absolute_sigma: bool = False,
168 check_finite: _Optional[bool] = None,
169 method: _Optional[_Method] = None,
170 full_output: _Literal[False] = False,
171 nan_policy: _Optional[_NaNPolicy] = None,
172 **kwargs
173) -> FitResult[_TFitParams]:
174 pass
177@_overload
178def make_fit(
179 fitparams_dc: _Type[_TFitParams],
180 xdata: _NDArray[_floating],
181 ydata: _NDArray[_floating],
182 f: _FitFunc[_TFitParams],
183 *,
184 p0: _Optional[_TFitParams] = None,
185 sigma: _Optional[_NDArray[_floating]] = None,
186 absolute_sigma: bool = False,
187 check_finite: _Optional[bool] = None,
188 method: _Optional[_Method] = None,
189 full_output: _Literal[True],
190 nan_policy: _Optional[_NaNPolicy] = None,
191 **kwargs
192) -> FullFitResult[_TFitParams]:
193 pass
196def make_fit(
197 fitparams_dc: _Type[_TFitParams],
198 xdata: _NDArray[_floating],
199 ydata: _NDArray[_floating],
200 f: _FitFunc[_TFitParams],
201 *,
202 p0: _Optional[_TFitParams] = None,
203 sigma: _Optional[_NDArray[_floating]] = None,
204 absolute_sigma: bool = False,
205 check_finite: _Optional[bool] = None,
206 method: _Optional[_Method] = None,
207 full_output: bool = False,
208 nan_policy: _Optional[_NaNPolicy] = None,
209 **kwargs
210) -> _Union[FitResult[_TFitParams], FullFitResult[_TFitParams]]:
211 """Perform the fit. For more information on some of the parameters
212 of this function see `scipy.optimize.curve_fit` which is used
213 internally by this function to perform the fit.
215 Args:
216 fitparams_dc (type[TFitParams]): The dataclass defining the
217 parameters.
218 xdata (NDArray[floating]): The x-values.
219 ydata (NDArray[floating]): The y-values to fit against.
220 f (FitFunc[TFitParams]): The fit func accepting the value of the
221 `xdata` parameter, followed by an instance of `fitparams_dc`
222 which represents the current set of parameters. The return
223 value should be the current y-values obtained from the given
224 set of parameters.
225 p0 (TFitParams, optional): The initial instance. Defaults to
226 None which constructs the instance using default values.
227 sigma (NDArray[floating], optional): The uncertainties for the
228 `ydata`. Defaults to None.
229 absolute_sigma (bool, optional): If True, `sigma` are
230 interpreted as absolute values. Defaults to False, which
231 means only their relative values matter. This will affect
232 the meaning of the returned covariance matrix.
233 check_finite (bool, optional): Whether to check that the input
234 arrays do not contain NaN values. Defaults to True if
235 `nan_policy` is unspecify and False, otherwise.
236 method (Method, optional): The method to use for the fit.
237 Defaults to "lm" if no bounds are provided and "trf"
238 otherwise.
239 full_output (bool, optional): Whether to return an object with
240 additional information. Defaults to False.
241 nan_policy (NaNPolicy, optional): How to handle NaN values.
242 Defaults to None. This argument is only available in
243 scipy >= 1.11 and must be None (unset) otherwise.
245 Raises:
246 NotImplementedError: If `nan_policy` is not None and the scipy
247 version is below 1.11.
249 Returns:
250 FitResult[TFitParams] or FullFitResult[TFitParams]: _description_
251 """
253 spec = _spec_registry.get_fit_spec(fitparams_dc)
255 if p0 is None:
256 p0_res = None
257 p0 = spec.create_default_fit_instance()
258 else:
259 p0_res = p0
261 p0_array = spec.new_empty_array(float)
262 spec.instance_to_array(p0, p0_array)
264 bounds = spec.bounds
266 def wrapper(
267 x: _NDArray[_floating],
268 *params: float
269 ) -> _NDArray[_floating]:
270 instance = spec.array_to_instance(params)
271 return f(x, instance)
273 # "nan_policy" should never be in the kwargs.
274 ver_kwargs = dict()
275 if _CURVE_FIT_HAS_NAN_POLICY:
276 ver_kwargs["nan_policy"] = nan_policy
277 elif nan_policy is not None:
278 raise NotImplementedError(
279 "nan_policy must be None if the argument is not supported "
280 "by scipy."
281 )
283 result = _optimize.curve_fit(
284 wrapper,
285 xdata,
286 ydata,
287 p0=p0_array,
288 bounds=bounds,
289 sigma=sigma,
290 absolute_sigma=absolute_sigma,
291 check_finite=check_finite,
292 method=method,
293 full_output=True, # The method computes the extras anyway
294 **ver_kwargs,
295 **kwargs
296 )
298 pcov: _NDArray
299 popt, pcov, info_dict, msg, ec = result
301 res_instance = spec.array_to_instance(popt)
303 pcov.flags.writeable = False
304 cov_mat = CovMatrix(tuple(spec.fitting_params), pcov)
306 if sigma is not None: # pragma: no cover
307 sigma = _numpy.copy(sigma)
308 sigma.flags.writeable = False
310 if not full_output:
311 return FitResult(res_instance, cov_mat, sigma, p0_res)
313 return FullFitResult(
314 res_instance, cov_mat, sigma, p0_res, info_dict, msg, ec
315 )