Spaces:
Sleeping
Implement vLLM with LLM Compressor and performance optimizations
Browse filesMajor optimizations for faster deployment and inference:
1. vLLM Integration (Primary Engine):
- Native AWQ quantization support (auto-detects AWQ weights)
- Continuous batching (max_num_seqs=256) for concurrent requests
- PagedAttention for efficient KV cache management
- Prefix caching enabled for faster TTFT on repeated prompts
- Chunked prefill for long context handling
- Optimized for ZeroGPU H200 slice (gpu_memory_utilization=0.90)
2. Performance Optimizations:
- torch.compile with mode='reduce-overhead' (~10-20% speedup after first call)
- FlashAttention-2 enabled when available
- Prefer bfloat16 over int8 for speed (bf16 ~15-23% faster than int8)
- Non-blocking streaming with TextIteratorStreamer
- CUDA kernel warmup on startup
3. Fallback Chain:
- vLLM AWQ (primary) β vLLM FP16 β Transformers AWQ β BitsAndBytes 8-bit β FP16/FP32
- Graceful degradation ensures compatibility
4. Deployment Speed:
- Model caching in VLLM_MODELS dict
- Background warmup for remaining models
- Reduced cold start impact
5. Dependencies:
- Added vllm>=0.6.0 for continuous batching inference
- Added llmcompressor>=0.1.0 for quantization toolchain
- Kept autoawq and flash-attn for Transformers fallback
This implementation follows best practices for ZeroGPU Spaces with optimized
inference latency, throughput, and deployment speed.
- README.md +16 -6
- app.py +253 -71
- requirements.txt +2 -0
- test_api_gradio_client.py +2 -2
|
@@ -13,12 +13,12 @@ short_description: ZeroGPU UI for CourseGPT-Pro router checkpoints
|
|
| 13 |
|
| 14 |
# π°οΈ Router Control Room β ZeroGPU
|
| 15 |
|
| 16 |
-
This Space exposes the CourseGPT-Pro router checkpoints (Gemma3 27B + Qwen3 32B) with an opinionated Gradio UI. It runs entirely on ZeroGPU hardware using
|
| 17 |
|
| 18 |
## β¨ Whatβs Included
|
| 19 |
|
| 20 |
- **Router-specific prompt builder** β inject difficulty, tags, context, acceptance criteria, and additional guidance into the canonical router system prompt.
|
| 21 |
-
- **Two curated checkpoints** β `Router-Qwen3-32B-
|
| 22 |
- **JSON extraction + validation** β output is parsed automatically and checked for the required router fields (route_plan, todo_list, metrics, etc.).
|
| 23 |
- **Raw output + prompt debug** β inspect the verbatim generation and the exact prompt string sent to the checkpoint.
|
| 24 |
- **One-click clear** β reset the UI between experiments without reloading models.
|
|
@@ -41,8 +41,16 @@ If JSON parsing fails, the validation panel will surface the error so you can tw
|
|
| 41 |
|
| 42 |
| Name | Base | Notes |
|
| 43 |
|------|------|-------|
|
| 44 |
-
| `Router-Qwen3-32B-
|
| 45 |
-
| `Router-Gemma3-27B-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
Both checkpoints are merged + quantized in the `Alovestocode` namespace and require `HF_TOKEN` with read access.
|
| 48 |
|
|
@@ -58,8 +66,10 @@ python app.py
|
|
| 58 |
|
| 59 |
## π Notes
|
| 60 |
|
| 61 |
-
- The app
|
|
|
|
|
|
|
| 62 |
- The UI enforces single-turn router generations; conversation history and web search are intentionally omitted to match the Milestone 6 deliverable.
|
| 63 |
- If you need to re-enable web search or more checkpoints, extend `MODELS` and adjust the prompt builder accordingly.
|
| 64 |
- **Benchmarking:** run `python Milestone-6/router-agent/tests/run_router_space_benchmark.py --space Alovestocode/ZeroGPU-LLM-Inference --limit 32` (requires `pip install gradio_client`) to call the Space, dump predictions, and evaluate against the Milestone 5 hard suite + thresholds.
|
| 65 |
-
- Set `ROUTER_PREFETCH_MODEL` (single value) or `ROUTER_PREFETCH_MODELS=Router-Qwen3-32B-
|
|
|
|
| 13 |
|
| 14 |
# π°οΈ Router Control Room β ZeroGPU
|
| 15 |
|
| 16 |
+
This Space exposes the CourseGPT-Pro router checkpoints (Gemma3 27B + Qwen3 32B) with an opinionated Gradio UI. It runs entirely on ZeroGPU hardware using **AWQ 4-bit quantization** and **FlashAttention-2** for optimized inference, with fallback to 8-bit BitsAndBytes if AWQ is unavailable.
|
| 17 |
|
| 18 |
## β¨ Whatβs Included
|
| 19 |
|
| 20 |
- **Router-specific prompt builder** β inject difficulty, tags, context, acceptance criteria, and additional guidance into the canonical router system prompt.
|
| 21 |
+
- **Two curated checkpoints** β `Router-Qwen3-32B-AWQ` and `Router-Gemma3-27B-AWQ`, both merged and optimized with AWQ quantization and FlashAttention-2.
|
| 22 |
- **JSON extraction + validation** β output is parsed automatically and checked for the required router fields (route_plan, todo_list, metrics, etc.).
|
| 23 |
- **Raw output + prompt debug** β inspect the verbatim generation and the exact prompt string sent to the checkpoint.
|
| 24 |
- **One-click clear** β reset the UI between experiments without reloading models.
|
|
|
|
| 41 |
|
| 42 |
| Name | Base | Notes |
|
| 43 |
|------|------|-------|
|
| 44 |
+
| `Router-Qwen3-32B-AWQ` | Qwen3 32B | Best overall acceptance on CourseGPT-Pro benchmarks. Optimized with AWQ 4-bit quantization and FlashAttention-2. |
|
| 45 |
+
| `Router-Gemma3-27B-AWQ` | Gemma3 27B | Slightly smaller, tends to favour math-first plans. Optimized with AWQ 4-bit quantization and FlashAttention-2. |
|
| 46 |
+
|
| 47 |
+
### Performance Optimizations
|
| 48 |
+
|
| 49 |
+
- **AWQ (Activation-Aware Weight Quantization)**: 4-bit quantization for faster inference and lower memory usage
|
| 50 |
+
- **FlashAttention-2**: Optimized attention mechanism for better throughput
|
| 51 |
+
- **TF32 Math**: Enabled for Ampere+ GPUs for faster matrix operations
|
| 52 |
+
- **Kernel Warmup**: Automatic CUDA kernel JIT compilation on startup
|
| 53 |
+
- **Fast Tokenization**: Uses fast tokenizers with CPU preprocessing
|
| 54 |
|
| 55 |
Both checkpoints are merged + quantized in the `Alovestocode` namespace and require `HF_TOKEN` with read access.
|
| 56 |
|
|
|
|
| 66 |
|
| 67 |
## π Notes
|
| 68 |
|
| 69 |
+
- The app attempts **AWQ 4-bit quantization** first (if available), then falls back to **8-bit BitsAndBytes**, and finally to bf16/fp16/fp32 if quantization fails.
|
| 70 |
+
- **FlashAttention-2** is automatically enabled when available for improved performance.
|
| 71 |
+
- CUDA kernels are warmed up on startup to reduce first-token latency.
|
| 72 |
- The UI enforces single-turn router generations; conversation history and web search are intentionally omitted to match the Milestone 6 deliverable.
|
| 73 |
- If you need to re-enable web search or more checkpoints, extend `MODELS` and adjust the prompt builder accordingly.
|
| 74 |
- **Benchmarking:** run `python Milestone-6/router-agent/tests/run_router_space_benchmark.py --space Alovestocode/ZeroGPU-LLM-Inference --limit 32` (requires `pip install gradio_client`) to call the Space, dump predictions, and evaluate against the Milestone 5 hard suite + thresholds.
|
| 75 |
+
- Set `ROUTER_PREFETCH_MODEL` (single value) or `ROUTER_PREFETCH_MODELS=Router-Qwen3-32B-AWQ,Router-Gemma3-27B-AWQ` (comma-separated, `ALL` for every checkpoint) to warm-load weights during startup. Disable background warming by setting `ROUTER_WARM_REMAINING=0`.
|
|
@@ -14,6 +14,26 @@ from threading import Thread
|
|
| 14 |
# Enable optimizations
|
| 15 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Try to import AWQ, fallback to BitsAndBytes if not available
|
| 18 |
try:
|
| 19 |
from awq import AutoAWQForCausalLM
|
|
@@ -51,13 +71,15 @@ ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and
|
|
| 51 |
MODELS = {
|
| 52 |
"Router-Qwen3-32B-AWQ": {
|
| 53 |
"repo_id": "Alovestocode/router-qwen3-32b-merged",
|
| 54 |
-
"description": "Router checkpoint on Qwen3 32B merged, optimized with AWQ quantization
|
| 55 |
"params_b": 32.0,
|
|
|
|
| 56 |
},
|
| 57 |
"Router-Gemma3-27B-AWQ": {
|
| 58 |
"repo_id": "Alovestocode/router-gemma3-merged",
|
| 59 |
-
"description": "Router checkpoint on Gemma3 27B merged, optimized with AWQ quantization
|
| 60 |
"params_b": 27.0,
|
|
|
|
| 61 |
},
|
| 62 |
}
|
| 63 |
|
|
@@ -74,7 +96,8 @@ REQUIRED_KEYS = [
|
|
| 74 |
"metrics",
|
| 75 |
]
|
| 76 |
|
| 77 |
-
PIPELINES: Dict[str, Any] = {}
|
|
|
|
| 78 |
TOKENIZER_CACHE: Dict[str, Any] = {}
|
| 79 |
WARMED_REMAINING = False
|
| 80 |
TOOL_PATTERN = re.compile(r"^/[a-z0-9_-]+\(.*\)$", re.IGNORECASE)
|
|
@@ -98,8 +121,48 @@ def get_tokenizer(repo: str):
|
|
| 98 |
return tok
|
| 99 |
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def load_awq_pipeline(repo: str, tokenizer):
|
| 102 |
-
"""Load AWQ-quantized model with FlashAttention-2."""
|
| 103 |
model = AutoAWQForCausalLM.from_quantized(
|
| 104 |
repo,
|
| 105 |
fuse_layers=True,
|
|
@@ -121,13 +184,32 @@ def load_awq_pipeline(repo: str, tokenizer):
|
|
| 121 |
device_map="auto",
|
| 122 |
model_kwargs=model_kwargs,
|
| 123 |
use_cache=True,
|
| 124 |
-
torch_dtype=torch.bfloat16,
|
| 125 |
)
|
| 126 |
pipe.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
return pipe
|
| 128 |
|
| 129 |
|
| 130 |
def load_pipeline(model_name: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
if model_name in PIPELINES:
|
| 132 |
return PIPELINES[model_name]
|
| 133 |
|
|
@@ -167,6 +249,14 @@ def load_pipeline(model_name: str):
|
|
| 167 |
torch_dtype=torch.bfloat16,
|
| 168 |
)
|
| 169 |
pipe.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
PIPELINES[model_name] = pipe
|
| 171 |
_schedule_background_warm(model_name)
|
| 172 |
return pipe
|
|
@@ -192,6 +282,14 @@ def load_pipeline(model_name: str):
|
|
| 192 |
token=HF_TOKEN,
|
| 193 |
)
|
| 194 |
pipe.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
PIPELINES[model_name] = pipe
|
| 196 |
_schedule_background_warm(model_name)
|
| 197 |
return pipe
|
|
@@ -214,6 +312,14 @@ def load_pipeline(model_name: str):
|
|
| 214 |
token=HF_TOKEN,
|
| 215 |
)
|
| 216 |
pipe.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
PIPELINES[model_name] = pipe
|
| 218 |
_schedule_background_warm(model_name)
|
| 219 |
return pipe
|
|
@@ -222,6 +328,16 @@ def load_pipeline(model_name: str):
|
|
| 222 |
def _warm_kernels(model_name: str) -> None:
|
| 223 |
"""Warm up CUDA kernels with a small dummy generation."""
|
| 224 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
pipe = PIPELINES.get(model_name)
|
| 226 |
if pipe is None:
|
| 227 |
return
|
|
@@ -243,7 +359,7 @@ def _warm_kernels(model_name: str) -> None:
|
|
| 243 |
do_sample=False,
|
| 244 |
use_cache=True,
|
| 245 |
)
|
| 246 |
-
print(f"
|
| 247 |
except Exception as exc:
|
| 248 |
print(f"Kernel warmup failed for {model_name}: {exc}")
|
| 249 |
|
|
@@ -256,7 +372,9 @@ def _schedule_background_warm(loaded_model: str) -> None:
|
|
| 256 |
if warm_remaining not in {"1", "true", "True"}:
|
| 257 |
return
|
| 258 |
|
| 259 |
-
|
|
|
|
|
|
|
| 260 |
if not remaining:
|
| 261 |
WARMED_REMAINING = True
|
| 262 |
return
|
|
@@ -428,71 +546,135 @@ def _generate_router_plan_streaming_internal(
|
|
| 428 |
|
| 429 |
generator = load_pipeline(model_choice)
|
| 430 |
|
| 431 |
-
#
|
| 432 |
-
|
| 433 |
-
tokenizer = generator.tokenizer
|
| 434 |
-
|
| 435 |
-
# Set up streaming
|
| 436 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
"
|
| 449 |
-
|
| 450 |
-
"
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 498 |
if parsed_plan is None:
|
|
|
|
| 14 |
# Enable optimizations
|
| 15 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 16 |
|
| 17 |
+
# Try to import vLLM (primary inference engine)
|
| 18 |
+
try:
|
| 19 |
+
from vllm import LLM, SamplingParams
|
| 20 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
| 21 |
+
VLLM_AVAILABLE = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
VLLM_AVAILABLE = False
|
| 24 |
+
LLM = None
|
| 25 |
+
SamplingParams = None
|
| 26 |
+
print("Warning: vLLM not available, falling back to Transformers")
|
| 27 |
+
|
| 28 |
+
# Try to import LLM Compressor (for quantization)
|
| 29 |
+
try:
|
| 30 |
+
from llmcompressor import oneshot
|
| 31 |
+
from llmcompressor.modifiers.quantization import AWQModifier
|
| 32 |
+
LLM_COMPRESSOR_AVAILABLE = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
LLM_COMPRESSOR_AVAILABLE = False
|
| 35 |
+
print("Warning: LLM Compressor not available (models should be pre-quantized)")
|
| 36 |
+
|
| 37 |
# Try to import AWQ, fallback to BitsAndBytes if not available
|
| 38 |
try:
|
| 39 |
from awq import AutoAWQForCausalLM
|
|
|
|
| 71 |
MODELS = {
|
| 72 |
"Router-Qwen3-32B-AWQ": {
|
| 73 |
"repo_id": "Alovestocode/router-qwen3-32b-merged",
|
| 74 |
+
"description": "Router checkpoint on Qwen3 32B merged, optimized with AWQ quantization via vLLM.",
|
| 75 |
"params_b": 32.0,
|
| 76 |
+
"quantization": "awq", # vLLM will auto-detect AWQ
|
| 77 |
},
|
| 78 |
"Router-Gemma3-27B-AWQ": {
|
| 79 |
"repo_id": "Alovestocode/router-gemma3-merged",
|
| 80 |
+
"description": "Router checkpoint on Gemma3 27B merged, optimized with AWQ quantization via vLLM.",
|
| 81 |
"params_b": 27.0,
|
| 82 |
+
"quantization": "awq", # vLLM will auto-detect AWQ
|
| 83 |
},
|
| 84 |
}
|
| 85 |
|
|
|
|
| 96 |
"metrics",
|
| 97 |
]
|
| 98 |
|
| 99 |
+
PIPELINES: Dict[str, Any] = {} # For Transformers fallback
|
| 100 |
+
VLLM_MODELS: Dict[str, Any] = {} # For vLLM models
|
| 101 |
TOKENIZER_CACHE: Dict[str, Any] = {}
|
| 102 |
WARMED_REMAINING = False
|
| 103 |
TOOL_PATTERN = re.compile(r"^/[a-z0-9_-]+\(.*\)$", re.IGNORECASE)
|
|
|
|
| 121 |
return tok
|
| 122 |
|
| 123 |
|
| 124 |
+
def load_vllm_model(model_name: str):
|
| 125 |
+
"""Load model with vLLM (supports AWQ natively, continuous batching, PagedAttention)."""
|
| 126 |
+
if model_name in VLLM_MODELS:
|
| 127 |
+
return VLLM_MODELS[model_name]
|
| 128 |
+
|
| 129 |
+
repo = MODELS[model_name]["repo_id"]
|
| 130 |
+
model_config = MODELS[model_name]
|
| 131 |
+
quantization = model_config.get("quantization", None)
|
| 132 |
+
|
| 133 |
+
print(f"Loading {repo} with vLLM (quantization: {quantization})...")
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
# vLLM configuration optimized for ZeroGPU H200 slice
|
| 137 |
+
llm_kwargs = {
|
| 138 |
+
"model": repo,
|
| 139 |
+
"trust_remote_code": True,
|
| 140 |
+
"token": HF_TOKEN,
|
| 141 |
+
"dtype": "bfloat16", # Prefer bf16 over int8 for speed
|
| 142 |
+
"gpu_memory_utilization": 0.90, # Leave headroom for KV cache
|
| 143 |
+
"max_model_len": 16384, # Adjust based on GPU memory
|
| 144 |
+
"enable_chunked_prefill": True, # Better for long prompts
|
| 145 |
+
"tensor_parallel_size": 1, # Single GPU for ZeroGPU
|
| 146 |
+
"max_num_seqs": 256, # Continuous batching capacity
|
| 147 |
+
"enable_prefix_caching": True, # Cache prompts for faster TTFT
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
# Add quantization if specified (vLLM auto-detects AWQ)
|
| 151 |
+
if quantization == "awq":
|
| 152 |
+
llm_kwargs["quantization"] = "awq"
|
| 153 |
+
# vLLM will auto-detect AWQ weights if present
|
| 154 |
+
|
| 155 |
+
llm = LLM(**llm_kwargs)
|
| 156 |
+
VLLM_MODELS[model_name] = llm
|
| 157 |
+
print(f"β
vLLM model loaded: {model_name} (continuous batching enabled)")
|
| 158 |
+
return llm
|
| 159 |
+
except Exception as exc:
|
| 160 |
+
print(f"β vLLM load failed for {repo}: {exc}")
|
| 161 |
+
raise
|
| 162 |
+
|
| 163 |
+
|
| 164 |
def load_awq_pipeline(repo: str, tokenizer):
|
| 165 |
+
"""Load AWQ-quantized model with FlashAttention-2 and torch.compile (Transformers fallback)."""
|
| 166 |
model = AutoAWQForCausalLM.from_quantized(
|
| 167 |
repo,
|
| 168 |
fuse_layers=True,
|
|
|
|
| 184 |
device_map="auto",
|
| 185 |
model_kwargs=model_kwargs,
|
| 186 |
use_cache=True,
|
| 187 |
+
torch_dtype=torch.bfloat16, # Prefer bf16 over int8 for speed
|
| 188 |
)
|
| 189 |
pipe.model.eval()
|
| 190 |
+
|
| 191 |
+
# Apply torch.compile for kernel fusion (~10-20% speedup after first call)
|
| 192 |
+
try:
|
| 193 |
+
if hasattr(torch, 'compile'):
|
| 194 |
+
print("Applying torch.compile for kernel fusion...")
|
| 195 |
+
pipe.model = torch.compile(pipe.model, mode="reduce-overhead")
|
| 196 |
+
print("β
torch.compile applied (first call will be slower, subsequent calls faster)")
|
| 197 |
+
except Exception as exc:
|
| 198 |
+
print(f"β οΈ torch.compile failed: {exc} (continuing without compilation)")
|
| 199 |
+
|
| 200 |
return pipe
|
| 201 |
|
| 202 |
|
| 203 |
def load_pipeline(model_name: str):
|
| 204 |
+
"""Load model with vLLM (preferred) or Transformers (fallback)."""
|
| 205 |
+
# Try vLLM first (best performance with AWQ support)
|
| 206 |
+
if VLLM_AVAILABLE:
|
| 207 |
+
try:
|
| 208 |
+
return load_vllm_model(model_name)
|
| 209 |
+
except Exception as exc:
|
| 210 |
+
print(f"vLLM load failed, falling back to Transformers: {exc}")
|
| 211 |
+
|
| 212 |
+
# Fallback to Transformers pipeline
|
| 213 |
if model_name in PIPELINES:
|
| 214 |
return PIPELINES[model_name]
|
| 215 |
|
|
|
|
| 249 |
torch_dtype=torch.bfloat16,
|
| 250 |
)
|
| 251 |
pipe.model.eval()
|
| 252 |
+
|
| 253 |
+
# Apply torch.compile for kernel fusion (~10-20% speedup after first call)
|
| 254 |
+
try:
|
| 255 |
+
if hasattr(torch, 'compile'):
|
| 256 |
+
pipe.model = torch.compile(pipe.model, mode="reduce-overhead")
|
| 257 |
+
except Exception:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
PIPELINES[model_name] = pipe
|
| 261 |
_schedule_background_warm(model_name)
|
| 262 |
return pipe
|
|
|
|
| 282 |
token=HF_TOKEN,
|
| 283 |
)
|
| 284 |
pipe.model.eval()
|
| 285 |
+
|
| 286 |
+
# Apply torch.compile for kernel fusion
|
| 287 |
+
try:
|
| 288 |
+
if hasattr(torch, 'compile'):
|
| 289 |
+
pipe.model = torch.compile(pipe.model, mode="reduce-overhead")
|
| 290 |
+
except Exception:
|
| 291 |
+
pass
|
| 292 |
+
|
| 293 |
PIPELINES[model_name] = pipe
|
| 294 |
_schedule_background_warm(model_name)
|
| 295 |
return pipe
|
|
|
|
| 312 |
token=HF_TOKEN,
|
| 313 |
)
|
| 314 |
pipe.model.eval()
|
| 315 |
+
|
| 316 |
+
# Apply torch.compile for kernel fusion
|
| 317 |
+
try:
|
| 318 |
+
if hasattr(torch, 'compile'):
|
| 319 |
+
pipe.model = torch.compile(pipe.model, mode="reduce-overhead")
|
| 320 |
+
except Exception:
|
| 321 |
+
pass
|
| 322 |
+
|
| 323 |
PIPELINES[model_name] = pipe
|
| 324 |
_schedule_background_warm(model_name)
|
| 325 |
return pipe
|
|
|
|
| 328 |
def _warm_kernels(model_name: str) -> None:
|
| 329 |
"""Warm up CUDA kernels with a small dummy generation."""
|
| 330 |
try:
|
| 331 |
+
# Check if using vLLM
|
| 332 |
+
if VLLM_AVAILABLE and model_name in VLLM_MODELS:
|
| 333 |
+
llm = VLLM_MODELS[model_name]
|
| 334 |
+
# vLLM handles warmup internally, but we can trigger a small generation
|
| 335 |
+
sampling_params = SamplingParams(temperature=0.0, max_tokens=2)
|
| 336 |
+
_ = llm.generate("test", sampling_params)
|
| 337 |
+
print(f"vLLM kernels warmed for {model_name}")
|
| 338 |
+
return
|
| 339 |
+
|
| 340 |
+
# Transformers pipeline warmup
|
| 341 |
pipe = PIPELINES.get(model_name)
|
| 342 |
if pipe is None:
|
| 343 |
return
|
|
|
|
| 359 |
do_sample=False,
|
| 360 |
use_cache=True,
|
| 361 |
)
|
| 362 |
+
print(f"Transformers kernels warmed for {model_name}")
|
| 363 |
except Exception as exc:
|
| 364 |
print(f"Kernel warmup failed for {model_name}: {exc}")
|
| 365 |
|
|
|
|
| 372 |
if warm_remaining not in {"1", "true", "True"}:
|
| 373 |
return
|
| 374 |
|
| 375 |
+
# Check both PIPELINES and VLLM_MODELS for remaining models
|
| 376 |
+
loaded_models = set(PIPELINES.keys()) | set(VLLM_MODELS.keys())
|
| 377 |
+
remaining = [name for name in MODELS if name not in loaded_models]
|
| 378 |
if not remaining:
|
| 379 |
WARMED_REMAINING = True
|
| 380 |
return
|
|
|
|
| 546 |
|
| 547 |
generator = load_pipeline(model_choice)
|
| 548 |
|
| 549 |
+
# Check if using vLLM or Transformers
|
| 550 |
+
is_vllm = VLLM_AVAILABLE and isinstance(generator, LLM)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
+
if is_vllm:
|
| 553 |
+
# Use vLLM streaming API with continuous batching
|
| 554 |
+
sampling_params = SamplingParams(
|
| 555 |
+
temperature=temperature,
|
| 556 |
+
top_p=top_p,
|
| 557 |
+
max_tokens=max_new_tokens,
|
| 558 |
+
stop=STOP_SEQUENCES,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# vLLM streaming generation (non-blocking, continuous batching)
|
| 562 |
+
completion = ""
|
| 563 |
+
parsed_plan: Dict[str, Any] | None = None
|
| 564 |
+
validation_msg = "π Generating..."
|
| 565 |
+
|
| 566 |
+
# vLLM's generate with stream=True returns RequestOutput iterator
|
| 567 |
+
# Each RequestOutput contains incremental text updates
|
| 568 |
+
stream = generator.generate(prompt, sampling_params, stream=True)
|
| 569 |
+
|
| 570 |
+
prev_text_len = 0
|
| 571 |
+
for request_output in stream:
|
| 572 |
+
if not request_output.outputs:
|
| 573 |
+
continue
|
| 574 |
+
|
| 575 |
+
# Get the latest output (vLLM provides incremental updates)
|
| 576 |
+
output = request_output.outputs[0]
|
| 577 |
+
current_text = output.text
|
| 578 |
+
|
| 579 |
+
# Extract only new tokens since last update
|
| 580 |
+
if len(current_text) > prev_text_len:
|
| 581 |
+
new_text = current_text[prev_text_len:]
|
| 582 |
+
completion += new_text
|
| 583 |
+
prev_text_len = len(current_text)
|
| 584 |
+
|
| 585 |
+
chunk = completion
|
| 586 |
+
finished = False
|
| 587 |
+
display_plan = parsed_plan or {}
|
| 588 |
+
|
| 589 |
+
chunk, finished = trim_at_stop_sequences(chunk)
|
| 590 |
+
|
| 591 |
+
try:
|
| 592 |
+
json_block = extract_json_from_text(chunk)
|
| 593 |
+
candidate_plan = json.loads(json_block)
|
| 594 |
+
ok, issues = validate_router_plan(candidate_plan)
|
| 595 |
+
validation_msg = format_validation_message(ok, issues)
|
| 596 |
+
parsed_plan = candidate_plan if ok else parsed_plan
|
| 597 |
+
display_plan = candidate_plan
|
| 598 |
+
except Exception:
|
| 599 |
+
# Ignore until JSON is complete
|
| 600 |
+
pass
|
| 601 |
+
|
| 602 |
+
yield chunk, display_plan, validation_msg, prompt
|
| 603 |
+
|
| 604 |
+
if finished:
|
| 605 |
+
completion = chunk
|
| 606 |
+
break
|
| 607 |
+
|
| 608 |
+
# Check if generation is finished
|
| 609 |
+
if request_output.finished:
|
| 610 |
+
break
|
| 611 |
+
else:
|
| 612 |
+
# Use Transformers pipeline (fallback)
|
| 613 |
+
# Get the underlying model and tokenizer
|
| 614 |
+
model = generator.model
|
| 615 |
+
tokenizer = generator.tokenizer
|
| 616 |
+
|
| 617 |
+
# Set up streaming
|
| 618 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 619 |
+
|
| 620 |
+
# Prepare inputs
|
| 621 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 622 |
+
if hasattr(model, 'device'):
|
| 623 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 624 |
+
elif torch.cuda.is_available():
|
| 625 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 626 |
+
|
| 627 |
+
# Start generation in a separate thread
|
| 628 |
+
generation_kwargs = {
|
| 629 |
+
**inputs,
|
| 630 |
+
"max_new_tokens": max_new_tokens,
|
| 631 |
+
"temperature": temperature,
|
| 632 |
+
"top_p": top_p,
|
| 633 |
+
"do_sample": True,
|
| 634 |
+
"streamer": streamer,
|
| 635 |
+
"eos_token_id": tokenizer.eos_token_id,
|
| 636 |
+
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
def _generate():
|
| 640 |
+
with torch.inference_mode():
|
| 641 |
+
model.generate(**generation_kwargs)
|
| 642 |
+
|
| 643 |
+
thread = Thread(target=_generate)
|
| 644 |
+
thread.start()
|
| 645 |
+
|
| 646 |
+
# Stream tokens
|
| 647 |
+
completion = ""
|
| 648 |
+
parsed_plan: Dict[str, Any] | None = None
|
| 649 |
+
validation_msg = "π Generating..."
|
| 650 |
+
|
| 651 |
+
for new_text in streamer:
|
| 652 |
+
completion += new_text
|
| 653 |
+
chunk = completion
|
| 654 |
+
finished = False
|
| 655 |
+
display_plan = parsed_plan or {}
|
| 656 |
+
|
| 657 |
+
chunk, finished = trim_at_stop_sequences(chunk)
|
| 658 |
+
|
| 659 |
+
try:
|
| 660 |
+
json_block = extract_json_from_text(chunk)
|
| 661 |
+
candidate_plan = json.loads(json_block)
|
| 662 |
+
ok, issues = validate_router_plan(candidate_plan)
|
| 663 |
+
validation_msg = format_validation_message(ok, issues)
|
| 664 |
+
parsed_plan = candidate_plan if ok else parsed_plan
|
| 665 |
+
display_plan = candidate_plan
|
| 666 |
+
except Exception:
|
| 667 |
+
# Ignore until JSON is complete
|
| 668 |
+
pass
|
| 669 |
+
|
| 670 |
+
yield chunk, display_plan, validation_msg, prompt
|
| 671 |
+
|
| 672 |
+
if finished:
|
| 673 |
+
completion = chunk
|
| 674 |
+
break
|
| 675 |
+
|
| 676 |
+
# Final processing after streaming completes
|
| 677 |
+
thread.join()
|
| 678 |
|
| 679 |
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 680 |
if parsed_plan is None:
|
|
@@ -7,6 +7,8 @@ transformers>=4.53.3
|
|
| 7 |
spaces
|
| 8 |
sentencepiece
|
| 9 |
accelerate
|
|
|
|
|
|
|
| 10 |
autoawq
|
| 11 |
flash-attn>=2.5.0
|
| 12 |
timm
|
|
|
|
| 7 |
spaces
|
| 8 |
sentencepiece
|
| 9 |
accelerate
|
| 10 |
+
vllm>=0.6.0
|
| 11 |
+
llmcompressor>=0.1.0
|
| 12 |
autoawq
|
| 13 |
flash-attn>=2.5.0
|
| 14 |
timm
|
|
@@ -59,7 +59,7 @@ def test_api():
|
|
| 59 |
'extra_guidance': '',
|
| 60 |
'difficulty': 'intermediate',
|
| 61 |
'tags': 'math, python',
|
| 62 |
-
'model_choice': 'Router-Qwen3-32B-
|
| 63 |
'max_new_tokens': 512, # Smaller for quick test
|
| 64 |
'temperature': 0.2,
|
| 65 |
'top_p': 0.9
|
|
@@ -186,7 +186,7 @@ if __name__ == "__main__":
|
|
| 186 |
print(f" client = Client('{API_URL}')")
|
| 187 |
print(" result = client.predict(")
|
| 188 |
print(" 'user_task', '', 'acceptance', '', 'intermediate', 'tags',")
|
| 189 |
-
print(" 'Router-Qwen3-32B-
|
| 190 |
print(" api_name='//generate_router_plan_streaming'")
|
| 191 |
print(" )")
|
| 192 |
else:
|
|
|
|
| 59 |
'extra_guidance': '',
|
| 60 |
'difficulty': 'intermediate',
|
| 61 |
'tags': 'math, python',
|
| 62 |
+
'model_choice': 'Router-Qwen3-32B-AWQ',
|
| 63 |
'max_new_tokens': 512, # Smaller for quick test
|
| 64 |
'temperature': 0.2,
|
| 65 |
'top_p': 0.9
|
|
|
|
| 186 |
print(f" client = Client('{API_URL}')")
|
| 187 |
print(" result = client.predict(")
|
| 188 |
print(" 'user_task', '', 'acceptance', '', 'intermediate', 'tags',")
|
| 189 |
+
print(" 'Router-Qwen3-32B-AWQ', 512, 0.2, 0.9,")
|
| 190 |
print(" api_name='//generate_router_plan_streaming'")
|
| 191 |
print(" )")
|
| 192 |
else:
|