147 lines
4.2 KiB
Python
147 lines
4.2 KiB
Python
# coding: utf-8
|
|
"""
|
|
:class:`asynctest.TestCase` decorator which controls checks performed after
|
|
tests.
|
|
|
|
This module is separated from :mod:`asynctest.case` to avoid circular imports
|
|
in modules registering new checks.
|
|
|
|
To implement new checks:
|
|
|
|
* its name must be added in the ``DEFAULTS`` dict,
|
|
|
|
* a static method of the same name must be added to the :class:`_fail_on`
|
|
class,
|
|
|
|
* an optional static method named ``before_[name of the check]`` can be
|
|
added to :class:`_fail_on` to implement some set-up before the test runs.
|
|
|
|
A check may be only available on some platforms, activated by a conditional
|
|
import. In this case, ``DEFAULT`` and :class:`_fail_on` can be updated in the
|
|
module. There is an example in the :mod:`asynctest.selector` module.
|
|
"""
|
|
from asyncio import TimerHandle
|
|
|
|
|
|
_FAIL_ON_ATTR = "_asynctest_fail_on"
|
|
|
|
|
|
#: Default value of the arguments of @fail_on, the name of the argument matches
|
|
#: the name of the static method performing the check in the :class:`_fail_on`.
|
|
#: The value is True when the check is enabled by default, False otherwise.
|
|
DEFAULTS = {
|
|
"unused_loop": False,
|
|
"active_handles": False,
|
|
}
|
|
|
|
|
|
class _fail_on:
|
|
def __init__(self, checks=None):
|
|
self.checks = checks or {}
|
|
self._computed_checks = None
|
|
|
|
def __call__(self, func):
|
|
checker = getattr(func, _FAIL_ON_ATTR, None)
|
|
if checker:
|
|
checker = checker.copy()
|
|
checker.update(self.checks)
|
|
else:
|
|
checker = self.copy()
|
|
|
|
setattr(func, _FAIL_ON_ATTR, checker)
|
|
return func
|
|
|
|
def update(self, checks, override=True):
|
|
if override:
|
|
self.checks.update(checks)
|
|
else:
|
|
for check, value in checks.items():
|
|
self.checks.setdefault(check, value)
|
|
|
|
def copy(self):
|
|
return _fail_on(self.checks.copy())
|
|
|
|
def get_checks(self, case):
|
|
# cache the result so it's consistent across calls to get_checks()
|
|
if self._computed_checks is None:
|
|
checks = DEFAULTS.copy()
|
|
|
|
try:
|
|
checks.update(getattr(case, _FAIL_ON_ATTR, None).checks)
|
|
except AttributeError:
|
|
pass
|
|
|
|
checks.update(self.checks)
|
|
self._computed_checks = checks
|
|
|
|
return self._computed_checks
|
|
|
|
def before_test(self, case):
|
|
checks = self.get_checks(case)
|
|
for check in filter(checks.get, checks):
|
|
try:
|
|
getattr(self, "before_test_" + check)(case)
|
|
except (AttributeError, TypeError):
|
|
pass
|
|
|
|
def check_test(self, case):
|
|
checks = self.get_checks(case)
|
|
for check in filter(checks.get, checks):
|
|
getattr(self, check)(case)
|
|
|
|
# checks
|
|
|
|
@staticmethod
|
|
def unused_loop(case):
|
|
if not case.loop._asynctest_ran:
|
|
case.fail("Loop did not run during the test")
|
|
|
|
@staticmethod
|
|
def _is_live_timer_handle(handle):
|
|
return isinstance(handle, TimerHandle) and not handle._cancelled
|
|
|
|
@classmethod
|
|
def _live_timer_handles(cls, loop):
|
|
return filter(cls._is_live_timer_handle, loop._scheduled)
|
|
|
|
@classmethod
|
|
def active_handles(cls, case):
|
|
handles = tuple(cls._live_timer_handles(case.loop))
|
|
if handles:
|
|
case.fail("Loop contained unfinished work {!r}".format(handles))
|
|
|
|
|
|
def fail_on(**kwargs):
|
|
"""
|
|
Enable checks on the loop state after a test ran to help testers to
|
|
identify common mistakes.
|
|
"""
|
|
# documented in asynctest.case.rst
|
|
for kwarg in kwargs:
|
|
if kwarg not in DEFAULTS:
|
|
raise TypeError("fail_on() got an unexpected keyword argument "
|
|
"'{}'".format(kwarg))
|
|
|
|
return _fail_on(kwargs)
|
|
|
|
|
|
def _fail_on_all(flag, func):
|
|
checker = _fail_on(dict((arg, flag) for arg in DEFAULTS))
|
|
return checker if func is None else checker(func)
|
|
|
|
|
|
def strict(func=None):
|
|
"""
|
|
Activate strict checking of the state of the loop after a test ran.
|
|
"""
|
|
# documented in asynctest.case.rst
|
|
return _fail_on_all(True, func)
|
|
|
|
|
|
def lenient(func=None):
|
|
"""
|
|
Deactivate all checks after a test ran.
|
|
"""
|
|
# documented in asynctest.case.rst
|
|
return _fail_on_all(False, func)
|