cellmil.models.mil.standard¶
Classes
|
|
|
Multiple Instance Learning model with fully connected layers for binary classification. |
|
Multiple Instance Learning model with fully connected layers for multi-class classification. |
- class cellmil.models.mil.standard.MIL_fc(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]¶
Bases:
ModuleMultiple Instance Learning model with fully connected layers for binary classification.
This model processes a bag of instances, applies a feature extractor (FC layers), and performs binary classification. It selects the top k instances based on their probability scores for the positive class.
- Parameters:
size_arg – Size configuration for the network architecture (‘small’ is the only option currently).
dropout – Dropout rate for regularization.
n_classes – Number of classes (must be 2 for binary classification).
top_k – Number of top instances to select based on positive class probability.
embed_dim – Dimension of the input feature embeddings.
- __init__(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(h: Tensor, return_features: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]][source]¶
Forward pass of the MIL_fc model.
- Parameters:
h – Input tensor of shape [n_instances, embed_dim] containing instance embeddings.
return_features – If True, returns the feature representations of top instances.
- Returns:
top_instance: Logits of the top instance(s).
Y_prob: Softmax probabilities for the top instance(s).
Y_hat: Predicted class labels for the top instance(s).
y_probs: Softmax probabilities for all instances.
results_dict: Additional results, may contain feature representations if return_features is True.
- Return type:
A tuple containing
- class cellmil.models.mil.standard.MIL_fc_mc(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]¶
Bases:
ModuleMultiple Instance Learning model with fully connected layers for multi-class classification.
This model processes a bag of instances, applies a feature extractor (FC layers), and performs multi-class classification. It selects the top instance based on the highest probability across all classes.
- Parameters:
size_arg – Size configuration for the network architecture (‘small’ is the only option currently).
dropout – Dropout rate for regularization.
n_classes – Number of classes (must be > 2 for multi-class classification).
top_k – Number of top instances to select (must be 1 for this implementation).
embed_dim – Dimension of the input feature embeddings.
- __init__(size_arg: Union[Literal['small'], list[int]] = 'small', dropout: float = 0.0, n_classes: int = 2, top_k: int = 1, embed_dim: int = 1024)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(h: Tensor, return_features: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]][source]¶
Forward pass of the MIL_fc_mc model for multi-class classification.
- Parameters:
h – Input tensor of shape [n_instances, embed_dim] containing instance embeddings.
return_features – If True, returns the feature representations of top instances.
- Returns:
top_instance: Logits of the top instance.
Y_prob: Softmax probabilities for the top instance.
Y_hat: Predicted class label for the top instance.
y_probs: Softmax probabilities for all instances.
results_dict: Additional results, may contain feature representations if return_features is True.
- Return type:
A tuple containing
- class cellmil.models.mil.standard.LitStandard(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, n_classes: int = 2)[source]¶
Bases:
LitGeneral- forward(x: Tensor)[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output