cellmil.interfaces.AttentionExplainerConfig

Classes

Aggregation(value)

Attention aggregation methods for GraphMIL.

AttentionExplainerConfig(*, output_path[, ...])

Configuration for Attention explainability method.

Normalization(value)

Attention normalization methods.

VisualizationMode(value)

Visualization output modes.

class cellmil.interfaces.AttentionExplainerConfig.Aggregation(value)[source]

Bases: str, Enum

Attention aggregation methods for GraphMIL.

pooling_only = 'pooling_only'
gnn_only = 'gnn_only'
gnn_layer = 'gnn_layer'
random_walk = 'random_walk'
_generate_next_value_(start, count, last_values)

Generate the next value when not given.

name: the name of the member start: the initial start value or None count: the number of existing members last_value: the last value assigned or None

class cellmil.interfaces.AttentionExplainerConfig.Normalization(value)[source]

Bases: str, Enum

Attention normalization methods.

min_max = 'min_max'
z_score = 'z_score'
robust = 'robust'
softmax = 'softmax'
sigmoid = 'sigmoid'
none = 'none'
_generate_next_value_(start, count, last_values)

Generate the next value when not given.

name: the name of the member start: the initial start value or None count: the number of existing members last_value: the last value assigned or None

class cellmil.interfaces.AttentionExplainerConfig.VisualizationMode(value)[source]

Bases: str, Enum

Visualization output modes.

geojson = 'geojson'
graph = 'graph'
all = 'all'
_generate_next_value_(start, count, last_values)

Generate the next value when not given.

name: the name of the member start: the initial start value or None count: the number of existing members last_value: the last value assigned or None

class cellmil.interfaces.AttentionExplainerConfig.AttentionExplainerConfig(*, output_path: Path, attention_aggregation: Aggregation = Aggregation.pooling_only, gnn_layer_index: Optional[int] = None, class_index: Optional[int] = None, attention_head: Optional[int] = None, visualization_mode: VisualizationMode = VisualizationMode.geojson, color_scheme: tuple[list[int], list[int]] = ([35, 92, 236], [255, 0, 0]), normalize_attention: bool = True, normalization: Normalization = Normalization.min_max, visualize_cell_types: bool = False)[source]

Bases: BaseModel

Configuration for Attention explainability method.

output_path: Path
attention_aggregation: Aggregation
gnn_layer_index: Optional[int]
class_index: Optional[int]
attention_head: Optional[int]
visualization_mode: VisualizationMode
color_scheme: tuple[list[int], list[int]]
normalize_attention: bool
normalization: Normalization
visualize_cell_types: bool
classmethod validate_gnn_layer_index(v: Optional[int]) Optional[int][source]
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].