Coverage for src/scipy_dataclassfitparams/_make_fit.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-01 17:34 +0000

1"""This module contains the function which wraps scipy for the fitting 

2procedure. 

3""" 

4 

5import numpy as _numpy 

6import scipy.optimize as _optimize # type: ignore [import] 

7import scipy.version as _scipy_version # type: ignore [import] 

8 

9from dataclasses import dataclass as _dataclass 

10from numpy import floating as _floating 

11from numpy.typing import NDArray as _NDArray 

12from typing import Any as _Any, Callable as _Callable, Dict as _Dict, \ 

13 Generic as _Generic, Literal as _Literal, Mapping as _Mapping, \ 

14 Optional as _Optional, overload as _overload, Sequence as _Sequence, \ 

15 Type as _Type, TypeVar as _TypeVar, Union as _Union 

16 

17from . import _spec_registry 

18 

19 

20_TFitParams = _TypeVar("_TFitParams") 

21"""The type of the fit parameter definition class.""" 

22 

23 

24_FitFunc = _Callable[[_Any, _TFitParams], _NDArray[_floating]] 

25"""The type of the fit function.""" 

26 

27 

28_Method = _Literal["lm", "trf", "dogbox"] # noqa: F821 

29"""The fitting method.""" 

30 

31 

32_NaNPolicy = _Literal["raise", "omit"] # noqa: F821, F722 

33"""The NaN policy values.""" 

34 

35 

36@_dataclass(frozen=True) 

37class CovMatrix: 

38 """Encapsulate the covariance matrix for the fit parameters.""" 

39 

40 __slots__ = ("fields", "cov", "_mapping") 

41 

42 fields: _Sequence[str] 

43 """The fields corresponding to the rows and columns.""" 

44 

45 cov: _NDArray[_floating] 

46 """The covariance matrix.""" 

47 

48 _mapping: _Mapping[str, int] 

49 """The mapping of the field names to their index. Basically the 

50 inverse of `fields`.""" 

51 

52 # Must be defined because "_mapping" cannot be a init=False field. 

53 def __init__(self, fields: _Sequence[str], cov: _NDArray[_floating]): 

54 """Construct a new `CovMatrix`. 

55 

56 Args: 

57 fields (Sequence[str]): The fields corresponding to the rows 

58 and columns of the covariance matrix. 

59 cov (NDArray[floating]): The covariance matrix. 

60 """ 

61 object.__setattr__(self, "fields", fields) 

62 object.__setattr__(self, "cov", cov) 

63 m = {x: i for i, x in enumerate(fields)} 

64 object.__setattr__(self, "_mapping", m) 

65 

66 def get_index(self, field: str) -> int: 

67 """Map a field name to the corresponding index. 

68 

69 Args: 

70 field (str): The field name. 

71 

72 Raises: 

73 KeyError: If `field` is not a valid field name. 

74 

75 Returns: 

76 int: The index in the covariance matrix belonging to 

77 `field`. 

78 """ 

79 return self._mapping[field] 

80 

81 def __getitem__( 

82 self, 

83 key: _Any 

84 ) -> _NDArray[_floating]: # pragma: no cover 

85 """Indexes the covariance matrix directly. This is equivalent 

86 to `self.cov[key]`. 

87 

88 Args: 

89 key (Any): The indexing object. 

90 

91 Returns: 

92 NDArray[floating]: The result of indexing the cov matrix. 

93 """ 

94 return self.cov[key] 

95 

96 

97@_dataclass(frozen=True) 

98class FitResult(_Generic[_TFitParams]): 

99 """Encapsulates the simple fit result.""" 

100 

101 __slots__ = ("opt_instance", "cov", "sigma", "p0") 

102 

103 opt_instance: _TFitParams 

104 """The optimizing instance found via the fit.""" 

105 

106 cov: CovMatrix 

107 """The covariance matrix.""" 

108 

109 sigma: _Union[_NDArray[_floating], None] 

110 """The sigma values.""" 

111 

112 p0: _Union[_TFitParams, None] 

113 """The initial value used for the fit or None, if none was 

114 provided.""" 

115 

116 

117@_dataclass(frozen=True) 

118class FullFitResult(FitResult[_TFitParams]): 

119 """Encapsulates the extended fit result. See also the doc of the 

120 return value of `scipy.optimize.curve_fit` with full_output=True for 

121 more info. 

122 """ 

123 

124 __slots__ = ("info_dict", "mesg", "ierr") 

125 

126 info_dict: _Dict[str, _Any] 

127 """The info dict for the full result.""" 

128 

129 mesg: str 

130 """The message containing information about the solution.""" 

131 

132 ierr: int 

133 """An integer flag showing the status.""" 

134 

135 

136def _check_has_nan_policy_arg() -> bool: 

137 """Determine if `scipy.optimize.curve_fit` has the 'nan_policy' 

138 argument. 

139 

140 Returns: 

141 bool: Whether `scipy.optimize.curve_fit` has the 'nan_policy' 

142 argument. 

143 """ 

144 maj, min, *_ = _scipy_version.short_version.split('.') 

145 imaj = int(maj) 

146 imin = int(min) 

147 # scipy version must be above or equal to 1.11. 

148 return (imaj > 1) or ((imaj == 1) and (imin >= 11)) 

149 

150 

151_CURVE_FIT_HAS_NAN_POLICY = _check_has_nan_policy_arg() 

152"""Whether `scipy.optimize.curve_fit` has the 'nan_policy' argument.""" 

153 

154 

155# There seemst to be no way to remove the nan_policy argument from 

156# the signature if it is not supported. 

157 

158@_overload 

159def make_fit( 

160 fitparams_dc: _Type[_TFitParams], 

161 xdata: _NDArray[_floating], 

162 ydata: _NDArray[_floating], 

163 f: _FitFunc[_TFitParams], 

164 *, 

165 p0: _Optional[_TFitParams] = None, 

166 sigma: _Optional[_NDArray[_floating]] = None, 

167 absolute_sigma: bool = False, 

168 check_finite: _Optional[bool] = None, 

169 method: _Optional[_Method] = None, 

170 full_output: _Literal[False] = False, 

171 nan_policy: _Optional[_NaNPolicy] = None, 

172 **kwargs 

173) -> FitResult[_TFitParams]: 

174 pass 

175 

176 

177@_overload 

178def make_fit( 

179 fitparams_dc: _Type[_TFitParams], 

180 xdata: _NDArray[_floating], 

181 ydata: _NDArray[_floating], 

182 f: _FitFunc[_TFitParams], 

183 *, 

184 p0: _Optional[_TFitParams] = None, 

185 sigma: _Optional[_NDArray[_floating]] = None, 

186 absolute_sigma: bool = False, 

187 check_finite: _Optional[bool] = None, 

188 method: _Optional[_Method] = None, 

189 full_output: _Literal[True], 

190 nan_policy: _Optional[_NaNPolicy] = None, 

191 **kwargs 

192) -> FullFitResult[_TFitParams]: 

193 pass 

194 

195 

196def make_fit( 

197 fitparams_dc: _Type[_TFitParams], 

198 xdata: _NDArray[_floating], 

199 ydata: _NDArray[_floating], 

200 f: _FitFunc[_TFitParams], 

201 *, 

202 p0: _Optional[_TFitParams] = None, 

203 sigma: _Optional[_NDArray[_floating]] = None, 

204 absolute_sigma: bool = False, 

205 check_finite: _Optional[bool] = None, 

206 method: _Optional[_Method] = None, 

207 full_output: bool = False, 

208 nan_policy: _Optional[_NaNPolicy] = None, 

209 **kwargs 

210) -> _Union[FitResult[_TFitParams], FullFitResult[_TFitParams]]: 

211 """Perform the fit. For more information on some of the parameters 

212 of this function see `scipy.optimize.curve_fit` which is used 

213 internally by this function to perform the fit. 

214 

215 Args: 

216 fitparams_dc (type[TFitParams]): The dataclass defining the 

217 parameters. 

218 xdata (NDArray[floating]): The x-values. 

219 ydata (NDArray[floating]): The y-values to fit against. 

220 f (FitFunc[TFitParams]): The fit func accepting the value of the 

221 `xdata` parameter, followed by an instance of `fitparams_dc` 

222 which represents the current set of parameters. The return 

223 value should be the current y-values obtained from the given 

224 set of parameters. 

225 p0 (TFitParams, optional): The initial instance. Defaults to 

226 None which constructs the instance using default values. 

227 sigma (NDArray[floating], optional): The uncertainties for the 

228 `ydata`. Defaults to None. 

229 absolute_sigma (bool, optional): If True, `sigma` are 

230 interpreted as absolute values. Defaults to False, which 

231 means only their relative values matter. This will affect 

232 the meaning of the returned covariance matrix. 

233 check_finite (bool, optional): Whether to check that the input 

234 arrays do not contain NaN values. Defaults to True if 

235 `nan_policy` is unspecify and False, otherwise. 

236 method (Method, optional): The method to use for the fit. 

237 Defaults to "lm" if no bounds are provided and "trf" 

238 otherwise. 

239 full_output (bool, optional): Whether to return an object with 

240 additional information. Defaults to False. 

241 nan_policy (NaNPolicy, optional): How to handle NaN values. 

242 Defaults to None. This argument is only available in 

243 scipy >= 1.11 and must be None (unset) otherwise. 

244 

245 Raises: 

246 NotImplementedError: If `nan_policy` is not None and the scipy 

247 version is below 1.11. 

248 

249 Returns: 

250 FitResult[TFitParams] or FullFitResult[TFitParams]: _description_ 

251 """ 

252 

253 spec = _spec_registry.get_fit_spec(fitparams_dc) 

254 

255 if p0 is None: 

256 p0_res = None 

257 p0 = spec.create_default_fit_instance() 

258 else: 

259 p0_res = p0 

260 

261 p0_array = spec.new_empty_array(float) 

262 spec.instance_to_array(p0, p0_array) 

263 

264 bounds = spec.bounds 

265 

266 def wrapper( 

267 x: _NDArray[_floating], 

268 *params: float 

269 ) -> _NDArray[_floating]: 

270 instance = spec.array_to_instance(params) 

271 return f(x, instance) 

272 

273 # "nan_policy" should never be in the kwargs. 

274 ver_kwargs = dict() 

275 if _CURVE_FIT_HAS_NAN_POLICY: 

276 ver_kwargs["nan_policy"] = nan_policy 

277 elif nan_policy is not None: 

278 raise NotImplementedError( 

279 "nan_policy must be None if the argument is not supported " 

280 "by scipy." 

281 ) 

282 

283 result = _optimize.curve_fit( 

284 wrapper, 

285 xdata, 

286 ydata, 

287 p0=p0_array, 

288 bounds=bounds, 

289 sigma=sigma, 

290 absolute_sigma=absolute_sigma, 

291 check_finite=check_finite, 

292 method=method, 

293 full_output=True, # The method computes the extras anyway 

294 **ver_kwargs, 

295 **kwargs 

296 ) 

297 

298 pcov: _NDArray 

299 popt, pcov, info_dict, msg, ec = result 

300 

301 res_instance = spec.array_to_instance(popt) 

302 

303 pcov.flags.writeable = False 

304 cov_mat = CovMatrix(tuple(spec.fitting_params), pcov) 

305 

306 if sigma is not None: # pragma: no cover 

307 sigma = _numpy.copy(sigma) 

308 sigma.flags.writeable = False 

309 

310 if not full_output: 

311 return FitResult(res_instance, cov_mat, sigma, p0_res) 

312 

313 return FullFitResult( 

314 res_instance, cov_mat, sigma, p0_res, info_dict, msg, ec 

315 )