from __future__ import annotations import json import math import os import re import threading from itertools import islice from typing import Any, Dict, List, Tuple, Optional import gradio as gr import spaces import torch from transformers import ( AutoTokenizer, TextIteratorStreamer, pipeline, StoppingCriteria, StoppingCriteriaList, ) from threading import Thread from concurrent.futures import ThreadPoolExecutor try: from huggingface_hub import snapshot_download HF_HUB_AVAILABLE = True except ImportError: # pragma: no cover HF_HUB_AVAILABLE = False try: from ddgs import DDGS DDGS_AVAILABLE = True except ImportError: DDGS_AVAILABLE = False # Enable optimizations torch.backends.cuda.matmul.allow_tf32 = True # ZeroGPU often exposes MIG UUIDs; keep them unless the variable is empty MIG_VISIBLE = False if torch.cuda.is_available(): cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") if not cuda_visible: os.environ["CUDA_VISIBLE_DEVICES"] = "0" cuda_visible = "0" print("CUDA_VISIBLE_DEVICES was empty -> set to 0") elif cuda_visible.startswith("MIG"): MIG_VISIBLE = True print(f"CUDA detected: {torch.cuda.get_device_name(0)}") print(f"CUDA_VISIBLE_DEVICES: {cuda_visible or os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}") else: print("WARNING: CUDA not available - vLLM will not work") # Try to import vLLM (primary inference engine) try: from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs VLLM_AVAILABLE = True except ImportError: VLLM_AVAILABLE = False LLM = None SamplingParams = None print("Warning: vLLM not available, falling back to Transformers") cancel_event = threading.Event() # Optional flag to disable vLLM (defaults to true on MIG due to device detection instability) DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1" # --------------------------------------------------------------------------- # Parallel prefetch of model weights/tokenizers to reduce first-load latency # --------------------------------------------------------------------------- PREFETCH_DISABLED = os.environ.get("DISABLE_PREFETCH", "0") == "1" PREFETCH_THREADS = int(os.environ.get("PREFETCH_THREADS", "4")) PREFETCH_EXECUTOR = None LOCAL_REPO_CACHE: Dict[str, str] = {} def _prefetch_repo(repo_id: str) -> None: if not HF_HUB_AVAILABLE: return try: snapshot_download( repo_id=repo_id, etag_timeout=10, resume_download=True, local_files_only=False, ) print(f"Prefetched repo: {repo_id}") except Exception as exc: # pragma: no cover print(f"Prefetch skipped for {repo_id}: {exc}") def _ensure_local_repo(repo_id: str) -> Optional[str]: if not HF_HUB_AVAILABLE: return None cached = LOCAL_REPO_CACHE.get(repo_id) if cached and os.path.isdir(cached): return cached try: local_path = snapshot_download( repo_id=repo_id, etag_timeout=10, resume_download=True, local_files_only=False, ) LOCAL_REPO_CACHE[repo_id] = local_path return local_path except Exception as exc: # pragma: no cover print(f"Local snapshot failed for {repo_id}: {exc}") return None def _retrieve_search_results(query: str, max_results: int, max_chars: int) -> List[str]: if not DDGS_AVAILABLE: return [] results: List[str] = [] try: with DDGS() as ddgs: for idx, item in enumerate( islice( ddgs.text( query, region="wt-wt", safesearch="moderate", timelimit="y", ), max_results, ) ): title = (item.get("title") or "Untitled").strip() body = (item.get("body") or "").strip() url = (item.get("href") or "").strip() snippet = body[: max_chars].replace("\n", " ") formatted = f"[{idx+1}] {title} — {snippet}" if url: formatted += f" ({url})" results.append(formatted) except Exception as exc: # pragma: no cover print(f"[DEBUG] DDG search failed: {exc}") return results class CancelStoppingCriteria(StoppingCriteria): def __call__(self, input_ids, scores, **kwargs) -> bool: return cancel_event.is_set() def estimate_gpu_seconds( model_name: str, max_new_tokens: int, enable_search: bool, ) -> float: params_b = MODELS.get(model_name, {}).get("params_b", 4.0) base = 12.0 + params_b * 3.0 tokens_per_sec = max(40.0, 320.0 / (1.0 + params_b / 6.0)) generation_time = max_new_tokens / tokens_per_sec search_time = 8.0 if enable_search else 0.0 return base + generation_time + search_time def format_gpu_estimate_message( model_name: str, max_new_tokens: int, enable_search: bool, ) -> Tuple[str, int]: est_seconds = estimate_gpu_seconds(model_name, max_new_tokens, enable_search) rounded = int(math.ceil(est_seconds)) recommended = int(math.ceil(max(60, rounded) / 60.0) * 60) recommended = max(60, min(1800, recommended)) model_size = MODELS.get(model_name, {}).get("params_b", 4.0) message = ( f"âąī¸ **Estimated GPU Time:** ~{rounded} seconds\n\n" f"📊 **Model Size:** {model_size:.1f}B parameters\n" f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n" f"✅ **Suggested GPU Duration slider:** {recommended} seconds" ) return message, recommended def update_gpu_controls( model_name: str, max_new_tokens: int, enable_search: bool, current_duration: int, ): message, recommended = format_gpu_estimate_message( model_name, max_new_tokens, enable_search, ) updated_value = current_duration if current_duration >= recommended else recommended return message, gr.update(value=updated_value) def _start_prefetch_workers(model_names: list[str]): global PREFETCH_EXECUTOR if PREFETCH_DISABLED or not HF_HUB_AVAILABLE: return if PREFETCH_EXECUTOR is not None: return if not model_names: return worker_count = max(1, min(PREFETCH_THREADS, len(model_names) * 2)) PREFETCH_EXECUTOR = ThreadPoolExecutor(max_workers=worker_count, thread_name_prefix="prefetch") submitted = set() for model_name in model_names: repos = {MODELS[model_name]["repo_id"]} tokenizer_repo = MODELS[model_name].get("tokenizer_repo") if tokenizer_repo: repos.add(tokenizer_repo) for repo in repos: if repo in submitted: continue submitted.add(repo) PREFETCH_EXECUTOR.submit(_prefetch_repo, repo) MODELS = { "Router-Gemma3-27B-AWQ": { "repo_id": "Alovestocode/router-gemma3-merged-awq", # AWQ quantized model "tokenizer_repo": "Alovestocode/router-gemma3-merged", # Tokenizer from original repo "description": "Router checkpoint on Gemma3 27B merged, optimized with AWQ quantization via vLLM.", "params_b": 27.0, "quantization": "awq", # vLLM will auto-detect AWQ }, "Router-Qwen3-32B-AWQ": { "repo_id": "Alovestocode/router-qwen3-32b-merged-awq", # AWQ quantized model "tokenizer_repo": "Alovestocode/router-qwen3-32b-merged", # Tokenizer from original repo "description": "Router checkpoint on Qwen3 32B merged, optimized with AWQ quantization via vLLM.", "params_b": 32.0, "quantization": "awq", # vLLM will auto-detect AWQ }, } DEFAULT_MODEL = os.environ.get("DEFAULT_ROUTER_MODEL", "Router-Gemma3-27B-AWQ") if DEFAULT_MODEL not in MODELS: DEFAULT_MODEL = next(iter(MODELS)) if MODELS else None def _resolve_prefetch_model_names(include_default: bool) -> list[str]: entries = os.environ.get("ROUTER_PREFETCH_MODELS") if entries: names = [item.strip() for item in entries.split(",") if item.strip()] else: single = os.environ.get("ROUTER_PREFETCH_MODEL") names = [single] if single else [] if names == ["ALL"] or names == ["all"]: names = list(MODELS.keys()) valid = [name for name in names if name in MODELS] if not valid and include_default and DEFAULT_MODEL: valid = [DEFAULT_MODEL] return valid _start_prefetch_workers(_resolve_prefetch_model_names(include_default=True)) # Try to import LLM Compressor (for quantization - optional, vLLM has native AWQ support) # Note: llm-compressor is only needed for quantizing models, not for loading pre-quantized AWQ models # vLLM can load AWQ models natively without llm-compressor try: # Try both package names (llm-compressor and llmcompressor) try: from llmcompressor import oneshot # Correct import path: AWQModifier is in modifiers.awq, not modifiers.quantization from llmcompressor.modifiers.awq import AWQModifier except ImportError: # Try alternative package name import sys import subprocess # Package might be named llm-compressor (with hyphen) try: import importlib.util spec = importlib.util.find_spec("llm_compressor") if spec is None: raise ImportError("llm-compressor not found") from llm_compressor import oneshot from llm_compressor.modifiers.awq import AWQModifier except ImportError: raise ImportError("Neither llmcompressor nor llm-compressor found") LLM_COMPRESSOR_AVAILABLE = True print("Info: LLM Compressor available (for quantizing models)") except ImportError: LLM_COMPRESSOR_AVAILABLE = False # This is fine - vLLM has native AWQ support, so we don't need llm-compressor for loading print("Info: LLM Compressor not available (not needed - vLLM has native AWQ support for pre-quantized models)") # Try to import AWQ (deprecated, but kept for fallback compatibility) # Note: AutoAWQ is deprecated; vLLM handles AWQ natively via llm-compressor try: from awq import AutoAWQForCausalLM AWQ_AVAILABLE = True import warnings warnings.filterwarnings("ignore", category=DeprecationWarning, module="awq") except ImportError: AWQ_AVAILABLE = False print("Info: AutoAWQ not available (using vLLM native AWQ support instead)") # Always import BitsAndBytesConfig for fallback try: from transformers import BitsAndBytesConfig BITSANDBYTES_AVAILABLE = True except ImportError: BITSANDBYTES_AVAILABLE = False BitsAndBytesConfig = None print("Warning: BitsAndBytes not available") # Try to import FlashAttention-2 try: import flash_attn FLASH_ATTN_AVAILABLE = True except ImportError: FLASH_ATTN_AVAILABLE = False print("Warning: FlashAttention-2 not available") HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.") PLAN_END_TOKEN = "<|end_of_plan|>" STOP_SEQUENCES = [PLAN_END_TOKEN, "", ""] ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else.""" REQUIRED_KEYS = [ "route_plan", "route_rationale", "expected_artifacts", "thinking_outline", "handoff_plan", "todo_list", "difficulty", "tags", "acceptance_criteria", "metrics", ] PIPELINES: Dict[str, Any] = {} # For Transformers fallback VLLM_MODELS: Dict[str, Any] = {} # For vLLM models TOKENIZER_CACHE: Dict[str, Any] = {} WARMED_REMAINING = False TOOL_PATTERN = re.compile(r"^/[a-z0-9_-]+\(.*\)$", re.IGNORECASE) def get_tokenizer(repo: str, tokenizer_repo: str = None): """Get tokenizer, preferring tokenizer_repo if provided (for AWQ models).""" # Use tokenizer_repo if provided (for AWQ models where tokenizer is in original repo) actual_repo = tokenizer_repo if tokenizer_repo else repo tok = TOKENIZER_CACHE.get(actual_repo) if tok is not None: return tok tok = AutoTokenizer.from_pretrained( actual_repo, token=HF_TOKEN, use_fast=True, trust_remote_code=True ) tok.padding_side = "left" tok.truncation_side = "left" if tok.pad_token_id is None and tok.eos_token_id is not None: tok.pad_token_id = tok.eos_token_id TOKENIZER_CACHE[actual_repo] = tok return tok def load_vllm_model(model_name: str): """Load model with vLLM (supports AWQ natively, continuous batching, PagedAttention).""" if model_name in VLLM_MODELS: return VLLM_MODELS[model_name] model_config = MODELS[model_name] repo = model_config["repo_id"] quantization = model_config.get("quantization", None) # For AWQ models, vLLM should point to repo root (not default/ subfolder) # If repo is stored with AWQ artifacts inside a default/ directory, fall back to local snapshot if quantization == "awq": model_path = repo local_repo = _ensure_local_repo(repo) if local_repo: default_dir = os.path.join(local_repo, "default") model_path = default_dir if os.path.isdir(default_dir) else local_repo print(f"Loading {model_path} (local snapshot) with vLLM (AWQ quantization)...") else: print(f"Loading {model_path} with vLLM (AWQ quantization, vLLM will find files in default/ via quantization_config.json)...") else: model_path = repo print(f"Loading {model_path} with vLLM (quantization: {quantization})...") try: # Detect device explicitly for vLLM # vLLM needs explicit device configuration on ZeroGPU if not torch.cuda.is_available(): raise RuntimeError("CUDA not available - vLLM requires GPU. Falling back to Transformers pipeline.") print(f" → CUDA available: {torch.cuda.get_device_name(0)}") print(f" → CUDA device count: {torch.cuda.device_count()}") print(f" → CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'not set')}") # vLLM configuration optimized for ZeroGPU H200 slice # vLLM natively supports AWQ via llm-compressor (replaces deprecated AutoAWQ) # Note: HF_TOKEN is passed via environment variable, not as a parameter # vLLM auto-detects CUDA from torch.cuda.is_available() and CUDA_VISIBLE_DEVICES # For AWQ models with files in default/ subfolder, vLLM should auto-detect via quantization_config.json llm_kwargs = { "model": model_path, # Use model_path which may point to default/ subfolder "trust_remote_code": True, "dtype": "bfloat16", # Prefer bf16 over int8 for speed "gpu_memory_utilization": 0.90, # Leave headroom for KV cache "max_model_len": 16384, # Adjust based on GPU memory "enable_chunked_prefill": True, # Better for long prompts "tensor_parallel_size": 1, # Single GPU for ZeroGPU "max_num_seqs": 256, # Continuous batching capacity "enable_prefix_caching": True, # Cache prompts for faster TTFT } # Ensure CUDA_VISIBLE_DEVICES is set correctly for vLLM device detection # ZeroGPU exposes MIG UUIDs; keep them unless the variable is empty cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") if not cuda_visible: os.environ["CUDA_VISIBLE_DEVICES"] = "0" cuda_visible = "0" print(" → CUDA_VISIBLE_DEVICES was empty, set to 0") try: if hasattr(torch.cuda, '_lazy_init'): torch.cuda._lazy_init() except Exception: pass else: print(f" → CUDA_VISIBLE_DEVICES retained: {cuda_visible}") # Force torch to see the correct device after ensuring CUDA_VISIBLE_DEVICES if torch.cuda.is_available(): device_name = torch.cuda.get_device_name(0) print(f" → Verified CUDA device accessible: {device_name}") torch.cuda.set_device(0) print(" → Set torch.cuda default device to 0") # Disable Ray executor on ZeroGPU to simplify device handling os.environ.setdefault("VLLM_USE_RAY", "0") os.environ.setdefault("VLLM_WORKER_USE_RAY", "0") # Add quantization if specified (vLLM auto-detects AWQ via llm-compressor) if quantization == "awq": llm_kwargs["quantization"] = "awq" # AWQ model files are in the 'default' subfolder # vLLM should auto-detect this via quantization_config.json at repo root # If auto-detection fails, we can explicitly point to default/ subfolder # Enable FP8 KV cache for 50% memory reduction (allows longer contexts) # FP8 KV cache is compatible with AWQ quantization try: llm_kwargs["kv_cache_dtype"] = "fp8" print(f" → AWQ quantization + FP8 KV cache enabled (vLLM native support)") print(f" → FP8 KV cache reduces memory by ~50%, enabling longer contexts") print(f" → Loading AWQ model from: {model_path} (files in default/ subfolder)") except Exception: # Fallback if FP8 KV cache not supported print(f" → AWQ quantization enabled (FP8 KV cache not available)") print(f" → Loading AWQ model from: {model_path} (files in default/ subfolder)") elif quantization == "fp8": # Try FP8 quantization if available (faster than AWQ) try: llm_kwargs["quantization"] = "fp8" llm_kwargs["dtype"] = "float8_e5m2" print(f" → FP8 quantization enabled (~2x faster than AWQ)") except Exception: print(f" → FP8 quantization not available, falling back to bf16") # vLLM will now detect the CUDA device via torch / environment settings above print(f" → Loading with vLLM (continuous batching, PagedAttention)...") llm = LLM(**llm_kwargs) VLLM_MODELS[model_name] = llm print(f"✅ vLLM model loaded: {model_name}") print(f" - Continuous batching: enabled (max {llm_kwargs['max_num_seqs']} concurrent)") print(f" - Prefix caching: enabled") print(f" - Quantization: {quantization or 'none (bf16)'}") return llm except Exception as exc: print(f"❌ vLLM load failed for {repo}: {exc}") import traceback traceback.print_exc() raise def load_awq_pipeline(repo: str, tokenizer): """Load AWQ-quantized model with FlashAttention-2 and torch.compile (Transformers fallback).""" model = AutoAWQForCausalLM.from_quantized( repo, fuse_layers=True, trust_remote_code=True, device_map="auto", token=HF_TOKEN, ) # Prepare model kwargs with FlashAttention-2 if available model_kwargs = {} if FLASH_ATTN_AVAILABLE: model_kwargs["attn_implementation"] = "flash_attention_2" pipe = pipeline( task="text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True, device_map="auto", model_kwargs=model_kwargs, use_cache=True, torch_dtype=torch.bfloat16, # Prefer bf16 over int8 for speed ) pipe.model.eval() # Apply torch.compile for kernel fusion (~10-20% speedup after first call) try: if hasattr(torch, 'compile'): print("Applying torch.compile for kernel fusion...") pipe.model = torch.compile(pipe.model, mode="reduce-overhead") print("✅ torch.compile applied (first call will be slower, subsequent calls faster)") except Exception as exc: print(f"âš ī¸ torch.compile failed: {exc} (continuing without compilation)") return pipe def load_pipeline(model_name: str): """Load model with vLLM (preferred) or Transformers (fallback). Fallback chain: 1. vLLM with AWQ (best performance, continuous batching) 2. vLLM with FP16 (if AWQ not available) 3. Transformers with AWQ (via AutoAWQ - deprecated but functional) 4. Transformers with BitsAndBytes 8-bit 5. Transformers with FP16/FP32 """ # Try vLLM first (best performance with native AWQ support via llm-compressor) # vLLM handles AWQ natively, so AutoAWQ deprecation doesn't affect us if VLLM_AVAILABLE and not DISABLE_VLLM: try: print(f"🔄 Attempting to load {model_name} with vLLM (native AWQ support)...") return load_vllm_model(model_name) except Exception as exc: print(f"âš ī¸ vLLM load failed: {exc}") print(f" → Falling back to Transformers pipeline...") import traceback traceback.print_exc() # Fallback to Transformers pipeline if model_name in PIPELINES: print(f"✅ Using cached Transformers pipeline for {model_name}") return PIPELINES[model_name] if DISABLE_VLLM and VLLM_AVAILABLE: print("âš ī¸ vLLM disabled for this deployment (DISABLE_VLLM=1 or MIG device detected)") model_config = MODELS[model_name] repo = model_config["repo_id"] tokenizer_repo = model_config.get("tokenizer_repo", None) quantization = model_config.get("quantization", None) # For AWQ models, the AWQ repo doesn't have standard model files (they're in default/) # Use the original repo for Transformers fallback, not the AWQ repo if quantization == "awq" and tokenizer_repo: # AWQ repos have files in default/ subfolder which Transformers can't load directly # Use the original repo for Transformers fallback transformers_repo = tokenizer_repo # Use original repo for Transformers print(f"âš ī¸ AWQ model detected - Transformers fallback will use original repo: {transformers_repo}") else: transformers_repo = repo tokenizer = get_tokenizer(repo, tokenizer_repo=tokenizer_repo) # Try AWQ first if available (Transformers fallback path) if AWQ_AVAILABLE: try: print(f"🔄 Loading {transformers_repo} with Transformers + AutoAWQ (fallback path)...") pipe = load_awq_pipeline(transformers_repo, tokenizer) PIPELINES[model_name] = pipe _schedule_background_warm(model_name) # Warm kernels immediately after loading Thread(target=lambda: _warm_kernels(model_name), daemon=True).start() print(f"✅ Transformers + AutoAWQ pipeline loaded: {model_name}") return pipe except Exception as exc: print(f"âš ī¸ AutoAWQ load failed for {transformers_repo}: {exc}") print(f" → Falling back to BitsAndBytes 8-bit...") # Fallback to BitsAndBytes 8-bit if BITSANDBYTES_AVAILABLE: try: print(f"🔄 Loading {transformers_repo} with BitsAndBytes 8-bit quantization...") quant_config = BitsAndBytesConfig(load_in_8bit=True) model_kwargs = {"quantization_config": quant_config} if FLASH_ATTN_AVAILABLE: model_kwargs["attn_implementation"] = "flash_attention_2" pipe = pipeline( task="text-generation", model=transformers_repo, tokenizer=tokenizer, trust_remote_code=True, device_map="auto", model_kwargs=model_kwargs, use_cache=True, token=HF_TOKEN, torch_dtype=torch.bfloat16, ) pipe.model.eval() # Apply torch.compile for kernel fusion (~10-20% speedup after first call) try: if hasattr(torch, 'compile'): pipe.model = torch.compile(pipe.model, mode="reduce-overhead") except Exception: pass PIPELINES[model_name] = pipe _schedule_background_warm(model_name) print(f"✅ BitsAndBytes 8-bit pipeline loaded: {model_name}") return pipe except Exception as exc: print(f"âš ī¸ BitsAndBytes 8-bit load failed for {transformers_repo}: {exc}") print(f" → Falling back to FP16/FP32...") # Fallback to bfloat16/fp16/fp32 (unquantized) for dtype in (torch.bfloat16, torch.float16, torch.float32): dtype_name = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"}[dtype] try: print(f"🔄 Loading {transformers_repo} with {dtype_name} precision...") model_kwargs = {} if FLASH_ATTN_AVAILABLE: model_kwargs["attn_implementation"] = "flash_attention_2" pipe = pipeline( task="text-generation", model=transformers_repo, tokenizer=tokenizer, trust_remote_code=True, device_map="auto", dtype=dtype, model_kwargs=model_kwargs, use_cache=True, token=HF_TOKEN, ) pipe.model.eval() # Apply torch.compile for kernel fusion try: if hasattr(torch, 'compile'): pipe.model = torch.compile(pipe.model, mode="reduce-overhead") except Exception: pass PIPELINES[model_name] = pipe _schedule_background_warm(model_name) print(f"✅ {dtype_name} pipeline loaded: {model_name}") return pipe except Exception as exc: print(f"âš ī¸ {dtype_name} load failed: {exc}") continue # Final fallback (no quantization, no FlashAttention) print(f"âš ī¸ All quantization methods failed, using basic pipeline...") model_kwargs = {} if FLASH_ATTN_AVAILABLE: model_kwargs["attn_implementation"] = "flash_attention_2" pipe = pipeline( task="text-generation", model=transformers_repo, tokenizer=tokenizer, trust_remote_code=True, device_map="auto", model_kwargs=model_kwargs, use_cache=True, token=HF_TOKEN, ) pipe.model.eval() # Apply torch.compile for kernel fusion try: if hasattr(torch, 'compile'): pipe.model = torch.compile(pipe.model, mode="reduce-overhead") except Exception: pass PIPELINES[model_name] = pipe _schedule_background_warm(model_name) print(f"✅ Basic pipeline loaded: {model_name}") return pipe def _warm_kernels(model_name: str) -> None: """Warm up CUDA kernels with a small dummy generation.""" try: # Check if using vLLM if VLLM_AVAILABLE and model_name in VLLM_MODELS: llm = VLLM_MODELS[model_name] # vLLM handles warmup internally, but we can trigger a small generation sampling_params = SamplingParams(temperature=0.0, max_tokens=2) _ = llm.generate("test", sampling_params) print(f"vLLM kernels warmed for {model_name}") return # Transformers pipeline warmup pipe = PIPELINES.get(model_name) if pipe is None: return tokenizer = pipe.tokenizer # Create a minimal prompt for warmup warmup_text = "test" inputs = tokenizer(warmup_text, return_tensors="pt") if hasattr(pipe.model, 'device'): inputs = {k: v.to(pipe.model.device) for k, v in inputs.items()} elif torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Run a tiny generation to JIT-fuse kernels with torch.inference_mode(): _ = pipe.model.generate( **inputs, max_new_tokens=2, do_sample=False, use_cache=True, ) print(f"Transformers kernels warmed for {model_name}") except Exception as exc: print(f"Kernel warmup failed for {model_name}: {exc}") def _schedule_background_warm(loaded_model: str) -> None: global WARMED_REMAINING if WARMED_REMAINING: return warm_remaining = os.environ.get("ROUTER_WARM_REMAINING", "1") if warm_remaining not in {"1", "true", "True"}: return # Check both PIPELINES and VLLM_MODELS for remaining models loaded_models = set(PIPELINES.keys()) | set(VLLM_MODELS.keys()) remaining = [name for name in MODELS if name not in loaded_models] if not remaining: WARMED_REMAINING = True return def _warm_all(): for name in remaining: try: print(f"Background warm start for {name}") load_pipeline(name) # Warm kernels after loading _warm_kernels(name) except Exception as exc: # pragma: no cover print(f"Warm start failed for {name}: {exc}") WARMED_REMAINING = True Thread(target=_warm_all, daemon=True).start() def build_router_prompt( user_task: str, context: str, acceptance: str, extra_guidance: str, difficulty: str, tags: str, ) -> str: prompt_parts = [ROUTER_SYSTEM_PROMPT.strip(), "\n### Router Inputs\n"] prompt_parts.append(f"Difficulty: {difficulty or 'intermediate'}") prompt_parts.append(f"Tags: {tags or 'general'}") if acceptance.strip(): prompt_parts.append(f"Acceptance criteria: {acceptance.strip()}") if extra_guidance.strip(): prompt_parts.append(f"Additional guidance: {extra_guidance.strip()}") if context.strip(): prompt_parts.append("\n### Supporting context\n" + context.strip()) prompt_parts.append("\n### User task\n" + user_task.strip()) prompt_parts.append("\nReturn only JSON.") return "\n".join(prompt_parts) def extract_json_from_text(text: str) -> str: start = text.find("{") if start == -1: raise ValueError("Router output did not contain a JSON object.") depth = 0 in_string = False escape = False for idx in range(start, len(text)): ch = text[idx] if in_string: if escape: escape = False elif ch == "\\": escape = True elif ch == '"': in_string = False continue if ch == '"': in_string = True continue if ch == '{': depth += 1 elif ch == '}': depth -= 1 if depth == 0: return text[start : idx + 1] raise ValueError("Router output JSON appears truncated.") def trim_at_stop_sequences(text: str) -> Tuple[str, bool]: """Trim text at stop sequences and return trimmed text and whether a stop was found.""" earliest = None for stop in STOP_SEQUENCES: idx = text.find(stop) if idx != -1 and (earliest is None or idx < earliest): earliest = idx if earliest is not None: return text[:earliest], True return text, False def is_function_call(text: str) -> bool: return bool(TOOL_PATTERN.match(text.strip())) def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]: issues: List[str] = [] for key in REQUIRED_KEYS: if key not in plan: issues.append(f"Missing key: {key}") route_plan = plan.get("route_plan") if isinstance(route_plan, str) and is_function_call(route_plan): plan["route_plan"] = [route_plan] route_plan = plan["route_plan"] if not isinstance(route_plan, list) or not route_plan: issues.append("route_plan must be a non-empty list of tool calls") else: cleaned: List[str] = [] for entry in route_plan: if isinstance(entry, str) and is_function_call(entry.strip().strip("'\"")): cleaned.append(entry.strip().strip("'\"")) else: issues.append(f"route_plan entry is not a tool call: {entry}") if cleaned: plan["route_plan"] = cleaned metrics = plan.get("metrics") if not isinstance(metrics, dict): issues.append("metrics must be an object containing primary/secondary entries") todo = plan.get("todo_list") if not isinstance(todo, list) or not todo: issues.append("todo_list must contain at least one checklist item") else: cleaned_todo: List[str] = [] for entry in todo: if isinstance(entry, str): text = entry.strip() if not text.startswith("- ["): text = text.lstrip("- ") text = f"- [ ] {text}" cleaned_todo.append(text) else: issues.append("todo_list entry must be a string") if cleaned_todo: plan["todo_list"] = cleaned_todo return len(issues) == 0, issues def format_validation_message(ok: bool, issues: List[str]) -> str: if ok: return "✅ Router plan includes all required fields." bullets = "\n".join(f"- {issue}" for issue in issues) return f"❌ Issues detected:\n{bullets}" def _generate_router_plan_streaming_internal( user_task: str, context: str, acceptance: str, extra_guidance: str, difficulty: str, tags: str, model_choice: str, max_new_tokens: int, temperature: float, top_p: float, gpu_duration: int, enable_search: bool, search_max_results: int, search_max_chars: int, search_timeout: float, ): """Internal generator function for streaming token output.""" if not user_task.strip(): yield "", {}, "❌ User task is required.", "" return if model_choice not in MODELS: yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", "" return cancel_event.clear() cancelled = False try: search_snippets: List[str] = [] if enable_search and DDGS_AVAILABLE and user_task.strip(): search_snippets_holder: List[str] = [] search_error: Optional[Exception] = None def _fetch_search(): nonlocal search_error try: results = _retrieve_search_results( user_task, max(1, int(search_max_results)), max(30, int(search_max_chars)), ) search_snippets_holder.extend(results) except Exception as exc: # pragma: no cover search_error = exc search_thread = Thread(target=_fetch_search, daemon=True) search_thread.start() search_thread.join(timeout=float(max(0.5, search_timeout))) if search_thread.is_alive(): print("[DEBUG] Search thread timed out; continuing without results.") if search_error: print(f"[DEBUG] Search error: {search_error}") search_snippets = search_snippets_holder context_for_prompt = context if search_snippets: search_block = "\n".join(f"- {snippet}" for snippet in search_snippets) addendum = ( "\n\n# Web Search Findings\n" "Use the following snippets as supplementary evidence. " "Cite them as needed in the generated plan.\n" f"{search_block}" ) context_for_prompt = (context_for_prompt or "").rstrip() + addendum prompt = build_router_prompt( user_task=user_task, context=context_for_prompt, acceptance=acceptance, extra_guidance=extra_guidance, difficulty=difficulty, tags=tags, ) print(f"[DEBUG] Loading model: {model_choice}") generator = load_pipeline(model_choice) print(f"[DEBUG] Model loaded successfully: {type(generator)}") # Check if using vLLM or Transformers is_vllm = VLLM_AVAILABLE and isinstance(generator, LLM) if is_vllm: # Use vLLM streaming API with continuous batching # Optimized sampling parameters for router plan generation sampling_params = SamplingParams( temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, stop=STOP_SEQUENCES, skip_special_tokens=False, # Keep special tokens for parsing spaces_between_special_tokens=False, # Don't add spaces around special tokens include_stop_str_in_output=False, # Don't include stop sequences in output ) # vLLM streaming generation (non-blocking, continuous batching) completion = "" parsed_plan: Dict[str, Any] | None = None validation_msg = "🔄 Generating..." # vLLM's generate with stream=True returns RequestOutput iterator # Each RequestOutput contains incremental text updates stream = generator.generate(prompt, sampling_params, stream=True) prev_text_len = 0 for request_output in stream: if cancel_event.is_set(): cancelled = True try: if hasattr(generator, "abort_request"): generator.abort_request(request_output.request_id) except Exception: pass break if not request_output.outputs: continue # Get the latest output (vLLM provides incremental updates) output = request_output.outputs[0] current_text = output.text # Extract only new tokens since last update if len(current_text) > prev_text_len: new_text = current_text[prev_text_len:] completion += new_text prev_text_len = len(current_text) chunk = completion finished = False display_plan = parsed_plan or {} chunk, finished = trim_at_stop_sequences(chunk) try: json_block = extract_json_from_text(chunk) candidate_plan = json.loads(json_block) ok, issues = validate_router_plan(candidate_plan) validation_msg = format_validation_message(ok, issues) parsed_plan = candidate_plan if ok else parsed_plan display_plan = candidate_plan except Exception: # Ignore until JSON is complete pass yield chunk, display_plan, validation_msg, prompt if finished: completion = chunk break # Check if generation is finished if request_output.finished: break else: # Use Transformers pipeline (fallback) # Get the underlying model and tokenizer model = generator.model tokenizer = generator.tokenizer # Set up streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Prepare inputs inputs = tokenizer(prompt, return_tensors="pt") if hasattr(model, 'device'): inputs = {k: v.to(model.device) for k, v in inputs.items()} elif torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Start generation in a separate thread generation_kwargs = { **inputs, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "do_sample": True, "streamer": streamer, "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, "stopping_criteria": StoppingCriteriaList([CancelStoppingCriteria()]), } generation_error = None def _generate(): nonlocal generation_error try: with torch.inference_mode(): model.generate(**generation_kwargs) except Exception as e: generation_error = e print(f"[DEBUG] Generation thread error: {e}") import traceback traceback.print_exc() thread = Thread(target=_generate) thread.start() # Stream tokens completion = "" parsed_plan: Dict[str, Any] | None = None validation_msg = "🔄 Generating..." print(f"[DEBUG] Starting to consume streamer...") token_count = 0 try: for new_text in streamer: if cancel_event.is_set(): cancelled = True break if generation_error: raise generation_error if new_text: token_count += 1 completion += new_text chunk = completion finished = False display_plan = parsed_plan or {} chunk, finished = trim_at_stop_sequences(chunk) try: json_block = extract_json_from_text(chunk) candidate_plan = json.loads(json_block) ok, issues = validate_router_plan(candidate_plan) validation_msg = format_validation_message(ok, issues) parsed_plan = candidate_plan if ok else parsed_plan display_plan = candidate_plan except Exception: # Ignore until JSON is complete pass yield chunk, display_plan, validation_msg, prompt if finished: completion = chunk break print(f"[DEBUG] Streamer finished. Received {token_count} tokens.") except Exception as stream_error: print(f"[DEBUG] Streamer error: {stream_error}") import traceback traceback.print_exc() # Wait for thread to finish thread.join(timeout=5.0) if generation_error: raise generation_error raise stream_error # Final processing after streaming completes thread.join(timeout=30.0) if thread.is_alive(): print("[DEBUG] WARNING: Generation thread still running after timeout") if generation_error: raise generation_error completion = trim_at_stop_sequences(completion.strip())[0] print(f"[DEBUG] Final completion length: {len(completion)}") if cancelled: validation_msg = "âšī¸ Generation cancelled by user." elif not completion: print("[DEBUG] WARNING: Completion is empty - model may not have generated output") validation_msg = "âš ī¸ Model generated empty output. Check GPU allocation and model loading." elif parsed_plan is None: try: json_block = extract_json_from_text(completion) parsed_plan = json.loads(json_block) ok, issues = validate_router_plan(parsed_plan) validation_msg = format_validation_message(ok, issues) except Exception as exc: parsed_plan = {} validation_msg = f"❌ JSON parsing failed: {exc}" print(f"[DEBUG] JSON parsing error: {exc}") yield completion, parsed_plan, validation_msg, prompt except Exception as exc: import traceback print(f"[DEBUG] Exception in generation: {exc}") print(f"[DEBUG] Traceback: {traceback.format_exc()}") error_msg = f"❌ Generation failed: {str(exc)}" yield "", {}, error_msg, "" # Pre-create GPU wrappers for common durations at module load time # This ensures spaces.GPU decorators are detected during startup _GPU_WRAPPERS: Dict[int, Any] = {} # Create wrappers for durations: 60, 120, 180, 240, 300, 360, 420, 480, 540, 600, # 720, 840, 960, 1080, 1200, 1320, 1440, 1560, 1680, 1800 (every 60s from 60 to 1800) def _make_gpu_wrapper(duration: int): """Factory function to create GPU-decorated wrapper with closure over duration.""" @spaces.GPU(duration=duration) def wrapper( user_task: str, context: str, acceptance: str, extra_guidance: str, difficulty: str, tags: str, model_choice: str, max_new_tokens: int, temperature: float, top_p: float, gpu_duration: int, enable_search: bool, search_max_results: int, search_max_chars: int, search_timeout: float, ): yield from _generate_router_plan_streaming_internal( user_task, context, acceptance, extra_guidance, difficulty, tags, model_choice, max_new_tokens, temperature, top_p, duration, enable_search, search_max_results, search_max_chars, search_timeout, ) return wrapper # Pre-create all wrappers at module load time for duration in range(60, 1801, 60): _GPU_WRAPPERS[duration] = _make_gpu_wrapper(duration) def generate_router_plan_streaming( user_task: str, context: str, acceptance: str, extra_guidance: str, difficulty: str, tags: str, model_choice: str, max_new_tokens: int, temperature: float, top_p: float, gpu_duration: int = 600, enable_search: bool = False, search_max_results: int = 4, search_max_chars: int = 120, search_timeout: float = 5.0, ): """ Generate router plan with streaming output. Uses user-specified gpu_duration to select the appropriate GPU wrapper. """ # Round to nearest 60 seconds and clamp between 60 and 1800 rounded_duration = ((gpu_duration + 30) // 60) * 60 rounded_duration = max(60, min(1800, rounded_duration)) # Get the pre-created wrapper with this duration wrapper = _GPU_WRAPPERS[rounded_duration] yield from wrapper( user_task, context, acceptance, extra_guidance, difficulty, tags, model_choice, max_new_tokens, temperature, top_p, rounded_duration, enable_search, int(search_max_results), int(search_max_chars), float(search_timeout), ) def clear_outputs(): return "", {}, "Awaiting generation.", "" def cancel_generation(): cancel_event.set() return "âšī¸ Cancel request sent. Finishing current step..." def build_ui(): description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans." initial_estimate_text, initial_recommended_duration = format_gpu_estimate_message( DEFAULT_MODEL, 16000, False, ) with gr.Blocks(theme=gr.themes.Soft(), css=""" textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; } .status-ok { color: #0d9488; font-weight: 600; } .status-bad { color: #dc2626; font-weight: 600; } """) as demo: gr.Markdown("# đŸ›°ī¸ Router Control Room — ZeroGPU" ) gr.Markdown(description) with gr.Row(): with gr.Column(scale=3): user_task = gr.Textbox( label="User Task / Problem Statement", placeholder="Describe the homework-style query that needs routing...", lines=8, value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.", ) context = gr.Textbox( label="Supporting Context (optional)", placeholder="Paste any retrieved evidence, PDFs, or rubric notes.", lines=4, ) acceptance = gr.Textbox( label="Acceptance Criteria", placeholder="Bullet list of 'definition of done' checks.", lines=3, value="- Provide citations for every claim.\n- Ensure /math verifies /code output.", ) extra_guidance = gr.Textbox( label="Additional Guidance", placeholder="Special constraints, tools to avoid, etc.", lines=3, ) with gr.Column(scale=2): model_choice = gr.Dropdown( label="Router Checkpoint", choices=list(MODELS.keys()), value=DEFAULT_MODEL, allow_custom_value=False, ) difficulty = gr.Radio( label="Difficulty Tier", choices=["introductory", "intermediate", "advanced"], value="advanced", interactive=True, ) tags = gr.Textbox( label="Tags", placeholder="Comma-separated e.g. calculus, optimization, python", value="calculus, optimization, python", ) max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens") temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") enable_search = gr.Checkbox( label="Enable DuckDuckGo Web Search", value=False, interactive=DDGS_AVAILABLE, info="Augment context with live snippets." if DDGS_AVAILABLE else "Install 'ddgs' package to enable search.", ) with gr.Accordion("Web Search Settings", open=False, visible=DDGS_AVAILABLE) as search_settings: search_max_results = gr.Slider( minimum=1, maximum=10, value=4, step=1, label="Search Results", interactive=DDGS_AVAILABLE, ) search_max_chars = gr.Slider( minimum=50, maximum=400, value=160, step=10, label="Max Characters per Result", interactive=DDGS_AVAILABLE, ) search_timeout = gr.Slider( minimum=1.0, maximum=20.0, value=5.0, step=0.5, label="Search Timeout (seconds)", interactive=DDGS_AVAILABLE, ) gpu_estimate_display = gr.Markdown( value=initial_estimate_text, elem_classes="status-ok", ) gpu_duration = gr.Slider( 60, 1800, value=initial_recommended_duration, step=60, label="GPU Duration (seconds)", info="Maximum GPU time allocation for this request", ) with gr.Row(): generate_btn = gr.Button("Generate Router Plan", variant="primary", scale=1) clear_btn = gr.Button("Clear", variant="secondary", scale=1) cancel_btn = gr.Button("Cancel", variant="stop", scale=1) with gr.Row(): raw_output = gr.Textbox(label="Raw Model Output", lines=12) plan_json = gr.JSON(label="Parsed Router Plan") validation_msg = gr.Markdown("Awaiting generation.") prompt_view = gr.Textbox(label="Full Prompt", lines=10) generate_btn.click( generate_router_plan_streaming, inputs=[ user_task, context, acceptance, extra_guidance, difficulty, tags, model_choice, max_new_tokens, temperature, top_p, gpu_duration, enable_search, search_max_results, search_max_chars, search_timeout, ], outputs=[raw_output, plan_json, validation_msg, prompt_view], show_progress="full", api_name="/generate_router_plan_streaming", ) clear_btn.click( fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view], api_name="/clear_outputs", ) cancel_btn.click( fn=cancel_generation, outputs=[validation_msg], ) model_choice.change( fn=update_gpu_controls, inputs=[model_choice, max_new_tokens, enable_search, gpu_duration], outputs=[gpu_estimate_display, gpu_duration], ) max_new_tokens.change( fn=update_gpu_controls, inputs=[model_choice, max_new_tokens, enable_search, gpu_duration], outputs=[gpu_estimate_display, gpu_duration], ) enable_search.change( fn=update_gpu_controls, inputs=[model_choice, max_new_tokens, enable_search, gpu_duration], outputs=[gpu_estimate_display, gpu_duration], ) return demo def _prefetch_from_env() -> None: names = _resolve_prefetch_model_names(include_default=False) for name in names: if name not in MODELS: print(f"Prefetch skipped, unknown model: {name}") continue try: load_pipeline(name) print(f"Prefetched router model: {name}") except Exception as exc: # pragma: no cover print(f"Prefetch failed for {name}: {exc}") _prefetch_from_env() demo = build_ui() if __name__ == "__main__": # pragma: no cover # Support both Hugging Face Spaces and Google Cloud Run # Cloud Run uses PORT, Hugging Face Spaces uses GRADIO_SERVER_PORT port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", 7860))) server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") demo.launch( server_name=server_name, server_port=port, show_api=True )