Coverage for src/pytest_samples/database/_session.py: 100%
190 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-20 19:47 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-20 19:47 +0000
1import dataclasses as _dataclasses
2import functools as _functools
3import itertools as _itertools
4import logging as _logging
5import sqlalchemy as _sqlalchemy
6import sqlalchemy.orm as _orm
8from arrow import Arrow as _Arrow
9from contextlib import AbstractContextManager as _AbstractContextManager
10from dataclasses import dataclass as _dataclass
11from sqlalchemy import Engine as _Engine
12from types import TracebackType as _TracebackType
13from typing import Callable as _Callable, Iterable as _Iterable, \
14 Optional as _Optional, Set as _Set, Type as _Type
16from . import _exceptions
17from ._defs import TestFile as _TestFile, TestItem as _TestItem
18from ..types import Location as _Location
19from .. import tools as _tools
22_logger = _logging.getLogger(__name__)
23"""The logger for this module."""
26@_dataclass(frozen=True)
27class BulkUpdateResult:
28 """Contains information regarding bulk updates."""
30 __slots__ = ("added", "updated", "removed", "pruned_files")
32 added: int
33 """The number of added items."""
35 updated: int
36 """The number of updated items."""
38 removed: int
39 """The number of removed items."""
41 pruned_files: _Optional[int]
42 """The number of pruned files."""
44 def __str__(self) -> str:
45 """Convert the instance to a str for logging."""
46 field_names = (f.name for f in _dataclasses.fields(self))
47 value_name = (
48 (getattr(self, n), n.replace('_', ' ')) for n in field_names
49 )
50 fvname = (vn for vn in value_name if vn[0] is not None)
51 enumeration = (f"{vn[0]} {vn[1]}" for vn in fvname)
52 return ", ".join(enumeration)
55@_dataclass(frozen=True)
56class DropAllEntriesResult:
57 """Contains information regarding the number of all dropped files
58 and items.
59 """
61 __slots__ = ("files_dropped", "tests_dropped")
63 files_dropped: int
64 """The number of dropped files."""
66 tests_dropped: int
67 """The number of dropped test items."""
69 def __str__(self) -> str:
70 """Convert the instance to a str for logging."""
71 field_names = (f.name for f in _dataclasses.fields(self))
72 value_name = (
73 (getattr(self, n), n.replace('_', ' ')) for n in field_names
74 )
75 enumeration = (f"{vn[0]} {vn[1]}" for vn in value_name)
76 return ", ".join(enumeration)
79TestFileHashProvider = _Callable[[_TestFile], bytes]
80"""A function providing the hash for a `TestFile`. The `last_hash` field
81of the passed instance is undefined.
82"""
85class Session(_AbstractContextManager):
86 """The session objects that basically act as the database
87 connection.
88 """
90 __slots__ = ("_engine", "_session")
92 def __init__(self, engine: _Engine) -> None:
93 """Initialize a new Session using the provided engine. The
94 connection will start once `__enter__` has been called.
96 Args:
97 engine (Engine): The engine for the connection.
98 """
99 self._engine = engine
100 self._session: _Optional[_orm.Session] = None
102 def _ensure_session(self) -> _orm.Session: # pragma: no cover
103 """Ensure that the session object has an active connection and
104 return the underlying database sqlalchemy `orm.Session`.
106 Raises:
107 InactiveSessionError: If the session was not startet by
108 calling `__enter__`.
110 Returns:
111 orm.Session: The sqlalchemy `orm.Session` object this
112 `Session` is based on.
113 """
114 s = self._session
115 if s is not None:
116 return s
117 raise _exceptions.InactiveSessionError(
118 "The session object was not started with __enter__."
119 )
121 def __enter__(self):
122 """Start the connection in a conext manager.
124 Returns:
125 self
126 """
127 self._session = _orm.Session(self._engine)
128 return self
130 def __exit__(
131 self,
132 exc_type: _Optional[_Type[BaseException]],
133 exc_value: _Optional[BaseException],
134 traceback: _Optional[_TracebackType]
135 ) -> None:
136 """Close the connection by exiting the context.
138 Args:
139 exc_type (Optional[Type[BaseException]]): The exception type
140 or None if no exception occured.
141 exc_value (Optional[BaseException]): The exception value
142 or None if no exception occured.
143 traceback (Optional[TracebackType]): The traceback or None
144 if no exception occured.
145 """
146 dbsession = self._session
147 if dbsession is None: # pragma: no cover
148 # I am not sure if this should raise an exception. This
149 # means that __exit__ is called but __enter__ was never
150 # called.
151 _logger.warning("__exit__ was called, but _session was None.")
152 return
153 dbsession.__exit__(exc_type, exc_value, traceback)
155 @classmethod
156 def _try_get_file(
157 cls, dbsession: _orm.Session, path: str
158 ) -> _Optional[_TestFile]:
159 """Search the database for a file containing tests.
161 Args:
162 path (str): The known path to the file.
164 Raises:
165 MultipleResultsFound: If multiple entries are found for
166 `path`. This should not happen since the path column
167 must be unique.
169 Returns:
170 Optional[TestFile]: A `TestFile` if the file was found and
171 None otherwise.
172 """
173 stmt = _sqlalchemy.select(_TestFile).where(_TestFile.path == path)
174 result = dbsession.execute(stmt)
175 return result.scalar_one_or_none()
177 def try_get_file(self, path: str) -> _Optional[_TestFile]:
178 """Search the database for a file containing tests.
180 Args:
181 path (str): The known path to the file.
183 Raises:
184 InactiveSessionError: If the session was not started by
185 calling __enter__.
186 MultipleResultsFound: If multiple entries are found for
187 `path`. This should not happen since the path column
188 must be unique.
190 Returns:
191 Optional[TestFile]: A `TestFile` if the file was found and
192 None otherwise.
193 """
194 session = self._ensure_session()
195 return self._try_get_file(session, path)
197 def add_file(self, path: str, hash: _Optional[bytes]) -> _TestFile:
198 """Add a new file.
200 Args:
201 path (str): The path of the file to add.
202 hash (Optional[bytes]): The hash of the file if it should be
203 added or None otherwise.
205 Raises:
206 InactiveSessionError: If the session was not started by
207 calling __enter__.
208 IntegrityError: If the `path` is already stored.
210 Returns:
211 TestFile: The newly added `TestFile` instance with its id
212 set according to the new entry.
213 """
214 instance = _TestFile(path=path, last_hash=hash)
215 dbsession = self._ensure_session()
216 dbsession.add(instance)
217 dbsession.commit()
218 return instance
220 @classmethod
221 def _try_get_item(
222 cls,
223 dbsession: _orm.Session,
224 file: _TestFile,
225 lineno: _Optional[int],
226 testname: str
227 ) -> _Optional[_TestItem]:
228 """Search the database for a test item.
230 Args:
231 dbsession (orm.Session): The session object.
232 file (TestFile): The file the test item should be found in.
233 lineno (Optional[int]): The line number of the test.
234 testname (str): The name of the test.
236 Raises:
237 DetachedInstanceError: If the provided `file` is not
238 attached to a `Session`.
239 MultipleResultsFound: If multiple entries are found that
240 match the item. This should not happen since the three
241 identifying columns must be unique together.
243 Returns:
244 Optional[TestItem]: A `TestItem` if the test item was found
245 and None otherwise.
246 """
247 if file not in dbsession:
248 raise _exceptions.DetachedInstanceError(
249 "The file was not attached to a session."
250 )
251 stmt = _sqlalchemy.select(_TestItem).where(
252 _TestItem.file == file,
253 _TestItem.lineno == lineno,
254 _TestItem.testname == testname
255 )
256 result = dbsession.execute(stmt)
257 return result.scalar_one_or_none()
259 def try_get_item(
260 self, file: _TestFile, lineno: _Optional[int], testname: str
261 ) -> _Optional[_TestItem]:
262 """Search the database for a test item.
264 Args:
265 file (TestFile): The file the test item should be found in.
266 lineno (Optional[int]): The line number of the test.
267 testname (str): The name of the test.
269 Raises:
270 DetachedInstanceError: If the provided `file` is not
271 attached to a `Session`.
272 InactiveSessionError: If the session was not started by
273 calling __enter__.
274 MultipleResultsFound: If multiple entries are found that
275 match the item. This should not happen since the three
276 identifying columns must be unique together.
278 Returns:
279 Optional[TestItem]: A `TestItem` if the test item was found
280 and None otherwise.
281 """
282 dbsession = self._ensure_session()
283 return self._try_get_item(dbsession, file, lineno, testname)
285 def invalidate_hash(self, file: _TestFile, new_hash: bytes) -> int:
286 """Replace the hash of a test file and delete all test items
287 that belong to this file.
289 Args:
290 file (TestFile): The test file to update in-place.
291 new_hash (bytes): The new hash to add to the file.
293 Raises:
294 DetachedInstanceError: If the provided `file` is not
295 attached to a `Session`.
296 InactiveSessionError: If the session was not started by
297 calling __enter__.
299 Returns:
300 int: The number of deleted `TestItem`s from the
301 corresponding table.
302 """
303 dbsession = self._ensure_session()
304 # Since we are removing items from the database, start a
305 # transaction
306 with dbsession.begin_nested():
307 # May raise DetachedInstanceError:
308 predicate = _TestItem.file_id == file.id
309 # Interestingly, the assignment below and the call to
310 # flush() would not raise the exception. Therefore, the
311 # predicate is formed above.
312 file.last_hash = new_hash
313 dbsession.flush()
314 # Delete all test items from that file since they may have
315 # changed.
316 del_stmt = _sqlalchemy.delete(_TestItem).where(predicate)
317 num_del_items = dbsession.execute(del_stmt).rowcount
318 dbsession.commit()
319 return num_del_items
321 @classmethod
322 def _prune_files(cls, dbsession: _orm.Session) -> int:
323 """Prune files that have no tests from the database. Does not
324 perform a commit.
326 Args:
327 dbsession (orm.Session): The orm session.
329 Returns:
330 int: The number of pruned files after the following commit.
331 """
332 stmt = _sqlalchemy.delete(_TestFile).where(~_TestFile.items.any())
333 fdeleted = dbsession.execute(stmt).rowcount
334 return fdeleted
336 def prune_files(self) -> int:
337 """Prune files that have no tests from the database.
339 Args:
340 seen_file_ids (Set[int]): The ids (primary keys) of the
341 seen files. This should include all added files
342 although the contained tests may not actually be run in
343 this test iteration.
344 seen_test_ids (Set[int]): The ids (primary keys) of the
345 seen tests.
347 Raises:
348 InactiveSessionError: If the session was not started by
349 calling __enter__.
351 Returns:
352 int: The number of pruned files.
353 """
354 dbsession = self._ensure_session()
355 with dbsession.begin_nested():
356 result = self._prune_files(dbsession)
357 dbsession.commit()
358 return result
360 @classmethod
361 def _prune_items(
362 cls, dbsession: _orm.Session, known_locations: _Set[_Location]
363 ) -> int:
364 """Prune items from the database.
366 Args:
367 dbsession (orm.Session): The session object.
368 known_locations (Set[Location]): All known locations where
369 items are expected. Items not at these locations will
370 be deleted.
372 Returns:
373 int: The number of deleted items.
374 """
375 select = _sqlalchemy.select(_TestItem)
376 result = dbsession.execute(select)
377 scalars = result.scalars()
378 partitions = scalars.partitions()
379 num_del = 0
380 for item in _itertools.chain.from_iterable(partitions):
381 loc = item.location
382 if loc in known_locations:
383 continue
384 dbsession.delete(item)
385 num_del += 1
386 return num_del
388 def prune_items(self, known_locations: _Set[_Location]) -> int:
389 """Prune items from the database.
391 Args:
392 known_locations (Set[Location]): All known locations where
393 items are expected. Items not at these locations will
394 be deleted.
396 Raises:
397 InactiveSessionError: If the session was not started by
398 calling __enter__.
400 Returns:
401 int: The number of deleted `TestItem`s.
402 """
403 dbsession = self._ensure_session()
404 with dbsession.begin_nested():
405 result = self._prune_items(dbsession, known_locations)
406 dbsession.commit()
407 return result
409 def add_or_update_item(
410 self,
411 file: _TestFile,
412 lineno: _Optional[int],
413 testname: str,
414 last_run: _Arrow
415 ) -> _TestItem:
416 """Add a new item or update its last run time if it already
417 exists.
419 Args:
420 file (TestFile): The file the test item should be found in.
421 lineno (Optional[int]): The line number of the test.
422 testname (str): The name of the test.
423 last_run (Arrow): The time of the last test run.
425 Raises:
426 DetachedInstanceError: If the provided `file` is not
427 attached to a `Session`.
428 InactiveSessionError: If the session was not started by
429 calling __enter__.
430 IntegrityError: If the added item does not fulfill
431 uniqueness constraints, which should never happen since
432 the item would already be present in the database and
433 would be modified instead and would imply that the
434 database is corrupted in some form.
435 MultipleResultsFound: If multiple files are found with the
436 path specified in `location`. If multiple items are
437 found that could be updated. These imply that the
438 database is corrupted.
439 NoResultFound: If no file is found with the path specified
440 in `location`.
442 Returns:
443 TestItem: The new or updated item.
444 """
445 dbsession = self._ensure_session()
447 # May raise DetachedInstanceError or MultipleResultsFound:
448 db_item = self._try_get_item(dbsession, file, lineno, testname)
449 # What if the file is not in the db?
450 if db_item is None:
451 db_item = _TestItem(
452 file=file,
453 lineno=lineno,
454 testname=testname,
455 last_run=last_run
456 )
457 dbsession.add(db_item)
458 else:
459 db_item.last_run = last_run
461 dbsession.commit()
463 return db_item
465 @classmethod
466 def _try_delete_item(
467 cls, dbsession: _orm.Session, location: _Location
468 ) -> bool:
469 """Try to delete an item from the database.
471 Args:
472 location (Location): The test item location.
474 Raises:
475 MultipleResultsFound: If multiple files are found with the
476 path specified in `location`. If multiple items to
477 delete are found. These should not happen due to
478 uniqueness constraints in the database and would imply
479 that it is corrupted in some way.
481 Returns:
482 bool: If the file was deleted from/present in the database.
483 """
484 file, lineno, testname = location
486 # May raise MultipleResultsFound:
487 db_file = cls._try_get_file(dbsession, file)
489 # If the file is not in the database, then the item will not be
490 # either.
491 if db_file is None:
492 return False
494 stmt = _sqlalchemy.delete(_TestItem).where(
495 _TestItem.file == db_file,
496 _TestItem.lineno == lineno,
497 _TestItem.testname == testname
498 )
499 rc = dbsession.execute(stmt).rowcount
500 if rc < 0:
501 raise AssertionError
502 if rc > 1:
503 raise _exceptions.MultipleResultsFound(
504 "Found multiple matching items."
505 ) from AssertionError
506 return rc == 1
508 def try_delete_item(self, location: _Location) -> bool:
509 """Try to delete an item from the database.
511 Args:
512 location (Location): The test item location.
514 Raises:
515 InactiveSessionError: If the session was not started by
516 calling __enter__.
517 MultipleResultsFound: If multiple files are found with the
518 path specified in `location`. If multiple items to
519 delete are found. These should not happen due to
520 uniqueness constraints in the database and would imply
521 that it is corrupted in some way.
523 Returns:
524 bool: If the file was deleted from/present in the database.
525 """
526 dbsession = self._ensure_session()
527 with dbsession.begin_nested():
528 # May raise MultipleResultsFound:
529 result = self._try_delete_item(dbsession, location)
530 if result:
531 dbsession.commit()
532 return result
534 def bulk_add_update_remove(
535 self,
536 last_run: _Arrow,
537 add_update: _Iterable[_Location],
538 try_delete: _Iterable[_Location],
539 hash_provider: _Optional[TestFileHashProvider],
540 prune_files: bool
541 ) -> BulkUpdateResult:
542 """Perform a bulk update by adding, updating and deleting test
543 items.
545 Args:
546 last_run (Arrow): The Arrow instance to set as the last run
547 time for the test items.
548 add_update (Iterable[Location]): The tests to add and update
549 with the new `last_run` time.
550 try_delete (Iterable[Location]): The items to delete if they
551 are present.
552 hash_provider (Optional[TestFileHashProvider]): If required,
553 a `TestFileHashProvider` for computing the hashes of
554 test files. Otherwise None.
555 hash_provider (bool): Whether to prune orphaned files from
556 the database.
558 Raises:
559 InactiveSessionError: If the session was not started by
560 calling __enter__.
561 IntegrityError: If the added items do not fulfill uniqueness
562 constraints. This should not happen since the second
563 appearance of a location would simply lead to an update.
564 MultipleResultsFound: If multiple files are found with a
565 path specified in one of the location items.
566 If multiple items to delete are found. If multiple items
567 are found that could be updated. These should not
568 happen due to uniqueness constraints in the database and
569 would imply that it is corrupted in some way.
570 NoResultFound: If no file is found with a path specified in
571 the locations for items that should be deleted.
573 Returns:
574 BulkUpdateResult: An object containing the amount of added,
575 updated and deleted items.
576 """
578 dbsession = self._ensure_session()
580 def add_or_update(location: _Location) -> bool:
581 """Add a new item or update its last run time if it already
582 exists. Does not perform a commit.
584 Args:
585 location (Location): The test item location.
587 Returns:
588 bool: Whether the item was already in the database.
589 """
590 file, lineno, testname = location
592 # Get or add the file
593 db_file = self.try_get_file(file)
594 if db_file is None:
595 db_file = _TestFile(path=file, last_hash=None)
596 if hash_provider is not None:
597 hash = hash_provider(db_file)
598 db_file.last_hash = hash
599 dbsession.add(db_file)
600 # The item cannot exist yet if the database is
601 # consistent
602 db_item = None
603 else:
604 db_item = self._try_get_item(
605 dbsession, db_file, lineno, testname
606 )
608 if db_item is None:
609 db_item = _TestItem(
610 file=db_file,
611 lineno=lineno,
612 testname=testname,
613 last_run=last_run
614 )
615 dbsession.add(db_item)
616 return False
618 db_item.last_run = last_run
619 return True
621 deleter = _functools.partial(self._try_delete_item, dbsession)
623 with dbsession.begin_nested():
625 results = map(add_or_update, add_update)
626 updated, added = _tools.count_truefalse(results)
627 dbsession.flush()
629 removed = sum(map(deleter, try_delete))
631 if prune_files:
632 dbsession.flush()
633 pruned = self._prune_files(dbsession)
634 else:
635 pruned = None
637 dbsession.commit()
639 return BulkUpdateResult(added, updated, removed, pruned)
641 def drop_all_entries(self) -> DropAllEntriesResult:
642 """Clean the database completely by dropping all `TestItem` and
643 `TestFile` entries.
645 Raises:
646 InactiveSessionError: If the session was not started by
647 calling __enter__.
649 Returns:
650 DropAllEntriesResult: A result object describing the number
651 of dropped entries from each table.
652 """
654 dbsession = self._ensure_session()
656 with dbsession.begin_nested():
658 deld_items = dbsession.query(_TestItem).delete()
659 deld_files = dbsession.query(_TestFile).delete()
661 dbsession.commit()
663 return DropAllEntriesResult(deld_files, deld_items)