Coverage for src/scipy_dataclassfitparams/_spec_registry.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-01 17:34 +0000

1"""This module contains the caching mechanism for the `FitSpec`s.""" 

2 

3import dataclasses as _dataclasses 

4import sys as _sys 

5import weakref as _weakref 

6 

7from dataclasses import dataclass as _dataclass 

8from typing import Generic as _Generic, Type as _Type, TypeVar as _TypeVar 

9 

10from ._fit_spec import FitSpec as _FitSpec, FitSpecBase as _FitSpecBase 

11 

12 

13_T = _TypeVar("_T") 

14 

15 

16@_dataclass(frozen=True) 

17class _StoredFitSpec(_FitSpecBase, _Generic[_T]): 

18 """A `_FitSpecBase` stored in the registry. No reference to the 

19 actual type defining the fit is stored. 

20 """ 

21 

22 def pin(self, cls: _Type[_T]) -> "_FitSpec[_T]": 

23 """Pin the actual type of the dataclass defining the fit and 

24 return a `FitSpec` instance containing the type. 

25 

26 Args: 

27 cls (Type[T]): The class defining the fit. 

28 

29 Returns: 

30 FitSpec[T]: The constructed `FitSpec` instance. 

31 """ 

32 # Copy over the fields 

33 keys = (f.name for f in _dataclasses.fields(_FitSpecBase)) 

34 kwargs = {k: getattr(self, k) for k in keys} 

35 return _FitSpec(clss=cls, **kwargs) 

36 

37 

38if _sys.version_info < (3, 9): 

39 from typing import Dict as _Dict 

40 

41 _reg: _Dict[type, _StoredFitSpec] 

42 _reg = _weakref.WeakKeyDictionary() # type: ignore [assignment] 

43else: 

44 _reg = _weakref.WeakKeyDictionary[ # type: ignore [assignment] 

45 type, _StoredFitSpec 

46 ]() 

47 

48 

49_REGISTRY = _reg 

50"""The registry mapping the classes defining fits to the corresponding 

51`_StoredFitSpec`.""" 

52 

53 

54def get_fit_spec(t: _Type[_T]) -> _FitSpec[_T]: 

55 """Get the fit spec for a class defining a fit. 

56 

57 Args: 

58 t (Type[T]): The class for which to get the fit spec. 

59 

60 Returns: 

61 FitSpec[T]: The additional information for the dataclass 

62 defining the fit. 

63 """ 

64 try: 

65 fs = _REGISTRY[t] 

66 except KeyError: 

67 fs = _StoredFitSpec.generate(t) 

68 _REGISTRY[t] = fs 

69 result = fs.pin(t) 

70 return result 

71 

72 

73def clear_cache() -> int: # pragma: no cover 

74 """Clear the cached fit specifications. It is recommended to perform 

75 a garbage collection afterwards. 

76 

77 Returns: 

78 int: The number of cached specs that have been released. 

79 """ 

80 num_entries = len(_REGISTRY) 

81 _REGISTRY.clear() 

82 return num_entries