cellmil.models.mil.graphmil

Classes

LitGraphMIL(gnn, pooling_classifier, ...[, ...])

Lightning module for Graph-based Multiple Instance Learning.

LitSurvGraphMIL(gnn, pooling_classifier, ...)

Lightning module for Graph-based Multiple Instance Learning with Survival Analysis.

class cellmil.models.mil.graphmil.GNN(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: Module, ABC

Abstract base class for Graph Neural Networks.

This class defines the interface for GNN models but cannot be instantiated directly. Subclasses must implement the layer creation logic in their __init__ method.

__init__(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_hyperparameters() dict[str, Any][source]

Get all hyperparameters for this GNN.

_get_specific_hyperparameters() dict[str, Any][source]

Override in subclasses to add specific hyperparameters.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Extract attention weights from GNN layers.

Parameters:

data – Input graph data

Returns:

Dictionary with attention weights from each layer (empty for non-attention GNNs)

class cellmil.models.mil.graphmil.GAT(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, heads: int = 1, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, heads: int = 1, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Extract attention weights from each GAT layer.

Parameters:

data – Input graph data

Returns:

{‘gnn_attention_layer_{i}’: weights}

Return type:

Dictionary with attention weights

class cellmil.models.mil.graphmil.GATv2(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Extract attention weights from each GATv2 layer.

Parameters:

data – Input graph data

Returns:

{‘gnn_attention_layer_{i}’: weights}

Return type:

Dictionary with attention weights

class cellmil.models.mil.graphmil.EGNN(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.SAGE(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class cellmil.models.mil.graphmil.CHIMERA(input_dim: int, dropout: float, heads: int = 1, residual: bool = True, n_layers: int = 3, hidden_dim: list[int] = [128, 256, 512], **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, dropout: float, heads: int = 1, residual: bool = True, n_layers: int = 3, hidden_dim: list[int] = [128, 256, 512], **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.GlobalPooling_Classifier(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, **kwargs: Any)[source]

Bases: Module, ABC

Abstract base class for global pooling classifiers in GraphMIL.

This class defines the interface for pooling classifiers that aggregate node features into graph-level predictions. Each subclass handles its own specific arguments and validates them appropriately.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_hyperparameters() dict[str, Any][source]

Get all hyperparameters for this pooling classifier.

_get_specific_hyperparameters() dict[str, Any][source]

Override in subclasses to add specific hyperparameters.

abstract forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for the pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor for global pooling (if applicable)

  • **kwargs – Additional arguments specific to each pooling classifier

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]

Extract attention weights from the pooling classifier.

Parameters:
  • x – Node features tensor

  • batch – Batch assignment tensor (if applicable)

Returns:

Attention weights tensor or None if not available

class cellmil.models.mil.graphmil.CLAM(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, gate: bool = True, k_sample: int = 8, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, clam_type: str = 'SB', temperature: float = 1.0, **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

CLAM pooling classifier with attention-based multiple instance learning.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, gate: bool = True, k_sample: int = 8, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, clam_type: str = 'SB', temperature: float = 1.0, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, label: Optional[Tensor] = None, instance_eval: bool = False, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for CLAM pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor (not used by CLAM, must be single graph)

  • label – Ground truth labels for instance evaluation

  • instance_eval – Whether to perform instance-level evaluation

  • **kwargs – Additional arguments (ignored by CLAM)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]

Extract attention weights from CLAM.

Parameters:
  • x – Node features tensor

  • batch – Batch assignment tensor (should be single graph)

Returns:

Attention weights tensor of shape [1, num_nodes] or [num_classes, num_nodes]

class cellmil.models.mil.graphmil.Standard(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str = 'small', standard_type: str = 'fc', **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

Standard MIL pooling classifier.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str = 'small', standard_type: str = 'fc', **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for Standard MIL pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor (not used by Standard, must be single graph)

  • **kwargs – Additional arguments (ignored by Standard)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

class cellmil.models.mil.graphmil.Attention(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], attention_branches: int = 1, temperature: float = 1.0, **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

AttentionDeepMIL pooling classifier.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], attention_branches: int = 1, temperature: float = 1.0, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for AttentionDeepMIL pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor (not used by AttentionDeepMIL, must be single graph)

  • **kwargs – Additional arguments (ignored by AttentionDeepMIL)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information and attention weights

Return type:

tuple containing

get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]

Extract attention weights from AttentionDeepMIL.

Parameters:
  • x – Node features tensor

  • batch – Batch assignment tensor (must be single graph)

Returns:

Attention weights tensor of shape [attention_branches, num_nodes]

class cellmil.models.mil.graphmil.Mean_MLP(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

Mean pooling followed by MLP classifier.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for Mean_MLP pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor for global pooling

  • **kwargs – Additional arguments (ignored by Mean_MLP)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

class cellmil.models.mil.graphmil.LitGraphMIL(gnn: GNN, pooling_classifier: GlobalPooling_Classifier, optimizer_cls: type[torch.optim.optimizer.Optimizer], optimizer_kwargs: dict[str, Any], loss_fn: Module = CrossEntropyLoss(), scheduler_cls: Optional[type[torch.optim.lr_scheduler.LRScheduler]] = None, scheduler_kwargs: Optional[dict[str, Any]] = None, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25, subsampling: float = 1.0, **kwargs: Any)[source]

Bases: LightningModule

Lightning module for Graph-based Multiple Instance Learning.

This model is designed to work with torch_geometric DataLoader and requires: - batch_size=1 for MIL tasks - Data objects with batch.y containing graph labels - GNNMILDataset from cellmil.datamodels.datasets.gnn_mil_dataset

Example usage:

from torch_geometric.loader import DataLoader from cellmil.datamodels.datasets.gnn_mil_dataset import GNNMILDataset

dataset = GNNMILDataset(…) dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

model = LitGraphMIL(gnn=…, pooling_classifier=…, …) trainer.fit(model, dataloader)

__init__(gnn: GNN, pooling_classifier: GlobalPooling_Classifier, optimizer_cls: type[torch.optim.optimizer.Optimizer], optimizer_kwargs: dict[str, Any], loss_fn: Module = CrossEntropyLoss(), scheduler_cls: Optional[type[torch.optim.lr_scheduler.LRScheduler]] = None, scheduler_kwargs: Optional[dict[str, Any]] = None, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25, subsampling: float = 1.0, **kwargs: Any)[source]
classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]

Load a model from a checkpoint.

Parameters:
  • checkpoint_path (str | Path | IO[bytes]) – Path to the checkpoint file or a file-like object.

  • map_location (optional) – Device mapping for loading the model.

  • hparams_file (optional) – Path to a YAML file containing hyperparameters.

  • strict (optional) – Whether to strictly enforce that the keys in state_dict match the keys returned by the model’s state_dict function.

  • **kwargs – Additional keyword arguments passed to the model’s constructor

Returns:

An instance of LitGraphMIL.

_subsample_graph(data: Data, subsampling: float) Data[source]

Sample subgraph using NeighborLoader to preserve local graph structure.

This method uses k-hop neighborhood sampling which preserves the local connectivity around seed nodes, providing better context for GNN message passing compared to random node sampling.

Note: This method is designed to work on CPU before GPU transfer when called from on_before_batch_transfer hook, saving GPU memory and transfer bandwidth.

Parameters:
  • data (Data) – Input graph data (typically on CPU).

  • subsampling (float) – Fraction of nodes to keep (0 < subsampling < 1.0) or absolute number of nodes (subsampling >= 1.0).

Returns:

Sampled subgraph with k-hop neighborhoods around seed nodes.

Return type:

Data

Note

This method requires either ‘pyg-lib’ or ‘torch-sparse’ to be installed. Install with: pip install pyg-lib torch-sparse -f https://data.pyg.org/whl/torch-{TORCH_VERSION}+{CUDA_VERSION}.html

forward(data: Data, **kwargs: Any)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

on_before_batch_transfer(batch: Data, dataloader_idx: int) Data[source]

Hook called before batch is transferred to GPU. Performs subsampling on CPU to reduce memory usage and transfer overhead.

Parameters:
  • batch (Data) – Input graph data on CPU.

  • dataloader_idx (int) – Index of the dataloader.

Returns:

Potentially subsampled graph data (still on CPU).

Return type:

Data

training_step(batch: Data, batch_idx: int)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch: Data, batch_idx: int)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

test_step(batch: Data, batch_idx: int)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

on_train_epoch_end() None[source]

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
on_validation_epoch_end()[source]

Called in the validation loop at the very end of the epoch.

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

_flatten_and_log_metrics(computed: dict[str, torch.Tensor], prefix: str) None[source]

Convert metric dictionary produced by torchmetrics into a flat dict of scalar values and log it with self.log_dict.

  • Vector/tensor metrics (e.g. per-class accuracy) are expanded into keys like {prefix}/class_{i}_acc.

  • Scalar tensors are converted to floats.

  • None values are converted to NaN to satisfy loggers that expect numeric scalars.

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

predict_step(batch: Data, batch_idx: int)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
static calculate_error(y_hat: Tensor, y: Tensor)[source]

Classification error = 1 - accuracy.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Get attention weights from both GNN layers and pooling classifier.

This method delegates to the individual component classes for clean separation of concerns and better maintainability.

Parameters:

data (Data) – Input graph data.

Returns:

Dictionary containing attention weights:
  • GNN attention weights (if available): ‘gnn_attention_layer_{i}’

  • Pooling attention weights (if available): ‘pooling_attention’

Return type:

dict[str, torch.Tensor]

class cellmil.models.mil.graphmil.LitSurvGraphMIL(gnn: GNN, pooling_classifier: GlobalPooling_Classifier, optimizer_cls: type[torch.optim.optimizer.Optimizer], optimizer_kwargs: dict[str, Any], loss_fn: Module = NegativeLogLikelihoodSurvLoss(), scheduler_cls: Optional[type[torch.optim.lr_scheduler.LRScheduler]] = None, scheduler_kwargs: Optional[dict[str, Any]] = None, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25, subsampling: float = 1.0, **kwargs: Any)[source]

Bases: LitGraphMIL

Lightning module for Graph-based Multiple Instance Learning with Survival Analysis.

This class extends LitGraphMIL to support survival analysis tasks using discrete-time hazard models. It uses survival-specific loss functions and metrics like C-index and Brier score.

Parameters:
  • gnn (GNN) – Graph Neural Network model for node feature extraction.

  • pooling_classifier (GlobalPooling_Classifier) – Pooling and classification module.

  • optimizer_cls (type[Optimizer]) – Optimizer class.

  • optimizer_kwargs (dict[str, Any]) – Optimizer keyword arguments.

  • loss_fn (nn.Module, optional) – Loss function. Defaults to NegativeLogLikelihoodSurvLoss.

  • scheduler_cls (type[LRScheduler] | None, optional) – Learning rate scheduler class.

  • scheduler_kwargs (dict[str, Any] | None, optional) – Scheduler keyword arguments.

  • use_aem (bool, optional) – Whether to use AEM regularization. Defaults to False.

  • aem_weight_initial (float, optional) – Initial weight for AEM loss. Defaults to 0.0001.

  • aem_weight_final (float, optional) – Final weight for AEM loss. Defaults to 0.0.

  • aem_annealing_epochs (int, optional) – Number of epochs to anneal AEM weight. Defaults to 25.

  • subsampling (float, optional) – Fraction of nodes to keep during training. Defaults to 1.0.

  • **kwargs – Additional keyword arguments.

__init__(gnn: GNN, pooling_classifier: GlobalPooling_Classifier, optimizer_cls: type[torch.optim.optimizer.Optimizer], optimizer_kwargs: dict[str, Any], loss_fn: Module = NegativeLogLikelihoodSurvLoss(), scheduler_cls: Optional[type[torch.optim.lr_scheduler.LRScheduler]] = None, scheduler_kwargs: Optional[dict[str, Any]] = None, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25, subsampling: float = 1.0, **kwargs: Any)[source]
_setup_metrics()[source]

Setup C-index and Brier score metrics for survival analysis.

predict_step(batch: Data, batch_idx: int)[source]

Prediction step returns logits for discrete-time hazard intervals.

class cellmil.models.mil.graphmil.SmallWorld(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, top_k_ratio: float = 0.005, heads: int = 1, layer_type: Literal['GCN', 'GAT', 'SAGE'] = 'SAGE', **kwargs: Any)[source]

Bases: GNN

SmallWorld GNN that creates additional connections between high-attention nodes.

After each layer, uses SAGPooling to identify important nodes (top 1% by attention score) and creates additional edges between them, forming a small-world-like topology where important nodes are more densely connected.

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, top_k_ratio: float = 0.005, heads: int = 1, layer_type: Literal['GCN', 'GAT', 'SAGE'] = 'SAGE', **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_add_small_world_edges(x: Tensor, edge_index: Tensor, batch: Tensor, attention_pool: SAGPooling) Tensor[source]

Add edges between top-k nodes based on attention scores.

Parameters:
  • x – Node features

  • edge_index – Current edge index

  • batch – Batch assignment for nodes

  • attention_pool – SAGPooling layer to compute attention scores

Returns:

Updated edge index with additional small-world connections

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.SGFormer(input_dim: int, hidden_dim: int, n_layers: int, dropout: float = 0.25, heads: int = 1, alpha: float = 0.5, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float = 0.25, heads: int = 1, alpha: float = 0.5, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Modules