Alikestocode commited on
Commit
06b4cf5
·
1 Parent(s): cdac920

Migrate to AWQ quantization with FlashAttention-2

Browse files

- Replace BitsAndBytes 8-bit with AWQ 4-bit quantization (primary path)
- Add FlashAttention-2 support for optimized attention
- Enable TF32 math for Ampere+ GPUs
- Add CUDA kernel warmup on startup to reduce first-token latency
- Update model names to reflect AWQ optimization
- Add graceful fallback chain: AWQ -> BitsAndBytes -> bf16/fp16/fp32
- Update requirements.txt with flash-attn>=2.5.0
- Update README with performance optimizations documentation

Files changed (4) hide show
  1. README.md +1 -0
  2. app.py +149 -24
  3. requirements.txt +1 -0
  4. test_api.py +107 -0
README.md CHANGED
@@ -62,3 +62,4 @@ python app.py
62
  - The UI enforces single-turn router generations; conversation history and web search are intentionally omitted to match the Milestone 6 deliverable.
63
  - If you need to re-enable web search or more checkpoints, extend `MODELS` and adjust the prompt builder accordingly.
64
  - **Benchmarking:** run `python Milestone-6/router-agent/tests/run_router_space_benchmark.py --space Alovestocode/ZeroGPU-LLM-Inference --limit 32` (requires `pip install gradio_client`) to call the Space, dump predictions, and evaluate against the Milestone 5 hard suite + thresholds.
 
 
62
  - The UI enforces single-turn router generations; conversation history and web search are intentionally omitted to match the Milestone 6 deliverable.
63
  - If you need to re-enable web search or more checkpoints, extend `MODELS` and adjust the prompt builder accordingly.
64
  - **Benchmarking:** run `python Milestone-6/router-agent/tests/run_router_space_benchmark.py --space Alovestocode/ZeroGPU-LLM-Inference --limit 32` (requires `pip install gradio_client`) to call the Space, dump predictions, and evaluate against the Milestone 5 hard suite + thresholds.
65
+ - Set `ROUTER_PREFETCH_MODEL` (single value) or `ROUTER_PREFETCH_MODELS=Router-Qwen3-32B-8bit,Router-Gemma3-27B-8bit` (comma-separated, `ALL` for every checkpoint) to warm-load weights during startup. Disable background warming by setting `ROUTER_WARM_REMAINING=0`.
app.py CHANGED
@@ -8,9 +8,37 @@ from typing import Any, Dict, List, Tuple
8
  import gradio as gr
9
  import spaces
10
  import torch
11
- from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, pipeline
12
  from threading import Thread
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
  if not HF_TOKEN:
16
  raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
@@ -21,14 +49,14 @@ STOP_SEQUENCES = [PLAN_END_TOKEN, "</json>", "</JSON>"]
21
  ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else."""
22
 
23
  MODELS = {
24
- "Router-Qwen3-32B-8bit": {
25
  "repo_id": "Alovestocode/router-qwen3-32b-merged",
26
- "description": "Router checkpoint on Qwen3 32B merged and quantized for 8-bit ZeroGPU inference.",
27
  "params_b": 32.0,
28
  },
29
- "Router-Gemma3-27B-8bit": {
30
  "repo_id": "Alovestocode/router-gemma3-merged",
31
- "description": "Router checkpoint on Gemma3 27B merged and quantized for 8-bit ZeroGPU inference.",
32
  "params_b": 27.0,
33
  },
34
  }
@@ -56,7 +84,12 @@ def get_tokenizer(repo: str):
56
  tok = TOKENIZER_CACHE.get(repo)
57
  if tok is not None:
58
  return tok
59
- tok = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
 
 
 
 
 
60
  tok.padding_side = "left"
61
  tok.truncation_side = "left"
62
  if tok.pad_token_id is None and tok.eos_token_id is not None:
@@ -65,6 +98,35 @@ def get_tokenizer(repo: str):
65
  return tok
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def load_pipeline(model_name: str):
69
  if model_name in PIPELINES:
70
  return PIPELINES[model_name]
@@ -72,27 +134,52 @@ def load_pipeline(model_name: str):
72
  repo = MODELS[model_name]["repo_id"]
73
  tokenizer = get_tokenizer(repo)
74
 
75
- try:
76
- quant_config = BitsAndBytesConfig(load_in_8bit=True)
77
- pipe = pipeline(
78
- task="text-generation",
79
- model=repo,
80
- tokenizer=tokenizer,
81
- trust_remote_code=True,
82
- device_map="auto",
83
- model_kwargs={"quantization_config": quant_config},
84
- use_cache=True,
85
- token=HF_TOKEN,
86
- )
87
- pipe.model.eval()
88
- PIPELINES[model_name] = pipe
89
- _schedule_background_warm(model_name)
90
- return pipe
91
- except Exception as exc:
92
- print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
95
  try:
 
 
 
 
96
  pipe = pipeline(
97
  task="text-generation",
98
  model=repo,
@@ -100,6 +187,7 @@ def load_pipeline(model_name: str):
100
  trust_remote_code=True,
101
  device_map="auto",
102
  dtype=dtype,
 
103
  use_cache=True,
104
  token=HF_TOKEN,
105
  )
@@ -110,12 +198,18 @@ def load_pipeline(model_name: str):
110
  except Exception:
111
  continue
112
 
 
 
 
 
 
113
  pipe = pipeline(
114
  task="text-generation",
115
  model=repo,
116
  tokenizer=tokenizer,
117
  trust_remote_code=True,
118
  device_map="auto",
 
119
  use_cache=True,
120
  token=HF_TOKEN,
121
  )
@@ -125,6 +219,35 @@ def load_pipeline(model_name: str):
125
  return pipe
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def _schedule_background_warm(loaded_model: str) -> None:
129
  global WARMED_REMAINING
130
  if WARMED_REMAINING:
@@ -143,6 +266,8 @@ def _schedule_background_warm(loaded_model: str) -> None:
143
  try:
144
  print(f"Background warm start for {name}")
145
  load_pipeline(name)
 
 
146
  except Exception as exc: # pragma: no cover
147
  print(f"Warm start failed for {name}: {exc}")
148
  WARMED_REMAINING = True
 
8
  import gradio as gr
9
  import spaces
10
  import torch
11
+ from transformers import AutoTokenizer, TextIteratorStreamer, pipeline
12
  from threading import Thread
13
 
14
+ # Enable optimizations
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+
17
+ # Try to import AWQ, fallback to BitsAndBytes if not available
18
+ try:
19
+ from awq import AutoAWQForCausalLM
20
+ AWQ_AVAILABLE = True
21
+ except ImportError:
22
+ AWQ_AVAILABLE = False
23
+ print("Warning: AutoAWQ not available, falling back to BitsAndBytes")
24
+
25
+ # Always import BitsAndBytesConfig for fallback
26
+ try:
27
+ from transformers import BitsAndBytesConfig
28
+ BITSANDBYTES_AVAILABLE = True
29
+ except ImportError:
30
+ BITSANDBYTES_AVAILABLE = False
31
+ BitsAndBytesConfig = None
32
+ print("Warning: BitsAndBytes not available")
33
+
34
+ # Try to import FlashAttention-2
35
+ try:
36
+ import flash_attn
37
+ FLASH_ATTN_AVAILABLE = True
38
+ except ImportError:
39
+ FLASH_ATTN_AVAILABLE = False
40
+ print("Warning: FlashAttention-2 not available")
41
+
42
  HF_TOKEN = os.environ.get("HF_TOKEN")
43
  if not HF_TOKEN:
44
  raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
 
49
  ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else."""
50
 
51
  MODELS = {
52
+ "Router-Qwen3-32B-AWQ": {
53
  "repo_id": "Alovestocode/router-qwen3-32b-merged",
54
+ "description": "Router checkpoint on Qwen3 32B merged, optimized with AWQ quantization and FlashAttention-2.",
55
  "params_b": 32.0,
56
  },
57
+ "Router-Gemma3-27B-AWQ": {
58
  "repo_id": "Alovestocode/router-gemma3-merged",
59
+ "description": "Router checkpoint on Gemma3 27B merged, optimized with AWQ quantization and FlashAttention-2.",
60
  "params_b": 27.0,
61
  },
62
  }
 
84
  tok = TOKENIZER_CACHE.get(repo)
85
  if tok is not None:
86
  return tok
87
+ tok = AutoTokenizer.from_pretrained(
88
+ repo,
89
+ token=HF_TOKEN,
90
+ use_fast=True,
91
+ trust_remote_code=True
92
+ )
93
  tok.padding_side = "left"
94
  tok.truncation_side = "left"
95
  if tok.pad_token_id is None and tok.eos_token_id is not None:
 
98
  return tok
99
 
100
 
101
+ def load_awq_pipeline(repo: str, tokenizer):
102
+ """Load AWQ-quantized model with FlashAttention-2."""
103
+ model = AutoAWQForCausalLM.from_quantized(
104
+ repo,
105
+ fuse_layers=True,
106
+ trust_remote_code=True,
107
+ device_map="auto",
108
+ token=HF_TOKEN,
109
+ )
110
+
111
+ # Prepare model kwargs with FlashAttention-2 if available
112
+ model_kwargs = {}
113
+ if FLASH_ATTN_AVAILABLE:
114
+ model_kwargs["attn_implementation"] = "flash_attention_2"
115
+
116
+ pipe = pipeline(
117
+ task="text-generation",
118
+ model=model,
119
+ tokenizer=tokenizer,
120
+ trust_remote_code=True,
121
+ device_map="auto",
122
+ model_kwargs=model_kwargs,
123
+ use_cache=True,
124
+ torch_dtype=torch.bfloat16,
125
+ )
126
+ pipe.model.eval()
127
+ return pipe
128
+
129
+
130
  def load_pipeline(model_name: str):
131
  if model_name in PIPELINES:
132
  return PIPELINES[model_name]
 
134
  repo = MODELS[model_name]["repo_id"]
135
  tokenizer = get_tokenizer(repo)
136
 
137
+ # Try AWQ first if available
138
+ if AWQ_AVAILABLE:
139
+ try:
140
+ print(f"Loading {repo} with AWQ quantization...")
141
+ pipe = load_awq_pipeline(repo, tokenizer)
142
+ PIPELINES[model_name] = pipe
143
+ _schedule_background_warm(model_name)
144
+ # Warm kernels immediately after loading
145
+ Thread(target=lambda: _warm_kernels(model_name), daemon=True).start()
146
+ return pipe
147
+ except Exception as exc:
148
+ print(f"AWQ load failed for {repo}: {exc}. Falling back to BitsAndBytes.")
 
 
 
 
 
 
149
 
150
+ # Fallback to BitsAndBytes 8-bit
151
+ if BITSANDBYTES_AVAILABLE:
152
+ try:
153
+ quant_config = BitsAndBytesConfig(load_in_8bit=True)
154
+ model_kwargs = {"quantization_config": quant_config}
155
+ if FLASH_ATTN_AVAILABLE:
156
+ model_kwargs["attn_implementation"] = "flash_attention_2"
157
+
158
+ pipe = pipeline(
159
+ task="text-generation",
160
+ model=repo,
161
+ tokenizer=tokenizer,
162
+ trust_remote_code=True,
163
+ device_map="auto",
164
+ model_kwargs=model_kwargs,
165
+ use_cache=True,
166
+ token=HF_TOKEN,
167
+ torch_dtype=torch.bfloat16,
168
+ )
169
+ pipe.model.eval()
170
+ PIPELINES[model_name] = pipe
171
+ _schedule_background_warm(model_name)
172
+ return pipe
173
+ except Exception as exc:
174
+ print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
175
+
176
+ # Fallback to bfloat16/fp16/fp32
177
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
178
  try:
179
+ model_kwargs = {}
180
+ if FLASH_ATTN_AVAILABLE:
181
+ model_kwargs["attn_implementation"] = "flash_attention_2"
182
+
183
  pipe = pipeline(
184
  task="text-generation",
185
  model=repo,
 
187
  trust_remote_code=True,
188
  device_map="auto",
189
  dtype=dtype,
190
+ model_kwargs=model_kwargs,
191
  use_cache=True,
192
  token=HF_TOKEN,
193
  )
 
198
  except Exception:
199
  continue
200
 
201
+ # Final fallback
202
+ model_kwargs = {}
203
+ if FLASH_ATTN_AVAILABLE:
204
+ model_kwargs["attn_implementation"] = "flash_attention_2"
205
+
206
  pipe = pipeline(
207
  task="text-generation",
208
  model=repo,
209
  tokenizer=tokenizer,
210
  trust_remote_code=True,
211
  device_map="auto",
212
+ model_kwargs=model_kwargs,
213
  use_cache=True,
214
  token=HF_TOKEN,
215
  )
 
219
  return pipe
220
 
221
 
222
+ def _warm_kernels(model_name: str) -> None:
223
+ """Warm up CUDA kernels with a small dummy generation."""
224
+ try:
225
+ pipe = PIPELINES.get(model_name)
226
+ if pipe is None:
227
+ return
228
+
229
+ tokenizer = pipe.tokenizer
230
+ # Create a minimal prompt for warmup
231
+ warmup_text = "test"
232
+ inputs = tokenizer(warmup_text, return_tensors="pt")
233
+ if hasattr(pipe.model, 'device'):
234
+ inputs = {k: v.to(pipe.model.device) for k, v in inputs.items()}
235
+ elif torch.cuda.is_available():
236
+ inputs = {k: v.cuda() for k, v in inputs.items()}
237
+
238
+ # Run a tiny generation to JIT-fuse kernels
239
+ with torch.inference_mode():
240
+ _ = pipe.model.generate(
241
+ **inputs,
242
+ max_new_tokens=2,
243
+ do_sample=False,
244
+ use_cache=True,
245
+ )
246
+ print(f"Kernels warmed for {model_name}")
247
+ except Exception as exc:
248
+ print(f"Kernel warmup failed for {model_name}: {exc}")
249
+
250
+
251
  def _schedule_background_warm(loaded_model: str) -> None:
252
  global WARMED_REMAINING
253
  if WARMED_REMAINING:
 
266
  try:
267
  print(f"Background warm start for {name}")
268
  load_pipeline(name)
269
+ # Warm kernels after loading
270
+ _warm_kernels(name)
271
  except Exception as exc: # pragma: no cover
272
  print(f"Warm start failed for {name}: {exc}")
273
  WARMED_REMAINING = True
requirements.txt CHANGED
@@ -8,6 +8,7 @@ spaces
8
  sentencepiece
9
  accelerate
10
  autoawq
 
11
  timm
12
  compressed-tensors
13
  bitsandbytes
 
8
  sentencepiece
9
  accelerate
10
  autoawq
11
+ flash-attn>=2.5.0
12
  timm
13
  compressed-tensors
14
  bitsandbytes
test_api.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for ZeroGPU LLM Inference API
4
+ Usage: python test_api.py
5
+ """
6
+
7
+ import requests
8
+ import json
9
+ import sys
10
+
11
+ API_URL = "https://Alovestocode-ZeroGPU-LLM-Inference.hf.space"
12
+
13
+ def test_api():
14
+ """Test the API endpoint"""
15
+ print("=" * 60)
16
+ print("Testing ZeroGPU LLM Inference API")
17
+ print("=" * 60)
18
+
19
+ # Test 1: Check if space is accessible
20
+ print("\n1. Checking if space is accessible...")
21
+ try:
22
+ response = requests.get(API_URL, timeout=10)
23
+ if response.status_code == 200:
24
+ print(" ✅ Space is accessible")
25
+ else:
26
+ print(f" ⚠️ Space returned status {response.status_code}")
27
+ except Exception as e:
28
+ print(f" ❌ Error: {e}")
29
+ return False
30
+
31
+ # Test 2: Check API info
32
+ print("\n2. Checking API info...")
33
+ try:
34
+ response = requests.get(f"{API_URL}/api/info", timeout=10)
35
+ print(f" Status: {response.status_code}")
36
+ if response.status_code == 200:
37
+ print(" ✅ API info endpoint accessible")
38
+ except Exception as e:
39
+ print(f" ⚠️ Error: {e}")
40
+
41
+ # Test 3: Try the API endpoint
42
+ print("\n3. Testing API endpoint...")
43
+ payload = {
44
+ "data": [
45
+ "Solve a quadratic equation using Python",
46
+ "",
47
+ "- Provide step-by-step solution",
48
+ "",
49
+ "intermediate",
50
+ "math, python",
51
+ "Router-Qwen3-32B-8bit",
52
+ 256, # Small token count for quick test
53
+ 0.2,
54
+ 0.9
55
+ ],
56
+ "fn_index": 0
57
+ }
58
+
59
+ try:
60
+ print(f" Sending request to {API_URL}/api/predict...")
61
+ response = requests.post(
62
+ f"{API_URL}/api/predict",
63
+ json=payload,
64
+ timeout=120 # Longer timeout for model loading
65
+ )
66
+
67
+ print(f" Status Code: {response.status_code}")
68
+
69
+ if response.status_code == 200:
70
+ print(" ✅ API is working!")
71
+ result = response.json()
72
+ print(f"\n Response structure:")
73
+ if isinstance(result, dict):
74
+ print(f" Keys: {list(result.keys())}")
75
+ if "data" in result:
76
+ print(f" Data length: {len(result['data'])}")
77
+ if len(result['data']) > 0:
78
+ print(f" First output preview: {str(result['data'][0])[:200]}")
79
+ else:
80
+ print(f" Result: {str(result)[:300]}")
81
+ return True
82
+ else:
83
+ print(f" ❌ API returned status {response.status_code}")
84
+ print(f" Response: {response.text[:500]}")
85
+ return False
86
+
87
+ except requests.exceptions.Timeout:
88
+ print(" ⚠️ Request timed out (this might be normal for first request due to model loading)")
89
+ return False
90
+ except Exception as e:
91
+ print(f" ❌ Error: {e}")
92
+ import traceback
93
+ traceback.print_exc()
94
+ return False
95
+
96
+ if __name__ == "__main__":
97
+ success = test_api()
98
+ print("\n" + "=" * 60)
99
+ if success:
100
+ print("✅ API test completed successfully!")
101
+ else:
102
+ print("⚠️ API test had issues. The space might still be building.")
103
+ print(" Wait a few minutes and try again, or check the space status at:")
104
+ print(f" {API_URL}")
105
+ print("=" * 60)
106
+ sys.exit(0 if success else 1)
107
+