Source code for tempor.models.clairvoyance2.datasets.dataset_retriever
# mypy: ignore-errors
import os
from abc import ABC, abstractmethod
from typing import Optional, Sequence, Tuple, TypeVar
from ..data import Dataset
from .download import download_file
TUrl = TypeVar("TUrl", bound=str)
TDatasetFileDef = Tuple[TUrl, str] # ("URL", "local_file_name")
DATASET_ROOT_DIR = os.path.join(os.path.expanduser("~"), ".clairvoyance/datasets/")
# TODO: Unit test.
[docs]class DatasetRetriever(ABC):
dataset_subdir: str
dataset_files: Optional[Sequence[TDatasetFileDef]]
cache_subdir: str = "cache"
@property
def dataset_dir(self) -> str:
return os.path.join(self.dataset_root_dir, self.dataset_subdir)
@property
def dataset_cache_dir(self) -> str:
return os.path.join(self.dataset_dir, self.cache_subdir)
def __init__(self, data_home: Optional[str] = None) -> None:
if data_home is None:
self.dataset_root_dir = DATASET_ROOT_DIR
else:
self.dataset_root_dir = data_home
os.makedirs(self.dataset_dir, exist_ok=True)
os.makedirs(self.dataset_cache_dir, exist_ok=True)
[docs] def download_dataset(self) -> None:
if self.dataset_files is not None:
for dataset_file in self.dataset_files:
url, file_name = dataset_file
download_file(url, os.path.join(self.dataset_dir, file_name))
[docs] @abstractmethod
def is_cached(self) -> bool:
# Check if the dataset has been cached.
...
[docs] @abstractmethod
def get_cache(self) -> Dataset:
# Retrieve dataset from cache.
...
[docs] @abstractmethod
def cache(self, data: Dataset) -> None:
# Cache the dataset for faster opening.
...
[docs] @abstractmethod
def prepare(self) -> Dataset:
# Prepare the dataset and return it.
...
[docs] def retrieve(self, refresh_cache: bool = False, redownload: bool = False) -> Dataset:
# Download dataset files (if required).
if self.dataset_files is not None:
if (
any([not os.path.exists(os.path.join(self.dataset_dir, f)) for _, f in self.dataset_files])
or redownload
):
self.download_dataset()
# Prepare and retrieve dataset.
if self.is_cached() and not refresh_cache:
return self.get_cache()
else:
data = self.prepare()
self.cache(data)
return self.get_cache()