""" Cell 4: Multi-Scale Geometric Extraction (Fully Batched) ========================================================= Optimizations: - Multi-image batching: N images → single mega classify call - Fused raw + deviance extraction per image - GPU-only channel clustering (no numpy round-trip) - torch.kthvalue replaces torch.quantile - No torch.cuda.empty_cache() in hot path - All GPU-resident until annotation construction """ import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass, field from typing import List, Optional, Tuple import math @dataclass class ExtractionConfig: canonical_shape: Tuple[int, int, int] = (8, 16, 16) scales: List[Tuple[int, int, int]] = field(default_factory=lambda: [ (16, 64, 64), (8, 32, 32), (8, 16, 16), (4, 8, 8), ]) overlap: float = 0.5 confidence_threshold: float = 0.7 min_occupancy: float = 0.01 binarize_percentile: float = 90.0 n_channel_groups: int = 8 max_classify_batch: int = 16384 image_batch_size: int = 32 # process N images simultaneously device: str = 'cuda' @dataclass class GeometricAnnotation: class_name: str class_idx: int confidence: float scale_level: int location: Tuple[int, int, int] patch_size: Tuple[int, int, int] dimension: int = -1 is_curved: bool = False curvature_type: str = "none" source: str = "raw" channel_group_pair: Optional[Tuple[int, int]] = None # === GPU Primitives =========================================================== def extract_patches_gpu(volume, patch_size, overlap=0.5): """Vectorized patch extraction. Returns (N, pz, py, px), (N, 3) locations.""" D, H, W = volume.shape pz, py, px = patch_size dev = volume.device if D < pz or H < py or W < px: volume = F.pad(volume, (0, max(px-W,0), 0, max(py-H,0), 0, max(pz-D,0))) D, H, W = volume.shape sz = max(1, int(pz * (1 - overlap))) sy = max(1, int(py * (1 - overlap))) sx = max(1, int(px * (1 - overlap))) z_s = torch.arange(0, max(1, D - pz + 1), sz, device=dev) y_s = torch.arange(0, max(1, H - py + 1), sy, device=dev) x_s = torch.arange(0, max(1, W - px + 1), sx, device=dev) if len(z_s) == 0: z_s = torch.tensor([0], device=dev) if len(y_s) == 0: y_s = torch.tensor([0], device=dev) if len(x_s) == 0: x_s = torch.tensor([0], device=dev) gz, gy, gx = torch.meshgrid(z_s, y_s, x_s, indexing='ij') locs = torch.stack([gz.flatten(), gy.flatten(), gx.flatten()], dim=1) N = locs.shape[0] oz = torch.arange(pz, device=dev) oy = torch.arange(py, device=dev) ox = torch.arange(px, device=dev) z_idx = (locs[:, 0:1] + oz.unsqueeze(0))[:, :, None, None].expand(N, pz, py, px) y_idx = (locs[:, 1:2] + oy.unsqueeze(0))[:, None, :, None].expand(N, pz, py, px) x_idx = (locs[:, 2:3] + ox.unsqueeze(0))[:, None, None, :].expand(N, pz, py, px) return volume[z_idx, y_idx, x_idx], locs def binarize_fast(patches, percentile=90.0, min_occ=0.01): """Fast binarization using kthvalue (not quantile).""" N = patches.shape[0] V = patches[0].numel() flat = patches.reshape(N, V).abs() k = max(1, int(V * (1.0 - percentile / 100.0))) thresholds = flat.kthvalue(V - k + 1, dim=1, keepdim=True).values binary = (flat >= thresholds).float() occ = binary.mean(dim=1) keep = (occ >= min_occ) & (occ <= 0.95) keep_idx = keep.nonzero(as_tuple=True)[0] return binary.reshape(N, *patches.shape[1:])[keep_idx], keep_idx def extract_and_prepare_volume(volume, config): """ Extract patches at ALL scales from a single volume. Returns (canonical_patches, meta_list) all on GPU. """ canonical = config.canonical_shape all_canonical = [] all_meta = [] # (level, kept_locs, scale, count) for level, scale in enumerate(config.scales): pz, py, px = scale D, H, W = volume.shape if D < pz or H < py or W < px: continue patches, locations = extract_patches_gpu(volume, scale, config.overlap) binary, keep_idx = binarize_fast( patches, config.binarize_percentile, config.min_occupancy) if binary.shape[0] == 0: continue kept_locs = locations[keep_idx] if binary.shape[1:] != tuple(canonical): resized = F.interpolate( binary.unsqueeze(1), size=canonical, mode='trilinear', align_corners=False).squeeze(1) else: resized = binary all_canonical.append(resized) all_meta.append((level, kept_locs, scale, resized.shape[0])) if not all_canonical: return None, [] return torch.cat(all_canonical, dim=0), all_meta # === GPU Channel Clustering =================================================== def cluster_channels_gpu(latents, n_groups=8): """Fully GPU channel clustering, no numpy.""" N, C, H, W = latents.shape flat = latents.reshape(N, C, -1) flat = flat - flat.mean(dim=-1, keepdim=True) flat = F.normalize(flat, dim=-1) corr = torch.bmm(flat, flat.transpose(1, 2)).mean(dim=0) dist = 1.0 - corr.abs() remaining = torch.ones(C, dtype=torch.bool, device=latents.device) target_size = max(1, C // n_groups) groups = [] for g in range(n_groups): if not remaining.any(): break avail = remaining.nonzero(as_tuple=True)[0] if g == 0: seed = avail[0].item() else: # Farthest from existing groups min_dists = torch.full((C,), float('inf'), device=latents.device) for grp in groups: grp_t = torch.tensor(grp, device=latents.device) d = dist[:, grp_t].min(dim=1).values min_dists = torch.min(min_dists, d) min_dists[~remaining] = -1 seed = min_dists.argmax().item() group = [seed] remaining[seed] = False dists_from_seed = dist[seed].clone() dists_from_seed[~remaining] = float('inf') _, nearest = dists_from_seed.topk(min(target_size - 1, remaining.sum().item()), largest=False) nearest = nearest[dists_from_seed[nearest] < float('inf')] for c in nearest.tolist(): group.append(c) remaining[c] = False groups.append(group) # Assign stragglers for c in remaining.nonzero(as_tuple=True)[0].tolist(): grp_dists = [] for gi, grp in enumerate(groups): grp_t = torch.tensor(grp, device=latents.device) grp_dists.append(dist[c, grp_t].min().item()) groups[min(range(len(groups)), key=lambda i: grp_dists[i])].append(c) return groups, corr def compute_deviance_volume(latent, groups): """Compute inter-group deviance. Returns (n_pairs, H, W), pair_list.""" group_means = torch.stack([latent[grp].mean(dim=0) for grp in groups]) n = len(groups) i_idx, j_idx = torch.triu_indices(n, n, offset=1, device=latent.device) deviances = (group_means[i_idx] - group_means[j_idx]).abs() pairs = list(zip(i_idx.cpu().tolist(), j_idx.cpu().tolist())) return deviances, pairs # === Batched Multi-Image Extractor ============================================ class MultiScaleExtractor: def __init__(self, classifier, config=None): self.classifier = classifier self.config = config or ExtractionConfig() self.classifier.eval() self.device = next(classifier.parameters()).device self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 @torch.no_grad() def classify_batch(self, patches): """Classify mega-batch with amp.""" N = patches.shape[0] if N == 0: return None max_b = self.config.max_classify_batch all_cls, all_conf, all_dim, all_curved, all_curv_type = [], [], [], [], [] for start in range(0, N, max_b): chunk = patches[start:start+max_b] with torch.amp.autocast('cuda', dtype=self.amp_dtype): out = self.classifier(chunk) probs = F.softmax(out["class_logits"].float(), dim=-1) top2 = probs.topk(2, dim=-1).values all_cls.append(probs.argmax(dim=-1)) all_conf.append(top2[:, 0] - top2[:, 1]) all_dim.append(out["dim_logits"].argmax(dim=-1)) all_curved.append(out["is_curved_pred"].squeeze(-1) > 0.0) all_curv_type.append(out["curv_type_logits"].argmax(dim=-1)) return { "pred_class": torch.cat(all_cls), "confidence": torch.cat(all_conf), "dim_pred": torch.cat(all_dim), "curved_pred": torch.cat(all_curved), "curv_type_pred": torch.cat(all_curv_type), } def _results_to_annotations(self, results, meta, conf_thresh, source="raw", pair_indices=None): """Convert batched results + meta into annotation list.""" annotations = [] offset = 0 for level, kept_locs, scale, count in meta: chunk_conf = results["confidence"][offset:offset+count] mask = chunk_conf >= conf_thresh local_idx = mask.nonzero(as_tuple=True)[0] if len(local_idx) > 0: gi = local_idx + offset cls = results["pred_class"][gi].cpu() conf = results["confidence"][gi].cpu() dim = results["dim_pred"][gi].cpu() curved = results["curved_pred"][gi].cpu() curv = results["curv_type_pred"][gi].cpu() locs = kept_locs[local_idx].cpu() for i in range(len(local_idx)): ann = GeometricAnnotation( class_name=CLASS_NAMES[cls[i].item()], class_idx=cls[i].item(), confidence=conf[i].item(), scale_level=level, location=tuple(int(x) for x in locs[i].tolist()), patch_size=scale, dimension=dim[i].item(), is_curved=bool(curved[i].item()), curvature_type=CURVATURE_NAMES[curv[i].item()], source=source, ) if source == "deviance" and pair_indices is not None: pair_idx = locs[i][0].item() if pair_idx < len(pair_indices): ann.channel_group_pair = pair_indices[pair_idx] annotations.append(ann) offset += count return annotations def extract_batch(self, latents, channel_groups): """ Process multiple latents simultaneously. latents: list of (C, H, W) tensors on GPU Returns: list of per-image result dicts """ conf_thresh = self.config.confidence_threshold # Phase 1: extract patches from ALL images, both raw + deviance all_patches = [] image_segments = [] # (img_idx, source, meta, pair_indices, patch_count) for img_idx, latent in enumerate(latents): # Raw volume: channels as depth raw_patches, raw_meta = extract_and_prepare_volume(latent, self.config) if raw_patches is not None: n = raw_patches.shape[0] all_patches.append(raw_patches) image_segments.append((img_idx, "raw", raw_meta, None, n)) # Deviance volume if channel_groups is not None: dev_vol, pair_indices = compute_deviance_volume(latent, channel_groups) dev_patches, dev_meta = extract_and_prepare_volume(dev_vol, self.config) if dev_patches is not None: n = dev_patches.shape[0] all_patches.append(dev_patches) image_segments.append((img_idx, "deviance", dev_meta, pair_indices, n)) if not all_patches: return [{ 'raw_annotations': [], 'deviance_annotations': [], 'n_raw': 0, 'n_deviance': 0, } for _ in latents] # Phase 2: SINGLE classify call for ALL images × ALL scales × raw+deviance mega_batch = torch.cat(all_patches, dim=0) del all_patches results = self.classify_batch(mega_batch) del mega_batch if results is None: return [{ 'raw_annotations': [], 'deviance_annotations': [], 'n_raw': 0, 'n_deviance': 0, } for _ in latents] # Phase 3: distribute results back to per-image annotations per_image = {i: {'raw': [], 'deviance': []} for i in range(len(latents))} global_offset = 0 for img_idx, source, meta, pair_indices, total_count in image_segments: # Slice this segment's results seg_results = { k: v[global_offset:global_offset+total_count] for k, v in results.items() } anns = self._results_to_annotations( seg_results, meta, conf_thresh, source, pair_indices) per_image[img_idx][source].extend(anns) global_offset += total_count del results # Build output output = [] for i in range(len(latents)): raw = per_image[i]['raw'] dev = per_image[i]['deviance'] output.append({ 'raw_annotations': raw, 'deviance_annotations': dev, 'n_raw': len(raw), 'n_deviance': len(dev), }) return output # Single-image compat def extract_from_latent(self, latent, channel_groups=None): return self.extract_batch([latent], channel_groups)[0] print("✓ Cell 4: Fully batched multi-image extraction") print(f" Scales: {ExtractionConfig().scales}") print(f" Canonical: {ExtractionConfig().canonical_shape}") print(f" Image batch: {ExtractionConfig().image_batch_size}") print(f" Classify batch: {ExtractionConfig().max_classify_batch}") print(f" Percentile: {ExtractionConfig().binarize_percentile}th (kthvalue)")