Source code for tempor.datasources.time_to_event.plugin_pbc

"""Module containing the PBC data source."""

import io
import os
from typing import Any

import numpy as np
import pandas as pd
import requests
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder, StandardScaler

from tempor.core import plugins
from tempor.data import dataset, utils
from tempor.datasources import datasource


[docs]@plugins.register_plugin(name="pbc", category="time_to_event", plugin_type="datasource") class PBCDataSource(datasource.TimeToEventAnalysisDataSource): def __init__(self, **kwargs: Any) -> None: """PBC data source. The data is from the PBC2 dataset, as found here: https://search.r-project.org/CRAN/refmans/DynForest/html/pbc2.html Actual data obtained from: https://raw.githubusercontent.com/autonlab/auton-survival/cf583e598ec9ab92fa5d510a0ca72d46dfe0706f/dsm/datasets/pbc2.csv Temporal features are: - ``"drug"`` - ``"ascites"`` - ``"hepatomegaly"`` - ``"spiders"`` - ``"edema"`` - ``"histologic"`` - ``"serBilir"`` - ``"serChol"`` - ``"albumin"`` - ``"alkaline"`` - ``"SGOT"`` - ``"platelets"`` - ``"prothrombin"`` - ``"age"`` Static features are: - ``"sex"`` The target is: - ``"status2"`` as found in the ``autonlab/auton-survival`` version. Args: **kwargs (Any): Keyword arguments to be passed to the parent class. """ self.datafile_path = os.path.join(self.dataset_dir(), "pbc2.csv") super().__init__(**kwargs)
[docs] @staticmethod def dataset_dir() -> str: # noqa: D102 return os.path.join(PBCDataSource.data_root_dir, "pbc/")
[docs] @staticmethod def url() -> str: # noqa: D102 return ( "https://raw.githubusercontent.com/autonlab/auton-survival/" "cf583e598ec9ab92fa5d510a0ca72d46dfe0706f/dsm/datasets/pbc2.csv" )
[docs] def load(self, **kwargs: Any) -> dataset.TimeToEventAnalysisDataset: # noqa: D102 if os.path.exists(self.datafile_path): data = pd.read_csv(self.datafile_path) else: request = requests.get(self.url(), timeout=5).content data = pd.read_csv(io.StringIO(request.decode("utf-8"))) data.to_csv(self.datafile_path, index=False) data["time"] = data["years"] - data["year"] data = data.sort_values(by=["id", "time"], ignore_index=True) data["histologic"] = data["histologic"].astype(str) dat_cat = data[["drug", "sex", "ascites", "hepatomegaly", "spiders", "edema", "histologic"]].copy() dat_num = data[["serBilir", "serChol", "albumin", "alkaline", "SGOT", "platelets", "prothrombin"]].copy() age = data["age"] + data["years"] for col in dat_cat.columns: dat_cat[col] = LabelEncoder().fit_transform(dat_cat[col]) x = pd.concat([dat_cat, dat_num, pd.Series(age, name="age")], axis=1) time = data["time"] event = data["status2"] x = pd.DataFrame( SimpleImputer(missing_values=np.nan, strategy="mean").fit_transform(x), columns=x.columns, ) scaled_cols = list(dat_num.columns) + ["age"] x_scaled = x.copy() x_scaled[scaled_cols] = pd.DataFrame( StandardScaler().fit_transform(x[scaled_cols]), columns=scaled_cols, index=data.index, ) x_static, x_temporal, t, e = [], [], [], [] time_indexes, event_feat_full = [], [] temporal_cols = [ "drug", "ascites", "hepatomegaly", "spiders", "edema", "histologic", "serBilir", "serChol", "albumin", "alkaline", "SGOT", "platelets", "prothrombin", "age", ] static_cols = ["sex"] sample_index = [] for id_ in sorted(list(set(data["id"]))): sample_index.append(id_) patient = x_scaled[data["id"] == id_] patient_static = patient[static_cols] if not (patient_static.iloc[0] == patient_static).all().all(): # pragma: no cover # This is a sanity check. raise RuntimeError( f"Found patient with static data that are not actually fixed:\nid: {id_}\n{patient_static}" ) x_static.append(patient_static.values[0].tolist()) patient_temporal = patient[temporal_cols] patient_temporal.index = time[data["id"] == id_].values x_temporal.append(patient_temporal) events = event[data["id"] == id_].values times = time[data["id"] == id_].values evt = np.amax(events) # pyright: ignore if evt == 0: pos = np.max(np.where(events == evt)) # Last censored else: pos = np.min(np.where(events == evt)) # First event t.append(times[pos]) e.append({0: False, 1: True}[evt]) time_indexes.append(list(times)) event_feat_full.append(events) df_static = pd.DataFrame(x_static, columns=static_cols, index=sample_index) df_time_series = utils.list_of_dataframes_to_multiindex_timeseries_dataframe( x_temporal, sample_index=sample_index, time_indexes=time_indexes, feature_index=temporal_cols, ) df_event = utils.event_time_value_pairs_to_event_dataframe( [(t, e)], sample_index=sample_index, feature_index=["status"] ) return dataset.TimeToEventAnalysisDataset( time_series=df_time_series, static=df_static, targets=df_event, )