Source code for pyepri.checks

"""Generic functions for validity checks"""

[docs] def _check_(test, **kwargs): """Apply generic assert test to all input keywords arguments excepting to `None` occurences. Parameters ---------- test : <class 'function'> generic test that will be applied to all non ``None`` values in `kwargs`, return type must be ``bool``. **kwargs : keywords arguments Input keyword arguments to be tested. Each keyword value must be either `None` or an array_like (this will not be tested). Return ------ ko : tuple A tuple containing the elements ``key in kwargs.keys()`` such that ``kwargs['key'] is not None and test(kwargs['key']) is False``. Note ---- The compatibility between the non ``None`` elements of ``kwargs.values()`` and the provided `test` function is **not verified** by this function. """ #out = {key: test(value) for key, value in args.items() if value is not None} #ok = tuple(key for key, value in out.items() if value) #ko = tuple(key for key, value in out.items() if not value) ko = tuple(key for key, value in kwargs.items() if value is not None and not test(value)) return ko
[docs] def _check_same_dtype_(**kwargs): """Check whether or not all non `None` keyword parameters have same array data type. Parameters ---------- **kwargs : keyword arguments Input keyword arguments to be tested. Each keyword value must be either `None` or an array_like (this will not be tested). Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ args2 = {key: value for key, value in kwargs.items() if value is not None} if len(args2) > 1: dtype = list(args2.values())[0].dtype if not all([a.dtype == dtype for a in args2.values()]): raise RuntimeError( "Parameters %s must have the same data type." % str(list(args2.keys())) ) return True
[docs] def _check_type_(t, **kwargs): """Check whether or not all non `None` keyword parameters are instances of the specified type. Parameters ---------- t : <class 'type'> The specified array data type **kwargs : keyword arguments Input keyword arguments to be tested. Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ ko = _check_(lambda x : isinstance(x, t), **kwargs) if len(ko) > 1: raise RuntimeError("Inconsistent type for parameters %s (expected `%s`)." % (str(ko), t)) elif len(ko) == 1: raise RuntimeError("Inconsistent type for parameter `%s` (expected `%s`)." % (ko[0], t)) return True
[docs] def _check_dtype_(dtype, **kwargs): """Check whether or not all non `None` keyword parameters have the specified array data type. Parameters ---------- dtype : <class 'type'> or <class 'torch.type'> The specified array data type **kwargs : keyword arguments Input keyword arguments to be tested. Each keyword value must be either `None` or an array_like (this will not be tested). Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ ko = _check_(lambda x : dtype == x.dtype, **kwargs) if len(ko) > 1: raise RuntimeError("Inconsistent data type for parameters %s (expected `%s`)." % (str(ko), dtype)) elif len(ko) == 1: raise RuntimeError("Inconsistent data type for parameter `%s` (expected `%s`)." % (ko[0], dtype)) return True
[docs] def _check_ndim_(ndim, **kwargs): """Check whether or not all non `None` keyword parameters have the specified number of array dimensions. Parameters ---------- ndim : int The specified number of array dimensions. **kwargs : keyword arguments Input keyword arguments to be tested. Each keyword value must be either `None` or an array_like (this will not be tested). Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ ko = _check_(lambda x : ndim == x.ndim, **kwargs) if len(ko) > 1: raise RuntimeError( "Inconsistent dimensions for parameters %s (expected `ndim=%d`)." % (str(ko), ndim) ) elif len(ko) == 1: raise RuntimeError( "Inconsistent dimensions for parameter `%s` (expected `ndim=%d`)." % (ko[0], ndim) ) return True
[docs] def _check_backend_(backend, **kwargs): """Check whether or not all non `None` keyword parameters array belongs to the same library (torch, numpy, cupy). Parameters ---------- backend : <class 'pyepri.backends.Backend'> A `numpy`, `cupy` or `torch` backend instance. **kwargs : keyword arguments Input keyword arguments to be tested. Each keyword value must be either `None` or an array_like (this will not be tested). Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ ko = _check_(lambda x : isinstance(x, backend.cls), **kwargs) if len(ko) > 1: raise RuntimeError( "Parameters %s are not consistent with the provided backend. Since\n" "`backend.lib` is `%s`, those parameters must all be %s instances.\n" % (str(ko), backend.lib.__name__, str(backend.cls)) ) elif len(ko) == 1: raise RuntimeError( "Parameter `%s` is not consistent with the provided backend. Since\n" "`backend.lib` is `%s`, `%s` must be a %s instance." % (ko[0], backend.lib.__name__, ko[0], str(backend.cls)) ) return True
[docs] def _check_seq_(t=None, dtype=None, n=None, ndim=None, **kwargs): """Perform consistency checks for sequence kwargs. Parameters ---------- t : <class 'type'>, optional if given, check that each non None item in ``kwargs`` is a sequence of elements with type `t`. dtype : <class 'type'> or <class 'torch.type'>, optional if given, check that each non None item in ``kwargs`` is a sequence of array_like with datatype `dtype`. n : int, optional if given, check that each non None item in ``kwargs`` is a sequence with length `n`. ndim : int, optional if given, each non None item in ``kwargs`` is assumed to be an array_like and we check that those array_like elements have a number of dimensions equal to ``ndim``. Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ for key, seq in kwargs.items(): if seq is not None: # check seq of seq & flatten if not isinstance(seq, (tuple, list)): raise RuntimeError( "Parameter `%s` must be a sequence (= tuple or list)" % key ) # check n (if given) if n is not None and n != len(seq): raise RuntimeError( "The sequence parameter `%s` must contain %d elements (len(%s) == %d)" % (key, n, key, n) ) # check items type (if given) if t is not None and not all([isinstance(item, t) for item in seq]): raise RuntimeError( "All elements in `%s` must have type %s" % (key, t) ) # check items dtype (if given) if dtype is not None and not all([dtype == item.dtype for item in seq]): raise RuntimeError( "All elements in `%s` must have dtype %s" % (key, dtype) ) # check ndim (if given) if ndim is not None and not all([ndim == item.ndim for item in seq]): raise RuntimeError( "All elements in `%s` must have a number of dimensions equal to %d" % ndim ) return True
[docs] def _check_seq_of_seq_(t=None, dtype=None, len0=None, len1=None, len2=None, ndim=None, tlen0=None, **kwargs): """Perform consistency checks for sequence of sequence kwargs. Parameters ---------- t : <class 'type'>, optional if given, check that each non None item in ``kwargs`` is a sequence of sequence(s) of elements with type `t`. dtype : <class 'type'> or <class 'torch.type'>, optional if given, check that each non None item in ``kwargs`` is a sequence of sequence(s) of array_like with datatype `dtype`. len0 : int, optional if given, check that each non None item in ``kwargs`` has length `len0`. len1 : int, optional if given, check that each non None item in ``kwargs`` is a sequence made of sequence(s) with length `len1`. len2 : int, optional if given, check that each non None leaf in ``kwargs`` has length `len2`. tlen0 : sequence of int, optional if given, check that each non None item in ``kwargs`` has length in `tlen0`. ndim : int, optional if given, check that each non None leaf in ``kwargs`` has a number of dimensions equal to `ndim`. Return ------ no_error_flag : bool True if the test is successful (otherwise an exception is raised). """ for key, seq in kwargs.items(): if seq is not None: # check seq of seq & flatten if not isinstance(seq, (tuple, list)) or not all((isinstance(s, (tuple, list)) for s in seq)): raise RuntimeError( "Parameter `%s` must be a sequence of sequence(s)" % key ) # check len0 (if given) if len0 is not None and len0 != len(seq): raise RuntimeError( "The sequence of sequence(s) parameter `%s` must contain %d elements (len(%s) == %d)" % (key, len0, key, len0) ) # check len1 (if given) if len1 is not None and not all((len1 == len(s) for s in seq)): raise RuntimeError( "All elements in `%s` must have length %d" % (key, len1) ) # check len2 (if given) if len2 is not None and not all((len2 == len(leaf) for subseq in seq for leaf in subseq)): raise RuntimeError( "All leaf elements in `%s` must have length %d" % (key, len2) ) # check tlen0 (if given) if tlen0 is not None and len(seq) not in tlen0: raise RuntimeError( "The sequence of sequence(s) parameter `%s` must have length in %s" % (key, tlen0) ) # check ndim (if given) if ndim is not None and not all((ndim == leaf.ndim for subseq in seq for leaf in subseq)): raise RuntimeError( "All leaf elements in `%s` must have a number of dimensions equal to %d" % (key, ndim) ) # check leaves type (if given) if t is not None and not all([isinstance(leaf, t) for subseq in seq for leaf in subseq]): raise RuntimeError( "All leaf elements in `%s` must have type %s" % (key, t) ) # check leaves dtype (if given) if dtype is not None and not all([dtype == leaf.dtype for subseq in seq for leaf in subseq]): raise RuntimeError( "All leaf elements in `%s` must have dtype %s" % (key, dtype) ) return True
[docs] def _max_len_(**kwargs): """Compute max length of all sequence elements in kwargs (return None if no element in ``kwargs`` is a sequence). """ L = [len(value) if isinstance(value, (tuple, list)) else -1 for key, value in kwargs.items()] max_L = max(L) out = None if max_L == -1 else max_L return out
[docs] def _backend_inference_(**kwargs): """Return a backend inferred from a sequence of array_like inputs. """ if len(kwargs) > 0: # retrieve type of the first input #first_input = kwargs[tuple(kwargs.keys())[0]] first_input = kwargs[next(iter(kwargs))] cls = type(first_input) module = cls.__module__ classname = cls.__name__ # check type consistency with other inputs ko = _check_(lambda x : isinstance(x, cls), **kwargs) if len(ko) > 0: raise RuntimeError( "Backend inference failed, parameters %s are not type consistent.\n" "Those parameter must have the same type (numpy.ndarray, " "cupy.ndarray or torch.Tensor)." % (str(tuple(kwargs.keys()))) ) # all inputs must be ndarray or Tensor if classname not in ('ndarray', 'Tensor'): raise RuntimeError( "Backend inference failed, parameters %s must be all\n" "numpy.ndarray or cupy.ndarray or torch.Tensor arrays" % str(tuple(kwargs.keys())) ) # create a numpy or cupy or torch backend instance import pyepri.backends as backends if 'numpy' == module: backend = backends.create_numpy_backend() elif 'cupy' == module: backend = backends.create_cupy_backend() elif 'torch' == module: #device = str(first_input.device).split(':')[0] device = first_input.device.type ko = _check_(lambda x : device == x.device.type, **kwargs) if len(ko) > 0: raise RuntimeError( "Backend inference failed, device inconsistency for parameter(s) %s.\n" % (str(ko)) ) backend = backends.create_torch_backend(device) else: raise RuntimeError( "Backend inference failed, unsupported module for parameter(s) %s." % (str(tuple(kwargs.keys()))) ) else: raise RuntimeError( "Backend inference failed (at least one input must be provided)" ) return backend