"""Dataset generation orchestration for multi-slide tiling workflows."""
import copy
import logging
from datetime import datetime
from pathlib import Path
from typing import Sequence
import cv2
import numpy as np
from PIL import Image
from glasscut.slides import Slide
from glasscut.tile import Tile
from glasscut.tiler import Tiler
from glasscut.tissue_detectors import OtsuTissueDetector
from glasscut.storage import (
DatasetMetadata,
SlideMetadata,
StorageOrganizer,
TileMetadata,
)
from glasscut.storage.structures import JsonValue
[docs]
class DatasetGenerator:
"""Generate a tile dataset from one or more slide files."""
[docs]
def __init__(
self,
dataset_id: str,
output_dir: str | Path,
*,
tiler: Tiler,
n_workers: int = 4,
batch_size: int = 128,
save_thumbnails: bool = True,
save_masks: bool = True,
save_processed_json: bool = True,
show_progress: bool = True,
verbose: bool = True,
) -> None:
"""Initialize generator from direct parameters.
Parameters
----------
dataset_id : str
Dataset identifier.
output_dir : str | Path
Output root directory.
tiler : Tiler
Preconfigured tiler instance used for extraction.
n_workers : int, optional
Number of workers for batched tile extraction. Default is ``4``.
batch_size : int, optional
Number of tiles per extraction batch. Default is ``128``.
save_thumbnails : bool, optional
Whether to save slide thumbnail artifacts.
save_masks : bool, optional
Whether to save tissue mask artifacts.
save_processed_json : bool, optional
Whether to save ``processed.json`` at dataset root.
show_progress : bool, optional
Whether to display progress bars for slides and tiles.
verbose : bool, optional
Whether to enable info-level logs.
"""
self._validate_parameters(
dataset_id=dataset_id,
output_dir=output_dir,
n_workers=n_workers,
batch_size=batch_size,
)
self.dataset_id = dataset_id
self.output_dir = str(Path(output_dir).resolve())
self.tiler = tiler
self.n_workers = n_workers
self.batch_size = batch_size
self.save_thumbnails = save_thumbnails
self.save_masks = save_masks
self.save_processed_json = save_processed_json
self.show_progress = show_progress
self.verbose = verbose
self.storage = StorageOrganizer(self.output_dir)
self.logger = self._setup_logger()
[docs]
def process_dataset(self, slide_paths: Sequence[str | Path]) -> DatasetMetadata:
"""Process all provided slides and persist tiles, artifacts, and metadata."""
if not slide_paths:
raise ValueError("No slide paths were provided")
normalized_paths = [str(Path(path)) for path in slide_paths]
self.storage.init_dataset(self.dataset_id)
all_tasks = list(enumerate(normalized_paths))
all_slide_ids = {self._slide_id_from_index(index) for index, _ in all_tasks}
resumed_metadata: list[SlideMetadata] = []
processed_slide_ids: list[str] = []
if self.save_processed_json:
previously_processed = self.storage.load_processed_json(self.dataset_id)
for slide_id in previously_processed:
if slide_id not in all_slide_ids:
continue
try:
resumed_metadata.append(
self.storage.load_slide_metadata(self.dataset_id, slide_id)
)
processed_slide_ids.append(slide_id)
except (FileNotFoundError, ValueError):
self.logger.warning(
"Skipping stale checkpoint entry for %s (missing/invalid metadata)",
slide_id,
)
processed_set = set(processed_slide_ids)
pending_tasks = [
(index, slide_path)
for index, slide_path in all_tasks
if self._slide_id_from_index(index) not in processed_set
]
self.logger.info(
"Starting dataset generation for %s (%d slides, %d remaining)",
self.dataset_id,
len(normalized_paths),
len(pending_tasks),
)
new_slides_metadata: list[SlideMetadata] = []
total_pending = len(pending_tasks)
for pending_index, (index, slide_path) in enumerate(pending_tasks, start=1):
if self.show_progress:
print(f"Processing slide {pending_index}/{total_pending}")
slide_meta = self._process_single_slide(slide_path, index)
new_slides_metadata.append(slide_meta)
self._checkpoint_processed_slide(processed_slide_ids, slide_meta.slide_id)
slides_metadata = resumed_metadata + new_slides_metadata
slides_metadata.sort(
key=lambda metadata: self._slide_index_from_id(metadata.slide_id)
)
total_tiles = sum(slide_meta.total_tiles for slide_meta in slides_metadata)
dataset_metadata = DatasetMetadata(
dataset_id=self.dataset_id,
created_at=datetime.now().isoformat(),
total_slides=len(slides_metadata),
total_tiles=total_tiles,
config=self._config_dict(),
slides=slides_metadata,
)
self.storage.save_dataset_metadata(self.dataset_id, dataset_metadata)
if self.save_processed_json and pending_tasks:
self.storage.save_processed_json(self.dataset_id, processed_slide_ids)
self.logger.info(
"Dataset generation complete: %d slides, %d tiles",
len(slides_metadata),
total_tiles,
)
return dataset_metadata
def _checkpoint_processed_slide(
self,
processed_slide_ids: list[str],
slide_id: str,
) -> None:
"""Persist progress checkpoint after each completed slide."""
if slide_id not in processed_slide_ids:
processed_slide_ids.append(slide_id)
if self.save_processed_json:
self.storage.save_processed_json(self.dataset_id, processed_slide_ids)
@staticmethod
def _slide_id_from_index(slide_index: int) -> str:
"""Format slide ID from zero-based index."""
return f"slide_{slide_index:03d}"
@staticmethod
def _slide_index_from_id(slide_id: str) -> int:
"""Extract numeric index from slide ID for sorting."""
try:
return int(slide_id.split("_")[1])
except (IndexError, ValueError):
return 10**9
def _process_single_slide(self, slide_path: str, slide_index: int) -> SlideMetadata:
"""Process a single slide end-to-end."""
slide_id = self._slide_id_from_index(slide_index)
directories = self.storage.init_slide(self.dataset_id, slide_id)
tiler = self._build_tiler()
with Slide(slide_path) as slide:
tile_metadata = self._extract_and_save_tiles(
slide=slide,
tiler=tiler,
tiles_dir=directories["tiles"],
)
if self.save_thumbnails:
slide.thumbnail.save(directories["thumbnails"] / "slide_thumbnail.png")
if self.save_masks:
self._save_tissue_mask(slide.thumbnail, directories["masks"])
slide_metadata = SlideMetadata(
slide_id=slide_id,
slide_name=slide.name,
slide_path=str(Path(slide_path).resolve()),
total_tiles=len(tile_metadata),
dimensions=slide.dimensions,
mpp=slide.mpp,
available_magnifications=[
float(magnification) for magnification in slide.magnifications
],
tile_size=self._resolve_slide_tile_size(tile_metadata),
tiler_name=tiler.__class__.__name__,
timestamp=datetime.now().isoformat(),
tiles=tile_metadata,
)
self.storage.save_slide_metadata(self.dataset_id, slide_id, slide_metadata)
self.logger.info(
"Processed %s with %d tiles", Path(slide_path).name, len(tile_metadata)
)
return slide_metadata
def _extract_and_save_tiles(
self,
slide: Slide,
tiler: Tiler,
tiles_dir: Path,
) -> list[TileMetadata]:
"""Extract tiles and persist them to disk while collecting metadata."""
metadata: list[TileMetadata] = []
for tile_index, tile in enumerate(
tiler.extract(
slide,
n_workers=self.n_workers,
batch_size=self.batch_size,
)
):
tile_id = f"tile_{tile_index:07d}"
tile_path = tiles_dir / f"{tile_id}.png"
self._save_tile_png(tile.image, tile_path)
x, y = tile.coords if tile.coords is not None else (0, 0)
width, height = tile.image.size
metadata.append(
TileMetadata(
tile_id=tile_id,
x=x,
y=y,
width=width,
height=height,
level=0,
magnification=self._resolve_tile_magnification(tile),
tissue_ratio=self._safe_tissue_ratio(tile),
file_path=str(tile_path.relative_to(Path(self.output_dir))),
)
)
return metadata
def _build_tiler(self) -> Tiler:
"""Build a configured tiler instance for extraction."""
# Use an independent tiler instance per slide to avoid shared mutable state.
return copy.deepcopy(self.tiler)
@staticmethod
def _resolve_tile_magnification(tile: Tile) -> float:
"""Return tile magnification, requiring tilers to emit it explicitly."""
if tile.magnification is None:
raise ValueError(
"Tile magnification is missing. Custom tilers must emit tiles "
"with a valid magnification value."
)
return float(tile.magnification)
@staticmethod
def _resolve_slide_tile_size(tile_metadata: list[TileMetadata]) -> tuple[int, int]:
"""Return representative slide tile size.
For mixed-size tilers this reflects the first produced tile.
"""
if not tile_metadata:
return (0, 0)
return (tile_metadata[0].width, tile_metadata[0].height)
@staticmethod
def _safe_tissue_ratio(tile: Tile) -> float:
"""Compute tissue ratio and return 0.0 on detector failures."""
try:
return float(tile.tissue_ratio)
except Exception:
return 0.0
@staticmethod
def _save_tissue_mask(thumbnail: Image.Image, masks_dir: Path) -> None:
"""Generate and persist a thumbnail-level tissue mask."""
detector = OtsuTissueDetector()
mask = detector.detect(thumbnail)
mask = np.asarray(mask)
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
Image.fromarray(mask * 255).save(masks_dir / "tissue_mask.png")
@staticmethod
def _save_tile_png(image: Image.Image, path: Path) -> None:
"""Save a tile as lossless PNG using OpenCV for faster encoding."""
image_np = np.asarray(image)
if image_np.dtype != np.uint8:
image_np = image_np.astype(np.uint8)
if image_np.ndim == 2:
encoded = image_np
elif image_np.ndim == 3 and image_np.shape[2] == 3:
encoded = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
elif image_np.ndim == 3 and image_np.shape[2] == 4:
encoded = cv2.cvtColor(image_np, cv2.COLOR_RGBA2BGRA)
else:
# Keep behavior robust for unexpected modes by falling back to PIL.
image.save(path)
return
success = cv2.imwrite(
str(path),
encoded,
[cv2.IMWRITE_PNG_COMPRESSION, 1],
)
if not success:
raise RuntimeError(f"Failed to write PNG tile with OpenCV: {path}")
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger(f"glasscut.dataset.{self.dataset_id}")
logger.handlers.clear()
logger.setLevel(logging.INFO if self.verbose else logging.WARNING)
logger.propagate = False
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
"%(asctime)s | %(name)s | %(levelname)s | %(message)s"
)
)
logger.addHandler(handler)
return logger
@staticmethod
def _validate_parameters(
dataset_id: str,
output_dir: str | Path,
n_workers: int,
batch_size: int,
) -> None:
if not dataset_id:
raise ValueError("dataset_id is required")
if not output_dir:
raise ValueError("output_dir is required")
if n_workers < 1:
raise ValueError("n_workers must be >= 1")
if batch_size < 1:
raise ValueError("batch_size must be >= 1")
def _config_dict(self) -> dict[str, JsonValue]:
"""Build JSON-serializable config payload for metadata.json."""
return {
"dataset_id": self.dataset_id,
"output_dir": self.output_dir,
"tiler_name": self.tiler.__class__.__name__,
"show_progress": self.show_progress,
"n_workers": self.n_workers,
"batch_size": self.batch_size,
"save_thumbnails": self.save_thumbnails,
"save_masks": self.save_masks,
"save_processed_json": self.save_processed_json,
"verbose": self.verbose,
}