cellmil.models.mil.standard

Classes

LitStandard(model, optimizer[, loss, ...])

MIL_fc([size_arg, dropout, n_classes, ...])

Multiple Instance Learning model with fully connected layers for binary classification.

MIL_fc_mc([size_arg, dropout, n_classes, ...])

Multiple Instance Learning model with fully connected layers for multi-class classification.

class cellmil.models.mil.standard.MIL_fc(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]

Bases: Module

Multiple Instance Learning model with fully connected layers for binary classification.

This model processes a bag of instances, applies a feature extractor (FC layers), and performs binary classification. It selects the top k instances based on their probability scores for the positive class.

Parameters:
  • size_arg – Size configuration for the network architecture (‘small’ is the only option currently).

  • dropout – Dropout rate for regularization.

  • n_classes – Number of classes (must be 2 for binary classification).

  • top_k – Number of top instances to select based on positive class probability.

  • embed_dim – Dimension of the input feature embeddings.

__init__(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(h: Tensor, return_features: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass of the MIL_fc model.

Parameters:
  • h – Input tensor of shape [n_instances, embed_dim] containing instance embeddings.

  • return_features – If True, returns the feature representations of top instances.

Returns:

  • top_instance: Logits of the top instance(s).

  • Y_prob: Softmax probabilities for the top instance(s).

  • Y_hat: Predicted class labels for the top instance(s).

  • y_probs: Softmax probabilities for all instances.

  • results_dict: Additional results, may contain feature representations if return_features is True.

Return type:

A tuple containing

class cellmil.models.mil.standard.MIL_fc_mc(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]

Bases: Module

Multiple Instance Learning model with fully connected layers for multi-class classification.

This model processes a bag of instances, applies a feature extractor (FC layers), and performs multi-class classification. It selects the top instance based on the highest probability across all classes.

Parameters:
  • size_arg – Size configuration for the network architecture (‘small’ is the only option currently).

  • dropout – Dropout rate for regularization.

  • n_classes – Number of classes (must be > 2 for multi-class classification).

  • top_k – Number of top instances to select (must be 1 for this implementation).

  • embed_dim – Dimension of the input feature embeddings.

__init__(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(h: Tensor, return_features: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]][source]

Forward pass of the MIL_fc_mc model for multi-class classification.

Parameters:
  • h – Input tensor of shape [n_instances, embed_dim] containing instance embeddings.

  • return_features – If True, returns the feature representations of top instances.

Returns:

  • top_instance: Logits of the top instance.

  • Y_prob: Softmax probabilities for the top instance.

  • Y_hat: Predicted class label for the top instance.

  • y_probs: Softmax probabilities for all instances.

  • results_dict: Additional results, may contain feature representations if return_features is True.

Return type:

A tuple containing

class cellmil.models.mil.standard.LitStandard(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, n_classes: int = 2)[source]

Bases: LitGeneral

forward(x: Tensor)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output