Coverage for src/scipy_dataclassfitparams/_fit_spec.py: 100%
185 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
1import dataclasses as _dataclasses
2import numpy as _numpy
3import sys as _sys
5from abc import ABC as _ABC, abstractmethod as _abstractmethod
6from dataclasses import dataclass as _dataclass
7from numpy import dtype as _dtype, floating as _floating
8from numpy.typing import NDArray as _NDArray
9from typing import Any as _Any, Dict as _Dict, final as _final, \
10 Generic as _Generic, List as _List, Mapping as _Mapping, \
11 Sequence as _Sequence, Tuple as _Tuple, Type as _Type, \
12 TypeVar as _TypeVar
13from typing import Union as _Union
15from . import _fields
16from . import _depgraph
18from ._fields import FittingField as _FittingField
21_T = _TypeVar("_T")
24def _is_dataclass(cls: _Type) -> bool:
25 """Checks whether a class is a dataclass. This wraps
26 `dataclasses.is_dataclass` and removes the TypeGuard since this
27 currently causes issues with mypy.
28 """
29 return _dataclasses.is_dataclass(cls)
32BoundsTuple = _Tuple[_NDArray[_floating], _NDArray[_floating]]
33"""The type of a tuple definining the fit parameter bounds."""
36@_dataclass(frozen=True)
37class DependentFieldValueResolver(_ABC):
38 """Base class for objects resolving values of dependent fields."""
40 __slots__ = ("target",)
42 target: str
43 """The name of the field that 'depends'."""
45 @_abstractmethod
46 def get(self, others: _Dict[str, _Any]) -> _Any:
47 pass
50class DefaultResolver(_ABC):
51 """Base class for objects resolving the default values of special
52 fields.
53 """
55 @_abstractmethod
56 def get(self, others: _Dict[str, _Any]) -> _Any:
57 pass
60@_final
61class UnsetDefaultResolver(DefaultResolver):
62 """A `DefaultResolver` for fields which do not have a default set
63 and it is also not possible to otherwise determine a default.
64 """
66 def get(self, _: _Dict[str, _Any]) -> _Any:
67 return 0.0
70@_final
71@_dataclass(frozen=True)
72class FixedDefaultResolver(DefaultResolver):
73 """A `DefaultResolver` with a fixed constant value."""
75 __slots__ = ("value",)
77 value: _Any
78 """The fixed default value"""
80 def get(self, _: _Dict[str, _Any]) -> _Any:
81 return self.value
84@_final
85@_dataclass(frozen=True)
86class DependentResolver(DependentFieldValueResolver, DefaultResolver):
87 """A `DefaultResolver` indicating that the value is identical to the
88 value of another field.
89 """
91 __slots__ = ("name",)
93 name: str
94 """The name of the field from which to obtain the default value."""
96 def get(self, others: _Dict[str, _Any]) -> _Any:
97 return others[self.name]
100@_dataclass(frozen=True)
101class FitSpecBase:
102 """The base class for a class containing info about a class defining
103 a fit.
104 """
106 __slots__ = (
107 "special_param_count",
108 "init_order",
109 "fitting_param_count",
110 "fitting_params",
111 "fitting_fields",
112 "default_resolvers",
113 "lower_bounds",
114 "upper_bounds",
115 "bounds",
116 "dep_resolvers"
117 )
119 special_param_count: int
120 """The number of fitting parameters, including dependent and const
121 parameters."""
123 init_order: _Sequence[str]
124 """The order in which the special parameters will be initialized."""
126 fitting_param_count: int
127 """The number of actual fitting parameters."""
129 fitting_params: _Sequence[str]
130 """The actual parameters that are part of the fitting procedure."""
132 fitting_fields: _Mapping[str, _Union[_FittingField, None]]
133 """A dict mapping field names to the descriptors for the fitting
134 fields or None, if it is a regular field."""
136 default_resolvers: _Sequence[DefaultResolver]
137 """The resolvers for the default values of the fields."""
139 lower_bounds: _Union[_NDArray[_floating], None]
140 """The lower bounds or None, if they do not need to be set."""
142 upper_bounds: _Union[_NDArray[_floating], None]
143 """The upper bounds or None, if they do not need to be set."""
145 bounds: BoundsTuple
146 """The bounds tuple."""
148 dep_resolvers: _Sequence[DependentResolver]
149 """The resolvers for the dependent fields."""
151 @classmethod
152 def _process_bounds(
153 cls,
154 set_bounds: _Dict[int, float],
155 num_params: int,
156 default: float
157 ) -> _Union[_NDArray[_floating], None]:
158 if len(set_bounds) == 0:
159 return None
160 result = _numpy.full((num_params,), default)
161 for k, v in set_bounds.items():
162 result[k] = v
163 result.flags.writeable = False
164 return result
166 @classmethod
167 def generate(cls, t: _Type[_T]):
168 """Generate the `FitSpecBase` or a derived class for a type.
170 Args:
171 t (Type[T]): The type/dataclass for which to generate the
172 `FitSpecBase`.
174 Raises:
175 TypeError: If the provided type `t` is not a dataclass.
176 NotImplementedError: If an unkown fitting field description
177 (FittingField subclass) is encountered.
178 ValueError: _description_
180 Returns:
181 FitSpecBase or subclass: The generated `FitSpecBase`.
182 """
183 if not _is_dataclass(t):
184 raise TypeError( # pragma: no cover
185 f"The type {t.__qualname__!r} is not a dataclass."
186 )
188 all_fields_ordered = _dataclasses.fields(t) # type: ignore [arg-type]
190 depgraph = _depgraph.DepGraph((f.name for f in all_fields_ordered))
192 field_descriptors: _Dict[str, _Union[_FittingField, None]] = dict()
194 set_lower_bounds: _Dict[int, float] = dict()
195 set_upper_bounds: _Dict[int, float] = dict()
196 default_resolvers: _List[DefaultResolver] = list()
198 special_fields: _List[str] = list()
199 fitting_fields: _List[str] = list()
201 dep_resolver_map: _Dict[str, DependentResolver] = dict()
203 init_fields = (f for f in all_fields_ordered if f.init)
204 for i, field in enumerate(init_fields):
205 field_name = field.name
206 special_fields.append(field_name)
208 info = _fields.get_special_fitting_field(field)
209 field_descriptors[field_name] = info
211 def_res: DefaultResolver
212 if (info is None) or isinstance(info, _fields.RegularField):
213 field_default = field.default
214 if field_default is _dataclasses.MISSING:
215 def_res = UnsetDefaultResolver()
216 else:
217 def_res = FixedDefaultResolver(field_default)
218 fitting_fields.append(field_name)
219 else:
220 if isinstance(info, _fields.BoundedField):
221 if info.min_finite:
222 v_min: _Any = info.min
223 set_lower_bounds[i] = v_min
224 if info.max_finite:
225 v_max: _Any = info.max
226 set_upper_bounds[i] = v_max
227 df_val = info.resolve_default(field.default)
228 def_res = FixedDefaultResolver(df_val)
229 fitting_fields.append(field_name)
230 elif isinstance(info, _fields.ConstField):
231 def_res = FixedDefaultResolver(info.value)
232 elif isinstance(info, _fields.SameAsField):
233 dep_name = info.name
234 field_name = field_name # Name of the depending one
235 depgraph.add_dependency(field_name, dep_name)
236 def_res = DependentResolver(field_name, dep_name)
237 dep_resolver_map[field_name] = def_res
238 else:
239 raise NotImplementedError(
240 f"Unknown FittingField encountered: {info!r}"
241 )
242 default_resolvers.append(def_res)
244 num_special_params = len(special_fields)
246 t_def_res = tuple(default_resolvers)
247 if len(t_def_res) != num_special_params:
248 raise AssertionError
250 if depgraph.dependency_count > 0:
251 if depgraph.has_closed_cycles():
252 raise ValueError(
253 "There are circular dependencies in the field "
254 "dependencies meaning that some fields cannot be "
255 "initialized."
256 )
257 init_order = depgraph.get_init_order()
258 ssf = set(special_fields)
259 init_order = tuple((f for f in init_order if f in ssf))
260 if len(init_order) != num_special_params:
261 raise AssertionError
263 def get_dep_resolvers(ordered_fields: _Sequence[str]):
264 """Yield the `DependentResolver`s for the dependent
265 fields in construction order.
267 Args:
268 ordered_fields (Sequence[str]): The fields in the
269 order of construction.
271 Yields:
272 DependentResolver: The resolvers for the fields.
273 """
274 for f in ordered_fields:
275 try:
276 yield dep_resolver_map[f]
277 except KeyError:
278 pass
280 dep_resolvers_seq = tuple(get_dep_resolvers(init_order))
281 else:
282 init_order = tuple(special_fields)
283 dep_resolvers_seq = ()
285 num_fitting_params = len(fitting_fields)
287 lower_bounds = cls._process_bounds(
288 set_lower_bounds, num_fitting_params, -_numpy.inf
289 )
290 upper_bounds = cls._process_bounds(
291 set_upper_bounds, num_fitting_params, _numpy.inf
292 )
294 lb = lower_bounds
295 if lb is None:
296 lb = _numpy.array(-_numpy.inf)
297 lb.flags.writeable = False
298 ub = upper_bounds
299 if ub is None:
300 ub = _numpy.array(_numpy.inf)
301 ub.flags.writeable = False
302 bounds = (lb, ub)
304 return cls(
305 num_special_params,
306 init_order,
307 num_fitting_params,
308 tuple(fitting_fields),
309 field_descriptors,
310 t_def_res,
311 lower_bounds,
312 upper_bounds,
313 bounds,
314 dep_resolvers_seq
315 )
318if _sys.version_info < (3, 9):
319 DTypeLikeFloat = _Union[ # type: ignore [misc]
320 _Type[float],
321 _Type[_floating]
322 ]
323else:
324 DTypeLikeFloat = _Union[ # type: ignore [misc]
325 _Type[float],
326 _Type[_floating],
327 _dtype[_floating],
328 ]
331@_dataclass(frozen=True)
332class FitSpec(FitSpecBase, _Generic[_T]):
333 """The specification for a fit class."""
335 __slots__ = ("clss",)
337 clss: _Type[_T] # Note: cannot be named cls in python < 3.9
338 """The type for which this instance specifies information."""
340 def create_default_fit_instance(self) -> _T:
341 """Get the default instance for the fit.
343 Returns:
344 T: The instance.
345 """
346 init_order = self.init_order
347 def_res = self.default_resolvers
349 params: _Dict[str, _Any] = dict()
351 for fname, resolver in zip(init_order, def_res):
352 value = resolver.get(params)
353 params[fname] = value
355 return self.clss(**params)
357 def new_empty_array(self, dtype: DTypeLikeFloat) -> _NDArray[_floating]:
358 return _numpy.empty((self.fitting_param_count,), dtype)
360 def instance_to_array(
361 self,
362 instance: _T,
363 out: _NDArray[_floating]
364 ):
365 num_params = self.fitting_param_count
366 given_shape = out.shape
367 expected_shape = (num_params,)
368 if given_shape != expected_shape:
369 raise ValueError( # pragma: no cover
370 f"The provided array has an invalid shape: {given_shape}."
371 f"Expected: {expected_shape}."
372 )
374 for i, field in enumerate(self.fitting_params):
375 out[i] = getattr(instance, field)
377 def array_to_instance(
378 self,
379 array: _Sequence[float]
380 ) -> _T:
381 num_params = self.fitting_param_count
382 array_len = len(array)
383 if array_len != num_params:
384 raise ValueError( # pragma: no cover
385 f"Unexpected number of parameters: {array_len}. "
386 f"Expected {num_params}."
387 )
389 kwargs: _Dict[str, _Any] = dict()
390 for i, field in enumerate(self.fitting_params):
391 kwargs[field] = array[i]
393 # Resolve dependent fields
394 for resolver in self.dep_resolvers:
395 kwargs[resolver.target] = resolver.get(kwargs)
397 return self.clss(**kwargs)