cellmil.datamodels.datasets.celltype_dataset

CellTypeDataset: A PyTorch Dataset for multi-class cell type classification.

This dataset treats each individual cell as a sample, creating a cell-level dataset where each item is a single cell with its features and one-hot encoded cell type label.

Supports label smoothing for specific cell types to handle annotation uncertainty.

Classes

CellTypeDataset(root, folder, data, extractor)

A PyTorch Dataset for multi-class cell type classification.

class cellmil.datamodels.datasets.celltype_dataset.CellTypeDataset(root: Union[str, Path], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, cell_types_to_keep: Optional[List[str]] = None, label_smoothing: Optional[Union[float, Dict[str, float]]] = None, max_workers: int = 8, force_reload: bool = False)[source]

Bases: Dataset[Tuple[Tensor, Tensor]]

A PyTorch Dataset for multi-class cell type classification.

This dataset treats each individual cell as a sample, creating a cell-level dataset where each item is a single cell with its features and one-hot encoded cell type label.

The labels are one-hot encoded for the 5 cell types: - Type 1: Neoplastic - Type 2: Inflammatory - Type 3: Connective - Type 4: Dead - Type 5: Epithelial

Supports label smoothing for specific cell types to handle annotation uncertainty.

Returns:

(features, label) where:
  • features is a tensor of shape (n_features,) for a single cell

  • label is a one-hot encoded tensor of shape (n_cell_types,) with optional label smoothing

Return type:

Tuple[torch.Tensor, torch.Tensor]

__init__(root: Union[str, Path], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, cell_types_to_keep: Optional[List[str]] = None, label_smoothing: Optional[Union[float, Dict[str, float]]] = None, max_workers: int = 8, force_reload: bool = False)[source]

Initialize the CellTypeDataset.

Parameters:
  • root – Root directory where the processed dataset will be cached

  • folder – Path to the dataset folder containing slide data

  • data – DataFrame containing metadata with at least ‘FULL_PATH’ and ‘SPLIT’ columns

  • extractor – Feature extractor type or list of types to use for feature extraction

  • graph_creator – Optional graph creator type, needed for some extractors

  • segmentation_model – Segmentation model type (‘cellvit’ or ‘hovernet’), required for cell type info

  • split – Dataset split (‘train’, ‘val’, ‘test’, or ‘all’)

  • transforms – Optional TransformPipeline or Transform to apply to features at getitem time

  • cell_types_to_keep – Optional list of cell type names to keep (e.g., [“Neoplastic”, “Connective”]). Valid names: “Neoplastic”, “Inflammatory”, “Connective”, “Dead”, “Epithelial” (case-insensitive). If provided, only cells of these types will be included.

  • label_smoothing

    Label smoothing configuration. Can be: - None or 0.0: No smoothing applied - float (0.0 to 1.0): Same smoothing value applied to all cell types - Dict[str, float]: Custom smoothing value for each cell type, e.g.,

    {“Neoplastic”: 0.0, “Inflammatory”: 0.1, “Dead”: 0.2, “Epithelial”: 0.15}

    Cell types not in the dict will have no smoothing applied.

  • max_workers – Maximum number of threads for parallel processing

  • force_reload – Whether to force reprocessing even if processed files exist

_process_label_smoothing(label_smoothing: Optional[Union[float, Dict[str, float]]]) Dict[int, float][source]

Process label smoothing configuration into a per-type dictionary.

Parameters:

label_smoothing – Either None, a float, or a dict mapping cell type names to smoothing values

Returns:

Dictionary mapping cell type indices (1-based) to smoothing values

_cell_type_names_to_indices(cell_type_names: Optional[List[str]]) Optional[List[int]][source]

Convert cell type names to their corresponding indices.

Parameters:

cell_type_names – List of cell type names (case-insensitive)

Returns:

List of cell type indices (1-based as in TYPE_NUCLEI_DICT), or None if input is None

_get_data_path() Path[source]

Get the path for the processed dataset file.

_process_dataset() None[source]

Process the entire dataset and cache results.

_process_slide(row: Series, start_global_idx: int) Optional[str][source]

Process a single slide and extract cell-level features and labels.

Parameters:
  • row – A pandas Series representing a row from the DataFrame

  • start_global_idx – Starting global index for cells in this slide

Returns:

slide_name if successful, None otherwise

_create_label(cell_type: int) Tensor[source]

Create a one-hot encoded label with optional label smoothing.

Parameters:

cell_type – Cell type index (1-based, 1-5)

Returns:

One-hot encoded tensor of shape (NUM_CELL_TYPES,) with optional smoothing

_save_data(path: Path) None[source]

Save processed data to disk (slides is derived, not saved).

_load_data(path: Path) None[source]

Load processed data from disk and derive slides list.

_log_class_distribution() None[source]

Log the distribution of cell types in the dataset.

get_class_distribution() Dict[str, int][source]

Get the distribution of cell types in the dataset.

Returns:

Dictionary with cell type names as keys and counts as values

get_num_classes() int[source]

Get the number of cell type classes.

Returns:

Number of cell type classes (5 for all types, or fewer if filtered)

get_weights_for_sampler() Tensor[source]

Compute weights for WeightedRandomSampler to handle class imbalance.

Returns:

Weights for each cell in the dataset, shape (len(dataset),)

Return type:

torch.Tensor

__len__() int[source]

Return the number of cells in the dataset.

__getitem__(idx: int) Tuple[Tensor, Tensor][source]

Get a sample from the dataset.

Parameters:

idx – Index of the cell to retrieve

Returns:

  • features is a tensor of shape (n_features,) for a single cell

  • label is a one-hot encoded tensor of shape (n_cell_types,)

Return type:

Tuple of (features, label) where

create_subset(indices: List[int]) CellTypeDataset[source]

Create a subset of the dataset using the specified indices.

Parameters:

indices – List of cell indices to include in the subset

Returns:

New CellTypeDataset instance containing only the specified cells

create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None) Tuple[CellTypeDataset, CellTypeDataset][source]

Create train and validation datasets from specified indices.

Parameters:
  • train_indices – List of cell indices for training set

  • val_indices – List of cell indices for validation set

  • transforms – Optional pre-fitted transforms to apply to both datasets. These should already be fitted on training data before calling this method.

Returns:

Tuple of (train_dataset, val_dataset) with the provided transforms

create_train_val_datasets_by_slides(train_slides: List[str], val_slides: List[str], transforms: Optional[Union[TransformPipeline, Transform]] = None) Tuple[CellTypeDataset, CellTypeDataset][source]

Create train and validation datasets based on slide names.

This method is useful when you want to split the CellTypeDataset based on the same slide-level split used for another dataset (e.g., MILDataset). All cells from slides in train_slides go to training, and all cells from slides in val_slides go to validation.

Parameters:
  • train_slides – List of slide names for training set

  • val_slides – List of slide names for validation set

  • transforms – Optional pre-fitted transforms to apply to both datasets. These should already be fitted on training data before calling this method.

Returns:

Tuple of (train_dataset, val_dataset) with the provided transforms