Source code for cellmil.datamodels.datasets.mil_dataset

import pandas as pd
from typing import Any, Literal
from pathlib import Path
from .cell_mil_dataset import CellMILDataset
from .patch_mil_dataset import PatchMILDataset
from cellmil.interfaces.FeatureExtractorConfig import (
    ExtractorType,
    FeatureExtractionType,
)


[docs]def MILDataset( root: Path, label: str | tuple[str, str], folder: Path, data: pd.DataFrame, extractor: ExtractorType | list[ExtractorType], split: Literal["train", "val", "test", "all"] = "all", **kwargs: Any, ) -> CellMILDataset | PatchMILDataset: if isinstance(extractor, list) or extractor not in FeatureExtractionType.Embedding: return CellMILDataset( root=root, label=label, folder=folder, data=data, extractor=extractor, split=split, graph_creator=kwargs.get("graph_creator", None), segmentation_model=kwargs.get("segmentation_model", None), cell_type=kwargs.get("cell_type", False), cell_types_to_keep=kwargs.get("cell_types_to_keep", None), roi_folder=kwargs.get("roi_folder", None), force_reload=kwargs.get("force_reload", False), transforms=kwargs.get("transforms", None), label_transforms=kwargs.get("label_transforms", None), return_cell_types=kwargs.get("return_cell_types", True), ) else: return PatchMILDataset( root=root, label=label, folder=folder, data=data, extractor=extractor, split=split, force_reload=kwargs.get("force_reload", False), transforms=None, # Safety measure label_transforms=kwargs.get("label_transforms", None), )