Coverage for src/pytest_samples/plugin/_broker_stateful.py: 100%
303 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-20 19:47 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-20 19:47 +0000
1import arrow as _arrow
2import logging as _logging
3import os.path as _ospath
4import pytest as _pytest
5import warnings as _warnings
7from abc import abstractmethod as _abstractmethod
8from dataclasses import dataclass as _dataclass
9from datetime import timedelta as _timedelta
10from enum import IntEnum as _IntEnum
11from pluggy import Result as _Result
12from pytest import hookimpl as _hookimpl, Item as _Item, \
13 TestReport as _TestReport
14from typing import Callable as _Callable, Dict as _Dict, \
15 Iterator as _Iterator, List as _List, Literal as _Literal, \
16 Optional as _Optional, Set as _Set, TYPE_CHECKING as _TYPE_CHECKING
18from . import _meta
19from ._broker_base import SamplesBrokerBase as _SamplesBrokerBase
20from .. import tools as _tools
21from ..types import Location as _Location
23if _TYPE_CHECKING: # pragma: no cover
24 # These are only used for type hints.
25 from ..database import Engine as _Engine, Session as _Session, \
26 TestFile as _TestFile, TestFileHashProvider as _TestFileHashProvider
29_logger = _logging.getLogger(__name__)
30"""The logger for this module."""
33_TestResultState = _Literal[
34 "passed", "skipped", "failed", "xfailed", "xpassed" # noqa: F821
35]
38class _TestResultAction(_IntEnum):
39 """Represents the reduced set of results used to decide whether to
40 write a test to the database, remove it or do neither.
41 """
43 WRITE = 0
44 """Write the test to the database, it counts as "passed"."""
46 DROP = 1
47 """Remove the test from the database, it counts as "failed"."""
49 IGNORE = 2
50 """Ignore the test, for example when it was skipped."""
53class TestResultStateNotImplementedWarning(UserWarning):
54 """Warning that is issued when a pytest test result state is not
55 implemented.
56 """
57 pass
60@_dataclass(frozen=True)
61class DatabaseItemFilterResult:
63 __slots__ = ("known_test_indices", "last_run_map")
65 known_test_indices: _List[int]
66 """The indices of items that are known from the database."""
68 last_run_map: _Dict[_Item, float]
69 """A dictionary mapping test items to the time passed since they
70 were last run.
71 """
74class StatefulSamplesBroker(_SamplesBrokerBase):
75 """ABC for samples broker in "stateful" mode."""
77 __slots__ = (
78 "_rootpath",
79 "_db_path",
80 "_hash_testfiles",
81 "_randomize",
82 "_no_pruning",
83 "_engine",
84 "_reset_on_saturation",
85 "_num_tests",
86 "_overwrite_broken_db",
87 "_session_finished"
88 )
90 def __init__(
91 self,
92 rootpath: str,
93 soft_timeout: _timedelta,
94 seed: _Optional[str],
95 db_path: str,
96 hash_testfiles: bool,
97 randomize: bool,
98 no_pruning: bool,
99 reset_on_saturation: bool,
100 overwrite_broken_db: bool
101 ) -> None:
102 """Initialize a new `StatefulSamplesBroker`.
104 Args:
105 rootpath (str): The path to the pytest root.
106 soft_timeout (timedelta): The time after which the timeout
107 should occur.
108 seed (Optional[str]): The seed for the RNG.
109 db_path (str): The path to the database to use for the state
110 information.
111 hash_testfiles (bool): Whether to hash the test files to
112 check when tests may have changed.
113 randomize (bool): Whether to randomize the "new" tests.
114 no_pruning (bool): Whether to keep old test and file entries
115 in the database.
116 reset_on_saturation (bool): Whether to drop all entries once
117 all tets have passed once.
118 overwrite_broken_db (bool): Whether to overwrite broken
119 database files.
120 enable_db_logging (bool): Whether to enable database related
121 logging.
122 """
123 super().__init__(soft_timeout, seed)
124 """The total number of tests found."""
125 self._rootpath = rootpath
126 """The path to the pytest root."""
127 self._db_path = db_path
128 """Whether to write the state to the database immediately."""
129 self._hash_testfiles = hash_testfiles
130 """Whether to hash test files."""
131 self._randomize = randomize
132 """Whether to run the tests in state mode in a random order.
133 Known tests will still be moved to the end of the chain.
134 """
135 self._no_pruning = no_pruning
136 """Whether pruning of remnant files and tests is disabled."""
137 self._engine: "_Optional[_Engine]" = None
138 """Stores the database engine created after all tests are
139 collected. Will be instantiated in
140 `pytest_collection_modifyitems`.
141 """
142 self._reset_on_saturation = reset_on_saturation
143 """Whether to drop all entries once all tets have passed
144 once.
145 """
146 self._num_tests: _Optional[int] = None
147 """The number of tests found."""
148 self._overwrite_broken_db = overwrite_broken_db
149 """Whether to overwrite broken database files."""
151 self._session_finished: bool = False
152 """Whether the `pytest_sessionfinish` hook was called."""
154 self._post_init()
156 def _make_pytest_abspath(self, path: str) -> str:
157 """Convert a path relative to the pytest root to an absolute
158 path. This is necessary whenever the file system is accessed
159 not via pytest because only the paths of test files relative
160 to the rootpath are known.
162 Args:
163 path (str): The path to convert.
165 Returns:
166 str: The absolute path.
167 """
168 return _ospath.normpath(_ospath.join(self._rootpath, path))
170 def _post_init(self) -> None: # pragma: no cover
171 """Called from the base class's __init__ after all fields
172 have been assigned. Derived classes can use this to set
173 fields.
174 """
175 pass
177 def _hash_file(self, path: str) -> bytes:
178 """Hash a file.
180 Args:
181 path (str): The path to the file to hash.
183 Returns:
184 bytes: The hash of the file.
185 """
186 from .._hashing import hash_file
187 abspath = self._make_pytest_abspath(path)
188 return hash_file(abspath)
190 @classmethod
191 def _setup_tables(cls, engine: "_Engine", overwrite: bool) -> None:
192 """Set up the database tables.
194 Args:
195 engine (Engine): The engine.
196 overwrite (bool): Whether to overwrite broken database
197 files.
199 Raises:
200 UsageError: If an error occurs during setup.
201 """
203 second_attempt = False
205 while True:
206 exception = None
208 try:
209 engine.setup_tables()
210 return
211 except Exception as de:
212 _logger.exception(
213 "An exception occurred when setup_tables() was called."
214 )
215 exception = de
217 if not overwrite:
218 raise _pytest.UsageError(
219 "The provided database file is invalid. The detailed "
220 "exception has been written to the logger."
221 ) from exception
223 if second_attempt:
224 break # pragma: no cover
226 second_attempt = True
228 try:
229 engine.truncate_database_file()
230 except IsADirectoryError as iade: # pragma: no cover
231 raise _pytest.UsageError(
232 "The provided database file path points to a directory."
233 ) from iade
235 raise _pytest.UsageError( # pragma: no cover
236 "The provided database file is invalid and truncating it did "
237 "not resolve the issue. The detailed exception has been "
238 "written to the logger."
239 )
241 @classmethod
242 def _setup_engine(cls, path: str, overwrite: bool) -> "_Engine":
243 """Set up the database engine.
245 Args:
246 path (str): The path to the database file.
247 overwrite (bool): Whether to overwrite broken database
248 files.
250 Raises:
251 UsageError: If an error occurs during setup.
253 Returns:
254 Engine: The database engine.
255 """
256 from ..database import Engine, RelativePathError
258 try:
259 engine = Engine(path)
260 except RelativePathError: # pragma: no cover
261 _logger.exception(
262 "A relative path was provided to the Engine constructor."
263 )
264 # This is an internal error which the user cannot fix
265 raise
266 except Exception as e: # pragma: no cover
267 _logger.exception("Error with Engine initialization.")
268 raise _pytest.UsageError(str(e)) from e
270 try:
271 cls._setup_tables(engine, overwrite)
272 except BaseException:
273 engine.dispose()
274 raise
276 return engine
278 @classmethod
279 def _compare_against_database( # noqa: C901
280 cls,
281 session: "_Session",
282 items: _List[_Item],
283 hash_func: _Optional[_Callable[[str], bytes]]
284 ) -> DatabaseItemFilterResult:
285 """Compare the found test items against the database and return
286 information regarding the found items and their last run time.
287 This may modify the database if a `hash_func` is provided.
289 Args:
290 session (Session): The database session.
291 items (List[Item]): The test items collected by pytest.
292 hash_func (Optional[Callable[[str], bytes]]): A function
293 providing hashes for the (relative) file paths of test
294 items if hashing is requested.
296 Returns:
297 DatabaseItemFilterResult: An object containing the indices
298 of test items found in the database and information
299 regarding their last run time.
300 """
302 @_dataclass(frozen=True)
303 class ItemWithFile:
305 __slots__ = ("item_index", "pytest_item", "file")
307 item_index: int
308 """The index of the test item."""
310 pytest_item: _Item
311 """The test item."""
313 file: "_TestFile"
314 """The database file entry in which the test item
315 appears.
316 """
318 def filter_with_files() -> _Iterator[ItemWithFile]:
319 """Filter out test items which do not have an associated
320 file.
322 Args:
323 items (Iterator[Item]): An iterator iterating over
324 available test items.
326 Returns:
327 Iterator[ItemWithFile]: An iterable yielding all test
328 items which have an associated file in the database.
329 """
330 for i, item in enumerate(items):
331 file = item.location[0]
332 db_file = session.try_get_file(file)
333 if db_file is None:
334 # Since the file is not known, the test cannot be
335 # known.
336 continue
337 yield ItemWithFile(i, item, db_file)
339 items_source = filter_with_files()
341 def filter_hashed_and_update(
342 items: _Iterator[ItemWithFile], hash_func: _Callable[[str], bytes]
343 ) -> _Iterator[ItemWithFile]:
344 """Filter out invalidated items from the provided iterable
345 and update the file entries' hashes if necessary. Note that
346 this may modify the underlying database.
348 Args:
349 items (Iterator[ItemWithFile]): An iterator iterating
350 over available test items and their files.
352 Returns:
353 Iterator[ItemWithFile]: An iterable yielding all test
354 items which have an associated file in the database
355 and have not been invalidated due to a changed file
356 hash.
357 """
358 # Keep track of files whose hash has been updated. The tests
359 # will still be considered out of date. This is a shortcut
360 # since the test would have been removed from the database
361 # anyway and will not be found further down the pipeline.
362 ok_file_ids: _Set[int] = set()
363 updated_file_ids: _Set[int] = set()
365 for item in items:
366 file = item.file
367 id = file.id
368 if id in updated_file_ids:
369 # The file hash has been updated and the test
370 # will not be found below. This is a shortcut.
371 continue
372 if id not in ok_file_ids:
373 file_path = item.pytest_item.location[0]
374 hash = hash_func(file_path)
375 if file.last_hash != hash:
376 num_del = session.invalidate_hash(file, hash)
377 _logger.info(
378 "Removed %s test items when updating hash.",
379 num_del
380 )
381 updated_file_ids.add(id)
382 else:
383 ok_file_ids.add(id)
384 yield item
386 if hash_func is not None:
387 items_source = filter_hashed_and_update(items_source, hash_func)
389 known_test_indices: _List[int] = list()
390 last_run_map: _Dict[_Item, float] = dict()
392 now = _arrow.utcnow()
394 for item in items_source:
395 pytest_test_item = item.pytest_item
396 _, lineno, name = pytest_test_item.location
397 db_item = session.try_get_item(item.file, lineno, name)
398 if db_item is None:
399 # No known successful run for this test item.
400 continue
401 time = (now - db_item.last_run).total_seconds()
402 last_run_map[pytest_test_item] = time
403 known_test_indices.append(item.item_index)
405 return DatabaseItemFilterResult(known_test_indices, last_run_map)
407 @_hookimpl(trylast=True)
408 def pytest_collection_modifyitems(self, items: _List[_Item]) -> None:
409 """The function called for the pytest "collection_modifyitems"
410 hook.
412 Args:
413 session (Session): The pytest session.
414 config (Config): The pytest config.
415 items (List[Item]): The list of collected items.
416 """
418 if self._randomize:
419 # XXX: This could be changed by moving the shuffle operation
420 # to after the move_idx_to_end-call below. Here, all
421 # elements will be shuffled, even those that will be moved
422 # to the end of the list anyway. However, the builtin
423 # Random does not support shuffle on a sublist
424 # out-of-the-box.
425 self._shuffle_items(items)
427 engine = self._setup_engine(
428 self._db_path,
429 self._overwrite_broken_db
430 )
431 self._engine = engine
433 num_tests = len(items)
434 self._num_tests = num_tests
436 if num_tests == 0:
437 return
439 prune = not self._no_pruning
441 with engine.new_session() as session: # pragma: no branch
443 # Keep track of items that we did not find in the database.
444 filter_res = self._compare_against_database(
445 session,
446 items,
447 (self._hash_file if self._hash_testfiles else None)
448 )
450 known_test_indices = filter_res.known_test_indices
452 if len(known_test_indices) == num_tests:
453 _logger.info("The database is saturated.")
454 if self._reset_on_saturation:
455 drop_res = session.drop_all_entries()
456 _logger.info("Saturated: %s", drop_res)
457 return
459 if prune:
460 # Prune orphaned entries.
461 known_locations = set(map(lambda it: it.location, items))
462 num = session.prune_items(known_locations)
463 _logger.info("Pruned %s disappeared tests.", num)
465 _logger.info(
466 "Moving %s items to end of list.", len(known_test_indices)
467 )
468 _tools.move_idx_to_end(
469 items,
470 known_test_indices,
471 sorting_key=filter_res.last_run_map.__getitem__
472 )
474 def _ensure_engine(self) -> "_Engine":
475 """Obtain the database engine or raise an exception.
477 Raises:
478 RuntimeError: If the engine was not yet assigned.
480 Returns:
481 Engine: The database engine.
482 """
483 engine = self._engine
484 if engine is None:
485 raise RuntimeError("The engine was never assigned.")
486 return engine
488 @_hookimpl(tryfirst=True, hookwrapper=True)
489 def pytest_report_teststatus(self, report: _TestReport):
491 # For failed items, this hook is called again after
492 # pytest_sessionfinish. But the database will be destroyed after
493 # that call. Therefore, do nothing here.
494 # XXX: Is this intended behavior of pytest or a bug?
495 if self._session_finished:
496 yield
497 return
499 # Either way, this can be called multiple times, for example if
500 # an error occurs in the teardown of a fixture that a test may
501 # use.
503 outcome: _Result = yield
504 state, *_ = outcome.get_result()
505 if state == '':
506 return
508 self._process_result(state, report.location)
510 @_abstractmethod
511 def _process_result(
512 self, state: _TestResultState, location: _Location
513 ) -> None:
514 """Process the result of a tests.
516 Args:
517 state (_TestResultState): The test result.
518 location (_Location): The location of the tests.
519 """
520 pass
522 def pytest_sessionfinish(self, exitstatus: int) -> None:
523 """Pytest hook called after all tests have finish.
525 Args:
526 exitstatus (int): The exit code.
527 """
528 try:
529 self._sessionfinish(exitstatus)
530 finally:
531 engine = self._engine
532 if engine is not None:
533 engine.dispose()
534 self._session_finished = True
536 def _sessionfinish(self, exitstatus: int) -> None:
537 """Called after all tests have finish. After the call the engine
538 will be disposed.
540 Args:
541 exitstatus (int): The exit code.
542 """
543 pass # pragma: no cover
545 def _ensure_num_tests(self) -> int:
546 """Obtain the stored number of tests or raise an exception if it
547 is not available.
549 Returns:
550 int: The total number of tests.
551 """
552 num_tests = self._num_tests
553 if num_tests is None:
554 raise AssertionError("'_num_tests' was not set.")
555 return num_tests
557 def _check_test_nums(self, passed: int, failed: int) -> None:
558 """Check that the number of passed and failed tests is smaller
559 than the total number of tests. Otherwise, write this
560 information to the logger as an error.
562 Args:
563 passed (int): Counted number of passed tests.
564 failed (int): Counted number of failed tests.
565 """
566 num_tests = self._ensure_num_tests()
567 if (passed + failed) <= num_tests:
568 return
569 _logger.error( # pragma: no cover
570 "The number of passed (%s) and failed (%s) tests is larger "
571 "than the number of found tests (%s).",
572 passed, failed, num_tests
573 )
575 @classmethod
576 def _is_error_exitstatus(cls, exitstatus: int) -> bool:
577 """Check whether the exit status indicates an error (but not
578 necessarily failed tests).
580 Args:
581 exitstatus (int): The exit code to check.
583 Returns:
584 bool: True, if the exit code is not 0 (all tests passed),
585 1 (some tests failed), 5 (no tests collected).
586 """
587 return exitstatus not in (0, 1, 5)
589 def _decide_result(
590 self, state: _TestResultState
591 ) -> _TestResultAction:
592 """Decide on whether to store, drop or ignore a test.
594 Args:
595 state (_TestResultState): The result string from pytset.
597 Returns:
598 _TestResultAction: The action.
599 """
600 if state == "skipped":
601 return _TestResultAction.IGNORE
602 if state in ("passed", "xpassed", "xfailed"):
603 return _TestResultAction.WRITE
604 if state in ("failed", "error"):
605 return _TestResultAction.DROP
606 _warnings.warn( # pragma: no cover # noqa: G010
607 f"Unexpected test result state: {state!r}." # noqa: G004
608 "Test will be ignored. This is an error "
609 f"in {_meta.PLUGIN_FULL_NAME}.",
610 category=TestResultStateNotImplementedWarning
611 )
612 return _TestResultAction.IGNORE # pragma: no cover
615class ImmediateStatefulSamplesBroker(StatefulSamplesBroker):
616 """The stateful plugin that directly writes each test result to the
617 database.
618 """
620 __slots__ = (
621 "_num_failed_tests",
622 "_num_passed_tests"
623 )
625 _num_failed_tests: int
626 """The number of failed tests."""
628 _num_passed_tests: int
629 """The number of passed tests."""
631 def _post_init(self) -> None:
632 """Called from the base class's __init__ after all fields
633 have been assigned. Derived classes can use this to set
634 fields.
635 """
636 self._num_failed_tests = 0
637 self._num_passed_tests = 0
639 def _process_result(
640 self, state: _TestResultState, location: _Location
641 ) -> None:
642 """Process the result of a tests.
644 Args:
645 state (_TestResultState): The test result.
646 location (_Location): The location of the tests.
647 """
648 action = self._decide_result(state)
650 if action == _TestResultAction.IGNORE:
651 return
653 engine = self._ensure_engine()
655 with engine.new_session() as session:
656 if action == _TestResultAction.DROP:
657 session.try_delete_item(location)
658 self._num_failed_tests += 1
659 return
660 if action == _TestResultAction.WRITE:
661 now = _arrow.utcnow()
662 file, lineno, testname = location
663 db_file = session.try_get_file(file)
664 if db_file is None:
665 if self._hash_testfiles:
666 hash = self._hash_file(file)
667 else:
668 hash = None
669 db_file = session.add_file(file, hash)
670 session.add_or_update_item(
671 db_file, lineno, testname, now
672 )
673 self._num_passed_tests += 1
674 return
675 raise AssertionError(f"Invalid action {action!r}.")
677 def _sessionfinish(self, exitstatus: int) -> None:
678 """Called after all tests have finish. After the call the engine
679 will be disposed.
681 Args:
682 exitstatus (int): The exit code.
683 """
684 if self._is_error_exitstatus(exitstatus):
685 return # pragma: no cover
687 num_passed = self._num_passed_tests
688 num_failed = self._num_failed_tests
689 self._check_test_nums(num_passed, num_failed)
691 engine = self._ensure_engine()
693 num_tests = self._ensure_num_tests()
694 if num_passed == num_tests:
695 _logger.info("All tests have passed.")
696 if self._reset_on_saturation:
697 with engine.new_session() as dbsession:
698 drop_res = dbsession.drop_all_entries()
699 _logger.info("Saturated: %s", drop_res)
700 # Pruning below will not change that the database is
701 # empty.
702 return
704 if self._no_pruning:
705 return
707 with engine.new_session() as session:
708 pfiles = session.prune_files()
709 _logger.info("Pruned %s files from the database.", pfiles)
712class LazyStatefulSamplesBroker(StatefulSamplesBroker):
713 """The stateful plugin that writes the test results to the database
714 after all tests have finished.
715 """
717 __slots__ = (
718 "_failed_tests",
719 "_passed_tests"
720 )
722 _failed_tests: _List[_Location]
723 """Collects the failed tests."""
725 _passed_tests: _List[_Location]
726 """Collects the passed tests."""
728 def _post_init(self) -> None:
729 """Called from the base class's __init__ after all fields
730 have been assigned. Derived classes can use this to set
731 fields.
732 """
733 self._failed_tests = list()
734 self._passed_tests = list()
736 def _process_result(
737 self, state: _TestResultState, location: _Location
738 ) -> None:
739 """Process the result of a tests.
741 Args:
742 state (_TestResultState): The test result.
743 location (_Location): The location of the tests.
744 """
745 action = self._decide_result(state)
747 if action == _TestResultAction.IGNORE:
748 return
750 if action == _TestResultAction.DROP:
751 lst = self._failed_tests
752 elif action == _TestResultAction.WRITE:
753 lst = self._passed_tests
754 else:
755 raise AssertionError(f"Invalid action {action!r}.")
757 lst.append(location)
759 def _sessionfinish(self, exitstatus: int) -> None:
760 """Called after all tests have finish. After the call the engine
761 will be disposed.
763 Args:
764 exitstatus (int): The exit code.
765 """
766 if self._is_error_exitstatus(exitstatus):
767 return # pragma: no cover
769 hash_provider: _Optional[_TestFileHashProvider]
770 if self._hash_testfiles:
772 def hash_provider(test_file: "_TestFile") -> bytes:
773 """Provide the hash for a `TestFile`.
775 Args:
776 test_file (TestFile): The instance to hash.
778 Returns:
779 bytes: The hash of the file
780 """
781 return self._hash_file(test_file.path)
783 else:
784 hash_provider = None
786 passed_tests = set(self._passed_tests)
787 failed_tests = set(self._failed_tests)
789 # If for example a test passes, but a fixture it depends on
790 # fails on teardown, the item will have been first aded to
791 # _passed_tests and later again to _failed_tests. Hence, all
792 # duplicates have to be removed from the passed tests.
793 passed_tests.difference_update(failed_tests)
795 num_passed_tests = len(passed_tests)
796 num_failed_tests = len(failed_tests)
797 self._check_test_nums(num_passed_tests, num_failed_tests)
799 engine = self._ensure_engine()
801 num_tests = self._ensure_num_tests()
802 if num_passed_tests == num_tests:
803 _logger.info("All tests have passed.")
804 if self._reset_on_saturation:
805 with engine.new_session() as dbsession:
806 drop_res = dbsession.drop_all_entries()
807 _logger.info("Saturated: %s", drop_res)
808 return
810 with engine.new_session() as dbsession:
811 result = dbsession.bulk_add_update_remove(
812 _arrow.utcnow(),
813 passed_tests,
814 failed_tests,
815 hash_provider,
816 prune_files=(not self._no_pruning)
817 )
819 _logger.info("Updated items test run: %s", result)