Source code for cellmil.utils.train.dataset

import pandas as pd
from typing import cast
from sklearn.model_selection import train_test_split  # type: ignore
from cellmil.utils import logger


[docs]def split_dataset( df: pd.DataFrame, train_size: float = 0.8, test: bool = False, stratify: bool = True, label: str = "label", random_state: int = 42, ): assert train_size <= 1.0 and train_size > 0.0, "train_size must be between 0 and 1" val_size = 1 - train_size if not test else (1 - train_size) / 2 test_size = 1 - train_size - val_size df["SPLIT"] = "train" if test: _df, test_df = cast( tuple[pd.DataFrame, pd.DataFrame], train_test_split( df, test_size=test_size, random_state=random_state, stratify=df[label] if stratify and label in df.columns else None, ), ) df.loc[test_df.index, "SPLIT"] = "test" logger.info(f"Test size: {len(test_df) if test else 0}") else: _df = df train_df, val_df = cast( tuple[pd.DataFrame, pd.DataFrame], train_test_split( _df, test_size=val_size, random_state=random_state, stratify=_df[label] if stratify and label in _df.columns else None, ), ) df.loc[val_df.index, "SPLIT"] = "val" logger.info(f"Validation size: {len(val_df)}") logger.info(f"Train size: {len(train_df)}") return df
[docs]def complementary_frequencies(df: pd.DataFrame, label: str = "label") -> dict[int, float]: label_counts = df[label].value_counts().sort_index() # type: ignore total_samples = len(df) complementary_freqs: dict[int, float] = {} for label_val in label_counts.index: count = cast(int, label_counts[label_val]) complementary_freqs[label_val] = (total_samples - count) / total_samples return complementary_freqs