cellmil.datamodels.datasets.celltype_dataset¶
CellTypeDataset: A PyTorch Dataset for multi-class cell type classification.
This dataset treats each individual cell as a sample, creating a cell-level dataset where each item is a single cell with its features and one-hot encoded cell type label.
Supports label smoothing for specific cell types to handle annotation uncertainty.
Classes
|
A PyTorch Dataset for multi-class cell type classification. |
- class cellmil.datamodels.datasets.celltype_dataset.CellTypeDataset(root: Union[str, Path], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, cell_types_to_keep: Optional[List[str]] = None, label_smoothing: Optional[Union[float, Dict[str, float]]] = None, max_workers: int = 8, force_reload: bool = False)[source]¶
Bases:
Dataset[Tuple[Tensor,Tensor]]A PyTorch Dataset for multi-class cell type classification.
This dataset treats each individual cell as a sample, creating a cell-level dataset where each item is a single cell with its features and one-hot encoded cell type label.
The labels are one-hot encoded for the 5 cell types: - Type 1: Neoplastic - Type 2: Inflammatory - Type 3: Connective - Type 4: Dead - Type 5: Epithelial
Supports label smoothing for specific cell types to handle annotation uncertainty.
- Returns:
- (features, label) where:
features is a tensor of shape (n_features,) for a single cell
label is a one-hot encoded tensor of shape (n_cell_types,) with optional label smoothing
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- __init__(root: Union[str, Path], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, cell_types_to_keep: Optional[List[str]] = None, label_smoothing: Optional[Union[float, Dict[str, float]]] = None, max_workers: int = 8, force_reload: bool = False)[source]¶
Initialize the CellTypeDataset.
- Parameters:
root – Root directory where the processed dataset will be cached
folder – Path to the dataset folder containing slide data
data – DataFrame containing metadata with at least ‘FULL_PATH’ and ‘SPLIT’ columns
extractor – Feature extractor type or list of types to use for feature extraction
graph_creator – Optional graph creator type, needed for some extractors
segmentation_model – Segmentation model type (‘cellvit’ or ‘hovernet’), required for cell type info
split – Dataset split (‘train’, ‘val’, ‘test’, or ‘all’)
transforms – Optional TransformPipeline or Transform to apply to features at getitem time
cell_types_to_keep – Optional list of cell type names to keep (e.g., [“Neoplastic”, “Connective”]). Valid names: “Neoplastic”, “Inflammatory”, “Connective”, “Dead”, “Epithelial” (case-insensitive). If provided, only cells of these types will be included.
label_smoothing –
Label smoothing configuration. Can be: - None or 0.0: No smoothing applied - float (0.0 to 1.0): Same smoothing value applied to all cell types - Dict[str, float]: Custom smoothing value for each cell type, e.g.,
{“Neoplastic”: 0.0, “Inflammatory”: 0.1, “Dead”: 0.2, “Epithelial”: 0.15}
Cell types not in the dict will have no smoothing applied.
max_workers – Maximum number of threads for parallel processing
force_reload – Whether to force reprocessing even if processed files exist
- _process_label_smoothing(label_smoothing: Optional[Union[float, Dict[str, float]]]) Dict[int, float][source]¶
Process label smoothing configuration into a per-type dictionary.
- Parameters:
label_smoothing – Either None, a float, or a dict mapping cell type names to smoothing values
- Returns:
Dictionary mapping cell type indices (1-based) to smoothing values
- _cell_type_names_to_indices(cell_type_names: Optional[List[str]]) Optional[List[int]][source]¶
Convert cell type names to their corresponding indices.
- Parameters:
cell_type_names – List of cell type names (case-insensitive)
- Returns:
List of cell type indices (1-based as in TYPE_NUCLEI_DICT), or None if input is None
- _process_slide(row: Series, start_global_idx: int) Optional[str][source]¶
Process a single slide and extract cell-level features and labels.
- Parameters:
row – A pandas Series representing a row from the DataFrame
start_global_idx – Starting global index for cells in this slide
- Returns:
slide_name if successful, None otherwise
- _create_label(cell_type: int) Tensor[source]¶
Create a one-hot encoded label with optional label smoothing.
- Parameters:
cell_type – Cell type index (1-based, 1-5)
- Returns:
One-hot encoded tensor of shape (NUM_CELL_TYPES,) with optional smoothing
- get_class_distribution() Dict[str, int][source]¶
Get the distribution of cell types in the dataset.
- Returns:
Dictionary with cell type names as keys and counts as values
- get_num_classes() int[source]¶
Get the number of cell type classes.
- Returns:
Number of cell type classes (5 for all types, or fewer if filtered)
- get_weights_for_sampler() Tensor[source]¶
Compute weights for WeightedRandomSampler to handle class imbalance.
- Returns:
Weights for each cell in the dataset, shape (len(dataset),)
- Return type:
- __getitem__(idx: int) Tuple[Tensor, Tensor][source]¶
Get a sample from the dataset.
- Parameters:
idx – Index of the cell to retrieve
- Returns:
features is a tensor of shape (n_features,) for a single cell
label is a one-hot encoded tensor of shape (n_cell_types,)
- Return type:
Tuple of (features, label) where
- create_subset(indices: List[int]) CellTypeDataset[source]¶
Create a subset of the dataset using the specified indices.
- Parameters:
indices – List of cell indices to include in the subset
- Returns:
New CellTypeDataset instance containing only the specified cells
- create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None) Tuple[CellTypeDataset, CellTypeDataset][source]¶
Create train and validation datasets from specified indices.
- Parameters:
train_indices – List of cell indices for training set
val_indices – List of cell indices for validation set
transforms – Optional pre-fitted transforms to apply to both datasets. These should already be fitted on training data before calling this method.
- Returns:
Tuple of (train_dataset, val_dataset) with the provided transforms
- create_train_val_datasets_by_slides(train_slides: List[str], val_slides: List[str], transforms: Optional[Union[TransformPipeline, Transform]] = None) Tuple[CellTypeDataset, CellTypeDataset][source]¶
Create train and validation datasets based on slide names.
This method is useful when you want to split the CellTypeDataset based on the same slide-level split used for another dataset (e.g., MILDataset). All cells from slides in train_slides go to training, and all cells from slides in val_slides go to validation.
- Parameters:
train_slides – List of slide names for training set
val_slides – List of slide names for validation set
transforms – Optional pre-fitted transforms to apply to both datasets. These should already be fitted on training data before calling this method.
- Returns:
Tuple of (train_dataset, val_dataset) with the provided transforms