Coverage for src/scipy_dataclassfitparams/_depgraph.py: 100%

72 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-01 17:34 +0000

1"""This module contains the `DepGraph` class which is a helper class for 

2keeping track of the dependencies between fitting fields. 

3""" 

4 

5from scipy import sparse as _sparse # type: ignore [import] 

6from scipy.sparse import csgraph as _csgraph # type: ignore [import] 

7 

8from dataclasses import dataclass as _dataclass 

9from typing import Iterable as _Iterable, List as _List, Mapping as _Mapping, \ 

10 Sequence as _Sequence, Set as _Set 

11 

12 

13@_dataclass(order=False, frozen=True) 

14class DepGraph: 

15 

16 __slots__ = ("N", "_graph", "fields", "_key_map", "_num_deps") 

17 

18 N: int 

19 """The number of elements in the graph.""" 

20 

21 _graph: _sparse.lil_matrix 

22 """The actual graph used for computations stored as a sparse 

23 matrix.""" 

24 

25 fields: _Sequence[str] 

26 """The fields.""" 

27 

28 _key_map: _Mapping[str, int] 

29 """The mapping from the field names to their index in the graph 

30 matrix.""" 

31 

32 _num_deps: int 

33 """The number of dependencies.""" 

34 

35 # It is not possible to use a __post_init__ here since _graph would 

36 # have to be init=False. But assigning to the field means it is a 

37 # class variable with the same name as a __slots__ variable, which 

38 # is not allowed. 

39 def __init__(self, fields: _Iterable[str]): 

40 """Construct a new `DepGraph`. 

41 

42 Args: 

43 fields (Iterable[str]): The fields for which to track 

44 dependencies. 

45 """ 

46 fields = tuple(fields) 

47 N = len(fields) 

48 object.__setattr__(self, 'N', N) 

49 g = _sparse.lil_matrix((N, N), dtype=int) 

50 object.__setattr__(self, "_graph", g) 

51 km = {s: i for i, s in enumerate(fields)} 

52 object.__setattr__(self, "_key_map", km) 

53 object.__setattr__(self, "fields", fields) 

54 object.__setattr__(self, "_num_deps", 0) 

55 

56 def add_dependency(self, which: str, source: str) -> bool: 

57 """Add a dependency to the graph. 

58 

59 Args: 

60 which (str): The name of the field that "depends". 

61 source (str): The name of the field on which it depends. 

62 

63 Raises: 

64 ValueError: If a dependency in the opposite direction was 

65 already set. 

66 

67 Returns: 

68 bool: True, if the dependency was added and False, if it 

69 was already present. 

70 """ 

71 if which == source: 

72 raise ValueError(f"The field {which!r} cannot depend on itself.") 

73 km = self._key_map 

74 wi = km[which] 

75 si = km[source] 

76 g = self._graph 

77 v = g[wi, si] 

78 change = v != -1 

79 if change: 

80 if g[si, wi] != 0: 

81 raise ValueError( 

82 f"The dependence of {which!r} on {source!r} is " 

83 "invalid since this creates a direct circular " 

84 f"dependence ({source!r} already depends directly " 

85 f"on {which!r})." 

86 ) 

87 g[wi, si] = -1 

88 object.__setattr__(self, "_num_deps", 1 + self._num_deps) 

89 return change 

90 

91 @property 

92 def dependency_count(self) -> int: 

93 """The number of dependencies ("edges" in the graph).""" 

94 return self._num_deps 

95 

96 def has_closed_cycles(self) -> bool: 

97 """Whether the dependency graph contains closed cycles, 

98 indicating circular dependencies. 

99 

100 Returns: 

101 bool: True, if the dependency graph contains closed loops, 

102 otherwise False. 

103 """ 

104 try: 

105 _csgraph.floyd_warshall(self._graph, directed=True) 

106 return False 

107 except _csgraph.NegativeCycleError: 

108 return True 

109 

110 def get_init_order(self) -> _Sequence[str]: 

111 """Construct the order in which arguments must be initialized in 

112 order to be compatible with the dependencies defined by the 

113 graph. Note that this order is not necessarily unique, but for 

114 fields that may be in arbitrary order, the order will follow 

115 the `fields` __init__ parameter/field. 

116 

117 Raises: 

118 ValueError: If the dependency graph contains closed cycles, 

119 indicating circular dependencies. 

120 

121 Returns: 

122 Sequence[str]: The order in which the fields have to be 

123 initialized so that their dependencies can be satisfied. 

124 """ 

125 if self.has_closed_cycles(): 

126 raise ValueError("The dependency graph contained closed loops.") 

127 

128 g = self._graph 

129 

130 idx_in_tree: _Set[int] = set(range(self.N)) 

131 idx_resolved: _List[int] = list() 

132 

133 # Keep track of how many fields depend on a certain field given 

134 # by the index. We are working with opposite signs here. 

135 nnum_deps = g.sum(axis=1) 

136 

137 while len(idx_in_tree) > 0: 

138 

139 # coverage.py is looking for a branch back to the while loop 

140 # above, which is of course impossible. 

141 for idx in idx_in_tree: # pragma: no branch 

142 if nnum_deps[idx, 0] != 0: 

143 continue 

144 # There are no futher fields in the tree depending 

145 # on idx. 

146 

147 # Decrement dependencies 

148 for i in idx_in_tree: 

149 if g[i, idx] == 0: 

150 continue 

151 # We are working with opposite signs 

152 nnum_deps[i, 0] += 1 

153 

154 # Removeindex from tree 

155 idx_in_tree.remove(idx) 

156 idx_resolved.append(idx) 

157 

158 # The tree has been modified 

159 break 

160 

161 fields = self.fields 

162 field_names = (fields[i] for i in idx_resolved) 

163 return tuple(field_names)