Coverage for src/scipy_dataclassfitparams/_dump_result.py: 100%

174 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-01 17:34 +0000

1import io as _io 

2import math as _math 

3import os as _os 

4 

5from dataclasses import dataclass as _dataclass 

6from enum import Enum as _Enum, unique as _unique 

7from io import TextIOBase as _TextIOBase 

8from typing import Generic as _Generic, List as _List, Literal as _Literal, \ 

9 Optional as _Optional, overload as _overload, \ 

10 SupportsFloat as _SupportsFloat, Type as _Type, TypeVar as _TypeVar, \ 

11 Union as _Union 

12 

13from . import _spec_registry as _spec_registry 

14 

15from ._make_fit import FitResult as _FitResult 

16from ._fit_spec import FitSpec as _FitSpec 

17from ._fields import BoundedField as _BoundedField, \ 

18 ConstField as _ConstField, FittingField as _FittingField, \ 

19 RegularField as _RegularField, SameAsField as _SameAsField 

20 

21 

22_T = _TypeVar("_T") 

23 

24 

25@_unique 

26class WhichFloat(_Enum): 

27 """Determines which float is to be formatted.""" 

28 

29 FieldValue = 1 

30 """The value of a field should be formatted.""" 

31 

32 LowerBound = 2 

33 """The lower bound for a field should be formatted.""" 

34 

35 UpperBound = 3 

36 """The upper bound for a field should be formatted.""" 

37 

38 InitialValue = 4 

39 """The initial value for a field should be formatted.""" 

40 

41 

42@_unique 

43class PrintBounds(_Enum): 

44 """The mode for printing information about the bounds of a fitting 

45 field. 

46 """ 

47 

48 always = "always" 

49 """Always print bounds.""" 

50 

51 never = "never" 

52 """Suppress all bounds.""" 

53 

54 bounded = "bounded" 

55 """Only print bounds for `BoundedField`s.""" 

56 

57 @classmethod 

58 def sanitize(cls, value: "_PrintBoundsArg") -> "PrintBounds": 

59 """Sanitize a `_PrintBoundsArg` value to be a `PrintBounds` 

60 value. 

61 

62 Args: 

63 value (_PrintBoundsArg): The value to sanitize. 

64 

65 Raises: 

66 TypeError: If `value` is not a valid `_PrintBoundsArg` 

67 value and not a str. 

68 ValueError: If `value` is not a valid `_PrintBoundsArg` 

69 value, but a str, and therefore does not correspond to a 

70 `PrintBounds` instance. 

71 

72 Returns: 

73 PrintBounds: The `PrintBounds` instance. 

74 """ 

75 if isinstance(value, PrintBounds): 

76 return value 

77 inner_exc = None 

78 exception: _Type[Exception] 

79 exception = TypeError 

80 if isinstance(value, str): 

81 try: 

82 return cls[value] 

83 except KeyError as e: 

84 inner_exc = e 

85 exception = ValueError 

86 raise exception( 

87 f"Invalid 'PrintBounds' value {value!r}." 

88 ) from inner_exc 

89 

90 

91_PrintBoundsLiteral = _Literal["always", "never", "bounded"] # noqa: F821 

92"""The values for the `PrintBounds` enumeration as `str` values.""" 

93 

94_PrintBoundsArg = _Union[PrintBounds, _PrintBoundsLiteral] 

95"""Either a value of the `PrintBounds` enumeration or its `str` 

96value.""" 

97 

98 

99@_dataclass(frozen=True) 

100class InstanceFormatter(_Generic[_T]): 

101 

102 __slots__ = ( 

103 "instance", 

104 "print_header", 

105 "print_extra_params", 

106 "linesep", 

107 "print_bounds", 

108 "print_initial_values", 

109 "clss", 

110 "fitspec", 

111 "initial_instance", 

112 "p0_provided" 

113 ) 

114 

115 instance: _T 

116 """The instance to format.""" 

117 

118 print_header: bool 

119 """Whether to print the header indicating the class name. Defaults 

120 to True.""" 

121 

122 print_extra_params: bool 

123 """Whether to print parameters that did not partake in the fitting 

124 procedure, such as dependent fields. Defaults to True.""" 

125 

126 linesep: str 

127 """The line separator. Defaults to `os.linesep`.""" 

128 

129 print_bounds: PrintBounds 

130 """The mode for printing information about the bounds of a fitting 

131 field. Defaults to 'always'.""" 

132 

133 print_initial_values: bool 

134 """Whether to print initial values for regular and bound fields. 

135 Defaults to True.""" 

136 

137 clss: _Type[_T] 

138 """The type of the instance to format.""" 

139 

140 fitspec: _FitSpec[_T] 

141 """The fitspec for the instance,""" 

142 

143 initial_instance: _T 

144 """The default instance (p0) which starts the fit corresponding to 

145 the instance. This may either be the provided value or the default 

146 value constructed from the `FitSpec`, depending on `p0_provided`.""" 

147 

148 p0_provided: bool 

149 """Whether a p0 instance was provided.""" 

150 

151 @_overload 

152 def __init__( 

153 self, 

154 instance: _T, 

155 *, 

156 print_header: _Optional[bool] = None, 

157 print_extra_params: _Optional[bool] = None, 

158 print_bounds: _Optional[_PrintBoundsArg] = None, 

159 print_initial_values: _Optional[bool] = None, 

160 p0: _Optional[_T] = None, 

161 linesep: _Optional[str] = None 

162 ): 

163 pass 

164 

165 @_overload 

166 def __init__( 

167 self, 

168 instance: _FitResult[_T], 

169 *, 

170 print_header: _Optional[bool] = None, 

171 print_extra_params: _Optional[bool] = None, 

172 print_bounds: _Optional[_PrintBoundsArg] = None, 

173 print_initial_values: _Optional[bool] = None, 

174 linesep: _Optional[str] = None 

175 ): 

176 pass 

177 

178 # Init has to be provided manually because if incompatibilities 

179 # between __slots__ and default values. 

180 def __init__( 

181 self, 

182 instance: _Union[_T, _FitResult[_T]], 

183 *, 

184 print_header: _Optional[bool] = None, 

185 print_extra_params: _Optional[bool] = None, 

186 print_bounds: _Optional[_PrintBoundsArg] = None, 

187 print_initial_values: _Optional[bool] = None, 

188 p0: _Optional[_T] = None, 

189 linesep: _Optional[str] = None 

190 ): 

191 """Create a new instance of the default `InstanceFormatter`. 

192 

193 Args: 

194 instance (T | FitResult[T]]): Either the instance or the 

195 corresponding FitResult which should be formatted. 

196 print_header (bool, optional): Whether to print the header 

197 mentioning the class name. Defaults to True. 

198 print_extra_params (bool, optional): Whether to print the 

199 extra parameters available on the class which were not 

200 fit against, such as constant values. Defaults to True. 

201 print_bounds (PrintBounds | Literal["always", "never", 

202 "bounded"], optional): A `PrintBounds` value or the 

203 corresponding `str` alue indicating how to print the 

204 bounds for fields. Defaults to "always". 

205 print_initial_values (bool, optional): Whether to print 

206 the initial values. Defaults to True. If no `p0` is 

207 provided, the initial instance will be generated 

208 from the `FitSpec`. 

209 p0 (T, optional): The initial instance. If not provided, it 

210 will be generated from the `FitSpec`. Must not be 

211 provided if `instance` is a `FitResult` instance, in 

212 which case p0 will be inferred from the instance. 

213 linesep (str, optional): The newline separator to use. 

214 Defaults to `os.linesep`. 

215 

216 Raises: 

217 ValueError: If `print_bounds` is not a valid value as 

218 mentioned above. If `instance` is a `FitResult` instance 

219 and `p0` is provided. 

220 """ 

221 if isinstance(instance, _FitResult): 

222 if p0 is not None: 

223 raise ValueError( 

224 "p0 must not be provided when 'instance' is a " 

225 "FitResult." 

226 ) 

227 p0 = instance.p0 

228 instance = instance.opt_instance 

229 

230 object.__setattr__(self, "instance", instance) 

231 clss = instance.__class__ 

232 object.__setattr__(self, "clss", clss) 

233 fitspec = _spec_registry.get_fit_spec(clss) 

234 object.__setattr__(self, "fitspec", fitspec) 

235 

236 if p0 is None: 

237 p0 = fitspec.create_default_fit_instance() 

238 object.__setattr__(self, "p0_provided", False) 

239 else: 

240 object.__setattr__(self, "p0_provided", True) 

241 object.__setattr__(self, "initial_instance", p0) 

242 

243 if print_header is None: 

244 print_header = True 

245 object.__setattr__(self, "print_header", print_header) 

246 if print_extra_params is None: 

247 print_extra_params = True 

248 object.__setattr__(self, "print_extra_params", print_extra_params) 

249 if linesep is None: 

250 linesep = _os.linesep 

251 object.__setattr__(self, "linesep", linesep) 

252 if print_initial_values is None: 

253 print_initial_values = True 

254 object.__setattr__(self, "print_initial_values", print_initial_values) 

255 

256 if print_bounds is None: 

257 print_bounds = PrintBounds.always 

258 print_bounds = PrintBounds.sanitize(print_bounds) 

259 object.__setattr__(self, "print_bounds", print_bounds) 

260 

261 def _write_line(self, line: str, stream: _TextIOBase) -> int: 

262 """Write a line to a text stream. It will be followed by the 

263 `self.linesep` value. 

264 

265 Args: 

266 line (str): The line to write. 

267 stream (TextIOBase): The stream to write to 

268 

269 Returns: 

270 int: The total numbers of written characters, including the 

271 line separator. 

272 """ 

273 return stream.write(line) + stream.write(self.linesep) 

274 

275 def format_instance(self) -> str: 

276 """Format the instance and return the result as a `str`. 

277 

278 Returns: 

279 str: The formatted instance. 

280 """ 

281 buffer = _io.StringIO(newline=self.linesep) 

282 self.process_instance(buffer) 

283 return buffer.getvalue() 

284 

285 # Maybe write to io object instead of list? 

286 def process_instance(self, output: _TextIOBase): 

287 """Process the instance and write the output information to the 

288 provided stream. Note that if an exception occurs during this 

289 method, the provided stream may contain incomplete/corrupted 

290 output. 

291 

292 Args: 

293 output (TextIOBase): The stream to write to. 

294 """ 

295 

296 # Header 

297 if self.print_header: 

298 header = self._format_header() 

299 self._write_line(header, output) 

300 

301 instance = self.instance 

302 

303 # Fitting fields 

304 fs = self.fitspec 

305 descriptors = dict(fs.fitting_fields) 

306 for fpfn in fs.fitting_params: 

307 field = descriptors.pop(fpfn) 

308 value = getattr(instance, fpfn) 

309 ff = self._format_field(fpfn, True, field, value) 

310 self._write_line(ff, output) 

311 

312 # Extra fields 

313 if self.print_extra_params and (len(descriptors) > 0): 

314 self._write_line("Additional parameters (not fitted):", output) 

315 

316 for p in fs.init_order: 

317 try: 

318 field = descriptors.pop(p) 

319 except KeyError: 

320 continue 

321 value = getattr(instance, p) 

322 ff = self._format_field(p, False, field, value) 

323 self._write_line(ff, output) 

324 

325 def _format_header(self) -> str: 

326 """Format the header that contains the class name. 

327 

328 Returns: 

329 str: The header to write to the output. 

330 """ 

331 clss = self.clss 

332 return f"Fit performed with type {clss.__qualname__!r}:" 

333 

334 def _format_float( 

335 self, 

336 value: _SupportsFloat, 

337 field: str, 

338 which: WhichFloat 

339 ) -> str: 

340 """Format a float to be written to the output. 

341 

342 Args: 

343 value (SupportsFloat): The value of the float. This may also 

344 be inf or -inf. 

345 field (str): The field this value belongs to. 

346 which (WhichFloat): Information about in which way the 

347 provided value belongs to the field. 

348 

349 Returns: 

350 str: The formatted float. 

351 """ 

352 return f"{float(value):.15e}" 

353 

354 def _format_bounds_extra( 

355 self, 

356 name: str, 

357 field: _Union[_FittingField, None] 

358 ) -> str: 

359 """Format the extra information about the bounds. 

360 

361 Args: 

362 name (str): The name of the field. 

363 field (FittingField or None): The fit field description. 

364 

365 Returns: 

366 str: The extra information about the field bounds. 

367 """ 

368 if ( 

369 (not isinstance(field, _BoundedField)) 

370 or (not field.actually_bounded) 

371 ): 

372 return "unbounded" 

373 

374 def format_bound( 

375 v: float, 

376 b_inf: str, 

377 b_fin: str, 

378 which: WhichFloat, 

379 fstr: str 

380 ) -> str: 

381 b = b_inf if _math.isinf(v) else b_fin 

382 v_s = self._format_float(v, name, which) 

383 return fstr.format(value=v_s, bracket=b) 

384 

385 max_s = format_bound( 

386 field.max, '[', ']', WhichFloat.UpperBound, "{value}{bracket}" 

387 ) 

388 min_s = format_bound( 

389 field.min, ']', '[', WhichFloat.LowerBound, "{bracket}{value}" 

390 ) 

391 return f"bounded: {min_s};{max_s}" 

392 

393 def _format_initial_value_extra( 

394 self, 

395 name: str, 

396 field: _Union[_FittingField, None] 

397 ) -> str: 

398 """Format the extra information about the field's initial value. 

399 

400 Args: 

401 name (str): The name of the field 

402 field (FittingField or None): The fit field description. 

403 

404 Returns: 

405 str: The extra information about the initial value of the 

406 field. 

407 """ 

408 default = getattr(self.initial_instance, name) 

409 dfvs = self._format_float(default, name, WhichFloat.InitialValue) 

410 return f"initial: {dfvs}" 

411 

412 def _format_field( 

413 self, 

414 name: str, 

415 is_fitparam: bool, 

416 field: _Union[_FittingField, None], 

417 value: float 

418 ) -> str: 

419 """Format a field. 

420 

421 Args: 

422 name (str): The name of the field. 

423 is_fitparam (bool): Whether the field as a fit parameter. 

424 If False, the field is an "extra". 

425 field (FittingField or None): The fit field description. 

426 value (float): The value of the field in the optimized 

427 instance. 

428 

429 Raises: 

430 NotImplementedError: If `field` is not a known 

431 `FittingField`. 

432 TypeError: If `field` is not a `FittingField`. 

433 

434 Returns: 

435 str: The formatted information regarding the `field`. 

436 """ 

437 value_str = self._format_float(value, name, WhichFloat.FieldValue) 

438 

439 extras: _List[str] = list() 

440 

441 if isinstance(field, _ConstField): 

442 extras.append("const.") 

443 elif isinstance(field, _SameAsField): 

444 extras.append(f"=!= {field.name!r}") 

445 elif isinstance(field, _BoundedField): 

446 if self.print_bounds != PrintBounds.never: 

447 extras.append(self._format_bounds_extra(name, field)) 

448 if self.print_initial_values: 

449 extras.append(self._format_initial_value_extra(name, field)) 

450 elif (field is None) or isinstance(field, _RegularField): 

451 if self.print_bounds == PrintBounds.always: 

452 extras.append(self._format_bounds_extra(name, field)) 

453 if self.print_initial_values: 

454 extras.append(self._format_initial_value_extra(name, field)) 

455 else: # pragma: no cover 

456 if isinstance(field, _FittingField): 

457 exception_type = NotImplementedError 

458 else: 

459 exception_type = TypeError 

460 raise exception_type( 

461 f"Invalid field description encountered: {field!r}" 

462 ) 

463 

464 extra_str = '' 

465 if len(extras) > 0: 

466 fextras = ", ".join(extras) 

467 extra_str = f" ({fextras})" 

468 

469 return f"{name}: {value_str}{extra_str}" 

470 

471 

472@_overload 

473def dump_result( 

474 instance: _FitResult[_T], 

475 *, 

476 print_header: _Optional[bool] = None, 

477 print_extra_params: _Optional[bool] = None, 

478 print_bounds: _Optional[_PrintBoundsArg] = None, 

479 print_initial_values: _Optional[bool] = None, 

480 linesep: _Optional[str] = None 

481) -> str: 

482 pass 

483 

484 

485@_overload 

486def dump_result( 

487 instance: _T, 

488 *, 

489 print_header: _Optional[bool] = None, 

490 print_extra_params: _Optional[bool] = None, 

491 print_bounds: _Optional[_PrintBoundsArg] = None, 

492 print_initial_values: _Optional[bool] = None, 

493 p0: _Optional[_T] = None, 

494 linesep: _Optional[str] = None 

495) -> str: 

496 pass 

497 

498 

499def dump_result( 

500 instance: _Union[_T, _FitResult[_T], _FitResult], 

501 *, 

502 print_header: _Optional[bool] = None, 

503 print_extra_params: _Optional[bool] = None, 

504 print_bounds: _Optional[_PrintBoundsArg] = None, 

505 print_initial_values: _Optional[bool] = None, 

506 p0: _Optional[_T] = None, 

507 linesep: _Optional[str] = None 

508) -> str: 

509 """Format the optimal instance of a fit into a multiline `str` with 

510 details using the default `InstanceFormatter`. It is recommended to 

511 provided the full `FitResult` instance to have `p0` included 

512 automatically. 

513 

514 Args: 

515 instance (T, FitResult[T]): The optimal instance to format, 

516 either provided directly or via the full `FitResult`. 

517 print_header (bool, optional): Whether to print the header 

518 mentioning the class name. Defaults to True. 

519 print_extra_params (bool, optional): Whether to print the 

520 extra parameters available on the class which were not 

521 fit against, such as constant values. Defaults to True. 

522 print_bounds (PrintBounds | Literal["always", "never", 

523 "bounded"], optional): A `PrintBounds` value or the 

524 corresponding `str` alue indicating how to print the 

525 bounds for fields. Defaults to "always". 

526 print_initial_values (bool, optional): Whether to print 

527 the initial values. Defaults to True. If no `p0` is 

528 provided, the initial instance will be generated 

529 from the `FitSpec`. 

530 p0 (T, optional): The initial instance. If not provided, it 

531 will be generated from the `FitSpec`. Must not be 

532 provided if `instance` is a `FitResult` instance, in 

533 which case p0 will be inferred from the instance. 

534 linesep (str, optional): The newline separator to use. 

535 Defaults to `os.linesep`. 

536 

537 Raises: 

538 ValueError: If `instance` is a `FitResult` instance and `p0` is 

539 provided. 

540 

541 Returns: 

542 str: The str describing the parameters of the optimal instance. 

543 """ 

544 formatter = InstanceFormatter( 

545 instance, 

546 print_header=print_header, 

547 print_extra_params=print_extra_params, 

548 print_bounds=print_bounds, 

549 print_initial_values=print_initial_values, 

550 p0=p0, 

551 linesep=linesep 

552 ) 

553 return formatter.format_instance()