forked from rc/aircox
218 lines
6.4 KiB
Python
218 lines
6.4 KiB
Python
"""This module provide test utilities."""
|
|
from collections import namedtuple
|
|
import inspect
|
|
|
|
|
|
__all__ = ("interface", "Interface")
|
|
|
|
|
|
def interface(obj, funcs):
|
|
"""Override provided object's functions using dict of funcs, as
|
|
``{func_name: return_value}``.
|
|
|
|
Attribute ``obj.calls`` is a dict with all call done using those
|
|
methods, as ``{func_name: (args, kwargs) | list[(args, kwargs]]}``.
|
|
"""
|
|
if not isinstance(getattr(obj, "calls", None), dict):
|
|
obj.calls = {}
|
|
for attr, value in funcs.items():
|
|
interface_wrap(obj, attr, value)
|
|
|
|
|
|
def interface_wrap(obj, attr, value):
|
|
obj.calls[attr] = None
|
|
|
|
def wrapper(*a, **kw):
|
|
call = obj.calls.get(attr)
|
|
if call is None:
|
|
obj.calls[attr] = (a, kw)
|
|
elif isinstance(call, tuple):
|
|
obj.calls[attr] = [call, (a, kw)]
|
|
else:
|
|
call.append((a, kw))
|
|
return value
|
|
|
|
setattr(obj, attr, wrapper)
|
|
return wrapper
|
|
|
|
|
|
InterfaceTarget = namedtuple(
|
|
"InterfaceTarget",
|
|
["target", "namespace", "key"],
|
|
defaults=[("namespace", None), ("key", None)],
|
|
)
|
|
|
|
|
|
class WrapperMixin:
|
|
def __init__(self, target=None, ns=None, ns_attr=None, **kwargs):
|
|
self.target = target
|
|
if ns:
|
|
self.inject(ns, ns_attr)
|
|
super().__init__(**kwargs)
|
|
|
|
@property
|
|
def ns_target(self):
|
|
if self.ns and self.ns_attr:
|
|
return getattr(self.ns, self.ns_attr, None)
|
|
return None
|
|
|
|
def inject(self, ns=None, ns_attr=None):
|
|
if ns and ns_attr:
|
|
ns_target = getattr(ns, ns_attr, None)
|
|
if self.target is ns_target:
|
|
return
|
|
elif self.target is not None:
|
|
raise RuntimeError(
|
|
"self target already injected. It must be "
|
|
"`release` before `inject`."
|
|
)
|
|
self.target = ns_target
|
|
setattr(ns, ns_attr, self.parent)
|
|
elif not ns or not ns_attr:
|
|
raise ValueError("ns and ns_attr must be provided together")
|
|
self.ns = ns
|
|
self.ns_attr = ns_attr
|
|
|
|
def release(self):
|
|
if self.ns_target is self:
|
|
setattr(self.target.namespace, self.target.name, self.target)
|
|
self.target = None
|
|
|
|
|
|
class SpoofMixin:
|
|
traces = None
|
|
|
|
def __init__(self, funcs=None, **kwargs):
|
|
self.reset(funcs or {})
|
|
super().__init__(**kwargs)
|
|
|
|
def reset(self, funcs=None):
|
|
self.traces = {}
|
|
if funcs is not None:
|
|
self.funcs = funcs
|
|
|
|
def get_trace(self, name, args=False, kw=False):
|
|
"""Get a function call parameters.
|
|
|
|
:param str name: function name
|
|
:param bool args: return positional arguments
|
|
:param bool kwargs: return named arguments
|
|
:returns either a tuple of args, a dict of kwargs, or a tuple \
|
|
of `(args, kwargs)`.
|
|
:raises ValueError: the function has been called multiple time.
|
|
"""
|
|
trace = self.traces[name]
|
|
if isinstance(trace, list):
|
|
raise ValueError(f"{name} called multiple times.")
|
|
return self._get_trace(trace, args=args, kw=kw)
|
|
|
|
def get_traces(self, name, args=False, kw=False):
|
|
"""Get a tuple of all call parameters.
|
|
|
|
Parameters are the same as `get()`.
|
|
"""
|
|
traces = self.traces[name]
|
|
if not isinstance(traces, list):
|
|
traces = (traces,)
|
|
return tuple(
|
|
self._get_trace(trace, args=args, kw=kw) for trace in traces
|
|
)
|
|
|
|
def _get_trace(self, trace, args=False, kw=False):
|
|
if (args and kw) or (not args and not kw):
|
|
return trace
|
|
elif args:
|
|
return trace[0]
|
|
elif isinstance(kw, str):
|
|
return trace[1][kw]
|
|
return trace[1]
|
|
|
|
def call(self, name, args, kw):
|
|
"""Add call for function of provided name, and return predefined
|
|
result."""
|
|
self.add(name, args, kw)
|
|
return self.get_result(name, args, kw)
|
|
|
|
def add(self, name, args, kw):
|
|
"""Add call parameters to `self.traces` for the function with the
|
|
provided `name`."""
|
|
trace = self.traces.get(name)
|
|
if trace is None:
|
|
self.traces[name] = (args, kw)
|
|
elif isinstance(trace, tuple):
|
|
self.traces[name] = [trace, (args, kw)]
|
|
else:
|
|
trace.append((args, kw))
|
|
|
|
def get_result(self, name, a, kw):
|
|
"""Get result for the function of the provided `name`.
|
|
|
|
:raises KeyError: no registered function with this `name`.
|
|
"""
|
|
func = self.funcs[name]
|
|
if callable(func):
|
|
return func(*a, **kw)
|
|
return func
|
|
|
|
|
|
class InterfaceMeta(SpoofMixin, WrapperMixin):
|
|
calls = None
|
|
"""Calls done."""
|
|
|
|
def __init__(self, parent, **kwargs):
|
|
self.parent = parent
|
|
super().__init__(**kwargs)
|
|
|
|
def __getitem__(self, name):
|
|
return self.traces[name]
|
|
|
|
|
|
class Interface:
|
|
_imeta = None
|
|
"""This contains a InterfaceMeta instance related to Interface one.
|
|
|
|
`_imeta` is used to check tests etc.
|
|
"""
|
|
|
|
def __init__(self, _target=None, _funcs=None, _imeta_kw=None, **kwargs):
|
|
if _imeta_kw is None:
|
|
_imeta_kw = {}
|
|
if _funcs is not None:
|
|
_imeta_kw.setdefault("funcs", _funcs)
|
|
if _target is not None:
|
|
_imeta_kw.setdefault("target", _target)
|
|
self._imeta = InterfaceMeta(self, **_imeta_kw)
|
|
self.__dict__.update(kwargs)
|
|
|
|
@property
|
|
def _itarget(self):
|
|
return self._imeta.target
|
|
|
|
@classmethod
|
|
def inject(cls, ns, ns_attr, funcs=None, **kwargs):
|
|
kwargs["_imeta_kw"] = {"ns": ns, "ns_attr": ns_attr, "funcs": funcs}
|
|
return cls(**kwargs)
|
|
|
|
def _irelease(self):
|
|
self._imeta.release()
|
|
|
|
def _trace(self, *args, **kw):
|
|
return self._imeta.get_trace(*args, **kw)
|
|
|
|
def _traces(self, *args, **kw):
|
|
return self._imeta.get_traces(*args, **kw)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
target = self._imeta.target
|
|
if inspect.isclass(target):
|
|
target = target(*args, **kwargs)
|
|
return type(self)(target, _imeta_kw={"funcs": self._imeta.funcs})
|
|
|
|
self._imeta.add("__call__", args, kwargs)
|
|
return self._imeta.target(*args, **kwargs)
|
|
|
|
def __getattr__(self, attr):
|
|
if attr in self._imeta.funcs:
|
|
return lambda *args, **kwargs: self._imeta.call(attr, args, kwargs)
|
|
return getattr(self._imeta.target, attr)
|