import logging import gc import time from enum import IntEnum from typing import Dict, Any, Optional, Callable, List from dataclasses import dataclass, field from threading import Lock import torch logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class ModelPriority(IntEnum): """ Model priority levels for memory management. Higher priority models are kept loaded longer under memory pressure. """ CRITICAL = 100 # Never unload (e.g., OpenCLIP for analysis) HIGH = 80 # Currently active pipeline MEDIUM = 50 # Recently used models LOW = 20 # Inactive pipelines, can be evicted DISPOSABLE = 0 # Temporary models, evict first @dataclass class ModelInfo: """ Information about a registered model. Attributes: name: Unique model identifier loader: Callable that returns the loaded model is_critical: If True, model won't be unloaded under memory pressure priority: ModelPriority level for eviction decisions estimated_memory_gb: Estimated GPU memory usage model_group: Group name for mutual exclusion (e.g., "pipeline") is_loaded: Whether model is currently loaded last_used: Timestamp of last use model_instance: The actual model object """ name: str loader: Callable[[], Any] is_critical: bool = False priority: int = ModelPriority.MEDIUM estimated_memory_gb: float = 0.0 model_group: str = "" # For mutual exclusion (e.g., "pipeline") is_loaded: bool = False last_used: float = 0.0 model_instance: Any = None class ModelManager: """ Singleton model manager for unified model lifecycle management. Handles lazy loading, caching, priority-based eviction, and mutual exclusion for pipeline models. Designed for memory-constrained environments like Google Colab and HuggingFace Spaces. Features: - Priority-based model eviction under memory pressure - Mutual exclusion for pipeline models (only one active at a time) - Automatic memory monitoring and cleanup - Support for model groups and dependencies Example: >>> manager = get_model_manager() >>> manager.register_model( ... name="sdxl_pipeline", ... loader=load_sdxl, ... priority=ModelPriority.HIGH, ... model_group="pipeline" ... ) >>> pipeline = manager.load_model("sdxl_pipeline") """ _instance = None _lock = Lock() # Known model groups for mutual exclusion PIPELINE_GROUP = "pipeline" # Only one pipeline can be loaded at a time def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self._models: Dict[str, ModelInfo] = {} self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage self._high_memory_threshold = 0.90 # Critical threshold for aggressive cleanup self._device = self._detect_device() self._active_pipeline: Optional[str] = None # Track currently active pipeline logger.info(f"ModelManager initialized on {self._device}") self._initialized = True def _detect_device(self) -> str: """Detect best available device.""" if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" return "cpu" def register_model( self, name: str, loader: Callable[[], Any], is_critical: bool = False, priority: int = ModelPriority.MEDIUM, estimated_memory_gb: float = 0.0, model_group: str = "" ): """ Register a model for managed loading. Parameters ---------- name : str Unique model identifier loader : callable Function that returns the loaded model is_critical : bool If True, model won't be unloaded under memory pressure priority : int ModelPriority level for eviction decisions estimated_memory_gb : float Estimated GPU memory usage in GB model_group : str Group name for mutual exclusion (e.g., "pipeline") """ if name in self._models: logger.warning(f"Model '{name}' already registered, updating") # Critical models always have highest priority if is_critical: priority = ModelPriority.CRITICAL self._models[name] = ModelInfo( name=name, loader=loader, is_critical=is_critical, priority=priority, estimated_memory_gb=estimated_memory_gb, model_group=model_group, is_loaded=False, last_used=0.0, model_instance=None ) logger.info(f"Registered model: {name} (priority={priority}, group={model_group}, ~{estimated_memory_gb:.1f}GB)") def load_model(self, name: str, update_priority: Optional[int] = None) -> Any: """ Load a model by name. Returns cached instance if already loaded. Implements mutual exclusion for pipeline models - loading a new pipeline will unload any existing pipeline first. Parameters ---------- name : str Model identifier update_priority : int, optional If provided, update the model's priority after loading Returns ------- Any Loaded model instance Raises ------ KeyError If model not registered RuntimeError If loading fails """ if name not in self._models: raise KeyError(f"Model '{name}' not registered") model_info = self._models[name] # Return cached instance if model_info.is_loaded and model_info.model_instance is not None: model_info.last_used = time.time() if update_priority is not None: model_info.priority = update_priority logger.debug(f"Using cached model: {name}") return model_info.model_instance # Handle mutual exclusion for pipeline group if model_info.model_group == self.PIPELINE_GROUP: self._ensure_pipeline_exclusion(name) # Check memory pressure before loading self.check_memory_pressure() # Load the model try: logger.info(f"Loading model: {name}") start_time = time.time() model_instance = model_info.loader() model_info.model_instance = model_instance model_info.is_loaded = True model_info.last_used = time.time() if update_priority is not None: model_info.priority = update_priority # Track active pipeline if model_info.model_group == self.PIPELINE_GROUP: self._active_pipeline = name load_time = time.time() - start_time logger.info(f"Model '{name}' loaded in {load_time:.1f}s") return model_instance except Exception as e: logger.error(f"Failed to load model '{name}': {e}") raise RuntimeError(f"Model loading failed: {e}") def _ensure_pipeline_exclusion(self, new_pipeline: str) -> None: """ Ensure only one pipeline is loaded at a time. Unloads any existing pipeline before loading a new one. Parameters ---------- new_pipeline : str Name of the pipeline about to be loaded """ for name, info in self._models.items(): if (info.model_group == self.PIPELINE_GROUP and info.is_loaded and name != new_pipeline): logger.info(f"Unloading {name} to make room for {new_pipeline}") self.unload_model(name) def unload_model(self, name: str) -> bool: """ Unload a specific model to free memory. Parameters ---------- name : str Model identifier Returns ------- bool True if model was unloaded successfully """ if name not in self._models: return False model_info = self._models[name] if not model_info.is_loaded: return True try: logger.info(f"Unloading model: {name}") # Delete model instance if model_info.model_instance is not None: del model_info.model_instance model_info.model_instance = None model_info.is_loaded = False # Update active pipeline tracking if self._active_pipeline == name: self._active_pipeline = None # Cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() logger.info(f"Model '{name}' unloaded") return True except Exception as e: logger.error(f"Error unloading model '{name}': {e}") return False def check_memory_pressure(self) -> bool: """ Check GPU memory usage and unload low-priority models if needed. Uses priority-based eviction: lower priority models are unloaded first, then falls back to least-recently-used within same priority tier. Returns ------- bool True if cleanup was performed """ if not torch.cuda.is_available(): return False allocated = torch.cuda.memory_allocated() / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 usage_ratio = allocated / total if usage_ratio < self._memory_threshold: return False logger.warning(f"Memory pressure detected: {usage_ratio:.1%} used") # Find evictable models (not critical, loaded) # Sort by priority (ascending) then by last_used (ascending) evictable = [ (name, info) for name, info in self._models.items() if info.is_loaded and info.priority < ModelPriority.CRITICAL ] evictable.sort(key=lambda x: (x[1].priority, x[1].last_used)) # Unload models starting from lowest priority cleaned = False for name, info in evictable: self.unload_model(name) cleaned = True # Re-check memory new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory if new_ratio < self._memory_threshold * 0.7: # Target 70% of threshold break return cleaned def force_cleanup(self, keep_critical_only: bool = True): """ Force cleanup models and clear caches. Parameters ---------- keep_critical_only : bool If True, only keep CRITICAL priority models loaded """ logger.info("Force cleanup initiated") # Unload models based on priority threshold = ModelPriority.CRITICAL if keep_critical_only else ModelPriority.HIGH for name, info in list(self._models.items()): if info.is_loaded and info.priority < threshold: self.unload_model(name) # Aggressive garbage collection for _ in range(5): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() torch.cuda.synchronize() logger.info("Force cleanup completed") def update_priority(self, name: str, priority: int) -> bool: """ Update a model's priority level. Parameters ---------- name : str Model identifier priority : int New priority level Returns ------- bool True if priority was updated """ if name not in self._models: return False self._models[name].priority = priority logger.debug(f"Updated priority for {name} to {priority}") return True def get_active_pipeline(self) -> Optional[str]: """ Get the name of currently active pipeline. Returns ------- str or None Name of active pipeline, or None if no pipeline is loaded """ return self._active_pipeline def switch_to_pipeline( self, name: str, loader: Optional[Callable[[], Any]] = None ) -> Any: """ Switch to a different pipeline, unloading current one. This is a convenience method for pipeline switching that handles mutual exclusion automatically. Parameters ---------- name : str Pipeline name to switch to loader : callable, optional Loader function if pipeline not already registered Returns ------- Any The loaded pipeline instance Raises ------ KeyError If pipeline not registered and no loader provided """ # Register if needed if name not in self._models and loader is not None: self.register_model( name=name, loader=loader, priority=ModelPriority.HIGH, model_group=self.PIPELINE_GROUP ) # Load will handle unloading of current pipeline return self.load_model(name, update_priority=ModelPriority.HIGH) def get_memory_status(self) -> Dict[str, Any]: """ Get detailed memory status. Returns: Dictionary with memory statistics """ status = { "device": self._device, "models": {}, "total_estimated_gb": 0.0 } # Model status for name, info in self._models.items(): status["models"][name] = { "loaded": info.is_loaded, "critical": info.is_critical, "estimated_gb": info.estimated_memory_gb, "last_used": info.last_used } if info.is_loaded: status["total_estimated_gb"] += info.estimated_memory_gb # GPU memory if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 cached = torch.cuda.memory_reserved() / 1024**3 status["gpu"] = { "allocated_gb": round(allocated, 2), "total_gb": round(total, 2), "cached_gb": round(cached, 2), "free_gb": round(total - allocated, 2), "usage_percent": round((allocated / total) * 100, 1) } return status def get_loaded_models(self) -> list: """Get list of currently loaded model names.""" return [name for name, info in self._models.items() if info.is_loaded] def is_model_loaded(self, name: str) -> bool: """Check if a specific model is loaded.""" if name not in self._models: return False return self._models[name].is_loaded # Global singleton instance _model_manager = None def get_model_manager() -> ModelManager: """Get the global ModelManager singleton instance.""" global _model_manager if _model_manager is None: _model_manager = ModelManager() return _model_manager