"""Model and dataset loading, inference, and label extraction functions.""" from __future__ import annotations import json import os from functools import lru_cache from typing import Any, Dict, Optional import numpy as np import torch from datasets import DatasetDict, load_dataset from PIL import Image from torchvision import transforms from torchvision.transforms import functional as TF from transformers import ( AutoImageProcessor, AutoModelForImageClassification, ) HF_REPO_ID = "raidium/curia" HF_DATASET_ID = "raidium/CuriaBench" class _NumpyToTensor: """Convert numpy arrays to tensors while preserving tensors/images.""" def __call__(self, value: Any) -> torch.Tensor: if isinstance(value, (torch.Tensor, Image.Image)): return value # type: ignore[return-value] return torch.tensor(value).unsqueeze(0) class AdaptativeResizeMask(torch.nn.Module): """Resize binary masks with a fallback threshold to avoid empty masks.""" def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None: super().__init__() self.target_size = target_size self.initial_threshold = initial_threshold def forward(self, mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] mask = mask.to(dtype=torch.float32) resized = TF.resize( mask, (self.target_size, self.target_size), interpolation=TF.InterpolationMode.BILINEAR, antialias=True, ) binary = resized > self.initial_threshold if binary.sum() == 0: new_threshold = torch.max(resized) * 0.5 binary = resized > new_threshold return binary.to(dtype=torch.float32) @lru_cache(maxsize=1) def make_mask_transform(crop_size: int = 512) -> transforms.Compose: """Return the resize transform used during training/inference.""" return transforms.Compose( [ _NumpyToTensor(), AdaptativeResizeMask(target_size=crop_size), ] ) def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]: """Apply Curia's mask preprocessing so heads get the ROI they expect.""" if mask is None: return None mask_transform = make_mask_transform() try: mask_arr = np.array(mask) except Exception: return None if mask_arr.size == 0: return None if mask_arr.ndim == 3: # (H, W, slices) tensor = mask_transform(mask_arr.transpose(2, 0, 1)) # (1, slices, H, W) tensor = tensor.transpose(1, 3).transpose(1, 2) # else: tensor = mask_transform(torch.tensor([mask_arr])) tensor = tensor.unsqueeze(0) if isinstance(tensor, np.ndarray): tensor = torch.from_numpy(tensor) return tensor @lru_cache(maxsize=1) def load_id_to_labels() -> Dict[str, Dict[str, str]]: """Load the id_to_labels.json mapping file.""" json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json") with open(json_path, "r") as f: data = json.load(f) # convert string keys to integers for head in data: data[head] = {int(k): v for k, v in data[head].items()} return data @lru_cache(maxsize=1) def load_processor() -> AutoImageProcessor: token = os.environ.get("HF_TOKEN") return AutoImageProcessor.from_pretrained( HF_REPO_ID, trust_remote_code=True, token=token ) @lru_cache(maxsize=None) def load_model(head: str) -> AutoModelForImageClassification: token = os.environ.get("HF_TOKEN") model = AutoModelForImageClassification.from_pretrained( HF_REPO_ID, trust_remote_code=True, subfolder=head, token=token, ) model.eval() return model @lru_cache(maxsize=None) def load_curia_dataset(subset: str) -> Any: token = os.environ.get("HF_TOKEN") ds = load_dataset( HF_DATASET_ID, subset, split="test", token=token, ) if isinstance(ds, DatasetDict): return ds["test"] return ds def infer_image( image: np.ndarray, head: str, mask: Any | None = None, return_probs: bool = True, ) -> torch.Tensor: processor = load_processor() model = load_model(head) with torch.no_grad(): processed = processor(images=image, return_tensors="pt") mask_tensor = prepare_mask_for_model(mask) if mask_tensor is not None: processed["mask"] = mask_tensor outputs = model(**processed) logits = outputs["logits"] if return_probs: probs = torch.nn.functional.softmax(logits[0], dim=-1) return probs else: return logits[0].squeeze()