"""This module provide test utilities.""" from collections import namedtuple import inspect __all__ = ("interface", "Interface") def interface(obj, funcs): """Override provided object's functions using dict of funcs, as ``{func_name: return_value}``. Attribute ``obj.calls`` is a dict with all call done using those methods, as ``{func_name: (args, kwargs) | list[(args, kwargs]]}``. """ if not isinstance(getattr(obj, "calls", None), dict): obj.calls = {} for attr, value in funcs.items(): interface_wrap(obj, attr, value) def interface_wrap(obj, attr, value): obj.calls[attr] = None def wrapper(*a, **kw): call = obj.calls.get(attr) if call is None: obj.calls[attr] = (a, kw) elif isinstance(call, tuple): obj.calls[attr] = [call, (a, kw)] else: call.append((a, kw)) return value setattr(obj, attr, wrapper) return wrapper InterfaceTarget = namedtuple( "InterfaceTarget", ["target", "namespace", "key"], defaults=[("namespace", None), ("key", None)], ) class WrapperMixin: def __init__(self, target=None, ns=None, ns_attr=None, **kwargs): self.target = target if ns: self.inject(ns, ns_attr) super().__init__(**kwargs) @property def ns_target(self): if self.ns and self.ns_attr: return getattr(self.ns, self.ns_attr, None) return None def inject(self, ns=None, ns_attr=None): if ns and ns_attr: ns_target = getattr(ns, ns_attr, None) if self.target is ns_target: return elif self.target is not None: raise RuntimeError( "self target already injected. It must be " "`release` before `inject`." ) self.target = ns_target setattr(ns, ns_attr, self.parent) elif not ns or not ns_attr: raise ValueError("ns and ns_attr must be provided together") self.ns = ns self.ns_attr = ns_attr def release(self): if self.ns_target is self: setattr(self.target.namespace, self.target.name, self.target) self.target = None class SpoofMixin: traces = None def __init__(self, funcs=None, **kwargs): self.reset(funcs or {}) super().__init__(**kwargs) def reset(self, funcs=None): self.traces = {} if funcs is not None: self.funcs = funcs def get_trace(self, name, args=False, kw=False): """Get a function call parameters. :param str name: function name :param bool args: return positional arguments :param bool|str kwargs: return named arguments. If a string, get the \ named argument at this key. :returns either a tuple of args, a dict of kwargs, or a tuple \ of `(args, kwargs)`. :raises ValueError: the function has been called multiple time. """ trace = self.traces[name] if isinstance(trace, list): raise ValueError(f"{name} called multiple times.") return self._get_trace(trace, args=args, kw=kw) def get_traces(self, name, args=False, kw=False): """Get a tuple of all call parameters. Parameters are the same as `get()`. """ traces = self.traces[name] if not isinstance(traces, list): traces = (traces,) return tuple( self._get_trace(trace, args=args, kw=kw) for trace in traces ) def _get_trace(self, trace, args=False, kw=False): if (args and kw) or (not args and not kw): return trace elif args: return trace[0] elif isinstance(kw, str): return trace[1][kw] return trace[1] def call(self, name, args, kw): """Add call for function of provided name, and return predefined result.""" self.add(name, args, kw) return self.get_result(name, args, kw) def add(self, name, args, kw): """Add call parameters to `self.traces` for the function with the provided `name`.""" trace = self.traces.get(name) if trace is None: self.traces[name] = (args, kw) elif isinstance(trace, tuple): self.traces[name] = [trace, (args, kw)] else: trace.append((args, kw)) def get_result(self, name, a, kw): """Get result for the function of the provided `name`. :raises KeyError: no registered function with this `name`. """ func = self.funcs[name] if callable(func): return func(*a, **kw) return func class InterfaceMeta(SpoofMixin, WrapperMixin): calls = None """Calls done.""" def __init__(self, parent, **kwargs): self.parent = parent super().__init__(**kwargs) def __getitem__(self, name): return self.traces[name] class Interface: _imeta = None """This contains a InterfaceMeta instance related to Interface one. `_imeta` is used to check tests etc. """ def __init__(self, _target=None, _funcs=None, _imeta_kw=None, **kwargs): if _imeta_kw is None: _imeta_kw = {} if _funcs is not None: _imeta_kw.setdefault("funcs", _funcs) if _target is not None: _imeta_kw.setdefault("target", _target) self._imeta = InterfaceMeta(self, **_imeta_kw) self.__dict__.update(kwargs) @property def _itarget(self): return self._imeta.target @classmethod def inject(cls, ns, ns_attr, funcs=None, **kwargs): kwargs["_imeta_kw"] = {"ns": ns, "ns_attr": ns_attr, "funcs": funcs} return cls(**kwargs) def _irelease(self): self._imeta.release() def _trace(self, *args, **kw): return self._imeta.get_trace(*args, **kw) def _traces(self, *args, **kw): return self._imeta.get_traces(*args, **kw) def __call__(self, *args, **kwargs): target = self._imeta.target if inspect.isclass(target): target = target(*args, **kwargs) return type(self)(target, _imeta_kw={"funcs": self._imeta.funcs}) self._imeta.add("__call__", args, kwargs) return self._imeta.target(*args, **kwargs) def __getattr__(self, attr): if attr in self._imeta.funcs: return lambda *args, **kwargs: self._imeta.call(attr, args, kwargs) return getattr(self._imeta.target, attr)