# -*- coding: utf-8 -*-
# Utility functions
#
# This modules provides utility functions for the CellMIL framework, with some functions adapted from the CellViT project.
#
# References:
# CellViT: Vision Transformers for precise cell segmentation and classification
# Fabian Hörst et al., Medical Image Analysis, 2024
# DOI: https://doi.org/10.1016/j.media.2024.103143
import importlib
import logging
import sys
import os
import torch
import types
import matplotlib.pyplot as plt
from datetime import timedelta
from torch_geometric.data import Data # type: ignore
from timeit import default_timer as timer
from typing import Dict, List, Optional, Tuple, Union, Any, cast
from types import ModuleType
import plotly.graph_objects as go # type: ignore
import numpy as np
from cellmil.utils import logger
[docs]def to_float(x: torch.Tensor) -> torch.Tensor:
"""Convert tensor to float32"""
return x.float()
[docs]def to_float_normalized(x: torch.Tensor) -> torch.Tensor:
"""Convert tensor to float32 and normalize to [0, 1]"""
return x.float() / 255.0
# Helper timing functions
[docs]def start_timer() -> float:
"""Returns the number of seconds passed since epoch. The epoch is the point where the time starts,
and is platform dependent.
Returns:
float: The number of seconds passed since epoch
"""
return timer()
[docs]def end_timer(start_time: float, timed_event: str = "Time usage") -> None:
"""Prints the time passed from start_time.
Args:
start_time (float): The number of seconds passed since epoch when the timer started
timed_event (str, optional): A string describing the activity being monitored. Defaults to "Time usage".
"""
logger.info(f"{timed_event}: {timedelta(seconds=timer() - start_time)}")
[docs]def module_exists(
*names: Union[List[str], str], # type: ignore
error: str = "ignore",
warn_every_time: bool = False,
__INSTALLED_OPTIONAL_MODULES: Dict[str, bool] = {},
) -> Optional[Union[Tuple[types.ModuleType | None, ...], types.ModuleType]]:
"""Try to import optional dependencies.
Ref: https://stackoverflow.com/a/73838546/4900327
Args:
names (Union(List(str), str)): The module name(s) to import. Str or list of strings.
error (str, optional): What to do when a dependency is not found:
* raise : Raise an ImportError.
* warn: print a warning.
* ignore: If any module is not installed, return None, otherwise, return the module(s).
Defaults to "ignore".
warn_every_time (bool, optional): Whether to warn every time an import is tried. Only applies when error="warn".
Setting this to True will result in multiple warnings if you try to import the same library multiple times.
Defaults to False.
Raises:
ImportError: ImportError of Module
Returns:
Optional[ModuleType, Tuple[ModuleType...]]: The imported module(s), if all are found.
None is returned if any module is not found and `error!="raise"`.
"""
assert error in {"raise", "warn", "ignore"}
if isinstance(names, (list, tuple, set)): # type: ignore
names: List[str] = list(names) # type: ignore
else:
print(type(names))
print(names)
assert isinstance(names, str)
names: List[str] = [names]
modules: list[ModuleType | None] = []
for name in names:
try:
module = importlib.import_module(name)
modules.append(module)
__INSTALLED_OPTIONAL_MODULES[name] = True
except ImportError:
modules.append(None)
def error_msg(missing: Union[str, List[str]]):
if not isinstance(missing, (list, tuple)):
missing = [missing]
missing_str: str = " ".join([f'"{name}"' for name in missing])
dep_str = "dependencies"
if len(missing) == 1:
dep_str = "dependency"
msg = f"Missing optional {dep_str} {missing_str}. Use pip or conda to install."
return msg
missing_modules: List[str] = [
name for name, module in zip(names, modules) if module is None
]
if len(missing_modules) > 0:
if error == "raise":
raise ImportError(error_msg(missing_modules))
if error == "warn":
for name in missing_modules:
# Ensures warning is printed only once
if warn_every_time is True or name not in __INSTALLED_OPTIONAL_MODULES:
logger.warning(f"Warning: {error_msg(name)}")
__INSTALLED_OPTIONAL_MODULES[name] = False
return None
if len(modules) == 1:
return modules[0]
return tuple(modules)
[docs]def close_logger(logger: logging.Logger) -> None:
"""Closing a logger savely
Args:
logger (logging.Logger): Logger to close
"""
handlers = logger.handlers[:]
for handler in handlers:
logger.removeHandler(handler)
handler.close()
logger.handlers.clear()
logging.shutdown()
[docs]def get_size_of_dict(d: dict[str, Any]) -> int:
size = sys.getsizeof(d)
for key, value in d.items():
size += sys.getsizeof(key)
size += sys.getsizeof(value)
return size
[docs]def unflatten_dict(d: dict[str, Any], sep: str = ".") -> dict[str, Any]:
"""Unflatten a flattened dictionary (created a nested dictionary)
Args:
d (dict): Dict to be nested
sep (str, optional): Seperator of flattened keys. Defaults to '.'.
Returns:
dict: Nested dict
"""
output_dict: dict[str, Any] = {}
for key, value in d.items():
keys = key.split(sep)
d = output_dict
for k in keys[:-1]:
d = d.setdefault(k, {})
d[keys[-1]] = value
return output_dict
[docs]def get_cpu_count() -> int:
"""Get the number of available CPU cores."""
try:
return os.cpu_count() or 1 # Fallback to 1 if os.cpu_count() returns None
except Exception as e:
logger.error(f"Error getting CPU count: {e}")
return 1
[docs]def create_color_gradient(
bins: int, initial_color: list[int], final_color: list[int]
) -> list[list[int]]:
"""Create a color gradient from initial_color to final_color.
Args:
bins (int): Number of bins in the gradient.
initial_color (list[int]): RGB values for the initial color.
final_color (list[int]): RGB values for the final color.
Returns:
list[list[int]]: List of RGB colors representing the gradient.
"""
if len(initial_color) != 3 or len(final_color) != 3:
raise ValueError(
"Initial and final colors must be RGB values (lists of length 3)."
)
gradient: list[list[int]] = []
for i in range(bins):
ratio = i / (bins - 1)
color: list[int] = [
int(initial_color[j] + ratio * (final_color[j] - initial_color[j]))
for j in range(3)
]
gradient.append(color)
return gradient
[docs]def plot_vector(vector: torch.Tensor, title: str = "Vector Plot") -> None:
"""Plot a 1D vector using matplotlib.
Args:
vector (torch.Tensor): 1D tensor to plot.
title (str, optional): Title of the plot. Defaults to "Vector Plot".
"""
if vector.dim() != 1:
raise ValueError("Input tensor must be 1-dimensional.")
plt.figure(figsize=(10, 4)) # type: ignore
plt.plot(vector.cpu().numpy(), marker="o") # type: ignore
plt.title(title) # type: ignore
plt.xlabel("Index") # type: ignore
plt.ylabel("Value") # type: ignore
plt.grid(True) # type: ignore
plt.show() # type: ignore
[docs]def plot_graph(data: Data, title: str = "Graph Visualization") -> None:
"""Plot a PyTorch Geometric graph using Plotly.
Args:
data (Data): PyTorch Geometric Data object containing graph structure.
title (str, optional): Title of the plot. Defaults to "Graph Visualization".
"""
# Extract node positions - if not available, use a simple layout
if hasattr(data, "pos") and data.pos is not None:
pos = cast(np.ndarray[Any, Any], data.pos.cpu().numpy()) # type: ignore
if pos.shape[1] == 2:
x, y = pos[:, 0], pos[:, 1]
elif pos.shape[1] >= 3:
x, y = pos[:, 0], pos[:, 1]
else:
# 1D positions, create y coordinates
x = pos[:, 0]
y = np.zeros_like(x)
else:
# Create a simple circular layout if no positions are available
num_nodes = data.num_nodes
angles = cast(
np.ndarray[Any, Any],
np.linspace(0, 2 * np.pi, num_nodes, endpoint=False), # type: ignore
)
x = np.cos(angles)
y = np.sin(angles)
# Extract edges
edge_index = cast(np.ndarray[Any, Any], data.edge_index.cpu().numpy()) # type: ignore
# Create edge traces
edge_x: list[float | np.ndarray[Any, Any] | None] = []
edge_y: list[float | np.ndarray[Any, Any] | None] = []
for i in range(edge_index.shape[1]):
x0, x1 = x[edge_index[0, i]], x[edge_index[1, i]]
y0, y1 = y[edge_index[0, i]], y[edge_index[1, i]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=1, color="#888"),
hoverinfo="none",
mode="lines",
name="Edges",
)
# Create node trace
node_trace = go.Scatter(
x=x,
y=y,
mode="markers",
hoverinfo="text",
name="Nodes",
marker=dict(size=10, color="lightblue", line=dict(width=2, color="black")),
)
# Add node information to hover text
node_text: list[str] = []
node_adjacencies: list[int] = []
for i in range(len(x)):
# Count connections: incoming edges + outgoing edges
num_connections = int((edge_index[1] == i).sum() + (edge_index[0] == i).sum())
node_adjacencies.append(num_connections)
node_info = f"Node {i}<br>Connections: {num_connections}"
# Add node features if available
if hasattr(data, "x") and data.x is not None:
if i < data.x.shape[0]:
features = cast(np.ndarray[Any, Any], data.x[i].cpu().numpy()) # type: ignore
if len(features) <= 5: # Show only first 5 features
feature_str = ", ".join([f"{f:.3f}" for f in features])
node_info += f"<br>Features: [{feature_str}]"
else:
feature_str = ", ".join([f"{f:.3f}" for f in features[:5]]) + "..."
node_info += f"<br>Features: [{feature_str}]"
node_text.append(node_info)
node_trace.text = node_text
# Color nodes by number of connections
if node_adjacencies:
node_trace.marker.color = node_adjacencies # type: ignore
node_trace.marker.colorscale = "Viridis" # type: ignore
node_trace.marker.showscale = True # type: ignore
node_trace.marker.colorbar = dict(title="Number of connections") # type: ignore
# Create the figure
fig = go.Figure(
data=[edge_trace, node_trace],
layout=go.Layout(
title=dict(text=title, font=dict(size=16)),
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
annotations=[
dict(
text=f"Nodes: {data.num_nodes}, Edges: {data.num_edges}",
showarrow=False,
xref="paper",
yref="paper",
x=0.005,
y=-0.002,
xanchor="left",
yanchor="bottom",
font=dict(size=12),
)
],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, autorange="reversed"),
),
)
fig.write_html(f"graph_{title}.html") # type: ignore