Coverage for src/scipy_dataclassfitparams/_fields.py: 100%
135 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
1"""This module defines the functions that can be used for the field
2definitions and the corresponding classes that encapsulate the dataclass
3metadata values.
4"""
6import cmath as _cmath
7import dataclasses as _dataclasses
8import sys as _sys
9import warnings as _warnings
11from dataclasses import dataclass as _dataclass, Field as _Field
12from typing import Any as _Any, Callable as _Callable, ClassVar as _ClassVar, \
13 Mapping as _Mapping, Optional as _Optional, Union as _Union
15from . import _warning_types
18_METADATA_KEY = object()
19"""The object used as the key for the fields' metadata mapping."""
22@_dataclass(frozen=True)
23class FittingField:
24 """Base class for the field metadata values."""
25 pass
28def get_special_fitting_field(f: _Field) -> _Union[FittingField, None]:
29 """Get the `FittingField` instance associated with a dataclass
30 `Field`.
32 Args:
33 f (Field): The dataclass field.
35 Raises:
36 AssertionError: If the `_METADATA_KEY` for a fitting field is
37 not associated with a `FittingField` instance.
39 Returns:
40 FittingField | None: The found `FittingField` or None, if none
41 was found.
42 """
43 md = f.metadata
44 try:
45 val = md[_METADATA_KEY]
46 except KeyError:
47 return None
48 if not isinstance(val, FittingField):
49 raise AssertionError(
50 f"Found invalid object {val!r} where a FittingField was "
51 f"expected. Dataclass field: {f.name!r}"
52 ) from TypeError
53 return val
56_MISSING = _dataclasses.MISSING
57"""The 'missing' values of the dataclasses module."""
60_MissingType = type(_MISSING)
61"""The type of `dataclasses.MISSING` and `_MISSING`."""
64_FloatOrMissing = _Union[
65 float, _MissingType # type: ignore [valid-type]
66]
67"""The type representing a float or `dataclasses.MISSING`."""
70@_dataclass(frozen=True)
71class BoundedField(FittingField):
73 __slots__ = ("min", "max")
75 min: float
76 """The min value."""
78 max: float
79 """The max value."""
81 POS_INF: _ClassVar[float] = float("inf")
82 """Positive infinity."""
84 NEG_INF: _ClassVar[float] = float("-inf")
85 """Negative infinity."""
87 def __init__(self, min: _Optional[float], max: _Optional[float]):
88 if max is None:
89 max = self.POS_INF
90 if min is None:
91 min = self.NEG_INF
92 if min > max:
93 raise ValueError(
94 "The provided max value was smaller than the provided "
95 "min value."
96 )
97 object.__setattr__(self, "min", min)
98 object.__setattr__(self, "max", max)
100 @property
101 def actually_bounded(self) -> bool:
102 """Whether this instance actually imposes boundaries on the fit
103 parameter. If both `min` and `max` are set to None, this is
104 not the case.
105 """
106 return self.min_finite or self.max_finite
108 @property
109 def min_finite(self) -> bool:
110 """Whether the min value is finite.
112 Returns:
113 bool: Whether the min value is finite.
114 """
115 return _cmath.isfinite(self.min)
117 @property
118 def max_finite(self) -> bool:
119 """Whether the max value is finite.
121 Returns:
122 bool: Whether the max value is finite.
123 """
124 return _cmath.isfinite(self.max)
126 def resolve_default(self, field_default: _FloatOrMissing) -> float:
127 """Resolves the default value to be assigned to a field
128 described by the `BoundedField` instance.
130 Args:
131 field_default (float or MISSING): The default value passed
132 to the dataclass `Field` or the missing value, if none
133 was passed.
135 Returns:
136 float: The default value for the field. This will be
137 `field_default` if it was provided and is within the
138 bounds described by this instance or the result of
139 `get_bounds_default()` otherwise.
140 """
141 has_default = field_default is not _dataclasses.MISSING
142 if (has_default and self.contains(field_default)): # type: ignore
143 return field_default
144 return self.get_bounds_default()
146 def get_bounds_default(self) -> float:
147 """Compute a valid default value within the bounds.
149 Returns:
150 float: The computed default value. If both bounds are
151 finite, this is their average value. If both are
152 non-finite, the value is 0. If one bound is finite, the
153 value lies with a distance of 1 to the finite bound
154 within the bounded region.
155 """
156 min_fin = self.min_finite
157 max_fin = self.max_finite
158 if min_fin and max_fin:
159 min = self.min
160 max = self.max
161 return 0.5 * (min + max) # type: ignore [operator]
162 if min_fin:
163 return self.min + 1 # type: ignore [operator,return-value]
164 elif max_fin:
165 return self.max - 1 # type: ignore [operator,return-value]
166 else:
167 return 0.0 # type: ignore [return-value]
169 def contains(self, value: float) -> bool:
170 """Checks whether a given value lies in the bounded region
171 described by this `BoundedField` instance.
173 Args:
174 value (float): The value to check.
176 Returns:
177 bool: Whether the value lies between the lower (self.min)
178 and upper bound (self.max).
179 """
180 res = True
181 if self.min_finite:
182 res &= value >= self.min
183 if self.max_finite:
184 res &= value <= self.max
185 return res
188@_dataclass(frozen=True)
189class ConstField(FittingField):
191 __slots__ = ("value",)
193 value: float
194 """The value of the field."""
197@_dataclass(frozen=True)
198class SameAsField(FittingField):
200 __slots__ = ("name",)
202 name: str
203 """The name of the field the field depends on."""
206@_dataclass(frozen=True)
207class RegularField(FittingField):
209 _instance: _ClassVar[_Optional["RegularField"]] = None
210 """An instance of this class."""
212 @classmethod
213 def get_instance(cls) -> "RegularField": # pragma: no cover
214 """Get a cached instance of this class.
216 Returns:
217 RegularField: A (possible cached) instance of the
218 `RegularField` class.
219 """
220 instance = cls._instance
221 if instance is None:
222 instance = cls()
223 cls._instance = instance
224 return instance
227_BoolOrMissing = _Union[bool, _MissingType] # type: ignore [valid-type]
228"""The type representing a `bool` or `dataclasses.MISSING`."""
231_FloatDefaultFactoryOrMissing = _Union[
232 _Callable[[], float], _MissingType # type: ignore [valid-type]
233]
234"""The type of the default factory returning a float for dataclass
235fields or `dataclasses.MISSING`."""
238def _merge_medatada(
239 field_metadata_instance: FittingField,
240 provided: _Optional[_Mapping[_Any, _Any]]
241) -> _Mapping[_Any, _Any]:
242 """Merge the user-provided metadata with the metadata key for the
243 fit description fields.
245 Args:
246 field_metadata_instance (FittingField): The field to add.
247 provided (Mapping[Any, Any] | None): The user-provided metadata
248 or None.
250 Returns:
251 Mapping[Any, Any]: The metadata to store in the field.
252 """
253 if provided is not None: # pragma: no cover
254 res = dict(provided)
255 if _METADATA_KEY in provided:
256 _warnings.warn(
257 "The metadata key for the fit field definitions was "
258 "already contained in the provided 'metadata' mapping. "
259 "The corresponding value will be overwritten.",
260 stacklevel=3,
261 category=_warning_types.MetadataKeyOverwrittenWarning
262 )
263 else:
264 res = dict()
265 res[_METADATA_KEY] = field_metadata_instance
266 return res
269if _sys.version_info >= (3, 10):
271 def bounded(
272 min: _Optional[float] = None,
273 max: _Optional[float] = None,
274 *,
275 default: _FloatOrMissing = _MISSING,
276 default_factory: _FloatDefaultFactoryOrMissing = _MISSING,
277 repr: bool = True,
278 hash: _Optional[bool] = None,
279 compare: bool = True,
280 metadata: _Optional[_Mapping[_Any, _Any]] = None,
281 kw_only: _BoolOrMissing = _MISSING
282 ) -> _Any:
283 """Define a fit parameter as bounded. Either one, two or no
284 boundaries may be provided. The remaining unset boundaries are
285 interpreted to be -inf for the lower and +inf for the upper
286 boundary.
288 Args:
289 min (float, optional): The lower bound. Defaults to -inf.
290 max (float, optional): The upper bound. Defaults to +inf.
291 default (float, optional): The default (initial) value for
292 the field. Defaults to the default value used by the fit
293 function for a bounded field.
294 default_factory (Callable[[], float], optional): A function
295 to be called to produce the fields default value. If
296 unset, the `default` will be used instead.
297 repr (bool, optional): Whether the field should be included
298 in the object's repr(). Defaults to True.
299 hash (bool, optional): Whether the field should be included
300 in the object's hash(). Defaults to None.
301 compare (bool, optional): Whether the field should be used
302 in comparison function. Defaults to True.
303 metadata (Mapping[Any, Any], optional): Metadata to add to
304 the field. Defaults to None.
305 kw_only (bool, optional): Whether the field will become a
306 keyword-only parameter to __init__(). Defaults to False.
308 Returns:
309 The dataclass field.
310 """
311 m = BoundedField(min, max)
312 metadata = _merge_medatada(m, metadata)
313 return _dataclasses.field( # type: ignore [call-overload]
314 default=default,
315 default_factory=default_factory,
316 repr=repr,
317 hash=hash,
318 compare=compare,
319 metadata=metadata,
320 kw_only=kw_only
321 )
323 def const(
324 value: float,
325 *,
326 repr: bool = True,
327 hash: _Optional[bool] = None,
328 compare: bool = True,
329 metadata: _Optional[_Mapping[_Any, _Any]] = None,
330 kw_only: _BoolOrMissing = _MISSING
331 ) -> float:
332 """Define a fit parameter as constant. The parameter is fixed
333 and will not be fitted against. The value will also be set as
334 the default value for the field at the constructor.
336 Args:
337 value (float): The value of the field.
338 repr (bool, optional): Whether the field should be included
339 in the object's repr(). Defaults to True.
340 hash (bool, optional): Whether the field should be included
341 in the object's hash(). Defaults to None.
342 compare (bool, optional): Whether the field should be used
343 in comparison function. Defaults to True.
344 metadata (Mapping[Any, Any], optional): Metadata to add to
345 the field. Defaults to None.
346 kw_only (bool, optional): Whether the field will become a
347 keyword-only parameter to __init__(). Defaults to False.
349 Returns:
350 The dataclass field.
351 """
352 m = ConstField(value)
353 metadata = _merge_medatada(m, metadata)
354 return _dataclasses.field( # type: ignore [call-overload]
355 default=value,
356 repr=repr,
357 hash=hash,
358 compare=compare,
359 metadata=metadata,
360 kw_only=kw_only
361 )
363 def same_as(
364 name: str,
365 *,
366 default: _FloatOrMissing = _MISSING,
367 default_factory: _FloatDefaultFactoryOrMissing = _MISSING,
368 repr: bool = True,
369 hash: _Optional[bool] = None,
370 compare: bool = True,
371 metadata: _Optional[_Mapping[_Any, _Any]] = None,
372 kw_only: _BoolOrMissing = _MISSING
373 ) -> _Any:
374 """Define a fit parameter as identical to another parameter. It
375 is still possible to set a `default` or `default_factory` for
376 the field. This will be used when calling the constructor
377 without providing a value for this field.
379 Args:
380 name (str): The name of the parameter this one is identical
381 to.
382 default (float, optional): The default (initial) value for
383 the field. Defaults to the default value used by the fit
384 function for a bounded field.
385 default_factory (Callable[[], float], optional): A function
386 to be called to produce the fields default value. If
387 unset, the `default` will be used instead.
388 repr (bool, optional): Whether the field should be included
389 in the object's repr(). Defaults to True.
390 hash (bool, optional): Whether the field should be included
391 in the object's hash(). Defaults to None.
392 compare (bool, optional): Whether the field should be used
393 in comparison function. Defaults to True.
394 metadata (Mapping[Any, Any], optional): Metadata to add to
395 the field. Defaults to None.
396 kw_only (bool, optional): Whether the field will become a
397 keyword-only parameter to __init__(). Defaults to False.
399 Returns:
400 The dataclass field.
401 """
402 m = SameAsField(name)
403 metadata = _merge_medatada(m, metadata)
404 return _dataclasses.field( # type: ignore [call-overload]
405 default=default,
406 default_factory=default_factory,
407 repr=repr,
408 hash=hash,
409 compare=compare,
410 metadata=metadata,
411 kw_only=kw_only
412 )
414 def regular(
415 *,
416 default: _FloatOrMissing = _MISSING,
417 default_factory: _FloatDefaultFactoryOrMissing = _MISSING,
418 repr: bool = True,
419 hash: _Optional[bool] = None,
420 compare: bool = True,
421 metadata: _Optional[_Mapping[_Any, _Any]] = None,
422 kw_only: _BoolOrMissing = _MISSING
423 ) -> _Any:
424 """Define regular fit parameter. The value must be provided at
425 __init__ when constructing instances of the dataclass.
427 Args:
428 default (float, optional): The default (initial) value for
429 the field. Defaults to the default value used by the fit
430 function for a bounded field.
431 default_factory (Callable[[], float], optional): A function
432 to be called to produce the fields default value. If
433 unset, the `default` will be used instead.
434 repr (bool, optional): Whether the field should be included
435 in the object's repr(). Defaults to True.
436 hash (bool, optional): Whether the field should be included
437 in the object's hash(). Defaults to None.
438 compare (bool, optional): Whether the field should be used
439 in comparison function. Defaults to True.
440 metadata (Mapping[Any, Any], optional): Metadata to add to
441 the field. Defaults to None.
442 kw_only (bool, optional): Whether the field will become a
443 keyword-only parameter to __init__(). Defaults to False.
445 Returns:
446 The dataclass field.
447 """
448 m = RegularField.get_instance()
449 metadata = _merge_medatada(m, metadata)
450 return _dataclasses.field( # type: ignore [call-overload]
451 default=default,
452 default_factory=default_factory,
453 repr=repr,
454 hash=hash,
455 compare=compare,
456 metadata=metadata,
457 kw_only=kw_only
458 )
460else:
462 # The "type: ignore [misc]" markers below prevent mypy from
463 # complaining about differing signatures of the conditional
464 # function defintions.
466 def bounded( # type: ignore [misc]
467 min: _Optional[float] = None,
468 max: _Optional[float] = None,
469 *,
470 default: _FloatOrMissing = _MISSING,
471 default_factory: _FloatDefaultFactoryOrMissing = _MISSING,
472 repr: bool = True,
473 hash: _Optional[bool] = None,
474 compare: bool = True,
475 metadata: _Optional[_Mapping[_Any, _Any]] = None
476 ) -> _Any:
477 """Define a fit parameter as bounded. Either one, two or no
478 boundaries may be provided. The remaining unset boundaries are
479 interpreted to be -inf for the lower and +inf for the upper
480 boundary.
482 Args:
483 min (float, optional): The lower bound. Defaults to -inf.
484 max (float, optional): The upper bound. Defaults to +inf.
485 default (float, optional): The default (initial) value for
486 the field. Defaults to the default value used by the fit
487 function for a bounded field.
488 default_factory (Callable[[], float], optional): A function
489 to be called to produce the fields default value. If
490 unset, the `default` will be used instead.
491 repr (bool, optional): Whether the field should be included
492 in the object's repr(). Defaults to True.
493 hash (bool, optional): Whether the field should be included
494 in the object's hash(). Defaults to None.
495 compare (bool, optional): Whether the field should be used
496 in comparison function. Defaults to True.
497 metadata (Mapping[Any, Any], optional): Metadata to add to
498 the field. Defaults to None.
500 Returns:
501 The dataclass field.
502 """
503 m = BoundedField(min, max)
504 metadata = _merge_medatada(m, metadata)
505 return _dataclasses.field( # type: ignore [call-overload]
506 default=default,
507 default_factory=default_factory,
508 repr=repr,
509 hash=hash,
510 compare=compare,
511 metadata=metadata
512 )
514 def const( # type: ignore [misc]
515 value: float,
516 *,
517 repr: bool = True,
518 hash: _Optional[bool] = None,
519 compare: bool = True,
520 metadata: _Optional[_Mapping[_Any, _Any]] = None
521 ) -> float:
522 """Define a fit parameter as constant. The parameter is fixed
523 and will not be fitted against. The value will also be set as
524 the default value for the field at the constructor.
526 Args:
527 value (float): The value of the field.
528 repr (bool, optional): Whether the field should be included
529 in the object's repr(). Defaults to True.
530 hash (bool, optional): Whether the field should be included
531 in the object's hash(). Defaults to None.
532 compare (bool, optional): Whether the field should be used
533 in comparison function. Defaults to True.
534 metadata (Mapping[Any, Any], optional): Metadata to add to
535 the field. Defaults to None.
537 Returns:
538 The dataclass field.
539 """
540 m = ConstField(value)
541 metadata = _merge_medatada(m, metadata)
542 return _dataclasses.field( # type: ignore [call-overload]
543 default=value,
544 repr=repr,
545 hash=hash,
546 compare=compare,
547 metadata=metadata
548 )
550 def same_as( # type: ignore [misc]
551 name: str,
552 *,
553 default: _FloatOrMissing = _MISSING,
554 default_factory: _FloatDefaultFactoryOrMissing = _MISSING,
555 repr: bool = True,
556 hash: _Optional[bool] = None,
557 compare: bool = True,
558 metadata: _Optional[_Mapping[_Any, _Any]] = None
559 ) -> _Any:
560 """Define a fit parameter as identical to another parameter. It
561 is still possible to set a `default` or `default_factory` for
562 the field. This will be used when calling the constructor
563 without providing a value for this field.
565 Args:
566 name (str): The name of the parameter this one is identical
567 to.
568 default (float, optional): The default (initial) value for
569 the field. Defaults to the default value used by the fit
570 function for a bounded field.
571 default_factory (Callable[[], float], optional): A function
572 to be called to produce the fields default value. If
573 unset, the `default` will be used instead.
574 repr (bool, optional): Whether the field should be included
575 in the object's repr(). Defaults to True.
576 hash (bool, optional): Whether the field should be included
577 in the object's hash(). Defaults to None.
578 compare (bool, optional): Whether the field should be used
579 in comparison function. Defaults to True.
580 metadata (Mapping[Any, Any], optional): Metadata to add to
581 the field. Defaults to None.
582 kw_only (bool, optional): Whether the field will become a
583 keyword-only parameter to __init__(). Defaults to False.
585 Returns:
586 The dataclass field.
587 """
588 m = SameAsField(name)
589 metadata = _merge_medatada(m, metadata)
590 return _dataclasses.field( # type: ignore [call-overload]
591 default=default,
592 default_factory=default_factory,
593 repr=repr,
594 hash=hash,
595 compare=compare,
596 metadata=metadata
597 )
599 def regular( # type: ignore [misc]
600 *,
601 default: _FloatOrMissing = _MISSING,
602 default_factory: _FloatDefaultFactoryOrMissing = _MISSING,
603 repr: bool = True,
604 hash: _Optional[bool] = None,
605 compare: bool = True,
606 metadata: _Optional[_Mapping[_Any, _Any]] = None
607 ) -> _Any:
608 """Define regular fit parameter. The value must be provided at
609 __init__ when constructing instances of the dataclass.
611 Args:
612 default (float, optional): The default (initial) value for
613 the field. Defaults to the default value used by the fit
614 function for a bounded field.
615 default_factory (Callable[[], float], optional): A function
616 to be called to produce the fields default value. If
617 unset, the `default` will be used instead.
618 repr (bool, optional): Whether the field should be included
619 in the object's repr(). Defaults to True.
620 hash (bool, optional): Whether the field should be included
621 in the object's hash(). Defaults to None.
622 compare (bool, optional): Whether the field should be used
623 in comparison function. Defaults to True.
624 metadata (Mapping[Any, Any], optional): Metadata to add to
625 the field. Defaults to None.
627 Returns:
628 The dataclass field.
629 """
630 m = RegularField.get_instance()
631 metadata = _merge_medatada(m, metadata)
632 return _dataclasses.field( # type: ignore [call-overload]
633 default=default,
634 default_factory=default_factory,
635 repr=repr,
636 hash=hash,
637 compare=compare,
638 metadata=metadata
639 )