Coverage for src/scipy_dataclassfitparams/_spec_registry.py: 100%
28 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
1"""This module contains the caching mechanism for the `FitSpec`s."""
3import dataclasses as _dataclasses
4import sys as _sys
5import weakref as _weakref
7from dataclasses import dataclass as _dataclass
8from typing import Generic as _Generic, Type as _Type, TypeVar as _TypeVar
10from ._fit_spec import FitSpec as _FitSpec, FitSpecBase as _FitSpecBase
13_T = _TypeVar("_T")
16@_dataclass(frozen=True)
17class _StoredFitSpec(_FitSpecBase, _Generic[_T]):
18 """A `_FitSpecBase` stored in the registry. No reference to the
19 actual type defining the fit is stored.
20 """
22 def pin(self, cls: _Type[_T]) -> "_FitSpec[_T]":
23 """Pin the actual type of the dataclass defining the fit and
24 return a `FitSpec` instance containing the type.
26 Args:
27 cls (Type[T]): The class defining the fit.
29 Returns:
30 FitSpec[T]: The constructed `FitSpec` instance.
31 """
32 # Copy over the fields
33 keys = (f.name for f in _dataclasses.fields(_FitSpecBase))
34 kwargs = {k: getattr(self, k) for k in keys}
35 return _FitSpec(clss=cls, **kwargs)
38if _sys.version_info < (3, 9):
39 from typing import Dict as _Dict
41 _reg: _Dict[type, _StoredFitSpec]
42 _reg = _weakref.WeakKeyDictionary() # type: ignore [assignment]
43else:
44 _reg = _weakref.WeakKeyDictionary[ # type: ignore [assignment]
45 type, _StoredFitSpec
46 ]()
49_REGISTRY = _reg
50"""The registry mapping the classes defining fits to the corresponding
51`_StoredFitSpec`."""
54def get_fit_spec(t: _Type[_T]) -> _FitSpec[_T]:
55 """Get the fit spec for a class defining a fit.
57 Args:
58 t (Type[T]): The class for which to get the fit spec.
60 Returns:
61 FitSpec[T]: The additional information for the dataclass
62 defining the fit.
63 """
64 try:
65 fs = _REGISTRY[t]
66 except KeyError:
67 fs = _StoredFitSpec.generate(t)
68 _REGISTRY[t] = fs
69 result = fs.pin(t)
70 return result
73def clear_cache() -> int: # pragma: no cover
74 """Clear the cached fit specifications. It is recommended to perform
75 a garbage collection afterwards.
77 Returns:
78 int: The number of cached specs that have been released.
79 """
80 num_entries = len(_REGISTRY)
81 _REGISTRY.clear()
82 return num_entries