Coverage for src/scipy_dataclassfitparams/_dump_result.py: 100%
174 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-01 17:34 +0000
1import io as _io
2import math as _math
3import os as _os
5from dataclasses import dataclass as _dataclass
6from enum import Enum as _Enum, unique as _unique
7from io import TextIOBase as _TextIOBase
8from typing import Generic as _Generic, List as _List, Literal as _Literal, \
9 Optional as _Optional, overload as _overload, \
10 SupportsFloat as _SupportsFloat, Type as _Type, TypeVar as _TypeVar, \
11 Union as _Union
13from . import _spec_registry as _spec_registry
15from ._make_fit import FitResult as _FitResult
16from ._fit_spec import FitSpec as _FitSpec
17from ._fields import BoundedField as _BoundedField, \
18 ConstField as _ConstField, FittingField as _FittingField, \
19 RegularField as _RegularField, SameAsField as _SameAsField
22_T = _TypeVar("_T")
25@_unique
26class WhichFloat(_Enum):
27 """Determines which float is to be formatted."""
29 FieldValue = 1
30 """The value of a field should be formatted."""
32 LowerBound = 2
33 """The lower bound for a field should be formatted."""
35 UpperBound = 3
36 """The upper bound for a field should be formatted."""
38 InitialValue = 4
39 """The initial value for a field should be formatted."""
42@_unique
43class PrintBounds(_Enum):
44 """The mode for printing information about the bounds of a fitting
45 field.
46 """
48 always = "always"
49 """Always print bounds."""
51 never = "never"
52 """Suppress all bounds."""
54 bounded = "bounded"
55 """Only print bounds for `BoundedField`s."""
57 @classmethod
58 def sanitize(cls, value: "_PrintBoundsArg") -> "PrintBounds":
59 """Sanitize a `_PrintBoundsArg` value to be a `PrintBounds`
60 value.
62 Args:
63 value (_PrintBoundsArg): The value to sanitize.
65 Raises:
66 TypeError: If `value` is not a valid `_PrintBoundsArg`
67 value and not a str.
68 ValueError: If `value` is not a valid `_PrintBoundsArg`
69 value, but a str, and therefore does not correspond to a
70 `PrintBounds` instance.
72 Returns:
73 PrintBounds: The `PrintBounds` instance.
74 """
75 if isinstance(value, PrintBounds):
76 return value
77 inner_exc = None
78 exception: _Type[Exception]
79 exception = TypeError
80 if isinstance(value, str):
81 try:
82 return cls[value]
83 except KeyError as e:
84 inner_exc = e
85 exception = ValueError
86 raise exception(
87 f"Invalid 'PrintBounds' value {value!r}."
88 ) from inner_exc
91_PrintBoundsLiteral = _Literal["always", "never", "bounded"] # noqa: F821
92"""The values for the `PrintBounds` enumeration as `str` values."""
94_PrintBoundsArg = _Union[PrintBounds, _PrintBoundsLiteral]
95"""Either a value of the `PrintBounds` enumeration or its `str`
96value."""
99@_dataclass(frozen=True)
100class InstanceFormatter(_Generic[_T]):
102 __slots__ = (
103 "instance",
104 "print_header",
105 "print_extra_params",
106 "linesep",
107 "print_bounds",
108 "print_initial_values",
109 "clss",
110 "fitspec",
111 "initial_instance",
112 "p0_provided"
113 )
115 instance: _T
116 """The instance to format."""
118 print_header: bool
119 """Whether to print the header indicating the class name. Defaults
120 to True."""
122 print_extra_params: bool
123 """Whether to print parameters that did not partake in the fitting
124 procedure, such as dependent fields. Defaults to True."""
126 linesep: str
127 """The line separator. Defaults to `os.linesep`."""
129 print_bounds: PrintBounds
130 """The mode for printing information about the bounds of a fitting
131 field. Defaults to 'always'."""
133 print_initial_values: bool
134 """Whether to print initial values for regular and bound fields.
135 Defaults to True."""
137 clss: _Type[_T]
138 """The type of the instance to format."""
140 fitspec: _FitSpec[_T]
141 """The fitspec for the instance,"""
143 initial_instance: _T
144 """The default instance (p0) which starts the fit corresponding to
145 the instance. This may either be the provided value or the default
146 value constructed from the `FitSpec`, depending on `p0_provided`."""
148 p0_provided: bool
149 """Whether a p0 instance was provided."""
151 @_overload
152 def __init__(
153 self,
154 instance: _T,
155 *,
156 print_header: _Optional[bool] = None,
157 print_extra_params: _Optional[bool] = None,
158 print_bounds: _Optional[_PrintBoundsArg] = None,
159 print_initial_values: _Optional[bool] = None,
160 p0: _Optional[_T] = None,
161 linesep: _Optional[str] = None
162 ):
163 pass
165 @_overload
166 def __init__(
167 self,
168 instance: _FitResult[_T],
169 *,
170 print_header: _Optional[bool] = None,
171 print_extra_params: _Optional[bool] = None,
172 print_bounds: _Optional[_PrintBoundsArg] = None,
173 print_initial_values: _Optional[bool] = None,
174 linesep: _Optional[str] = None
175 ):
176 pass
178 # Init has to be provided manually because if incompatibilities
179 # between __slots__ and default values.
180 def __init__(
181 self,
182 instance: _Union[_T, _FitResult[_T]],
183 *,
184 print_header: _Optional[bool] = None,
185 print_extra_params: _Optional[bool] = None,
186 print_bounds: _Optional[_PrintBoundsArg] = None,
187 print_initial_values: _Optional[bool] = None,
188 p0: _Optional[_T] = None,
189 linesep: _Optional[str] = None
190 ):
191 """Create a new instance of the default `InstanceFormatter`.
193 Args:
194 instance (T | FitResult[T]]): Either the instance or the
195 corresponding FitResult which should be formatted.
196 print_header (bool, optional): Whether to print the header
197 mentioning the class name. Defaults to True.
198 print_extra_params (bool, optional): Whether to print the
199 extra parameters available on the class which were not
200 fit against, such as constant values. Defaults to True.
201 print_bounds (PrintBounds | Literal["always", "never",
202 "bounded"], optional): A `PrintBounds` value or the
203 corresponding `str` alue indicating how to print the
204 bounds for fields. Defaults to "always".
205 print_initial_values (bool, optional): Whether to print
206 the initial values. Defaults to True. If no `p0` is
207 provided, the initial instance will be generated
208 from the `FitSpec`.
209 p0 (T, optional): The initial instance. If not provided, it
210 will be generated from the `FitSpec`. Must not be
211 provided if `instance` is a `FitResult` instance, in
212 which case p0 will be inferred from the instance.
213 linesep (str, optional): The newline separator to use.
214 Defaults to `os.linesep`.
216 Raises:
217 ValueError: If `print_bounds` is not a valid value as
218 mentioned above. If `instance` is a `FitResult` instance
219 and `p0` is provided.
220 """
221 if isinstance(instance, _FitResult):
222 if p0 is not None:
223 raise ValueError(
224 "p0 must not be provided when 'instance' is a "
225 "FitResult."
226 )
227 p0 = instance.p0
228 instance = instance.opt_instance
230 object.__setattr__(self, "instance", instance)
231 clss = instance.__class__
232 object.__setattr__(self, "clss", clss)
233 fitspec = _spec_registry.get_fit_spec(clss)
234 object.__setattr__(self, "fitspec", fitspec)
236 if p0 is None:
237 p0 = fitspec.create_default_fit_instance()
238 object.__setattr__(self, "p0_provided", False)
239 else:
240 object.__setattr__(self, "p0_provided", True)
241 object.__setattr__(self, "initial_instance", p0)
243 if print_header is None:
244 print_header = True
245 object.__setattr__(self, "print_header", print_header)
246 if print_extra_params is None:
247 print_extra_params = True
248 object.__setattr__(self, "print_extra_params", print_extra_params)
249 if linesep is None:
250 linesep = _os.linesep
251 object.__setattr__(self, "linesep", linesep)
252 if print_initial_values is None:
253 print_initial_values = True
254 object.__setattr__(self, "print_initial_values", print_initial_values)
256 if print_bounds is None:
257 print_bounds = PrintBounds.always
258 print_bounds = PrintBounds.sanitize(print_bounds)
259 object.__setattr__(self, "print_bounds", print_bounds)
261 def _write_line(self, line: str, stream: _TextIOBase) -> int:
262 """Write a line to a text stream. It will be followed by the
263 `self.linesep` value.
265 Args:
266 line (str): The line to write.
267 stream (TextIOBase): The stream to write to
269 Returns:
270 int: The total numbers of written characters, including the
271 line separator.
272 """
273 return stream.write(line) + stream.write(self.linesep)
275 def format_instance(self) -> str:
276 """Format the instance and return the result as a `str`.
278 Returns:
279 str: The formatted instance.
280 """
281 buffer = _io.StringIO(newline=self.linesep)
282 self.process_instance(buffer)
283 return buffer.getvalue()
285 # Maybe write to io object instead of list?
286 def process_instance(self, output: _TextIOBase):
287 """Process the instance and write the output information to the
288 provided stream. Note that if an exception occurs during this
289 method, the provided stream may contain incomplete/corrupted
290 output.
292 Args:
293 output (TextIOBase): The stream to write to.
294 """
296 # Header
297 if self.print_header:
298 header = self._format_header()
299 self._write_line(header, output)
301 instance = self.instance
303 # Fitting fields
304 fs = self.fitspec
305 descriptors = dict(fs.fitting_fields)
306 for fpfn in fs.fitting_params:
307 field = descriptors.pop(fpfn)
308 value = getattr(instance, fpfn)
309 ff = self._format_field(fpfn, True, field, value)
310 self._write_line(ff, output)
312 # Extra fields
313 if self.print_extra_params and (len(descriptors) > 0):
314 self._write_line("Additional parameters (not fitted):", output)
316 for p in fs.init_order:
317 try:
318 field = descriptors.pop(p)
319 except KeyError:
320 continue
321 value = getattr(instance, p)
322 ff = self._format_field(p, False, field, value)
323 self._write_line(ff, output)
325 def _format_header(self) -> str:
326 """Format the header that contains the class name.
328 Returns:
329 str: The header to write to the output.
330 """
331 clss = self.clss
332 return f"Fit performed with type {clss.__qualname__!r}:"
334 def _format_float(
335 self,
336 value: _SupportsFloat,
337 field: str,
338 which: WhichFloat
339 ) -> str:
340 """Format a float to be written to the output.
342 Args:
343 value (SupportsFloat): The value of the float. This may also
344 be inf or -inf.
345 field (str): The field this value belongs to.
346 which (WhichFloat): Information about in which way the
347 provided value belongs to the field.
349 Returns:
350 str: The formatted float.
351 """
352 return f"{float(value):.15e}"
354 def _format_bounds_extra(
355 self,
356 name: str,
357 field: _Union[_FittingField, None]
358 ) -> str:
359 """Format the extra information about the bounds.
361 Args:
362 name (str): The name of the field.
363 field (FittingField or None): The fit field description.
365 Returns:
366 str: The extra information about the field bounds.
367 """
368 if (
369 (not isinstance(field, _BoundedField))
370 or (not field.actually_bounded)
371 ):
372 return "unbounded"
374 def format_bound(
375 v: float,
376 b_inf: str,
377 b_fin: str,
378 which: WhichFloat,
379 fstr: str
380 ) -> str:
381 b = b_inf if _math.isinf(v) else b_fin
382 v_s = self._format_float(v, name, which)
383 return fstr.format(value=v_s, bracket=b)
385 max_s = format_bound(
386 field.max, '[', ']', WhichFloat.UpperBound, "{value}{bracket}"
387 )
388 min_s = format_bound(
389 field.min, ']', '[', WhichFloat.LowerBound, "{bracket}{value}"
390 )
391 return f"bounded: {min_s};{max_s}"
393 def _format_initial_value_extra(
394 self,
395 name: str,
396 field: _Union[_FittingField, None]
397 ) -> str:
398 """Format the extra information about the field's initial value.
400 Args:
401 name (str): The name of the field
402 field (FittingField or None): The fit field description.
404 Returns:
405 str: The extra information about the initial value of the
406 field.
407 """
408 default = getattr(self.initial_instance, name)
409 dfvs = self._format_float(default, name, WhichFloat.InitialValue)
410 return f"initial: {dfvs}"
412 def _format_field(
413 self,
414 name: str,
415 is_fitparam: bool,
416 field: _Union[_FittingField, None],
417 value: float
418 ) -> str:
419 """Format a field.
421 Args:
422 name (str): The name of the field.
423 is_fitparam (bool): Whether the field as a fit parameter.
424 If False, the field is an "extra".
425 field (FittingField or None): The fit field description.
426 value (float): The value of the field in the optimized
427 instance.
429 Raises:
430 NotImplementedError: If `field` is not a known
431 `FittingField`.
432 TypeError: If `field` is not a `FittingField`.
434 Returns:
435 str: The formatted information regarding the `field`.
436 """
437 value_str = self._format_float(value, name, WhichFloat.FieldValue)
439 extras: _List[str] = list()
441 if isinstance(field, _ConstField):
442 extras.append("const.")
443 elif isinstance(field, _SameAsField):
444 extras.append(f"=!= {field.name!r}")
445 elif isinstance(field, _BoundedField):
446 if self.print_bounds != PrintBounds.never:
447 extras.append(self._format_bounds_extra(name, field))
448 if self.print_initial_values:
449 extras.append(self._format_initial_value_extra(name, field))
450 elif (field is None) or isinstance(field, _RegularField):
451 if self.print_bounds == PrintBounds.always:
452 extras.append(self._format_bounds_extra(name, field))
453 if self.print_initial_values:
454 extras.append(self._format_initial_value_extra(name, field))
455 else: # pragma: no cover
456 if isinstance(field, _FittingField):
457 exception_type = NotImplementedError
458 else:
459 exception_type = TypeError
460 raise exception_type(
461 f"Invalid field description encountered: {field!r}"
462 )
464 extra_str = ''
465 if len(extras) > 0:
466 fextras = ", ".join(extras)
467 extra_str = f" ({fextras})"
469 return f"{name}: {value_str}{extra_str}"
472@_overload
473def dump_result(
474 instance: _FitResult[_T],
475 *,
476 print_header: _Optional[bool] = None,
477 print_extra_params: _Optional[bool] = None,
478 print_bounds: _Optional[_PrintBoundsArg] = None,
479 print_initial_values: _Optional[bool] = None,
480 linesep: _Optional[str] = None
481) -> str:
482 pass
485@_overload
486def dump_result(
487 instance: _T,
488 *,
489 print_header: _Optional[bool] = None,
490 print_extra_params: _Optional[bool] = None,
491 print_bounds: _Optional[_PrintBoundsArg] = None,
492 print_initial_values: _Optional[bool] = None,
493 p0: _Optional[_T] = None,
494 linesep: _Optional[str] = None
495) -> str:
496 pass
499def dump_result(
500 instance: _Union[_T, _FitResult[_T], _FitResult],
501 *,
502 print_header: _Optional[bool] = None,
503 print_extra_params: _Optional[bool] = None,
504 print_bounds: _Optional[_PrintBoundsArg] = None,
505 print_initial_values: _Optional[bool] = None,
506 p0: _Optional[_T] = None,
507 linesep: _Optional[str] = None
508) -> str:
509 """Format the optimal instance of a fit into a multiline `str` with
510 details using the default `InstanceFormatter`. It is recommended to
511 provided the full `FitResult` instance to have `p0` included
512 automatically.
514 Args:
515 instance (T, FitResult[T]): The optimal instance to format,
516 either provided directly or via the full `FitResult`.
517 print_header (bool, optional): Whether to print the header
518 mentioning the class name. Defaults to True.
519 print_extra_params (bool, optional): Whether to print the
520 extra parameters available on the class which were not
521 fit against, such as constant values. Defaults to True.
522 print_bounds (PrintBounds | Literal["always", "never",
523 "bounded"], optional): A `PrintBounds` value or the
524 corresponding `str` alue indicating how to print the
525 bounds for fields. Defaults to "always".
526 print_initial_values (bool, optional): Whether to print
527 the initial values. Defaults to True. If no `p0` is
528 provided, the initial instance will be generated
529 from the `FitSpec`.
530 p0 (T, optional): The initial instance. If not provided, it
531 will be generated from the `FitSpec`. Must not be
532 provided if `instance` is a `FitResult` instance, in
533 which case p0 will be inferred from the instance.
534 linesep (str, optional): The newline separator to use.
535 Defaults to `os.linesep`.
537 Raises:
538 ValueError: If `instance` is a `FitResult` instance and `p0` is
539 provided.
541 Returns:
542 str: The str describing the parameters of the optimal instance.
543 """
544 formatter = InstanceFormatter(
545 instance,
546 print_header=print_header,
547 print_extra_params=print_extra_params,
548 print_bounds=print_bounds,
549 print_initial_values=print_initial_values,
550 p0=p0,
551 linesep=linesep
552 )
553 return formatter.format_instance()