Source code for cellmil.interfaces.MILTrainerConfig

from pydantic import BaseModel, Field, field_validator
from pathlib import Path
from .CellSegmenterConfig import ModelType
from .FeatureExtractorConfig import ExtractorType
from .GraphCreatorConfig import GraphCreatorType
from .MIL import MILType
import torch

torch.set_float32_matmul_precision('medium')

[docs]class MILTrainerConfig(BaseModel): """Configuration for MIL prediction using CLAM or standard MIL models""" root: Path = Field(..., description="Root directory for the dataset") folder: Path = Field(..., description="Path to the dataset directory") excel_path: Path = Field(..., description="Path to the Excel file with target labels") label: str = Field(..., description="Label for the dataset (e.g., 'dcr')") model: MILType = Field(..., description="Type of MIL model to use") gpu: int = Field(0, description="GPU index to use for prediction") extractor: ExtractorType | list[ExtractorType] = Field(..., description="Feature extractor to use for patch extraction") segmentation_model: ModelType = Field(..., description="Segmentation model used to extract cells") graph_creator: GraphCreatorType = Field(..., description="Graph creation method to use for GNN+MIL models") ckpt_path: Path = Field(..., description="Path to the checkpoint directory where the model will be saved") normalization: bool = Field(default=False, description="Whether to apply normalization to the input data") correlation_filter: float = Field(default=0.0, description="Correlation filter threshold for feature selection") cell_type: bool = Field(default=False, description="Whether to include cell type information in the features") n_bins: int = Field(default=4, description="Number of time bins for discrete-time survival models")
[docs] @field_validator('model') def validate_model(cls, v: str) -> str: if v not in MILType.values(): raise ValueError(f"Unsupported MIL model type: {v}. Supported models are: {MILType.values()}") return v
[docs] class Config: arbitrary_types_allowed = True