Coverage for src/pytest_samples/database/_session.py: 100%

190 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-20 19:47 +0000

1import dataclasses as _dataclasses 

2import functools as _functools 

3import itertools as _itertools 

4import logging as _logging 

5import sqlalchemy as _sqlalchemy 

6import sqlalchemy.orm as _orm 

7 

8from arrow import Arrow as _Arrow 

9from contextlib import AbstractContextManager as _AbstractContextManager 

10from dataclasses import dataclass as _dataclass 

11from sqlalchemy import Engine as _Engine 

12from types import TracebackType as _TracebackType 

13from typing import Callable as _Callable, Iterable as _Iterable, \ 

14 Optional as _Optional, Set as _Set, Type as _Type 

15 

16from . import _exceptions 

17from ._defs import TestFile as _TestFile, TestItem as _TestItem 

18from ..types import Location as _Location 

19from .. import tools as _tools 

20 

21 

22_logger = _logging.getLogger(__name__) 

23"""The logger for this module.""" 

24 

25 

26@_dataclass(frozen=True) 

27class BulkUpdateResult: 

28 """Contains information regarding bulk updates.""" 

29 

30 __slots__ = ("added", "updated", "removed", "pruned_files") 

31 

32 added: int 

33 """The number of added items.""" 

34 

35 updated: int 

36 """The number of updated items.""" 

37 

38 removed: int 

39 """The number of removed items.""" 

40 

41 pruned_files: _Optional[int] 

42 """The number of pruned files.""" 

43 

44 def __str__(self) -> str: 

45 """Convert the instance to a str for logging.""" 

46 field_names = (f.name for f in _dataclasses.fields(self)) 

47 value_name = ( 

48 (getattr(self, n), n.replace('_', ' ')) for n in field_names 

49 ) 

50 fvname = (vn for vn in value_name if vn[0] is not None) 

51 enumeration = (f"{vn[0]} {vn[1]}" for vn in fvname) 

52 return ", ".join(enumeration) 

53 

54 

55@_dataclass(frozen=True) 

56class DropAllEntriesResult: 

57 """Contains information regarding the number of all dropped files 

58 and items. 

59 """ 

60 

61 __slots__ = ("files_dropped", "tests_dropped") 

62 

63 files_dropped: int 

64 """The number of dropped files.""" 

65 

66 tests_dropped: int 

67 """The number of dropped test items.""" 

68 

69 def __str__(self) -> str: 

70 """Convert the instance to a str for logging.""" 

71 field_names = (f.name for f in _dataclasses.fields(self)) 

72 value_name = ( 

73 (getattr(self, n), n.replace('_', ' ')) for n in field_names 

74 ) 

75 enumeration = (f"{vn[0]} {vn[1]}" for vn in value_name) 

76 return ", ".join(enumeration) 

77 

78 

79TestFileHashProvider = _Callable[[_TestFile], bytes] 

80"""A function providing the hash for a `TestFile`. The `last_hash` field 

81of the passed instance is undefined. 

82""" 

83 

84 

85class Session(_AbstractContextManager): 

86 """The session objects that basically act as the database 

87 connection. 

88 """ 

89 

90 __slots__ = ("_engine", "_session") 

91 

92 def __init__(self, engine: _Engine) -> None: 

93 """Initialize a new Session using the provided engine. The 

94 connection will start once `__enter__` has been called. 

95 

96 Args: 

97 engine (Engine): The engine for the connection. 

98 """ 

99 self._engine = engine 

100 self._session: _Optional[_orm.Session] = None 

101 

102 def _ensure_session(self) -> _orm.Session: # pragma: no cover 

103 """Ensure that the session object has an active connection and 

104 return the underlying database sqlalchemy `orm.Session`. 

105 

106 Raises: 

107 InactiveSessionError: If the session was not startet by 

108 calling `__enter__`. 

109 

110 Returns: 

111 orm.Session: The sqlalchemy `orm.Session` object this 

112 `Session` is based on. 

113 """ 

114 s = self._session 

115 if s is not None: 

116 return s 

117 raise _exceptions.InactiveSessionError( 

118 "The session object was not started with __enter__." 

119 ) 

120 

121 def __enter__(self): 

122 """Start the connection in a conext manager. 

123 

124 Returns: 

125 self 

126 """ 

127 self._session = _orm.Session(self._engine) 

128 return self 

129 

130 def __exit__( 

131 self, 

132 exc_type: _Optional[_Type[BaseException]], 

133 exc_value: _Optional[BaseException], 

134 traceback: _Optional[_TracebackType] 

135 ) -> None: 

136 """Close the connection by exiting the context. 

137 

138 Args: 

139 exc_type (Optional[Type[BaseException]]): The exception type 

140 or None if no exception occured. 

141 exc_value (Optional[BaseException]): The exception value 

142 or None if no exception occured. 

143 traceback (Optional[TracebackType]): The traceback or None 

144 if no exception occured. 

145 """ 

146 dbsession = self._session 

147 if dbsession is None: # pragma: no cover 

148 # I am not sure if this should raise an exception. This 

149 # means that __exit__ is called but __enter__ was never 

150 # called. 

151 _logger.warning("__exit__ was called, but _session was None.") 

152 return 

153 dbsession.__exit__(exc_type, exc_value, traceback) 

154 

155 @classmethod 

156 def _try_get_file( 

157 cls, dbsession: _orm.Session, path: str 

158 ) -> _Optional[_TestFile]: 

159 """Search the database for a file containing tests. 

160 

161 Args: 

162 path (str): The known path to the file. 

163 

164 Raises: 

165 MultipleResultsFound: If multiple entries are found for 

166 `path`. This should not happen since the path column 

167 must be unique. 

168 

169 Returns: 

170 Optional[TestFile]: A `TestFile` if the file was found and 

171 None otherwise. 

172 """ 

173 stmt = _sqlalchemy.select(_TestFile).where(_TestFile.path == path) 

174 result = dbsession.execute(stmt) 

175 return result.scalar_one_or_none() 

176 

177 def try_get_file(self, path: str) -> _Optional[_TestFile]: 

178 """Search the database for a file containing tests. 

179 

180 Args: 

181 path (str): The known path to the file. 

182 

183 Raises: 

184 InactiveSessionError: If the session was not started by 

185 calling __enter__. 

186 MultipleResultsFound: If multiple entries are found for 

187 `path`. This should not happen since the path column 

188 must be unique. 

189 

190 Returns: 

191 Optional[TestFile]: A `TestFile` if the file was found and 

192 None otherwise. 

193 """ 

194 session = self._ensure_session() 

195 return self._try_get_file(session, path) 

196 

197 def add_file(self, path: str, hash: _Optional[bytes]) -> _TestFile: 

198 """Add a new file. 

199 

200 Args: 

201 path (str): The path of the file to add. 

202 hash (Optional[bytes]): The hash of the file if it should be 

203 added or None otherwise. 

204 

205 Raises: 

206 InactiveSessionError: If the session was not started by 

207 calling __enter__. 

208 IntegrityError: If the `path` is already stored. 

209 

210 Returns: 

211 TestFile: The newly added `TestFile` instance with its id 

212 set according to the new entry. 

213 """ 

214 instance = _TestFile(path=path, last_hash=hash) 

215 dbsession = self._ensure_session() 

216 dbsession.add(instance) 

217 dbsession.commit() 

218 return instance 

219 

220 @classmethod 

221 def _try_get_item( 

222 cls, 

223 dbsession: _orm.Session, 

224 file: _TestFile, 

225 lineno: _Optional[int], 

226 testname: str 

227 ) -> _Optional[_TestItem]: 

228 """Search the database for a test item. 

229 

230 Args: 

231 dbsession (orm.Session): The session object. 

232 file (TestFile): The file the test item should be found in. 

233 lineno (Optional[int]): The line number of the test. 

234 testname (str): The name of the test. 

235 

236 Raises: 

237 DetachedInstanceError: If the provided `file` is not 

238 attached to a `Session`. 

239 MultipleResultsFound: If multiple entries are found that 

240 match the item. This should not happen since the three 

241 identifying columns must be unique together. 

242 

243 Returns: 

244 Optional[TestItem]: A `TestItem` if the test item was found 

245 and None otherwise. 

246 """ 

247 if file not in dbsession: 

248 raise _exceptions.DetachedInstanceError( 

249 "The file was not attached to a session." 

250 ) 

251 stmt = _sqlalchemy.select(_TestItem).where( 

252 _TestItem.file == file, 

253 _TestItem.lineno == lineno, 

254 _TestItem.testname == testname 

255 ) 

256 result = dbsession.execute(stmt) 

257 return result.scalar_one_or_none() 

258 

259 def try_get_item( 

260 self, file: _TestFile, lineno: _Optional[int], testname: str 

261 ) -> _Optional[_TestItem]: 

262 """Search the database for a test item. 

263 

264 Args: 

265 file (TestFile): The file the test item should be found in. 

266 lineno (Optional[int]): The line number of the test. 

267 testname (str): The name of the test. 

268 

269 Raises: 

270 DetachedInstanceError: If the provided `file` is not 

271 attached to a `Session`. 

272 InactiveSessionError: If the session was not started by 

273 calling __enter__. 

274 MultipleResultsFound: If multiple entries are found that 

275 match the item. This should not happen since the three 

276 identifying columns must be unique together. 

277 

278 Returns: 

279 Optional[TestItem]: A `TestItem` if the test item was found 

280 and None otherwise. 

281 """ 

282 dbsession = self._ensure_session() 

283 return self._try_get_item(dbsession, file, lineno, testname) 

284 

285 def invalidate_hash(self, file: _TestFile, new_hash: bytes) -> int: 

286 """Replace the hash of a test file and delete all test items 

287 that belong to this file. 

288 

289 Args: 

290 file (TestFile): The test file to update in-place. 

291 new_hash (bytes): The new hash to add to the file. 

292 

293 Raises: 

294 DetachedInstanceError: If the provided `file` is not 

295 attached to a `Session`. 

296 InactiveSessionError: If the session was not started by 

297 calling __enter__. 

298 

299 Returns: 

300 int: The number of deleted `TestItem`s from the 

301 corresponding table. 

302 """ 

303 dbsession = self._ensure_session() 

304 # Since we are removing items from the database, start a 

305 # transaction 

306 with dbsession.begin_nested(): 

307 # May raise DetachedInstanceError: 

308 predicate = _TestItem.file_id == file.id 

309 # Interestingly, the assignment below and the call to 

310 # flush() would not raise the exception. Therefore, the 

311 # predicate is formed above. 

312 file.last_hash = new_hash 

313 dbsession.flush() 

314 # Delete all test items from that file since they may have 

315 # changed. 

316 del_stmt = _sqlalchemy.delete(_TestItem).where(predicate) 

317 num_del_items = dbsession.execute(del_stmt).rowcount 

318 dbsession.commit() 

319 return num_del_items 

320 

321 @classmethod 

322 def _prune_files(cls, dbsession: _orm.Session) -> int: 

323 """Prune files that have no tests from the database. Does not 

324 perform a commit. 

325 

326 Args: 

327 dbsession (orm.Session): The orm session. 

328 

329 Returns: 

330 int: The number of pruned files after the following commit. 

331 """ 

332 stmt = _sqlalchemy.delete(_TestFile).where(~_TestFile.items.any()) 

333 fdeleted = dbsession.execute(stmt).rowcount 

334 return fdeleted 

335 

336 def prune_files(self) -> int: 

337 """Prune files that have no tests from the database. 

338 

339 Args: 

340 seen_file_ids (Set[int]): The ids (primary keys) of the 

341 seen files. This should include all added files 

342 although the contained tests may not actually be run in 

343 this test iteration. 

344 seen_test_ids (Set[int]): The ids (primary keys) of the 

345 seen tests. 

346 

347 Raises: 

348 InactiveSessionError: If the session was not started by 

349 calling __enter__. 

350 

351 Returns: 

352 int: The number of pruned files. 

353 """ 

354 dbsession = self._ensure_session() 

355 with dbsession.begin_nested(): 

356 result = self._prune_files(dbsession) 

357 dbsession.commit() 

358 return result 

359 

360 @classmethod 

361 def _prune_items( 

362 cls, dbsession: _orm.Session, known_locations: _Set[_Location] 

363 ) -> int: 

364 """Prune items from the database. 

365 

366 Args: 

367 dbsession (orm.Session): The session object. 

368 known_locations (Set[Location]): All known locations where 

369 items are expected. Items not at these locations will 

370 be deleted. 

371 

372 Returns: 

373 int: The number of deleted items. 

374 """ 

375 select = _sqlalchemy.select(_TestItem) 

376 result = dbsession.execute(select) 

377 scalars = result.scalars() 

378 partitions = scalars.partitions() 

379 num_del = 0 

380 for item in _itertools.chain.from_iterable(partitions): 

381 loc = item.location 

382 if loc in known_locations: 

383 continue 

384 dbsession.delete(item) 

385 num_del += 1 

386 return num_del 

387 

388 def prune_items(self, known_locations: _Set[_Location]) -> int: 

389 """Prune items from the database. 

390 

391 Args: 

392 known_locations (Set[Location]): All known locations where 

393 items are expected. Items not at these locations will 

394 be deleted. 

395 

396 Raises: 

397 InactiveSessionError: If the session was not started by 

398 calling __enter__. 

399 

400 Returns: 

401 int: The number of deleted `TestItem`s. 

402 """ 

403 dbsession = self._ensure_session() 

404 with dbsession.begin_nested(): 

405 result = self._prune_items(dbsession, known_locations) 

406 dbsession.commit() 

407 return result 

408 

409 def add_or_update_item( 

410 self, 

411 file: _TestFile, 

412 lineno: _Optional[int], 

413 testname: str, 

414 last_run: _Arrow 

415 ) -> _TestItem: 

416 """Add a new item or update its last run time if it already 

417 exists. 

418 

419 Args: 

420 file (TestFile): The file the test item should be found in. 

421 lineno (Optional[int]): The line number of the test. 

422 testname (str): The name of the test. 

423 last_run (Arrow): The time of the last test run. 

424 

425 Raises: 

426 DetachedInstanceError: If the provided `file` is not 

427 attached to a `Session`. 

428 InactiveSessionError: If the session was not started by 

429 calling __enter__. 

430 IntegrityError: If the added item does not fulfill 

431 uniqueness constraints, which should never happen since 

432 the item would already be present in the database and 

433 would be modified instead and would imply that the 

434 database is corrupted in some form. 

435 MultipleResultsFound: If multiple files are found with the 

436 path specified in `location`. If multiple items are 

437 found that could be updated. These imply that the 

438 database is corrupted. 

439 NoResultFound: If no file is found with the path specified 

440 in `location`. 

441 

442 Returns: 

443 TestItem: The new or updated item. 

444 """ 

445 dbsession = self._ensure_session() 

446 

447 # May raise DetachedInstanceError or MultipleResultsFound: 

448 db_item = self._try_get_item(dbsession, file, lineno, testname) 

449 # What if the file is not in the db? 

450 if db_item is None: 

451 db_item = _TestItem( 

452 file=file, 

453 lineno=lineno, 

454 testname=testname, 

455 last_run=last_run 

456 ) 

457 dbsession.add(db_item) 

458 else: 

459 db_item.last_run = last_run 

460 

461 dbsession.commit() 

462 

463 return db_item 

464 

465 @classmethod 

466 def _try_delete_item( 

467 cls, dbsession: _orm.Session, location: _Location 

468 ) -> bool: 

469 """Try to delete an item from the database. 

470 

471 Args: 

472 location (Location): The test item location. 

473 

474 Raises: 

475 MultipleResultsFound: If multiple files are found with the 

476 path specified in `location`. If multiple items to 

477 delete are found. These should not happen due to 

478 uniqueness constraints in the database and would imply 

479 that it is corrupted in some way. 

480 

481 Returns: 

482 bool: If the file was deleted from/present in the database. 

483 """ 

484 file, lineno, testname = location 

485 

486 # May raise MultipleResultsFound: 

487 db_file = cls._try_get_file(dbsession, file) 

488 

489 # If the file is not in the database, then the item will not be 

490 # either. 

491 if db_file is None: 

492 return False 

493 

494 stmt = _sqlalchemy.delete(_TestItem).where( 

495 _TestItem.file == db_file, 

496 _TestItem.lineno == lineno, 

497 _TestItem.testname == testname 

498 ) 

499 rc = dbsession.execute(stmt).rowcount 

500 if rc < 0: 

501 raise AssertionError 

502 if rc > 1: 

503 raise _exceptions.MultipleResultsFound( 

504 "Found multiple matching items." 

505 ) from AssertionError 

506 return rc == 1 

507 

508 def try_delete_item(self, location: _Location) -> bool: 

509 """Try to delete an item from the database. 

510 

511 Args: 

512 location (Location): The test item location. 

513 

514 Raises: 

515 InactiveSessionError: If the session was not started by 

516 calling __enter__. 

517 MultipleResultsFound: If multiple files are found with the 

518 path specified in `location`. If multiple items to 

519 delete are found. These should not happen due to 

520 uniqueness constraints in the database and would imply 

521 that it is corrupted in some way. 

522 

523 Returns: 

524 bool: If the file was deleted from/present in the database. 

525 """ 

526 dbsession = self._ensure_session() 

527 with dbsession.begin_nested(): 

528 # May raise MultipleResultsFound: 

529 result = self._try_delete_item(dbsession, location) 

530 if result: 

531 dbsession.commit() 

532 return result 

533 

534 def bulk_add_update_remove( 

535 self, 

536 last_run: _Arrow, 

537 add_update: _Iterable[_Location], 

538 try_delete: _Iterable[_Location], 

539 hash_provider: _Optional[TestFileHashProvider], 

540 prune_files: bool 

541 ) -> BulkUpdateResult: 

542 """Perform a bulk update by adding, updating and deleting test 

543 items. 

544 

545 Args: 

546 last_run (Arrow): The Arrow instance to set as the last run 

547 time for the test items. 

548 add_update (Iterable[Location]): The tests to add and update 

549 with the new `last_run` time. 

550 try_delete (Iterable[Location]): The items to delete if they 

551 are present. 

552 hash_provider (Optional[TestFileHashProvider]): If required, 

553 a `TestFileHashProvider` for computing the hashes of 

554 test files. Otherwise None. 

555 hash_provider (bool): Whether to prune orphaned files from 

556 the database. 

557 

558 Raises: 

559 InactiveSessionError: If the session was not started by 

560 calling __enter__. 

561 IntegrityError: If the added items do not fulfill uniqueness 

562 constraints. This should not happen since the second 

563 appearance of a location would simply lead to an update. 

564 MultipleResultsFound: If multiple files are found with a 

565 path specified in one of the location items. 

566 If multiple items to delete are found. If multiple items 

567 are found that could be updated. These should not 

568 happen due to uniqueness constraints in the database and 

569 would imply that it is corrupted in some way. 

570 NoResultFound: If no file is found with a path specified in 

571 the locations for items that should be deleted. 

572 

573 Returns: 

574 BulkUpdateResult: An object containing the amount of added, 

575 updated and deleted items. 

576 """ 

577 

578 dbsession = self._ensure_session() 

579 

580 def add_or_update(location: _Location) -> bool: 

581 """Add a new item or update its last run time if it already 

582 exists. Does not perform a commit. 

583 

584 Args: 

585 location (Location): The test item location. 

586 

587 Returns: 

588 bool: Whether the item was already in the database. 

589 """ 

590 file, lineno, testname = location 

591 

592 # Get or add the file 

593 db_file = self.try_get_file(file) 

594 if db_file is None: 

595 db_file = _TestFile(path=file, last_hash=None) 

596 if hash_provider is not None: 

597 hash = hash_provider(db_file) 

598 db_file.last_hash = hash 

599 dbsession.add(db_file) 

600 # The item cannot exist yet if the database is 

601 # consistent 

602 db_item = None 

603 else: 

604 db_item = self._try_get_item( 

605 dbsession, db_file, lineno, testname 

606 ) 

607 

608 if db_item is None: 

609 db_item = _TestItem( 

610 file=db_file, 

611 lineno=lineno, 

612 testname=testname, 

613 last_run=last_run 

614 ) 

615 dbsession.add(db_item) 

616 return False 

617 

618 db_item.last_run = last_run 

619 return True 

620 

621 deleter = _functools.partial(self._try_delete_item, dbsession) 

622 

623 with dbsession.begin_nested(): 

624 

625 results = map(add_or_update, add_update) 

626 updated, added = _tools.count_truefalse(results) 

627 dbsession.flush() 

628 

629 removed = sum(map(deleter, try_delete)) 

630 

631 if prune_files: 

632 dbsession.flush() 

633 pruned = self._prune_files(dbsession) 

634 else: 

635 pruned = None 

636 

637 dbsession.commit() 

638 

639 return BulkUpdateResult(added, updated, removed, pruned) 

640 

641 def drop_all_entries(self) -> DropAllEntriesResult: 

642 """Clean the database completely by dropping all `TestItem` and 

643 `TestFile` entries. 

644 

645 Raises: 

646 InactiveSessionError: If the session was not started by 

647 calling __enter__. 

648 

649 Returns: 

650 DropAllEntriesResult: A result object describing the number 

651 of dropped entries from each table. 

652 """ 

653 

654 dbsession = self._ensure_session() 

655 

656 with dbsession.begin_nested(): 

657 

658 deld_items = dbsession.query(_TestItem).delete() 

659 deld_files = dbsession.query(_TestFile).delete() 

660 

661 dbsession.commit() 

662 

663 return DropAllEntriesResult(deld_files, deld_items)