Coverage for src/scipy_dataclassfitparams/_fit_spec.py: 100%

185 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-01 17:34 +0000

1import dataclasses as _dataclasses 

2import numpy as _numpy 

3import sys as _sys 

4 

5from abc import ABC as _ABC, abstractmethod as _abstractmethod 

6from dataclasses import dataclass as _dataclass 

7from numpy import dtype as _dtype, floating as _floating 

8from numpy.typing import NDArray as _NDArray 

9from typing import Any as _Any, Dict as _Dict, final as _final, \ 

10 Generic as _Generic, List as _List, Mapping as _Mapping, \ 

11 Sequence as _Sequence, Tuple as _Tuple, Type as _Type, \ 

12 TypeVar as _TypeVar 

13from typing import Union as _Union 

14 

15from . import _fields 

16from . import _depgraph 

17 

18from ._fields import FittingField as _FittingField 

19 

20 

21_T = _TypeVar("_T") 

22 

23 

24def _is_dataclass(cls: _Type) -> bool: 

25 """Checks whether a class is a dataclass. This wraps 

26 `dataclasses.is_dataclass` and removes the TypeGuard since this 

27 currently causes issues with mypy. 

28 """ 

29 return _dataclasses.is_dataclass(cls) 

30 

31 

32BoundsTuple = _Tuple[_NDArray[_floating], _NDArray[_floating]] 

33"""The type of a tuple definining the fit parameter bounds.""" 

34 

35 

36@_dataclass(frozen=True) 

37class DependentFieldValueResolver(_ABC): 

38 """Base class for objects resolving values of dependent fields.""" 

39 

40 __slots__ = ("target",) 

41 

42 target: str 

43 """The name of the field that 'depends'.""" 

44 

45 @_abstractmethod 

46 def get(self, others: _Dict[str, _Any]) -> _Any: 

47 pass 

48 

49 

50class DefaultResolver(_ABC): 

51 """Base class for objects resolving the default values of special 

52 fields. 

53 """ 

54 

55 @_abstractmethod 

56 def get(self, others: _Dict[str, _Any]) -> _Any: 

57 pass 

58 

59 

60@_final 

61class UnsetDefaultResolver(DefaultResolver): 

62 """A `DefaultResolver` for fields which do not have a default set 

63 and it is also not possible to otherwise determine a default. 

64 """ 

65 

66 def get(self, _: _Dict[str, _Any]) -> _Any: 

67 return 0.0 

68 

69 

70@_final 

71@_dataclass(frozen=True) 

72class FixedDefaultResolver(DefaultResolver): 

73 """A `DefaultResolver` with a fixed constant value.""" 

74 

75 __slots__ = ("value",) 

76 

77 value: _Any 

78 """The fixed default value""" 

79 

80 def get(self, _: _Dict[str, _Any]) -> _Any: 

81 return self.value 

82 

83 

84@_final 

85@_dataclass(frozen=True) 

86class DependentResolver(DependentFieldValueResolver, DefaultResolver): 

87 """A `DefaultResolver` indicating that the value is identical to the 

88 value of another field. 

89 """ 

90 

91 __slots__ = ("name",) 

92 

93 name: str 

94 """The name of the field from which to obtain the default value.""" 

95 

96 def get(self, others: _Dict[str, _Any]) -> _Any: 

97 return others[self.name] 

98 

99 

100@_dataclass(frozen=True) 

101class FitSpecBase: 

102 """The base class for a class containing info about a class defining 

103 a fit. 

104 """ 

105 

106 __slots__ = ( 

107 "special_param_count", 

108 "init_order", 

109 "fitting_param_count", 

110 "fitting_params", 

111 "fitting_fields", 

112 "default_resolvers", 

113 "lower_bounds", 

114 "upper_bounds", 

115 "bounds", 

116 "dep_resolvers" 

117 ) 

118 

119 special_param_count: int 

120 """The number of fitting parameters, including dependent and const 

121 parameters.""" 

122 

123 init_order: _Sequence[str] 

124 """The order in which the special parameters will be initialized.""" 

125 

126 fitting_param_count: int 

127 """The number of actual fitting parameters.""" 

128 

129 fitting_params: _Sequence[str] 

130 """The actual parameters that are part of the fitting procedure.""" 

131 

132 fitting_fields: _Mapping[str, _Union[_FittingField, None]] 

133 """A dict mapping field names to the descriptors for the fitting 

134 fields or None, if it is a regular field.""" 

135 

136 default_resolvers: _Sequence[DefaultResolver] 

137 """The resolvers for the default values of the fields.""" 

138 

139 lower_bounds: _Union[_NDArray[_floating], None] 

140 """The lower bounds or None, if they do not need to be set.""" 

141 

142 upper_bounds: _Union[_NDArray[_floating], None] 

143 """The upper bounds or None, if they do not need to be set.""" 

144 

145 bounds: BoundsTuple 

146 """The bounds tuple.""" 

147 

148 dep_resolvers: _Sequence[DependentResolver] 

149 """The resolvers for the dependent fields.""" 

150 

151 @classmethod 

152 def _process_bounds( 

153 cls, 

154 set_bounds: _Dict[int, float], 

155 num_params: int, 

156 default: float 

157 ) -> _Union[_NDArray[_floating], None]: 

158 if len(set_bounds) == 0: 

159 return None 

160 result = _numpy.full((num_params,), default) 

161 for k, v in set_bounds.items(): 

162 result[k] = v 

163 result.flags.writeable = False 

164 return result 

165 

166 @classmethod 

167 def generate(cls, t: _Type[_T]): 

168 """Generate the `FitSpecBase` or a derived class for a type. 

169 

170 Args: 

171 t (Type[T]): The type/dataclass for which to generate the 

172 `FitSpecBase`. 

173 

174 Raises: 

175 TypeError: If the provided type `t` is not a dataclass. 

176 NotImplementedError: If an unkown fitting field description 

177 (FittingField subclass) is encountered. 

178 ValueError: _description_ 

179 

180 Returns: 

181 FitSpecBase or subclass: The generated `FitSpecBase`. 

182 """ 

183 if not _is_dataclass(t): 

184 raise TypeError( # pragma: no cover 

185 f"The type {t.__qualname__!r} is not a dataclass." 

186 ) 

187 

188 all_fields_ordered = _dataclasses.fields(t) # type: ignore [arg-type] 

189 

190 depgraph = _depgraph.DepGraph((f.name for f in all_fields_ordered)) 

191 

192 field_descriptors: _Dict[str, _Union[_FittingField, None]] = dict() 

193 

194 set_lower_bounds: _Dict[int, float] = dict() 

195 set_upper_bounds: _Dict[int, float] = dict() 

196 default_resolvers: _List[DefaultResolver] = list() 

197 

198 special_fields: _List[str] = list() 

199 fitting_fields: _List[str] = list() 

200 

201 dep_resolver_map: _Dict[str, DependentResolver] = dict() 

202 

203 init_fields = (f for f in all_fields_ordered if f.init) 

204 for i, field in enumerate(init_fields): 

205 field_name = field.name 

206 special_fields.append(field_name) 

207 

208 info = _fields.get_special_fitting_field(field) 

209 field_descriptors[field_name] = info 

210 

211 def_res: DefaultResolver 

212 if (info is None) or isinstance(info, _fields.RegularField): 

213 field_default = field.default 

214 if field_default is _dataclasses.MISSING: 

215 def_res = UnsetDefaultResolver() 

216 else: 

217 def_res = FixedDefaultResolver(field_default) 

218 fitting_fields.append(field_name) 

219 else: 

220 if isinstance(info, _fields.BoundedField): 

221 if info.min_finite: 

222 v_min: _Any = info.min 

223 set_lower_bounds[i] = v_min 

224 if info.max_finite: 

225 v_max: _Any = info.max 

226 set_upper_bounds[i] = v_max 

227 df_val = info.resolve_default(field.default) 

228 def_res = FixedDefaultResolver(df_val) 

229 fitting_fields.append(field_name) 

230 elif isinstance(info, _fields.ConstField): 

231 def_res = FixedDefaultResolver(info.value) 

232 elif isinstance(info, _fields.SameAsField): 

233 dep_name = info.name 

234 field_name = field_name # Name of the depending one 

235 depgraph.add_dependency(field_name, dep_name) 

236 def_res = DependentResolver(field_name, dep_name) 

237 dep_resolver_map[field_name] = def_res 

238 else: 

239 raise NotImplementedError( 

240 f"Unknown FittingField encountered: {info!r}" 

241 ) 

242 default_resolvers.append(def_res) 

243 

244 num_special_params = len(special_fields) 

245 

246 t_def_res = tuple(default_resolvers) 

247 if len(t_def_res) != num_special_params: 

248 raise AssertionError 

249 

250 if depgraph.dependency_count > 0: 

251 if depgraph.has_closed_cycles(): 

252 raise ValueError( 

253 "There are circular dependencies in the field " 

254 "dependencies meaning that some fields cannot be " 

255 "initialized." 

256 ) 

257 init_order = depgraph.get_init_order() 

258 ssf = set(special_fields) 

259 init_order = tuple((f for f in init_order if f in ssf)) 

260 if len(init_order) != num_special_params: 

261 raise AssertionError 

262 

263 def get_dep_resolvers(ordered_fields: _Sequence[str]): 

264 """Yield the `DependentResolver`s for the dependent 

265 fields in construction order. 

266 

267 Args: 

268 ordered_fields (Sequence[str]): The fields in the 

269 order of construction. 

270 

271 Yields: 

272 DependentResolver: The resolvers for the fields. 

273 """ 

274 for f in ordered_fields: 

275 try: 

276 yield dep_resolver_map[f] 

277 except KeyError: 

278 pass 

279 

280 dep_resolvers_seq = tuple(get_dep_resolvers(init_order)) 

281 else: 

282 init_order = tuple(special_fields) 

283 dep_resolvers_seq = () 

284 

285 num_fitting_params = len(fitting_fields) 

286 

287 lower_bounds = cls._process_bounds( 

288 set_lower_bounds, num_fitting_params, -_numpy.inf 

289 ) 

290 upper_bounds = cls._process_bounds( 

291 set_upper_bounds, num_fitting_params, _numpy.inf 

292 ) 

293 

294 lb = lower_bounds 

295 if lb is None: 

296 lb = _numpy.array(-_numpy.inf) 

297 lb.flags.writeable = False 

298 ub = upper_bounds 

299 if ub is None: 

300 ub = _numpy.array(_numpy.inf) 

301 ub.flags.writeable = False 

302 bounds = (lb, ub) 

303 

304 return cls( 

305 num_special_params, 

306 init_order, 

307 num_fitting_params, 

308 tuple(fitting_fields), 

309 field_descriptors, 

310 t_def_res, 

311 lower_bounds, 

312 upper_bounds, 

313 bounds, 

314 dep_resolvers_seq 

315 ) 

316 

317 

318if _sys.version_info < (3, 9): 

319 DTypeLikeFloat = _Union[ # type: ignore [misc] 

320 _Type[float], 

321 _Type[_floating] 

322 ] 

323else: 

324 DTypeLikeFloat = _Union[ # type: ignore [misc] 

325 _Type[float], 

326 _Type[_floating], 

327 _dtype[_floating], 

328 ] 

329 

330 

331@_dataclass(frozen=True) 

332class FitSpec(FitSpecBase, _Generic[_T]): 

333 """The specification for a fit class.""" 

334 

335 __slots__ = ("clss",) 

336 

337 clss: _Type[_T] # Note: cannot be named cls in python < 3.9 

338 """The type for which this instance specifies information.""" 

339 

340 def create_default_fit_instance(self) -> _T: 

341 """Get the default instance for the fit. 

342 

343 Returns: 

344 T: The instance. 

345 """ 

346 init_order = self.init_order 

347 def_res = self.default_resolvers 

348 

349 params: _Dict[str, _Any] = dict() 

350 

351 for fname, resolver in zip(init_order, def_res): 

352 value = resolver.get(params) 

353 params[fname] = value 

354 

355 return self.clss(**params) 

356 

357 def new_empty_array(self, dtype: DTypeLikeFloat) -> _NDArray[_floating]: 

358 return _numpy.empty((self.fitting_param_count,), dtype) 

359 

360 def instance_to_array( 

361 self, 

362 instance: _T, 

363 out: _NDArray[_floating] 

364 ): 

365 num_params = self.fitting_param_count 

366 given_shape = out.shape 

367 expected_shape = (num_params,) 

368 if given_shape != expected_shape: 

369 raise ValueError( # pragma: no cover 

370 f"The provided array has an invalid shape: {given_shape}." 

371 f"Expected: {expected_shape}." 

372 ) 

373 

374 for i, field in enumerate(self.fitting_params): 

375 out[i] = getattr(instance, field) 

376 

377 def array_to_instance( 

378 self, 

379 array: _Sequence[float] 

380 ) -> _T: 

381 num_params = self.fitting_param_count 

382 array_len = len(array) 

383 if array_len != num_params: 

384 raise ValueError( # pragma: no cover 

385 f"Unexpected number of parameters: {array_len}. " 

386 f"Expected {num_params}." 

387 ) 

388 

389 kwargs: _Dict[str, _Any] = dict() 

390 for i, field in enumerate(self.fitting_params): 

391 kwargs[field] = array[i] 

392 

393 # Resolve dependent fields 

394 for resolver in self.dep_resolvers: 

395 kwargs[resolver.target] = resolver.get(kwargs) 

396 

397 return self.clss(**kwargs)