Source code for tempor.metrics.metric

"""Module containing the base class for metrics."""

# pylint: disable=unnecessary-ellipsis

import abc
from typing import Any, Generator, List

import numpy as np
import rich.pretty

import tempor.core.utils
from tempor.core import plugins
from tempor.data import data_typing

from . import metric_typing


[docs]class Metric(plugins.Plugin, abc.ABC): """Metric abstract base class, defines the required methods.""" @property @abc.abstractmethod def direction(self) -> metric_typing.MetricDirection: # pragma: no cover """The direction of the metric""" ...
[docs] def evaluate(self, actual: Any, predicted: Any, *args: Any, **kwargs: Any) -> Any: # pragma: no cover """The metric evaluation call. Args: actual (Any): Actual value(s). predicted (Any): Predicted value(s). *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. Returns: Any: Evaluated metric value(s) """ self._validate(actual, predicted) return self._evaluate(actual, predicted, *args, **kwargs)
def __call__(self, *args: Any, **kwargs: Any) -> Any: """A convenience method to call `evaluate` directly.""" return self.evaluate(*args, **kwargs) @abc.abstractmethod def _evaluate(self, actual: Any, predicted: Any, *args: Any, **kwargs: Any) -> Any: # pragma: no cover """The metric evaluation call *implementation* to be overridden in derived classes. Args: actual (Any): Actual value(s). predicted (Any): Predicted value(s). *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. Returns: Any: Evaluated metric value(s). """ ... def _validate(self, actual: Any, predicted: Any) -> None: """Input validation. Can be overridden in derived classes, but a call to `super()._validate(...)` should be preserved. Args: actual (Any): Actual value(s). predicted (Any): Predicted value(s). """ if actual is None or predicted is None: raise ValueError("The actual values and the predicted values must not be `None`.") def __rich_repr__(self) -> Generator: """A `rich` representation of the class. The ``description`` field is auto-generated from the class and init docstrings. Yields: Generator: The fields and their values fed to `rich`. """ yield "name", self.name yield "description", tempor.core.utils.make_description_from_doc(self) def __repr__(self) -> str: """The `repr()` representation of the class. Returns: str: The representation. """ return rich.pretty.pretty_repr(self)
# TODO: Multi-feature cases. # TODO: Typing of arguments may change. # TODO: Update the abstract methods for each case properly.
[docs]class OneOffPredictionMetric(Metric): """Metric abstract base class for the one-off prediction task.""" # Overridden for type hinting. @abc.abstractmethod def _evaluate( self, actual: np.ndarray, predicted: np.ndarray, *args: Any, **kwargs: Any ) -> float: # pragma: no cover # noqa: D102 ...
# Overrides for type hinting and docstrings.
[docs]class OneOffClassificationMetric(OneOffPredictionMetric): """Metric abstract base class for the one-off prediction task, classification case."""
[docs] def evaluate(self, actual: np.ndarray, predicted: np.ndarray, *args: Any, **kwargs: Any) -> float: """The metric evaluation call. ``actual`` and ``predicted`` are expected to be numpy arrays (sample in the 0th dimension). ``predicted`` must be the probabilities in this case. Args: actual (np.ndarray): Actual value(s). predicted (np.ndarray): Predicted value(s). *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. Returns: float: Evaluated metric value. """ return super().evaluate(actual, predicted, *args, **kwargs)
@abc.abstractmethod def _evaluate( self, actual: np.ndarray, predicted: np.ndarray, *args: Any, **kwargs: Any ) -> float: # pragma: no cover # noqa: D102 ...
plugins.register_plugin_category("prediction.one_off.classification", OneOffClassificationMetric, plugin_type="metric") # Overrides for type hinting and docstrings.
[docs]class OneOffRegressionMetric(OneOffPredictionMetric): """Metric abstract base class for the one-off prediction task, regression case."""
[docs] def evaluate(self, actual: np.ndarray, predicted: np.ndarray, *args: Any, **kwargs: Any) -> float: """The metric evaluation call. ``actual`` and ``predicted`` are expected to be numpy arrays (sample in the 0th dimension). Args: actual (np.ndarray): Actual value(s). predicted (np.ndarray): Predicted value(s). *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. Returns: float: Evaluated metric value. """ return super().evaluate(actual, predicted, *args, **kwargs)
@abc.abstractmethod def _evaluate( self, actual: np.ndarray, predicted: np.ndarray, *args: Any, **kwargs: Any ) -> float: # pragma: no cover # noqa: D102 ...
plugins.register_plugin_category("prediction.one_off.regression", OneOffRegressionMetric, plugin_type="metric") # Overrides for type hinting and docstrings.
[docs]class TimeToEventMetric(Metric): """Metric abstract base class for the time-to-event (survival) task."""
[docs] def evaluate( # pylint: disable=arguments-differ self, actual: metric_typing.EventArrayTimeArray, predicted: np.ndarray, horizons: data_typing.TimeIndex, *args: Any, **kwargs: Any, ) -> List[float]: """The metric evaluation call. Args: actual (metric_typing.EventArrayTimeArray): A tuple of two numpy arrays: the event values array and the event times array, for the actual event vales. predicted (np.ndarray): A numpy array of shape ``(n_samples, n_horizons_timesteps, n_features)`` with the predicted risk estimates. horizons (data_typing.TimeIndex): List of horizons time points. *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. Returns: List[float]: The metric values for each horizon time point. """ return super().evaluate(actual, predicted, horizons, *args, **kwargs)
@abc.abstractmethod def _evaluate( # pylint: disable=arguments-differ self, actual: metric_typing.EventArrayTimeArray, predicted: np.ndarray, horizons: data_typing.TimeIndex, *args: Any, **kwargs: Any, ) -> List[float]: # pragma: no cover # noqa: D102 ...
plugins.register_plugin_category("time_to_event", TimeToEventMetric, plugin_type="metric")