Coverage for src/scipy_dataclassfitparams/_depgraph.py: 100%
72 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
1"""This module contains the `DepGraph` class which is a helper class for
2keeping track of the dependencies between fitting fields.
3"""
5from scipy import sparse as _sparse # type: ignore [import]
6from scipy.sparse import csgraph as _csgraph # type: ignore [import]
8from dataclasses import dataclass as _dataclass
9from typing import Iterable as _Iterable, List as _List, Mapping as _Mapping, \
10 Sequence as _Sequence, Set as _Set
13@_dataclass(order=False, frozen=True)
14class DepGraph:
16 __slots__ = ("N", "_graph", "fields", "_key_map", "_num_deps")
18 N: int
19 """The number of elements in the graph."""
21 _graph: _sparse.lil_matrix
22 """The actual graph used for computations stored as a sparse
23 matrix."""
25 fields: _Sequence[str]
26 """The fields."""
28 _key_map: _Mapping[str, int]
29 """The mapping from the field names to their index in the graph
30 matrix."""
32 _num_deps: int
33 """The number of dependencies."""
35 # It is not possible to use a __post_init__ here since _graph would
36 # have to be init=False. But assigning to the field means it is a
37 # class variable with the same name as a __slots__ variable, which
38 # is not allowed.
39 def __init__(self, fields: _Iterable[str]):
40 """Construct a new `DepGraph`.
42 Args:
43 fields (Iterable[str]): The fields for which to track
44 dependencies.
45 """
46 fields = tuple(fields)
47 N = len(fields)
48 object.__setattr__(self, 'N', N)
49 g = _sparse.lil_matrix((N, N), dtype=int)
50 object.__setattr__(self, "_graph", g)
51 km = {s: i for i, s in enumerate(fields)}
52 object.__setattr__(self, "_key_map", km)
53 object.__setattr__(self, "fields", fields)
54 object.__setattr__(self, "_num_deps", 0)
56 def add_dependency(self, which: str, source: str) -> bool:
57 """Add a dependency to the graph.
59 Args:
60 which (str): The name of the field that "depends".
61 source (str): The name of the field on which it depends.
63 Raises:
64 ValueError: If a dependency in the opposite direction was
65 already set.
67 Returns:
68 bool: True, if the dependency was added and False, if it
69 was already present.
70 """
71 if which == source:
72 raise ValueError(f"The field {which!r} cannot depend on itself.")
73 km = self._key_map
74 wi = km[which]
75 si = km[source]
76 g = self._graph
77 v = g[wi, si]
78 change = v != -1
79 if change:
80 if g[si, wi] != 0:
81 raise ValueError(
82 f"The dependence of {which!r} on {source!r} is "
83 "invalid since this creates a direct circular "
84 f"dependence ({source!r} already depends directly "
85 f"on {which!r})."
86 )
87 g[wi, si] = -1
88 object.__setattr__(self, "_num_deps", 1 + self._num_deps)
89 return change
91 @property
92 def dependency_count(self) -> int:
93 """The number of dependencies ("edges" in the graph)."""
94 return self._num_deps
96 def has_closed_cycles(self) -> bool:
97 """Whether the dependency graph contains closed cycles,
98 indicating circular dependencies.
100 Returns:
101 bool: True, if the dependency graph contains closed loops,
102 otherwise False.
103 """
104 try:
105 _csgraph.floyd_warshall(self._graph, directed=True)
106 return False
107 except _csgraph.NegativeCycleError:
108 return True
110 def get_init_order(self) -> _Sequence[str]:
111 """Construct the order in which arguments must be initialized in
112 order to be compatible with the dependencies defined by the
113 graph. Note that this order is not necessarily unique, but for
114 fields that may be in arbitrary order, the order will follow
115 the `fields` __init__ parameter/field.
117 Raises:
118 ValueError: If the dependency graph contains closed cycles,
119 indicating circular dependencies.
121 Returns:
122 Sequence[str]: The order in which the fields have to be
123 initialized so that their dependencies can be satisfied.
124 """
125 if self.has_closed_cycles():
126 raise ValueError("The dependency graph contained closed loops.")
128 g = self._graph
130 idx_in_tree: _Set[int] = set(range(self.N))
131 idx_resolved: _List[int] = list()
133 # Keep track of how many fields depend on a certain field given
134 # by the index. We are working with opposite signs here.
135 nnum_deps = g.sum(axis=1)
137 while len(idx_in_tree) > 0:
139 # coverage.py is looking for a branch back to the while loop
140 # above, which is of course impossible.
141 for idx in idx_in_tree: # pragma: no branch
142 if nnum_deps[idx, 0] != 0:
143 continue
144 # There are no futher fields in the tree depending
145 # on idx.
147 # Decrement dependencies
148 for i in idx_in_tree:
149 if g[i, idx] == 0:
150 continue
151 # We are working with opposite signs
152 nnum_deps[i, 0] += 1
154 # Removeindex from tree
155 idx_in_tree.remove(idx)
156 idx_resolved.append(idx)
158 # The tree has been modified
159 break
161 fields = self.fields
162 field_names = (fields[i] for i in idx_resolved)
163 return tuple(field_names)