"""Utilities for `pandera` validation."""
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NoReturn, Optional, Tuple, Type, Union, cast
import numpy as np
import pandas as pd
import pandera as pa
import pandera.dtypes as pa_dtypes
import pandera.engines.pandas_engine as pd_engine
from packaging.version import Version
import tempor.core.utils
from . import data_typing
_PA_DF_SCHEMA_INIT_PARAMETERS = [
"columns",
"checks",
"index",
"dtype",
"coerce",
"strict",
"name",
"ordered",
"unique",
"report_duplicates",
"unique_column_names",
"title",
"description",
]
_PA_INDEX_INIT_PARAMETERS = [
"dtype",
"checks",
"nullable",
"unique",
"report_duplicates",
"coerce",
"name",
"title",
"description",
]
_PA_MULTI_INDEX_INIT_PARAMETERS = [
"indexes",
"coerce",
"strict",
"name",
"ordered",
"unique",
]
if Version(pa.__version__) < Version("0.14"): # pragma: no cover
# Before v0.14, pandera API had an extra parameter `report_duplicates`.
_PA_MULTI_INDEX_INIT_PARAMETERS.append("report_duplicates")
def _get_pa_init_args(pa_object: Any, param_names: List[str]) -> Dict[str, Any]:
"""A helper method for updating `pandera` objects dynamically.
Get values of items in ``pa_object``'s ``__dict__``, specified by ``param_names`` .
``param_names`` should contain names (`str`) of the ``__init__`` parameters of the ``pa_object``.
Algorithm:
- Will attempt to get by ``param_names`` item.
- If not found, will attempt to get by ``param_names`` item prepended with ``_``.
- If an ``arg_name`` item isn't found, it is ignored.
Args:
pa_object (Any): `pandera` object.
param_names (List[str]): list of `pandera` object's ``__init__`` parameters.
Returns:
Dict[str, Any]: dictionary mapping ``pa_object`` 's ``__init__`` parameter names to their current values.
"""
# Try attributes with matching name:
args = set(param_names)
items = {k: v for k, v in pa_object.__dict__.items() if k in args}
# If any left, try attributes prepended with `_`.
args_left = args - set(items.keys())
_args_left = set([f"_{i}" for i in args_left])
additional_items = {k[1:]: v for k, v in pa_object.__dict__.items() if k in _args_left}
items.update(additional_items)
return items
[docs]def update_schema(schema: pa.DataFrameSchema, **kwargs: Any) -> pa.DataFrameSchema:
"""Update a pandera dataframe schema with ``kwargs``.
Args:
schema (pa.DataFrameSchema): `pandera` dataframe schema.
**kwargs (Any): keyword arguments to update ``schema`` with.
Returns:
pa.DataFrameSchema: `pandera` dataframe schema.
"""
items = _get_pa_init_args(schema, param_names=_PA_DF_SCHEMA_INIT_PARAMETERS)
items.update(kwargs)
return pa.DataFrameSchema(**items)
[docs]def update_index(index: pa.Index, **kwargs: Any) -> pa.Index:
"""Update a pandera index with ``kwargs``.
Args:
index (pa.Index): `pandera` index.
**kwargs (Any): keyword arguments to update ``index`` with.
Returns:
pa.Index: `pandera` index.
"""
items = _get_pa_init_args(index, param_names=_PA_INDEX_INIT_PARAMETERS)
items.update(kwargs)
return pa.Index(**items)
[docs]def update_multiindex(multi_index: pa.MultiIndex, **kwargs: Any) -> pa.MultiIndex:
"""Update a pandera multiindex with ``kwargs``.
Args:
multi_index (pa.MultiIndex): `pandera` multiindex.
**kwargs (Any): keyword arguments to update ``multi_index`` with.
Returns:
pa.MultiIndex: `pandera` multiindex.
"""
items = _get_pa_init_args(multi_index, param_names=_PA_MULTI_INDEX_INIT_PARAMETERS)
items.update(kwargs)
return pa.MultiIndex(**items)
PA_DTYPE_MAP: Dict[data_typing.Dtype, pa.DataType] = {
bool: pa.Bool(), # type: ignore [no-untyped-call]
int: pa.Int(), # type: ignore [no-untyped-call]
float: pa.Float(), # type: ignore [no-untyped-call]
str: pa.String(), # type: ignore [no-untyped-call]
"category": pa.Category(),
"datetime": pa.DateTime(), # type: ignore [no-untyped-call]
}
"""A mapping from dtype specified as :obj:`~tempor.data.data_typing.Dtype` to a `pandera.DataType`.
"""
[docs]def get_pa_dtypes(dtypes: Iterable[Union[data_typing.Dtype, pa.DataType]]) -> List[pa.DataType]:
"""Return a list of `pandera.DataType` corresponding to ``dtypes``. Raises `KeyError` If not found."""
pa_dtypes_ = []
for dt in dtypes:
if isinstance(dt, pa.DataType):
# If item in `dtypes` already an instance of `pandera.DataType`, pass it through.
dt_add = dt
elif hasattr(dt, "__mro__") and issubclass(dt, pa.DataType): # type: ignore
# If item in `dtypes` a `pandera.DataType` class, pass it through as an instance.
dt_add = dt() # type: ignore
else:
try:
dt_add = PA_DTYPE_MAP[dt]
except KeyError as ex:
raise KeyError(f"Mapping from `{dt}` to a pandera DataType not found") from ex
pa_dtypes_.append(dt_add)
# Allow for different bits of Int and Float, rather than just what pandera defaults to.
# Note: especially important on Windows, as pandera ends up with a different bit datatype on Windows than
# on other systems. See: https://github.com/unionai-oss/pandera/issues/726.
if dt_add == pa.Int(): # type: ignore [no-untyped-call]
pa_dtypes_.extend([pa.Int8(), pa.Int16(), pa.Int32(), pa.Int64()]) # type: ignore [no-untyped-call]
if dt_add == pa.Float(): # type: ignore [no-untyped-call]
pa_dtypes_.extend([pa.Float16(), pa.Float32(), pa.Float64()]) # type: ignore [no-untyped-call]
return list(set(pa_dtypes_))
[docs]class UnionDtype(pd_engine.DataType):
"""Extend `pandera` ``DataType`` s with a custom ``UnionDtype``, which will function similarly to ``Union``.
See `pandera` ``DataType`` [guide](https://pandera.readthedocs.io/en/stable/dtypes.html) for details.
In this case, rather than wrapping the extension ``DataType`` with ``register_dtype`` and ``immutable`` decorators,
we apply these directly to the class returned by ``__class_getitem__``, which dynamically creates the union
specified with its dtypes. In this way, `pandera`'s ``pandas`` engine correctly registers each new kind of union
as a different dtype.
"""
union_dtypes: List
"""The list of types in the union."""
type: Any
"""The string representation of the data type, which will be, e.g., shown in exceptions."""
name: str
"""The string representation of the data type used for `repr`."""
@classmethod
def __class_getitem__(cls: Type, item: Iterable[Union[data_typing.Dtype, pa.DataType]]) -> Callable:
"""Allows for setting union types like ``UnionDtype[dtype, ...]``.
Acceptable ``dtype`` s are: `pandera.DataType` (as a class or instance) or the keys of
`~tempor.data.pandera_utils.PA_DTYPE_MAP`.
"""
if not tempor.core.utils.is_iterable(item):
item = [item] # type: ignore [list-item]
union_dtypes = get_pa_dtypes(item)
union_dtypes = sorted(union_dtypes, key=str) # For consistency: `item` can get captured in random order.
repr_union_dtypes = str([str(t) for t in union_dtypes]).replace("'", "")
name = f"{cls.__name__}{repr_union_dtypes}"
cls_ = type(name, (UnionDtype,), dict())
cls_.union_dtypes = union_dtypes # type: ignore
cls_.type = name # type: ignore
cls_.name = name # type: ignore
return pd_engine.Engine.register_dtype(pa_dtypes.immutable(cls_)) # type: ignore
def __repr__(self) -> str:
"""The `repr()` representation of the class.
Returns:
str: The representation.
"""
return self.name
[docs] def check(
self,
pandera_dtype: pa_dtypes.DataType,
data_container: Any = None,
) -> Union[bool, Iterable[bool]]:
"""Checks whether the ``pandera_dtype`` and optionally ``data_container`` satisfy at least one the union's
``union_dtypes``.
Args:
pandera_dtype (pa_dtypes.DataType):
The data type received as part of the check/validation.
data_container (Any):
The data container received as part of the check/validation. Defaults to `None`.
Returns:
Union[bool, Iterable[bool]]:
A `bool` stating whether the data type is satisfied, or an iterable thereof \
(for each item in the ``data_container``).
"""
for union_dtype in self.union_dtypes:
validated = pd_engine.Engine.dtype(union_dtype).check(pandera_dtype, data_container)
if data_container is None:
# Only in case of direct type comparison, we also need to check via the union_dtype.check(pandera_dtype)
# method, making sure that pandera_dtype is an DataType instance not class.
if hasattr(pandera_dtype, "__mro__"): # pragma: no cover
pandera_dtype = pandera_dtype() # type: ignore
validated = validated or union_dtype.check(pandera_dtype)
if tempor.core.utils.is_iterable(validated):
validated = all(validated) # type: ignore
if validated:
if data_container is None:
return True
else:
return np.full_like(data_container, True, dtype=bool)
if data_container is None:
return False
else:
return np.full_like(data_container, False, dtype=bool)
[docs] def coerce(self, data_container: Any) -> NoReturn:
"""The ``coerce`` method is not supported and will throw a `NotImplementedError`."""
raise NotImplementedError(f"`coerce` not supported by {self.__class__.__name__}")
[docs]def init_schema(data: pd.DataFrame, **kwargs: Any) -> pa.DataFrameSchema:
"""Initialize a `pandera.DataFrameSchema` from ``data`` using `pandera.infer_schema`.
Args:
data (pd.DataFrame): Input dataframe.
**kwargs (Any): Keyword arguments to update the schema with after initialization.
Returns:
pa.DataFrameSchema: `pandera.DataFrameSchema` initialized from ``data``.
"""
schema = cast(pa.DataFrameSchema, pa.infer_schema(data))
schema = update_schema(schema, **kwargs)
return schema
[docs]def add_df_checks(schema: pa.DataFrameSchema, *, checks_list: List[pa.Check]) -> pa.DataFrameSchema:
"""Update ``schema`` with `pandera` checks specified in ``checks_list``.
Args:
schema (pa.DataFrameSchema): DataFrameSchema to add checks to.
checks_list (List[pa.Check]): The list of checks.
Returns:
pa.DataFrameSchema: DataFrameSchema with checks added.
"""
schema = update_schema(schema, checks=checks_list)
return schema
[docs]def add_regex_column_checks(
schema: pa.DataFrameSchema,
*,
regex: str = ".*",
dtype: Any,
nullable: bool,
checks_list: Optional[List[pa.Check]] = None,
) -> pa.DataFrameSchema:
"""Update ``schema`` with checks specified in ``checks_list``, applied to all columns specified by ``regex``.
``dtype`` and ``nullable`` can also be specified and will apply to all columns.
"""
schema_out = schema.add_columns(
{
regex: pa.Column(
dtype=dtype,
nullable=nullable,
regex=True,
checks=checks_list,
)
}
)
if TYPE_CHECKING: # pragma: no cover
assert isinstance(schema_out, pa.DataFrameSchema) # nosec B101
return schema_out
[docs]def set_up_index(
schema: pa.DataFrameSchema,
data: pd.DataFrame,
*,
dtype: Any,
name: str,
nullable: bool,
unique: bool,
coerce: bool,
checks_list: Optional[List[pa.Check]] = None,
) -> Tuple[pa.DataFrameSchema, pd.DataFrame]:
"""Update ``schema.index`` (`pandera.Index`) with ``dtype``, ``name``, ``nullable``, ... schema settings.
In addition, set the index name of ``data`` (`pandas.DataFrame`) to ``name``.
Returns the schema and the dataframe.
"""
if schema.index is None:
raise ValueError("Expected DataFrameSchema Index to not be None")
index = update_index(
schema.index,
dtype=dtype,
nullable=nullable,
unique=unique,
name=name,
checks=checks_list,
coerce=coerce,
)
schema = update_schema(schema, index=index)
data.index.set_names(name, inplace=True) # Name the index.
return schema, data
[docs]def set_up_2level_multiindex(
schema: pa.DataFrameSchema,
data: pd.DataFrame,
*,
dtypes: Tuple[Any, Any],
names: Tuple[str, str],
nullable: Tuple[bool, bool],
coerce: bool,
unique: Tuple[str, ...],
checks_list: Optional[Tuple[List[pa.Check], List[pa.Check]]] = None,
) -> Tuple[pa.DataFrameSchema, pd.DataFrame]:
"""Update ``schema.index`` (`pandera.MultiIndex`), which is expected to have 2 levels, with `dtypes```, ``names``,
``nullable``, ... schema settings.
In addition, set the index name of ``data`` (`pandas.DataFrame`) to ``name``.
Returns the schema and the dataframe.
"""
if schema.index is None:
raise ValueError("Expected DataFrameSchema Index to not be None")
if not isinstance(schema.index, pa.MultiIndex):
raise ValueError("Expected DataFrameSchema Index to not be MultiIndex")
if len(schema.index.indexes) != 2:
raise ValueError("Expected DataFrameSchema Index to have 2 levels")
index_0 = update_index(
schema.index.indexes[0],
dtype=dtypes[0],
name=names[0],
nullable=nullable[0],
coerce=coerce,
checks=checks_list[0] if checks_list is not None else None,
)
index_1 = update_index(
schema.index.indexes[1],
dtype=dtypes[1],
name=names[1],
coerce=coerce,
checks=checks_list[1] if checks_list is not None else None,
)
index = update_multiindex(schema.index, indexes=[index_0, index_1], unique=unique)
schema = update_schema(schema, index=index)
data.index.set_names(names, inplace=True) # Name the index.
return schema, data
[docs]class checks:
"""Namespace containing reusable `pandera.Check` s."""
forbid_multiindex_index = pa.Check(
lambda df: df.index.nlevels == 1,
error="MultiIndex Index not allowed",
)
forbid_multiindex_columns = pa.Check(
lambda df: df.columns.nlevels == 1,
error="MultiIndex Columns not allowed",
)
require_2level_multiindex_index = pa.Check(
lambda df: df.index.nlevels == 2,
error="Index must be a MultiIndex with 2 levels",
)
require_element_len_2 = pa.Check(
lambda x: len(x) == 2,
element_wise=True,
error="Each item must contain a sequence of length 2",
)
# require_2level_multiindex_one_to_one = pa.Check(
# lambda df: (df.groupby(level=0).size() == 1).all(),
# error="MultiIndex Index must one-to-one correspondence for between the two levels",
# )
[docs] class configurable:
"""Namespace containing functions to get configurable `pandera.Check` s."""
[docs] @staticmethod
def column_index_satisfies_dtype(dtype: Any, *, nullable: bool) -> pa.Check:
"""Return a `pandera.Check` that checks that the column index satisfies ``dtype``. Optionally, also set
the ``nullable`` attribute.
Args:
dtype (Any): The dtype to check against.
nullable (bool): The nullable attribute to set.
Returns:
pa.Check: The `pandera.Check` defined.
"""
series_name = "Column Index"
error = str(f"DataFrame {series_name} dtype validation failed, must be of type: {dtype}")
def _check(df: pd.DataFrame) -> bool:
pa.SeriesSchema(
dtype,
name=series_name,
nullable=nullable,
coerce=False,
).validate(pd.Series(df.columns, name=series_name))
return True
return pa.Check(_check, error=error)