Coverage for src/scipy_dataclassfitparams/_fields.py: 100%

135 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-01 17:34 +0000

1"""This module defines the functions that can be used for the field 

2definitions and the corresponding classes that encapsulate the dataclass 

3metadata values. 

4""" 

5 

6import cmath as _cmath 

7import dataclasses as _dataclasses 

8import sys as _sys 

9import warnings as _warnings 

10 

11from dataclasses import dataclass as _dataclass, Field as _Field 

12from typing import Any as _Any, Callable as _Callable, ClassVar as _ClassVar, \ 

13 Mapping as _Mapping, Optional as _Optional, Union as _Union 

14 

15from . import _warning_types 

16 

17 

18_METADATA_KEY = object() 

19"""The object used as the key for the fields' metadata mapping.""" 

20 

21 

22@_dataclass(frozen=True) 

23class FittingField: 

24 """Base class for the field metadata values.""" 

25 pass 

26 

27 

28def get_special_fitting_field(f: _Field) -> _Union[FittingField, None]: 

29 """Get the `FittingField` instance associated with a dataclass 

30 `Field`. 

31 

32 Args: 

33 f (Field): The dataclass field. 

34 

35 Raises: 

36 AssertionError: If the `_METADATA_KEY` for a fitting field is 

37 not associated with a `FittingField` instance. 

38 

39 Returns: 

40 FittingField | None: The found `FittingField` or None, if none 

41 was found. 

42 """ 

43 md = f.metadata 

44 try: 

45 val = md[_METADATA_KEY] 

46 except KeyError: 

47 return None 

48 if not isinstance(val, FittingField): 

49 raise AssertionError( 

50 f"Found invalid object {val!r} where a FittingField was " 

51 f"expected. Dataclass field: {f.name!r}" 

52 ) from TypeError 

53 return val 

54 

55 

56_MISSING = _dataclasses.MISSING 

57"""The 'missing' values of the dataclasses module.""" 

58 

59 

60_MissingType = type(_MISSING) 

61"""The type of `dataclasses.MISSING` and `_MISSING`.""" 

62 

63 

64_FloatOrMissing = _Union[ 

65 float, _MissingType # type: ignore [valid-type] 

66] 

67"""The type representing a float or `dataclasses.MISSING`.""" 

68 

69 

70@_dataclass(frozen=True) 

71class BoundedField(FittingField): 

72 

73 __slots__ = ("min", "max") 

74 

75 min: float 

76 """The min value.""" 

77 

78 max: float 

79 """The max value.""" 

80 

81 POS_INF: _ClassVar[float] = float("inf") 

82 """Positive infinity.""" 

83 

84 NEG_INF: _ClassVar[float] = float("-inf") 

85 """Negative infinity.""" 

86 

87 def __init__(self, min: _Optional[float], max: _Optional[float]): 

88 if max is None: 

89 max = self.POS_INF 

90 if min is None: 

91 min = self.NEG_INF 

92 if min > max: 

93 raise ValueError( 

94 "The provided max value was smaller than the provided " 

95 "min value." 

96 ) 

97 object.__setattr__(self, "min", min) 

98 object.__setattr__(self, "max", max) 

99 

100 @property 

101 def actually_bounded(self) -> bool: 

102 """Whether this instance actually imposes boundaries on the fit 

103 parameter. If both `min` and `max` are set to None, this is 

104 not the case. 

105 """ 

106 return self.min_finite or self.max_finite 

107 

108 @property 

109 def min_finite(self) -> bool: 

110 """Whether the min value is finite. 

111 

112 Returns: 

113 bool: Whether the min value is finite. 

114 """ 

115 return _cmath.isfinite(self.min) 

116 

117 @property 

118 def max_finite(self) -> bool: 

119 """Whether the max value is finite. 

120 

121 Returns: 

122 bool: Whether the max value is finite. 

123 """ 

124 return _cmath.isfinite(self.max) 

125 

126 def resolve_default(self, field_default: _FloatOrMissing) -> float: 

127 """Resolves the default value to be assigned to a field 

128 described by the `BoundedField` instance. 

129 

130 Args: 

131 field_default (float or MISSING): The default value passed 

132 to the dataclass `Field` or the missing value, if none 

133 was passed. 

134 

135 Returns: 

136 float: The default value for the field. This will be 

137 `field_default` if it was provided and is within the 

138 bounds described by this instance or the result of 

139 `get_bounds_default()` otherwise. 

140 """ 

141 has_default = field_default is not _dataclasses.MISSING 

142 if (has_default and self.contains(field_default)): # type: ignore 

143 return field_default 

144 return self.get_bounds_default() 

145 

146 def get_bounds_default(self) -> float: 

147 """Compute a valid default value within the bounds. 

148 

149 Returns: 

150 float: The computed default value. If both bounds are 

151 finite, this is their average value. If both are 

152 non-finite, the value is 0. If one bound is finite, the 

153 value lies with a distance of 1 to the finite bound 

154 within the bounded region. 

155 """ 

156 min_fin = self.min_finite 

157 max_fin = self.max_finite 

158 if min_fin and max_fin: 

159 min = self.min 

160 max = self.max 

161 return 0.5 * (min + max) # type: ignore [operator] 

162 if min_fin: 

163 return self.min + 1 # type: ignore [operator,return-value] 

164 elif max_fin: 

165 return self.max - 1 # type: ignore [operator,return-value] 

166 else: 

167 return 0.0 # type: ignore [return-value] 

168 

169 def contains(self, value: float) -> bool: 

170 """Checks whether a given value lies in the bounded region 

171 described by this `BoundedField` instance. 

172 

173 Args: 

174 value (float): The value to check. 

175 

176 Returns: 

177 bool: Whether the value lies between the lower (self.min) 

178 and upper bound (self.max). 

179 """ 

180 res = True 

181 if self.min_finite: 

182 res &= value >= self.min 

183 if self.max_finite: 

184 res &= value <= self.max 

185 return res 

186 

187 

188@_dataclass(frozen=True) 

189class ConstField(FittingField): 

190 

191 __slots__ = ("value",) 

192 

193 value: float 

194 """The value of the field.""" 

195 

196 

197@_dataclass(frozen=True) 

198class SameAsField(FittingField): 

199 

200 __slots__ = ("name",) 

201 

202 name: str 

203 """The name of the field the field depends on.""" 

204 

205 

206@_dataclass(frozen=True) 

207class RegularField(FittingField): 

208 

209 _instance: _ClassVar[_Optional["RegularField"]] = None 

210 """An instance of this class.""" 

211 

212 @classmethod 

213 def get_instance(cls) -> "RegularField": # pragma: no cover 

214 """Get a cached instance of this class. 

215 

216 Returns: 

217 RegularField: A (possible cached) instance of the 

218 `RegularField` class. 

219 """ 

220 instance = cls._instance 

221 if instance is None: 

222 instance = cls() 

223 cls._instance = instance 

224 return instance 

225 

226 

227_BoolOrMissing = _Union[bool, _MissingType] # type: ignore [valid-type] 

228"""The type representing a `bool` or `dataclasses.MISSING`.""" 

229 

230 

231_FloatDefaultFactoryOrMissing = _Union[ 

232 _Callable[[], float], _MissingType # type: ignore [valid-type] 

233] 

234"""The type of the default factory returning a float for dataclass 

235fields or `dataclasses.MISSING`.""" 

236 

237 

238def _merge_medatada( 

239 field_metadata_instance: FittingField, 

240 provided: _Optional[_Mapping[_Any, _Any]] 

241) -> _Mapping[_Any, _Any]: 

242 """Merge the user-provided metadata with the metadata key for the 

243 fit description fields. 

244 

245 Args: 

246 field_metadata_instance (FittingField): The field to add. 

247 provided (Mapping[Any, Any] | None): The user-provided metadata 

248 or None. 

249 

250 Returns: 

251 Mapping[Any, Any]: The metadata to store in the field. 

252 """ 

253 if provided is not None: # pragma: no cover 

254 res = dict(provided) 

255 if _METADATA_KEY in provided: 

256 _warnings.warn( 

257 "The metadata key for the fit field definitions was " 

258 "already contained in the provided 'metadata' mapping. " 

259 "The corresponding value will be overwritten.", 

260 stacklevel=3, 

261 category=_warning_types.MetadataKeyOverwrittenWarning 

262 ) 

263 else: 

264 res = dict() 

265 res[_METADATA_KEY] = field_metadata_instance 

266 return res 

267 

268 

269if _sys.version_info >= (3, 10): 

270 

271 def bounded( 

272 min: _Optional[float] = None, 

273 max: _Optional[float] = None, 

274 *, 

275 default: _FloatOrMissing = _MISSING, 

276 default_factory: _FloatDefaultFactoryOrMissing = _MISSING, 

277 repr: bool = True, 

278 hash: _Optional[bool] = None, 

279 compare: bool = True, 

280 metadata: _Optional[_Mapping[_Any, _Any]] = None, 

281 kw_only: _BoolOrMissing = _MISSING 

282 ) -> _Any: 

283 """Define a fit parameter as bounded. Either one, two or no 

284 boundaries may be provided. The remaining unset boundaries are 

285 interpreted to be -inf for the lower and +inf for the upper 

286 boundary. 

287 

288 Args: 

289 min (float, optional): The lower bound. Defaults to -inf. 

290 max (float, optional): The upper bound. Defaults to +inf. 

291 default (float, optional): The default (initial) value for 

292 the field. Defaults to the default value used by the fit 

293 function for a bounded field. 

294 default_factory (Callable[[], float], optional): A function 

295 to be called to produce the fields default value. If 

296 unset, the `default` will be used instead. 

297 repr (bool, optional): Whether the field should be included 

298 in the object's repr(). Defaults to True. 

299 hash (bool, optional): Whether the field should be included 

300 in the object's hash(). Defaults to None. 

301 compare (bool, optional): Whether the field should be used 

302 in comparison function. Defaults to True. 

303 metadata (Mapping[Any, Any], optional): Metadata to add to 

304 the field. Defaults to None. 

305 kw_only (bool, optional): Whether the field will become a 

306 keyword-only parameter to __init__(). Defaults to False. 

307 

308 Returns: 

309 The dataclass field. 

310 """ 

311 m = BoundedField(min, max) 

312 metadata = _merge_medatada(m, metadata) 

313 return _dataclasses.field( # type: ignore [call-overload] 

314 default=default, 

315 default_factory=default_factory, 

316 repr=repr, 

317 hash=hash, 

318 compare=compare, 

319 metadata=metadata, 

320 kw_only=kw_only 

321 ) 

322 

323 def const( 

324 value: float, 

325 *, 

326 repr: bool = True, 

327 hash: _Optional[bool] = None, 

328 compare: bool = True, 

329 metadata: _Optional[_Mapping[_Any, _Any]] = None, 

330 kw_only: _BoolOrMissing = _MISSING 

331 ) -> float: 

332 """Define a fit parameter as constant. The parameter is fixed 

333 and will not be fitted against. The value will also be set as 

334 the default value for the field at the constructor. 

335 

336 Args: 

337 value (float): The value of the field. 

338 repr (bool, optional): Whether the field should be included 

339 in the object's repr(). Defaults to True. 

340 hash (bool, optional): Whether the field should be included 

341 in the object's hash(). Defaults to None. 

342 compare (bool, optional): Whether the field should be used 

343 in comparison function. Defaults to True. 

344 metadata (Mapping[Any, Any], optional): Metadata to add to 

345 the field. Defaults to None. 

346 kw_only (bool, optional): Whether the field will become a 

347 keyword-only parameter to __init__(). Defaults to False. 

348 

349 Returns: 

350 The dataclass field. 

351 """ 

352 m = ConstField(value) 

353 metadata = _merge_medatada(m, metadata) 

354 return _dataclasses.field( # type: ignore [call-overload] 

355 default=value, 

356 repr=repr, 

357 hash=hash, 

358 compare=compare, 

359 metadata=metadata, 

360 kw_only=kw_only 

361 ) 

362 

363 def same_as( 

364 name: str, 

365 *, 

366 default: _FloatOrMissing = _MISSING, 

367 default_factory: _FloatDefaultFactoryOrMissing = _MISSING, 

368 repr: bool = True, 

369 hash: _Optional[bool] = None, 

370 compare: bool = True, 

371 metadata: _Optional[_Mapping[_Any, _Any]] = None, 

372 kw_only: _BoolOrMissing = _MISSING 

373 ) -> _Any: 

374 """Define a fit parameter as identical to another parameter. It 

375 is still possible to set a `default` or `default_factory` for 

376 the field. This will be used when calling the constructor 

377 without providing a value for this field. 

378 

379 Args: 

380 name (str): The name of the parameter this one is identical 

381 to. 

382 default (float, optional): The default (initial) value for 

383 the field. Defaults to the default value used by the fit 

384 function for a bounded field. 

385 default_factory (Callable[[], float], optional): A function 

386 to be called to produce the fields default value. If 

387 unset, the `default` will be used instead. 

388 repr (bool, optional): Whether the field should be included 

389 in the object's repr(). Defaults to True. 

390 hash (bool, optional): Whether the field should be included 

391 in the object's hash(). Defaults to None. 

392 compare (bool, optional): Whether the field should be used 

393 in comparison function. Defaults to True. 

394 metadata (Mapping[Any, Any], optional): Metadata to add to 

395 the field. Defaults to None. 

396 kw_only (bool, optional): Whether the field will become a 

397 keyword-only parameter to __init__(). Defaults to False. 

398 

399 Returns: 

400 The dataclass field. 

401 """ 

402 m = SameAsField(name) 

403 metadata = _merge_medatada(m, metadata) 

404 return _dataclasses.field( # type: ignore [call-overload] 

405 default=default, 

406 default_factory=default_factory, 

407 repr=repr, 

408 hash=hash, 

409 compare=compare, 

410 metadata=metadata, 

411 kw_only=kw_only 

412 ) 

413 

414 def regular( 

415 *, 

416 default: _FloatOrMissing = _MISSING, 

417 default_factory: _FloatDefaultFactoryOrMissing = _MISSING, 

418 repr: bool = True, 

419 hash: _Optional[bool] = None, 

420 compare: bool = True, 

421 metadata: _Optional[_Mapping[_Any, _Any]] = None, 

422 kw_only: _BoolOrMissing = _MISSING 

423 ) -> _Any: 

424 """Define regular fit parameter. The value must be provided at 

425 __init__ when constructing instances of the dataclass. 

426 

427 Args: 

428 default (float, optional): The default (initial) value for 

429 the field. Defaults to the default value used by the fit 

430 function for a bounded field. 

431 default_factory (Callable[[], float], optional): A function 

432 to be called to produce the fields default value. If 

433 unset, the `default` will be used instead. 

434 repr (bool, optional): Whether the field should be included 

435 in the object's repr(). Defaults to True. 

436 hash (bool, optional): Whether the field should be included 

437 in the object's hash(). Defaults to None. 

438 compare (bool, optional): Whether the field should be used 

439 in comparison function. Defaults to True. 

440 metadata (Mapping[Any, Any], optional): Metadata to add to 

441 the field. Defaults to None. 

442 kw_only (bool, optional): Whether the field will become a 

443 keyword-only parameter to __init__(). Defaults to False. 

444 

445 Returns: 

446 The dataclass field. 

447 """ 

448 m = RegularField.get_instance() 

449 metadata = _merge_medatada(m, metadata) 

450 return _dataclasses.field( # type: ignore [call-overload] 

451 default=default, 

452 default_factory=default_factory, 

453 repr=repr, 

454 hash=hash, 

455 compare=compare, 

456 metadata=metadata, 

457 kw_only=kw_only 

458 ) 

459 

460else: 

461 

462 # The "type: ignore [misc]" markers below prevent mypy from 

463 # complaining about differing signatures of the conditional 

464 # function defintions. 

465 

466 def bounded( # type: ignore [misc] 

467 min: _Optional[float] = None, 

468 max: _Optional[float] = None, 

469 *, 

470 default: _FloatOrMissing = _MISSING, 

471 default_factory: _FloatDefaultFactoryOrMissing = _MISSING, 

472 repr: bool = True, 

473 hash: _Optional[bool] = None, 

474 compare: bool = True, 

475 metadata: _Optional[_Mapping[_Any, _Any]] = None 

476 ) -> _Any: 

477 """Define a fit parameter as bounded. Either one, two or no 

478 boundaries may be provided. The remaining unset boundaries are 

479 interpreted to be -inf for the lower and +inf for the upper 

480 boundary. 

481 

482 Args: 

483 min (float, optional): The lower bound. Defaults to -inf. 

484 max (float, optional): The upper bound. Defaults to +inf. 

485 default (float, optional): The default (initial) value for 

486 the field. Defaults to the default value used by the fit 

487 function for a bounded field. 

488 default_factory (Callable[[], float], optional): A function 

489 to be called to produce the fields default value. If 

490 unset, the `default` will be used instead. 

491 repr (bool, optional): Whether the field should be included 

492 in the object's repr(). Defaults to True. 

493 hash (bool, optional): Whether the field should be included 

494 in the object's hash(). Defaults to None. 

495 compare (bool, optional): Whether the field should be used 

496 in comparison function. Defaults to True. 

497 metadata (Mapping[Any, Any], optional): Metadata to add to 

498 the field. Defaults to None. 

499 

500 Returns: 

501 The dataclass field. 

502 """ 

503 m = BoundedField(min, max) 

504 metadata = _merge_medatada(m, metadata) 

505 return _dataclasses.field( # type: ignore [call-overload] 

506 default=default, 

507 default_factory=default_factory, 

508 repr=repr, 

509 hash=hash, 

510 compare=compare, 

511 metadata=metadata 

512 ) 

513 

514 def const( # type: ignore [misc] 

515 value: float, 

516 *, 

517 repr: bool = True, 

518 hash: _Optional[bool] = None, 

519 compare: bool = True, 

520 metadata: _Optional[_Mapping[_Any, _Any]] = None 

521 ) -> float: 

522 """Define a fit parameter as constant. The parameter is fixed 

523 and will not be fitted against. The value will also be set as 

524 the default value for the field at the constructor. 

525 

526 Args: 

527 value (float): The value of the field. 

528 repr (bool, optional): Whether the field should be included 

529 in the object's repr(). Defaults to True. 

530 hash (bool, optional): Whether the field should be included 

531 in the object's hash(). Defaults to None. 

532 compare (bool, optional): Whether the field should be used 

533 in comparison function. Defaults to True. 

534 metadata (Mapping[Any, Any], optional): Metadata to add to 

535 the field. Defaults to None. 

536 

537 Returns: 

538 The dataclass field. 

539 """ 

540 m = ConstField(value) 

541 metadata = _merge_medatada(m, metadata) 

542 return _dataclasses.field( # type: ignore [call-overload] 

543 default=value, 

544 repr=repr, 

545 hash=hash, 

546 compare=compare, 

547 metadata=metadata 

548 ) 

549 

550 def same_as( # type: ignore [misc] 

551 name: str, 

552 *, 

553 default: _FloatOrMissing = _MISSING, 

554 default_factory: _FloatDefaultFactoryOrMissing = _MISSING, 

555 repr: bool = True, 

556 hash: _Optional[bool] = None, 

557 compare: bool = True, 

558 metadata: _Optional[_Mapping[_Any, _Any]] = None 

559 ) -> _Any: 

560 """Define a fit parameter as identical to another parameter. It 

561 is still possible to set a `default` or `default_factory` for 

562 the field. This will be used when calling the constructor 

563 without providing a value for this field. 

564 

565 Args: 

566 name (str): The name of the parameter this one is identical 

567 to. 

568 default (float, optional): The default (initial) value for 

569 the field. Defaults to the default value used by the fit 

570 function for a bounded field. 

571 default_factory (Callable[[], float], optional): A function 

572 to be called to produce the fields default value. If 

573 unset, the `default` will be used instead. 

574 repr (bool, optional): Whether the field should be included 

575 in the object's repr(). Defaults to True. 

576 hash (bool, optional): Whether the field should be included 

577 in the object's hash(). Defaults to None. 

578 compare (bool, optional): Whether the field should be used 

579 in comparison function. Defaults to True. 

580 metadata (Mapping[Any, Any], optional): Metadata to add to 

581 the field. Defaults to None. 

582 kw_only (bool, optional): Whether the field will become a 

583 keyword-only parameter to __init__(). Defaults to False. 

584 

585 Returns: 

586 The dataclass field. 

587 """ 

588 m = SameAsField(name) 

589 metadata = _merge_medatada(m, metadata) 

590 return _dataclasses.field( # type: ignore [call-overload] 

591 default=default, 

592 default_factory=default_factory, 

593 repr=repr, 

594 hash=hash, 

595 compare=compare, 

596 metadata=metadata 

597 ) 

598 

599 def regular( # type: ignore [misc] 

600 *, 

601 default: _FloatOrMissing = _MISSING, 

602 default_factory: _FloatDefaultFactoryOrMissing = _MISSING, 

603 repr: bool = True, 

604 hash: _Optional[bool] = None, 

605 compare: bool = True, 

606 metadata: _Optional[_Mapping[_Any, _Any]] = None 

607 ) -> _Any: 

608 """Define regular fit parameter. The value must be provided at 

609 __init__ when constructing instances of the dataclass. 

610 

611 Args: 

612 default (float, optional): The default (initial) value for 

613 the field. Defaults to the default value used by the fit 

614 function for a bounded field. 

615 default_factory (Callable[[], float], optional): A function 

616 to be called to produce the fields default value. If 

617 unset, the `default` will be used instead. 

618 repr (bool, optional): Whether the field should be included 

619 in the object's repr(). Defaults to True. 

620 hash (bool, optional): Whether the field should be included 

621 in the object's hash(). Defaults to None. 

622 compare (bool, optional): Whether the field should be used 

623 in comparison function. Defaults to True. 

624 metadata (Mapping[Any, Any], optional): Metadata to add to 

625 the field. Defaults to None. 

626 

627 Returns: 

628 The dataclass field. 

629 """ 

630 m = RegularField.get_instance() 

631 metadata = _merge_medatada(m, metadata) 

632 return _dataclasses.field( # type: ignore [call-overload] 

633 default=default, 

634 default_factory=default_factory, 

635 repr=repr, 

636 hash=hash, 

637 compare=compare, 

638 metadata=metadata 

639 )