"""
CellTypeDataset: A PyTorch Dataset for multi-class cell type classification.
This dataset treats each individual cell as a sample, creating a cell-level dataset
where each item is a single cell with its features and one-hot encoded cell type label.
Supports label smoothing for specific cell types to handle annotation uncertainty.
"""
import pandas as pd
import torch
import json
import hashlib
from tqdm import tqdm
from pathlib import Path
from typing import List, Literal, Tuple, Union, Optional, Dict, Any
from torch.utils.data import Dataset
from cellmil.interfaces.FeatureExtractorConfig import ExtractorType
from cellmil.interfaces.GraphCreatorConfig import GraphCreatorType
from cellmil.interfaces.CellSegmenterConfig import ModelType, TYPE_NUCLEI_DICT
from cellmil.utils import logger
from ..transforms import (
TransformPipeline,
Transform,
)
from .utils import (
wsl_preprocess,
filter_split,
get_cell_types,
get_cell_features,
)
# Number of cell types (Neoplastic, Inflammatory, Connective, Dead, Epithelial)
NUM_CELL_TYPES = len(TYPE_NUCLEI_DICT)
[docs]class CellTypeDataset(Dataset[Tuple[torch.Tensor, torch.Tensor]]):
"""
A PyTorch Dataset for multi-class cell type classification.
This dataset treats each individual cell as a sample, creating a cell-level dataset
where each item is a single cell with its features and one-hot encoded cell type label.
The labels are one-hot encoded for the 5 cell types:
- Type 1: Neoplastic
- Type 2: Inflammatory
- Type 3: Connective
- Type 4: Dead
- Type 5: Epithelial
Supports label smoothing for specific cell types to handle annotation uncertainty.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (features, label) where:
- features is a tensor of shape (n_features,) for a single cell
- label is a one-hot encoded tensor of shape (n_cell_types,) with optional label smoothing
"""
[docs] def __init__(
self,
root: Union[str, Path],
folder: Path,
data: pd.DataFrame,
extractor: Union[ExtractorType, List[ExtractorType]],
graph_creator: GraphCreatorType | None = None,
segmentation_model: ModelType | None = None,
split: Literal["train", "val", "test", "all"] = "all",
transforms: Optional[TransformPipeline | Transform] = None,
cell_types_to_keep: Optional[List[str]] = None,
label_smoothing: Optional[Union[float, Dict[str, float]]] = None,
max_workers: int = 8,
force_reload: bool = False,
):
"""
Initialize the CellTypeDataset.
Args:
root: Root directory where the processed dataset will be cached
folder: Path to the dataset folder containing slide data
data: DataFrame containing metadata with at least 'FULL_PATH' and 'SPLIT' columns
extractor: Feature extractor type or list of types to use for feature extraction
graph_creator: Optional graph creator type, needed for some extractors
segmentation_model: Segmentation model type ('cellvit' or 'hovernet'), required for cell type info
split: Dataset split ('train', 'val', 'test', or 'all')
transforms: Optional TransformPipeline or Transform to apply to features at getitem time
cell_types_to_keep: Optional list of cell type names to keep (e.g., ["Neoplastic", "Connective"]).
Valid names: "Neoplastic", "Inflammatory", "Connective", "Dead", "Epithelial" (case-insensitive).
If provided, only cells of these types will be included.
label_smoothing: Label smoothing configuration. Can be:
- None or 0.0: No smoothing applied
- float (0.0 to 1.0): Same smoothing value applied to all cell types
- Dict[str, float]: Custom smoothing value for each cell type, e.g.,
{"Neoplastic": 0.0, "Inflammatory": 0.1, "Dead": 0.2, "Epithelial": 0.15}
Cell types not in the dict will have no smoothing applied.
max_workers: Maximum number of threads for parallel processing
force_reload: Whether to force reprocessing even if processed files exist
"""
self.root = Path(root)
self.folder = folder
self.raw_data = data
self.extractor = extractor
self.graph_creator = graph_creator
self.segmentation_model = segmentation_model
self.split = split
self.max_workers = max_workers
self.force_reload = force_reload
self.transforms = transforms
# Validate segmentation model
if self.segmentation_model not in ["cellvit", "hovernet"]:
raise ValueError(
f"Cell type classification requires 'cellvit' or 'hovernet' segmentation model. "
f"Got '{self.segmentation_model}'"
)
# Process and validate label smoothing
self.label_smoothing_per_type = self._process_label_smoothing(label_smoothing)
# Convert cell type names to indices (1-based as in TYPE_NUCLEI_DICT)
self.cell_types_to_keep_indices = self._cell_type_names_to_indices(
cell_types_to_keep
)
# TODO: Make this configurable
self.wsl = False
# Data structures
self.all_slides: List[str] = [] # All slides with features
self.slides: List[str] = [] # Slides after filtering
# Cell-level data (populated after processing)
self.cell_features: Dict[
str, torch.Tensor
] = {} # slide_name -> features tensor
self.cell_types: Dict[
str, torch.Tensor
] = {} # slide_name -> cell types tensor (1-based indices)
self.cell_slide_indices: Dict[
str, List[int]
] = {} # slide_name -> list of global indices
# Flat list for indexing
self._global_indices: List[
Tuple[str, int]
] = [] # (slide_name, local_idx) for each cell
# Create root directory
self.root.mkdir(parents=True, exist_ok=True)
# Check if we need to process or can load from cache
data_path = self._get_data_path()
if self.force_reload or not data_path.exists():
logger.info("Processing dataset from scratch...")
self._process_dataset()
self._save_data(data_path)
else:
logger.info(f"Loading preprocessed dataset from {data_path}")
self._load_data(data_path)
[docs] def _process_label_smoothing(
self, label_smoothing: Optional[Union[float, Dict[str, float]]]
) -> Dict[int, float]:
"""
Process label smoothing configuration into a per-type dictionary.
Args:
label_smoothing: Either None, a float, or a dict mapping cell type names to smoothing values
Returns:
Dictionary mapping cell type indices (1-based) to smoothing values
"""
# Default: no smoothing for any type
smoothing_per_type: Dict[int, float] = {
i: 0.0 for i in range(1, NUM_CELL_TYPES + 1)
}
if label_smoothing is None:
return smoothing_per_type
if isinstance(label_smoothing, (int, float)):
# Single value: apply to all types
smoothing_value = float(label_smoothing)
if not 0.0 <= smoothing_value <= 1.0:
raise ValueError(
f"label_smoothing must be between 0.0 and 1.0. Got {smoothing_value}"
)
for i in range(1, NUM_CELL_TYPES + 1):
smoothing_per_type[i] = smoothing_value
elif isinstance(label_smoothing, dict): # type: ignore
# Dict: custom value per type
name_to_idx: Dict[str, int] = {
v.lower(): k for k, v in TYPE_NUCLEI_DICT.items()
}
for type_name, smoothing_value in label_smoothing.items():
type_name_lower = type_name.lower()
if type_name_lower not in name_to_idx:
valid_names = list(TYPE_NUCLEI_DICT.values())
raise ValueError(
f"Invalid cell type name in label_smoothing: '{type_name}'. "
f"Valid names are: {valid_names}"
)
if not 0.0 <= smoothing_value <= 1.0:
raise ValueError(
f"label_smoothing value for '{type_name}' must be between 0.0 and 1.0. "
f"Got {smoothing_value}"
)
type_idx = name_to_idx[type_name_lower]
smoothing_per_type[type_idx] = smoothing_value
else:
raise TypeError(
f"label_smoothing must be None, float, or Dict[str, float]. "
f"Got {type(label_smoothing)}"
)
return smoothing_per_type
[docs] def _cell_type_names_to_indices(
self, cell_type_names: Optional[List[str]]
) -> Optional[List[int]]:
"""
Convert cell type names to their corresponding indices.
Args:
cell_type_names: List of cell type names (case-insensitive)
Returns:
List of cell type indices (1-based as in TYPE_NUCLEI_DICT), or None if input is None
"""
if cell_type_names is None:
return None
# Create reverse mapping from name to index
name_to_idx: Dict[str, int] = {
v.lower(): k for k, v in TYPE_NUCLEI_DICT.items()
}
indices: List[int] = []
for name in cell_type_names:
name_lower = name.lower()
if name_lower not in name_to_idx:
valid_names = list(TYPE_NUCLEI_DICT.values())
raise ValueError(
f"Invalid cell type name: '{name}'. Valid names are: {valid_names}"
)
indices.append(name_to_idx[name_lower])
return indices
[docs] def _get_data_path(self) -> Path:
"""Get the path for the processed dataset file."""
# Create a hash of the configuration to ensure cache invalidation
extractor_str = (
json.dumps(self.extractor, sort_keys=True)
if isinstance(self.extractor, list)
else str(self.extractor)
)
config_dict: Dict[str, Any] = {
"extractor": extractor_str,
"graph_creator": self.graph_creator,
"segmentation_model": self.segmentation_model,
"split": self.split,
"cell_types_to_keep": sorted(self.cell_types_to_keep_indices)
if self.cell_types_to_keep_indices
else None,
}
config_str = json.dumps(config_dict, sort_keys=True)
config_hash = hashlib.md5(config_str.encode("utf-8")).hexdigest()[:8]
logger.info(f"Dataset configuration: {config_dict}")
logger.info(f"Dataset configuration hash: {config_hash}")
return self.root / f"celltype_{self.split}_{config_hash}.pt"
[docs] def _process_dataset(self) -> None:
"""Process the entire dataset and cache results."""
try:
if self.wsl:
processed_data = wsl_preprocess(self.raw_data)
else:
processed_data = self.raw_data.copy()
# Validate required columns
if processed_data.empty:
raise ValueError("Data is empty")
required_columns = (
["FULL_PATH", "SPLIT"] if self.split != "all" else ["FULL_PATH"]
)
for col in required_columns:
if col not in processed_data.columns:
raise ValueError(f"Required column '{col}' not found in data")
# Filter by split type
if self.split != "all":
processed_data = filter_split(processed_data, self.split)
else:
logger.info(f"Using all data: {len(processed_data)} slides")
# Process each slide
logger.info("Processing slides and extracting cells...")
global_idx = 0
for _, row in tqdm(
processed_data.iterrows(),
total=len(processed_data),
desc="Processing slides",
):
slide_name = self._process_slide(row, global_idx)
if slide_name is not None:
self.all_slides.append(slide_name)
# Update global index based on number of cells added
if slide_name in self.cell_slide_indices:
global_idx += len(self.cell_slide_indices[slide_name])
# Filter slides to those with valid features
self.slides = [
s
for s in self.all_slides
if s in self.cell_features and len(self.cell_features[s]) > 0
]
logger.info(
f"Found {len(self.slides)} valid slides out of {len(processed_data)} total slides"
)
logger.info(f"Total cells: {len(self._global_indices)}")
if not self.slides:
raise ValueError("No valid slides found")
# Log class distribution
self._log_class_distribution()
except Exception as e:
logger.error(f"Failed to process dataset: {e}")
raise ValueError(f"Failed to process dataset: {e}")
[docs] def _process_slide(self, row: pd.Series, start_global_idx: int) -> Optional[str]:
"""
Process a single slide and extract cell-level features and labels.
Args:
row: A pandas Series representing a row from the DataFrame
start_global_idx: Starting global index for cells in this slide
Returns:
slide_name if successful, None otherwise
"""
try:
file_path = Path(row["FULL_PATH"])
slide_name = file_path.stem
# Get features and cell indices
features, cell_indices, _ = get_cell_features(
self.folder,
slide_name,
self.extractor,
self.graph_creator,
self.segmentation_model,
)
if features is None or cell_indices is None or len(cell_indices) == 0:
logger.debug(f"No features found for slide {slide_name}")
return None
# Get cell types
if self.segmentation_model is not None:
cell_types = get_cell_types(
self.folder, slide_name, self.segmentation_model
)
else:
logger.warning(f"No segmentation model for slide {slide_name}")
return None
if not cell_types:
logger.debug(f"No cell types found for slide {slide_name}")
return None
# Filter and extract cell-level data
filtered_features: List[torch.Tensor] = []
filtered_cell_types: List[int] = []
local_indices: List[int] = []
for cell_id, tensor_idx in cell_indices.items():
if cell_id not in cell_types:
continue
cell_type = cell_types[cell_id] # 1-based index
# Filter by cell types to keep if specified
if self.cell_types_to_keep_indices is not None:
if cell_type not in self.cell_types_to_keep_indices:
continue
# Get cell features
cell_feature = features[tensor_idx] # Shape: (n_features,)
filtered_features.append(cell_feature)
filtered_cell_types.append(cell_type)
local_indices.append(len(self._global_indices))
self._global_indices.append((slide_name, len(filtered_features) - 1))
if not filtered_features:
logger.debug(f"No cells after filtering for slide {slide_name}")
return None
# Store as tensors
self.cell_features[slide_name] = torch.stack(filtered_features, dim=0)
self.cell_types[slide_name] = torch.tensor(
filtered_cell_types, dtype=torch.long
)
self.cell_slide_indices[slide_name] = local_indices
return slide_name
except Exception as e:
logger.warning(f"Error processing slide row: {e}")
return None
[docs] def _create_label(self, cell_type: int) -> torch.Tensor:
"""
Create a one-hot encoded label with optional label smoothing.
Args:
cell_type: Cell type index (1-based, 1-5)
Returns:
One-hot encoded tensor of shape (NUM_CELL_TYPES,) with optional smoothing
"""
# Convert to 0-based index
idx = cell_type - 1
# Create base one-hot encoding
label = torch.zeros(NUM_CELL_TYPES, dtype=torch.float32)
label[idx] = 1.0
# Apply label smoothing for this specific cell type
smoothing = self.label_smoothing_per_type.get(cell_type, 0.0)
if smoothing > 0:
# Label smoothing: (1 - smoothing) * one_hot + smoothing / num_classes
label = (1.0 - smoothing) * label + smoothing / NUM_CELL_TYPES
return label
[docs] def _save_data(self, path: Path) -> None:
"""Save processed data to disk (slides is derived, not saved)."""
data_dict: Dict[str, Any] = {
"all_slides": self.all_slides,
"cell_features": self.cell_features,
"cell_types": self.cell_types,
"cell_slide_indices": self.cell_slide_indices,
"_global_indices": self._global_indices,
}
torch.save(data_dict, path)
logger.info(f"Saved dataset to {path}")
[docs] def _load_data(self, path: Path) -> None:
"""Load processed data from disk and derive slides list."""
data_dict = torch.load(path, map_location="cpu", weights_only=False)
self.all_slides = data_dict["all_slides"]
self.cell_features = data_dict["cell_features"]
self.cell_types = data_dict["cell_types"]
self.cell_slide_indices = data_dict["cell_slide_indices"]
self._global_indices = data_dict["_global_indices"]
# Derive slides from all_slides (those with valid features)
self.slides = [
s
for s in self.all_slides
if s in self.cell_features and len(self.cell_features[s]) > 0
]
logger.info(
f"Loaded {len(self.slides)} slides with {len(self._global_indices)} total cells"
)
[docs] def _log_class_distribution(self) -> None:
"""Log the distribution of cell types in the dataset."""
type_counts: Dict[int, int] = {i: 0 for i in range(1, NUM_CELL_TYPES + 1)}
for slide_name in self.slides:
cell_types_tensor = self.cell_types[slide_name]
for cell_type in cell_types_tensor:
type_counts[int(cell_type.item())] += 1
total = sum(type_counts.values())
logger.info("Cell type distribution:")
for cell_type, count in type_counts.items():
type_name = TYPE_NUCLEI_DICT[cell_type]
percentage = (count / total * 100) if total > 0 else 0
logger.info(f" {type_name}: {count} ({percentage:.2f}%)")
[docs] def get_class_distribution(self) -> Dict[str, int]:
"""
Get the distribution of cell types in the dataset.
Returns:
Dictionary with cell type names as keys and counts as values
"""
type_counts: Dict[str, int] = {}
for _, name in TYPE_NUCLEI_DICT.items():
type_counts[name] = 0
for slide_name in self.slides:
cell_types_tensor = self.cell_types[slide_name]
for cell_type in cell_types_tensor:
type_name = TYPE_NUCLEI_DICT[int(cell_type.item())]
type_counts[type_name] += 1
type_counts["total"] = sum(type_counts.values())
return type_counts
[docs] def get_num_classes(self) -> int:
"""
Get the number of cell type classes.
Returns:
Number of cell type classes (5 for all types, or fewer if filtered)
"""
if self.cell_types_to_keep_indices is not None:
return len(self.cell_types_to_keep_indices)
return NUM_CELL_TYPES
[docs] def get_weights_for_sampler(self) -> torch.Tensor:
"""
Compute weights for WeightedRandomSampler to handle class imbalance.
Returns:
torch.Tensor: Weights for each cell in the dataset, shape (len(dataset),)
"""
if len(self._global_indices) == 0:
logger.warning("No cells found for weight computation")
return torch.ones(1, dtype=torch.float32)
# Count cells per type (using 1-based indices)
type_counts: Dict[int, int] = {i: 0 for i in range(1, NUM_CELL_TYPES + 1)}
cell_types_list: List[int] = []
for slide_name, local_idx in self._global_indices:
cell_type = int(self.cell_types[slide_name][local_idx].item()) # 1-based
type_counts[cell_type] += 1
cell_types_list.append(cell_type)
# Compute weights (inverse of class frequency)
total_cells = len(cell_types_list)
weights = torch.zeros(total_cells, dtype=torch.float32)
for i, cell_type in enumerate(cell_types_list):
count = type_counts[cell_type]
if count > 0:
weights[i] = 1.0 / count
else:
weights[i] = 1.0
# Normalize weights
weights = weights * len(weights) / weights.sum()
logger.info("Computed sampling weights for cell type classification:")
for cell_type, count in type_counts.items():
type_name = TYPE_NUCLEI_DICT.get(cell_type, f"Type {cell_type}")
logger.info(f" {type_name}: {count} cells")
return weights
[docs] def __len__(self) -> int:
"""Return the number of cells in the dataset."""
return len(self._global_indices)
[docs] def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a sample from the dataset.
Args:
idx: Index of the cell to retrieve
Returns:
Tuple of (features, label) where:
- features is a tensor of shape (n_features,) for a single cell
- label is a one-hot encoded tensor of shape (n_cell_types,)
"""
if idx >= len(self._global_indices):
raise IndexError(
f"Index {idx} out of range for dataset of size {len(self._global_indices)}"
)
slide_name, local_idx = self._global_indices[idx]
# Get cell features
features = self.cell_features[slide_name][local_idx].clone()
# Get cell type and create one-hot label with optional smoothing
cell_type = int(self.cell_types[slide_name][local_idx].item()) # 1-based
label = self._create_label(cell_type)
# Apply transforms if provided
if self.transforms is not None:
# Transforms expect (n_instances, n_features), so add batch dimension
features = features.unsqueeze(0)
features = self.transforms.transform(features)
features = features.squeeze(0)
return features, label
[docs] def create_subset(self, indices: List[int]) -> "CellTypeDataset":
"""
Create a subset of the dataset using the specified indices.
Args:
indices: List of cell indices to include in the subset
Returns:
New CellTypeDataset instance containing only the specified cells
"""
if not indices:
raise ValueError("Indices list cannot be empty")
max_idx = len(self._global_indices) - 1
invalid_indices = [idx for idx in indices if idx < 0 or idx > max_idx]
if invalid_indices:
raise ValueError(
f"Invalid indices {invalid_indices[:5]}... Valid range is 0-{max_idx}"
)
# Create a new instance with the same configuration
subset = CellTypeDataset.__new__(CellTypeDataset)
# Copy configuration
subset.root = self.root
subset.folder = self.folder
subset.raw_data = self.raw_data
subset.extractor = self.extractor
subset.graph_creator = self.graph_creator
subset.segmentation_model = self.segmentation_model
subset.split = self.split
subset.max_workers = self.max_workers
subset.force_reload = self.force_reload
subset.transforms = self.transforms
subset.label_smoothing_per_type = self.label_smoothing_per_type
subset.cell_types_to_keep_indices = self.cell_types_to_keep_indices
subset.wsl = self.wsl
# Share feature/label data (read-only, avoid duplication)
subset.cell_features = self.cell_features
subset.cell_types = self.cell_types
subset.cell_slide_indices = self.cell_slide_indices
subset.all_slides = self.all_slides
# Create subset indices
subset._global_indices = [self._global_indices[i] for i in indices]
# Determine which slides are in the subset
subset_slide_set = set(slide for slide, _ in subset._global_indices)
subset.slides = [s for s in self.slides if s in subset_slide_set]
logger.info(
f"Created subset with {len(subset._global_indices)} cells from {len(subset.slides)} slides"
)
return subset
[docs] def create_train_val_datasets(
self,
train_indices: List[int],
val_indices: List[int],
transforms: Optional[Union[TransformPipeline, Transform]] = None,
) -> Tuple["CellTypeDataset", "CellTypeDataset"]:
"""
Create train and validation datasets from specified indices.
Args:
train_indices: List of cell indices for training set
val_indices: List of cell indices for validation set
transforms: Optional pre-fitted transforms to apply to both datasets.
These should already be fitted on training data before calling this method.
Returns:
Tuple of (train_dataset, val_dataset) with the provided transforms
"""
if not train_indices:
raise ValueError("Train indices list cannot be empty")
if not val_indices:
raise ValueError("Val indices list cannot be empty")
# Validate indices
max_idx = len(self._global_indices) - 1
invalid_train = [idx for idx in train_indices if idx < 0 or idx > max_idx]
invalid_val = [idx for idx in val_indices if idx < 0 or idx > max_idx]
if invalid_train:
raise ValueError(
f"Invalid train indices {invalid_train[:5]}... Valid range is 0-{max_idx}"
)
if invalid_val:
raise ValueError(
f"Invalid val indices {invalid_val[:5]}... Valid range is 0-{max_idx}"
)
# Validate that transforms are fitted if provided
if transforms is not None:
if hasattr(transforms, "is_fitted") and not transforms.is_fitted: # type: ignore
raise ValueError(
"Transforms must be fitted before passing to create_train_val_datasets. "
"Fit the transforms on training data first using transforms.fit()."
)
logger.info(
f"Creating train/val datasets with {len(train_indices)} train and {len(val_indices)} val cells"
)
# Create datasets
train_dataset = self.create_subset(train_indices)
val_dataset = self.create_subset(val_indices)
# Apply pre-fitted transforms if provided
if transforms is not None:
train_dataset.transforms = transforms
val_dataset.transforms = transforms
logger.info(f"Train dataset: {len(train_dataset)} cells")
logger.info(f"Val dataset: {len(val_dataset)} cells")
return train_dataset, val_dataset
[docs] def create_train_val_datasets_by_slides(
self,
train_slides: List[str],
val_slides: List[str],
transforms: Optional[Union[TransformPipeline, Transform]] = None,
) -> Tuple["CellTypeDataset", "CellTypeDataset"]:
"""
Create train and validation datasets based on slide names.
This method is useful when you want to split the CellTypeDataset
based on the same slide-level split used for another dataset (e.g., MILDataset).
All cells from slides in train_slides go to training, and all cells from
slides in val_slides go to validation.
Args:
train_slides: List of slide names for training set
val_slides: List of slide names for validation set
transforms: Optional pre-fitted transforms to apply to both datasets.
These should already be fitted on training data before calling this method.
Returns:
Tuple of (train_dataset, val_dataset) with the provided transforms
"""
if not train_slides:
raise ValueError("Train slides list cannot be empty")
if not val_slides:
raise ValueError("Val slides list cannot be empty")
# Convert to sets for faster lookup
train_slides_set = set(train_slides)
val_slides_set = set(val_slides)
# Validate that slides exist in dataset
available_slides = set(self.slides)
invalid_train = [s for s in train_slides if s not in available_slides]
invalid_val = [s for s in val_slides if s not in available_slides]
if invalid_train:
logger.warning(
f"Some train slides not found in dataset: {invalid_train[:5]}..."
)
if invalid_val:
logger.warning(
f"Some val slides not found in dataset: {invalid_val[:5]}..."
)
# Find cell indices for each split based on slide membership
train_indices: List[int] = [idx for idx, (slide_name, _) in enumerate(self._global_indices) if slide_name in train_slides_set]
val_indices: List[int] = [idx for idx, (slide_name, _) in enumerate(self._global_indices) if slide_name in val_slides_set]
if not train_indices:
raise ValueError(
"No cells found for training slides. Check that slide names match."
)
if not val_indices:
raise ValueError(
"No cells found for validation slides. Check that slide names match."
)
logger.info(
f"Split by slides: {len(train_indices)} train cells from {len([s for s in train_slides if s in available_slides])} slides, "
f"{len(val_indices)} val cells from {len([s for s in val_slides if s in available_slides])} slides"
)
# Use the existing method to create the datasets
return self.create_train_val_datasets(
train_indices=train_indices,
val_indices=val_indices,
transforms=transforms,
)