Source code for cellmil.graph.debug_visualizer
"""
Interactive debug visualizer for graph creation hyperparameter tuning.
This module provides an interactive Dash application for visualizing and tuning
graph creation hyperparameters in real-time with a sampled subset of cells.
"""
import torch
import numpy as np
from typing import Any, Dict, List, cast
import plotly.graph_objs as go # type: ignore
from dash import Dash, dcc, html, Input, Output, State
import dash_bootstrap_components as dbc # type: ignore
from cellmil.utils import logger
from cellmil.interfaces.GraphCreatorConfig import GraphCreatorType
from .creator import (
KNNEdgeCreator,
RadiusEdgeCreator,
DelaunayEdgeCreator,
DilateEdgeCreator,
SimilarityEdgeCreator,
)
[docs]class GraphDebugVisualizer:
"""Interactive visualizer for debugging graph creation with different hyperparameters."""
[docs] def __init__(
self,
cells: List[Dict[str, Any]],
method: GraphCreatorType,
device: str = "cpu",
):
"""
Initialize the debug visualizer.
Args:
cells: List of cell dictionaries with centroid and other features
method: Graph creation method to use
device: Computing device ('cpu' or 'cuda:X')
"""
self.cells = cells
self.method = method
self.device = device
self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# Extract positions
centroids = np.array([cell["centroid"] for cell in cells])
self.positions = cast(torch.Tensor, torch.from_numpy(centroids).float().to(device)) # type: ignore
logger.info(
f"Debug visualizer initialized with {len(cells)} cells using {method} method"
)
self._setup_layout()
self._setup_callbacks()
[docs] def _setup_layout(self):
"""Setup the Dash application layout with controls and visualization."""
# Create control panel based on method
controls = self._create_controls()
self.app.layout = dbc.Container(
[
dbc.Row(
[
dbc.Col(
html.H1(
f"Graph Creator Debug Mode - {self.method.upper()}",
className="text-center mb-4",
)
)
]
),
dbc.Row(
[
dbc.Col(
[
html.H4("Hyperparameters", className="mb-3"),
html.Div(controls),
html.Hr(),
html.Div(id="graph-stats", className="mt-3"),
],
width=3,
),
dbc.Col(
[
dcc.Loading(
id="loading",
type="default",
children=dcc.Graph(
id="graph-plot",
style={"height": "85vh"},
config={"displayModeBar": True},
),
)
],
width=9,
),
]
),
],
fluid=True,
style={"padding": "20px"},
)
[docs] def _create_controls(self) -> List[html.Div]:
"""Create control widgets based on the graph creation method."""
controls: list[Any] = []
if self.method == GraphCreatorType.knn:
controls.extend(
[
html.Label("K (Number of Neighbors):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="k-slider",
min=1,
max=50,
step=1,
value=5,
marks={i: str(i) for i in [1, 5, 10, 20, 30, 40, 50]},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="k-input",
type="number",
min=1,
max=50,
step=1,
value=5,
style={"width": "100%"},
),
width=4,
),
],
className="mb-3",
),
]
)
elif self.method == GraphCreatorType.radius:
controls.extend(
[
html.Label("Radius (pixels):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="radius-slider",
min=10,
max=500,
step=10,
value=100,
marks={i: str(i) for i in [10, 100, 200, 300, 400, 500]},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="radius-input",
type="number",
min=10,
max=500,
step=10,
value=100,
style={"width": "100%"},
),
width=4,
),
],
className="mb-3",
),
]
)
elif self.method == GraphCreatorType.delaunay_radius:
controls.extend(
[
html.Label("Limit Radius (pixels):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="limit-radius-slider",
min=10,
max=500,
step=10,
value=150,
marks={i: str(i) for i in [10, 100, 200, 300, 400, 500]},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="limit-radius-input",
type="number",
min=10,
max=500,
step=10,
value=150,
style={"width": "100%"},
),
width=4,
),
],
className="mb-3",
),
]
)
elif self.method == GraphCreatorType.dilate:
controls.extend(
[
html.Label("Dilation (pixels):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="dilation-slider",
min=1,
max=50,
step=1,
value=10,
marks={i: str(i) for i in [1, 5, 10, 20, 30, 40, 50]},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="dilation-input",
type="number",
min=1,
max=50,
step=1,
value=10,
style={"width": "100%"},
),
width=4,
),
],
className="mb-3",
),
]
)
elif self.method == GraphCreatorType.similarity:
# Check if cells have features
has_features = "features" in self.cells[0] if self.cells else False
if not has_features:
controls.append(
dbc.Alert(
"Warning: Cells don't have morphological features loaded. "
"Similarity method requires features.",
color="warning",
)
)
controls.extend(
[
html.Label("Similarity Threshold / K:", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="similarity-threshold-slider",
min=0.1,
max=20,
step=0.01,
value=0.5,
marks={
0.1: "0.1",
0.5: "0.5",
1.0: "1",
5: "5",
10: "10",
15: "15",
20: "20",
},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="similarity-threshold-input",
type="number",
min=0.1,
max=20,
step=0.01,
value=0.5,
style={"width": "100%"},
),
width=4,
),
],
className="mb-2",
),
html.Small(
"Values < 1: threshold mode, Values ≥ 1: KNN mode",
className="text-muted",
),
html.Br(),
html.Br(),
html.Label("Distance Sigma (for spatial distance):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="distance-sigma-slider",
min=10,
max=2000,
step=1,
value=200,
marks={i: str(i) for i in [20, 100, 200, 300, 400, 500, 1000, 1500, 2000]},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="distance-sigma-input",
type="number",
min=10,
max=2000,
step=1,
value=200,
style={"width": "100%"},
),
width=4,
),
],
className="mb-3",
),
html.Label("Alpha (similarity vs distance):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="alpha-slider",
min=0.0,
max=1.0,
step=0.01,
value=0.5,
marks={i / 10: str(i / 10) for i in range(0, 11, 2)},
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="alpha-input",
type="number",
min=0.0,
max=1.0,
step=0.01,
value=0.5,
style={"width": "100%"},
),
width=4,
),
],
className="mb-2",
),
html.Small(
"0: distance only, 1: similarity only", className="text-muted"
),
html.Br(),
html.Br(),
html.Label("Combination Method:", className="fw-bold"),
dcc.RadioItems(
id="combination-method-radio",
options=[
{"label": " Additive", "value": "additive"},
{"label": " Multiplicative", "value": "multiplicative"},
],
value="additive",
className="mb-2",
labelStyle={"display": "block", "margin": "5px 0"},
),
html.Small(
"Additive: α·sim + (1-α)·dist | Multiplicative: sim^α · dist^(1-α)",
className="text-muted",
),
html.Br(),
html.Br(),
html.Label("Distance Similarity Metric:", className="fw-bold"),
dcc.Dropdown(
id="distance-metric-dropdown",
options=[
{"label": "Gaussian (exp(-d²/(2σ²)))", "value": "gaussian"},
{"label": "Laplacian (exp(-d/σ))", "value": "laplacian"},
{"label": "Inverse (1/(1+d))", "value": "inverse"},
{"label": "Inverse Square (1/(1+d²))", "value": "inverse_square"},
],
value="gaussian",
className="mb-2",
clearable=False,
),
html.Br(),
html.Label("Feature Similarity Metric:", className="fw-bold"),
dcc.Dropdown(
id="feature-metric-dropdown",
options=[
{"label": "Cosine Similarity", "value": "cosine"},
{"label": "Pearson Correlation", "value": "correlation"},
{"label": "Euclidean (exp(-d))", "value": "euclidean"},
{"label": "Gaussian (exp(-d²/(2σ_f²)))", "value": "gaussian"},
],
value="cosine",
className="mb-2",
clearable=False,
),
html.Br(),
html.Label("Feature Sigma (for Gaussian):", className="fw-bold"),
dbc.Row(
[
dbc.Col(
dcc.Slider(
id="feature-sigma-slider",
min=0.1,
max=10,
step=0.1,
value=1.0,
marks={i: str(i) for i in [0.1, 1, 2, 5, 10]}, # type: ignore
tooltip={"placement": "bottom", "always_visible": True},
),
width=8,
),
dbc.Col(
dcc.Input(
id="feature-sigma-input",
type="number",
min=0.1,
max=10,
step=0.1,
value=1.0,
style={"width": "100%"},
),
width=4,
),
],
className="mb-3",
),
html.Small(
"Only used when Feature Metric is Gaussian", className="text-muted"
),
html.Br(),
html.Br(),
]
)
# Add update button
controls.append(
dbc.Button(
"Update Graph",
id="update-button",
color="primary",
className="w-100",
size="lg",
)
)
return controls
[docs] def _setup_callbacks(self):
"""Setup Dash callbacks for interactive updates."""
# Setup sync callbacks between sliders and inputs
self._setup_sync_callbacks()
# Determine which inputs to use based on method
inputs = [Input("update-button", "n_clicks")]
states: list[State] = []
if self.method == GraphCreatorType.knn:
states.append(State("k-slider", "value"))
elif self.method == GraphCreatorType.radius:
states.append(State("radius-slider", "value"))
elif self.method == GraphCreatorType.delaunay_radius:
states.append(State("limit-radius-slider", "value"))
elif self.method == GraphCreatorType.dilate:
states.append(State("dilation-slider", "value"))
elif self.method == GraphCreatorType.similarity:
states.extend(
[
State("similarity-threshold-slider", "value"),
State("distance-sigma-slider", "value"),
State("alpha-slider", "value"),
State("combination-method-radio", "value"),
State("distance-metric-dropdown", "value"),
State("feature-metric-dropdown", "value"),
State("feature-sigma-slider", "value"),
]
)
@self.app.callback( # type: ignore
[Output("graph-plot", "figure"), Output("graph-stats", "children")],
inputs,
states,
prevent_initial_call=False,
)
def update_graph(n_clicks: int, *params: Any): # type: ignore
"""Update the graph visualization based on parameters."""
try:
# Create edge creator with current parameters
edge_creator = self._create_edge_creator(*params)
# Create edges
edge_indices, edge_features = edge_creator.create_edges(
self.positions, self.cells
)
# Create visualization
fig = self._create_figure(edge_indices)
# Create statistics
stats = self._create_stats(edge_indices, edge_features, *params)
return fig, stats
except Exception as e:
logger.error(f"Error updating graph: {e}")
import traceback
traceback.print_exc()
# Return empty figure and error message
return go.Figure(), html.Div(
[
dbc.Alert(
f"Error creating graph: {str(e)}", color="danger"
)
]
)
[docs] def _setup_sync_callbacks(self):
"""Setup callbacks to sync sliders and input fields."""
if self.method == GraphCreatorType.knn:
# Sync K slider and input
@self.app.callback( # type: ignore
Output("k-slider", "value"),
Input("k-input", "value"),
prevent_initial_call=True,
)
def sync_k_slider(value: int | None): # type: ignore
return value if value is not None else 5
@self.app.callback( # type: ignore
Output("k-input", "value"),
Input("k-slider", "value"),
prevent_initial_call=True,
)
def sync_k_input(value: int | None): # type: ignore
return value
elif self.method == GraphCreatorType.radius:
# Sync radius slider and input
@self.app.callback( # type: ignore
Output("radius-slider", "value"),
Input("radius-input", "value"),
prevent_initial_call=True,
)
def sync_radius_slider(value: int | None): # type: ignore
return value if value is not None else 100
@self.app.callback( # type: ignore
Output("radius-input", "value"),
Input("radius-slider", "value"),
prevent_initial_call=True,
)
def sync_radius_input(value: int | None): # type: ignore
return value
elif self.method == GraphCreatorType.delaunay_radius:
# Sync limit radius slider and input
@self.app.callback( # type: ignore
Output("limit-radius-slider", "value"),
Input("limit-radius-input", "value"),
prevent_initial_call=True,
)
def sync_limit_radius_slider(value: int | None): # type: ignore
return value if value is not None else 150
@self.app.callback( # type: ignore
Output("limit-radius-input", "value"),
Input("limit-radius-slider", "value"),
prevent_initial_call=True,
)
def sync_limit_radius_input(value: int | None): # type: ignore
return value
elif self.method == GraphCreatorType.dilate:
# Sync dilation slider and input
@self.app.callback( # type: ignore
Output("dilation-slider", "value"),
Input("dilation-input", "value"),
prevent_initial_call=True,
)
def sync_dilation_slider(value: int | None): # type: ignore
return value if value is not None else 10
@self.app.callback( # type: ignore
Output("dilation-input", "value"),
Input("dilation-slider", "value"),
prevent_initial_call=True,
)
def sync_dilation_input(value: int | None): # type: ignore
return value
elif self.method == GraphCreatorType.similarity:
# Sync similarity threshold slider and input
@self.app.callback( # type: ignore
Output("similarity-threshold-slider", "value"),
Input("similarity-threshold-input", "value"),
prevent_initial_call=True,
)
def sync_threshold_slider(value: float | None): # type: ignore
return value if value is not None else 0.5
@self.app.callback( # type: ignore
Output("similarity-threshold-input", "value"),
Input("similarity-threshold-slider", "value"),
prevent_initial_call=True,
)
def sync_threshold_input(value: float | None): # type: ignore
return value
# Sync distance sigma slider and input
@self.app.callback( # type: ignore
Output("distance-sigma-slider", "value"),
Input("distance-sigma-input", "value"),
prevent_initial_call=True,
)
def sync_distance_sigma_slider(value: float | None): # type: ignore
return value if value is not None else 200
@self.app.callback( # type: ignore
Output("distance-sigma-input", "value"),
Input("distance-sigma-slider", "value"),
prevent_initial_call=True,
)
def sync_distance_sigma_input(value: float | None): # type: ignore
return value
# Sync alpha slider and input
@self.app.callback( # type: ignore
Output("alpha-slider", "value"),
Input("alpha-input", "value"),
prevent_initial_call=True,
)
def sync_alpha_slider(value: float | None): # type: ignore
return value if value is not None else 0.5
@self.app.callback( # type: ignore
Output("alpha-input", "value"),
Input("alpha-slider", "value"),
prevent_initial_call=True,
)
def sync_alpha_input(value: float | None): # type: ignore
return value
# Sync feature sigma slider and input
@self.app.callback( # type: ignore
Output("feature-sigma-slider", "value"),
Input("feature-sigma-input", "value"),
prevent_initial_call=True,
)
def sync_feature_sigma_slider(value: float | None): # type: ignore
return value if value is not None else 1.0
@self.app.callback( # type: ignore
Output("feature-sigma-input", "value"),
Input("feature-sigma-slider", "value"),
prevent_initial_call=True,
)
def sync_feature_sigma_input(value: float | None): # type: ignore
return value
[docs] def _create_edge_creator(self, *params: Any):
"""Create an edge creator instance with the given parameters."""
if self.method == GraphCreatorType.knn:
k = params[0]
return KNNEdgeCreator(self.device, k=k)
elif self.method == GraphCreatorType.radius:
radius = params[0]
return RadiusEdgeCreator(self.device, radius=radius)
elif self.method == GraphCreatorType.delaunay_radius:
limit_radius = params[0]
return DelaunayEdgeCreator(
self.device, limit_radius=limit_radius
)
elif self.method == GraphCreatorType.dilate:
dilation = params[0]
return DilateEdgeCreator(
self.device, dilation=dilation
)
elif self.method == GraphCreatorType.similarity:
similarity_threshold = params[0]
distance_sigma = params[1]
alpha = params[2]
combination_method = params[3] if len(params) > 3 else "additive"
distance_metric = params[4] if len(params) > 4 else "gaussian"
feature_metric = params[5] if len(params) > 5 else "cosine"
feature_sigma = params[6] if len(params) > 6 else 1.0
return SimilarityEdgeCreator(
self.device,
similarity_threshold=similarity_threshold,
distance_sigma=distance_sigma,
alpha=alpha,
combination_method=combination_method,
distance_metric=distance_metric,
feature_metric=feature_metric,
feature_sigma=feature_sigma,
)
else:
raise ValueError(f"Unsupported method: {self.method}")
[docs] def _create_figure(self, edge_indices: torch.Tensor) -> go.Figure:
"""Create a Plotly figure for the graph visualization."""
pos_np = cast(np.ndarray[Any, Any],
self.positions.cpu().numpy() # type: ignore
if self.positions.is_cuda
else self.positions.numpy() # type: ignore
)
edges_np = cast(np.ndarray[Any, Any],
edge_indices.cpu().numpy() # type: ignore
if edge_indices.is_cuda
else edge_indices.numpy() # type: ignore
)
# Create edge traces
edge_x: List[float | None] = []
edge_y: List[float | None] = []
if edges_np.shape[1] > 0:
for i in range(min(edges_np.shape[1], 50000)): # Limit for performance
start_idx, end_idx = int(edges_np[0, i]), int(edges_np[1, i])
start_pos = pos_np[start_idx]
end_pos = pos_np[end_idx]
edge_x.extend([start_pos[0], end_pos[0], None])
edge_y.extend([start_pos[1], end_pos[1], None])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=1, color="rgba(70, 130, 180, 0.4)"),
hoverinfo="none",
mode="lines",
showlegend=False,
)
# Create node trace
node_x = pos_np[:, 0].tolist()
node_y = pos_np[:, 1].tolist()
node_trace = go.Scatter(
x=node_x,
y=node_y,
mode="markers",
hoverinfo="text",
text=[f"Cell {i}" for i in range(len(pos_np))],
marker=dict(
size=5,
color="rgba(255, 100, 100, 0.8)",
line=dict(width=0.5, color="rgba(255, 255, 255, 0.5)"),
),
showlegend=False,
)
# Create figure
fig = go.Figure(
data=[edge_trace, node_trace],
layout=go.Layout(
title=dict(
text=f"Cell Graph - {self.method.upper()}",
font=dict(size=18, color="black"),
x=0.5,
),
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=20, r=20, t=50),
xaxis=dict(
showgrid=True,
gridcolor="rgba(200, 200, 200, 0.2)",
zeroline=False,
showticklabels=True,
title="X coordinate",
),
yaxis=dict(
showgrid=True,
gridcolor="rgba(200, 200, 200, 0.2)",
zeroline=False,
showticklabels=True,
title="Y coordinate",
autorange="reversed",
scaleanchor="x",
scaleratio=1,
),
plot_bgcolor="white",
paper_bgcolor="white",
),
)
return fig
[docs] def _create_stats(
self, edge_indices: torch.Tensor, edge_features: torch.Tensor, *params: Any
) -> html.Div:
"""Create statistics display for the current graph."""
n_nodes = len(self.cells)
n_edges = edge_indices.shape[1]
avg_degree = (2 * n_edges / n_nodes) if n_nodes > 0 else 0
# Calculate edge distance statistics
if edge_features.shape[0] > 0:
distances = cast(np.ndarray[Any, Any], edge_features[:, 0].cpu().numpy()) # type: ignore
avg_distance = float(np.mean(distances))
std_distance = float(np.std(distances))
min_distance = float(np.min(distances))
max_distance = float(np.max(distances))
else:
avg_distance = std_distance = min_distance = max_distance = 0.0
# Format parameters
param_text = self._format_params(*params)
stats_content: list[Any] = [
html.H5("Graph Statistics", className="mb-3"),
html.P([html.Strong("Nodes: "), f"{n_nodes:,}"]),
html.P([html.Strong("Edges: "), f"{n_edges:,}"]),
html.P([html.Strong("Avg Degree: "), f"{avg_degree:.2f}"]),
html.Hr(),
html.H6("Edge Distances", className="mb-2"),
html.P([html.Strong("Mean: "), f"{avg_distance:.2f} px"]),
html.P([html.Strong("Std: "), f"{std_distance:.2f} px"]),
html.P([html.Strong("Min: "), f"{min_distance:.2f} px"]),
html.P([html.Strong("Max: "), f"{max_distance:.2f} px"]),
html.Hr(),
html.H6("Parameters", className="mb-2"),
html.Pre(param_text, style={"fontSize": "12px"}),
]
return html.Div(stats_content)
[docs] def _format_params(self, *params: Any) -> str:
"""Format parameters for display."""
lines: list[str] = []
if self.method == GraphCreatorType.knn:
lines.append(f"K: {params[0]}")
elif self.method == GraphCreatorType.radius:
lines.append(f"Radius: {params[0]} px")
elif self.method == GraphCreatorType.delaunay_radius:
lines.append(f"Limit Radius: {params[0]} px")
elif self.method == GraphCreatorType.dilate:
lines.append(f"Dilation: {params[0]} px")
elif self.method == GraphCreatorType.similarity:
combination_method = params[3] if len(params) > 3 else "additive"
distance_metric = params[4] if len(params) > 4 else "gaussian"
feature_metric = params[5] if len(params) > 5 else "cosine"
feature_sigma = params[6] if len(params) > 6 else 1.0
param_lines = [
f"Threshold/K: {params[0]}", # type: ignore
f"Distance Sigma: {params[1]}", # type: ignore
f"Alpha: {params[2]}", # type: ignore
f"Combination: {combination_method}",
f"Distance Metric: {distance_metric}",
f"Feature Metric: {feature_metric}",
]
# Add feature_sigma if gaussian metric is used
if feature_metric == "gaussian":
param_lines.append(f"Feature Sigma: {feature_sigma}")
lines.extend(param_lines)
return "\n".join(lines)
[docs] def run(self, host: str = "127.0.0.1", port: int = 8050, debug: bool = True):
"""
Run the Dash application.
Args:
port: Port to run the application on
debug: Enable debug mode for Dash
"""
print(f"Starting Feature Visualizer Dashboard at http://{host}:{port}")
self.app.run(host=host, port=port, debug=debug) # type: ignore