Alikestocode commited on
Commit
bf2fdae
·
1 Parent(s): 4c3d05b

Fix deprecation warnings and improve error handling

Browse files

- Replace deprecated load_in_8bit with BitsAndBytesConfig
- Fix dropdown value to dynamically use first model
- Increase GPU duration to 600s for model loading
- Add better error handling for GPU task aborted errors
- Add model_choice validation

Files changed (1) hide show
  1. app.py +246 -540
app.py CHANGED
@@ -1,77 +1,73 @@
 
 
 
1
  import os
2
- import time
3
- import gc
4
- import sys
5
- import threading
6
- from itertools import islice
7
- from datetime import datetime
8
- import re # for parsing <think> blocks
9
  import gradio as gr
 
10
  import torch
11
- from transformers import pipeline, TextIteratorStreamer, StoppingCriteria
12
- from transformers import AutoTokenizer
13
- from ddgs import DDGS
14
- import spaces # Import spaces early to enable ZeroGPU support
15
- from torch.utils._pytree import tree_map
16
-
17
- # Global event to signal cancellation from the UI thread to the generation thread
18
- cancel_event = threading.Event()
19
 
20
- access_token=os.environ['HF_TOKEN']
 
 
21
 
22
- # Optional: Disable GPU visibility if you wish to force CPU usage
23
- # os.environ["CUDA_VISIBLE_DEVICES"] = ""
24
 
25
- # ------------------------------
26
- # Torch-Compatible Model Definitions with Adjusted Descriptions
27
- # ------------------------------
28
  MODELS = {
29
  "Router-Qwen3-32B-8bit": {
30
  "repo_id": "Alovestocode/router-qwen3-32b-merged",
31
- "description": "CourseGPT-Pro router checkpoint built on Qwen3 32B and quantized to 8-bit for ZeroGPU deployment. Optimised for orchestrating math, code, and general-search agents while staying within the ZeroGPU memory envelope.",
32
  "params_b": 32.0,
33
  },
34
  "Router-Gemma3-27B-8bit": {
35
  "repo_id": "Alovestocode/router-gemma3-merged",
36
- "description": "CourseGPT-Pro router checkpoint built on Gemma3 27B and quantized to 8-bit. Provides the same JSON routing schema with a Gemma-flavoured inductive bias for math-heavy prompts.",
37
  "params_b": 27.0,
38
  },
39
  }
40
 
41
- # Global cache for pipelines to avoid re-loading.
42
- PIPELINES = {}
 
 
 
 
 
 
 
 
 
 
43
 
44
- def load_pipeline(model_name):
45
- """
46
- Load and cache a transformers pipeline for text generation.
47
- Prefers 8-bit loading (bitsandbytes) to stay within ZeroGPU limits,
48
- falling back to bf16/fp16/fp32 if quantized loading is unavailable.
49
- """
50
- global PIPELINES
51
  if model_name in PIPELINES:
52
  return PIPELINES[model_name]
53
 
54
  repo = MODELS[model_name]["repo_id"]
55
- tokenizer = AutoTokenizer.from_pretrained(repo, token=access_token)
56
 
57
- # First try to load in 8-bit to minimise VRAM usage.
58
  try:
 
59
  pipe = pipeline(
60
  task="text-generation",
61
  model=repo,
62
  tokenizer=tokenizer,
63
  trust_remote_code=True,
64
  device_map="auto",
65
- model_kwargs={"load_in_8bit": True},
66
  use_cache=True,
67
- token=access_token,
68
  )
69
  PIPELINES[model_name] = pipe
70
  return pipe
71
  except Exception as exc:
72
  print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
73
 
74
- # Fallback ladder when 8-bit is not available.
75
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
76
  try:
77
  pipe = pipeline(
@@ -79,17 +75,16 @@ def load_pipeline(model_name):
79
  model=repo,
80
  tokenizer=tokenizer,
81
  trust_remote_code=True,
82
- dtype=dtype,
83
  device_map="auto",
 
84
  use_cache=True,
85
- token=access_token,
86
  )
87
  PIPELINES[model_name] = pipe
88
  return pipe
89
  except Exception:
90
  continue
91
 
92
- # Final fallback with framework defaults.
93
  pipe = pipeline(
94
  task="text-generation",
95
  model=repo,
@@ -97,523 +92,234 @@ def load_pipeline(model_name):
97
  trust_remote_code=True,
98
  device_map="auto",
99
  use_cache=True,
100
- token=access_token,
101
  )
102
  PIPELINES[model_name] = pipe
103
  return pipe
104
 
105
 
106
- def retrieve_context(query, max_results=6, max_chars=50):
107
- """
108
- Retrieve search snippets from DuckDuckGo (runs in background).
109
- Returns a list of result strings.
110
- """
111
- try:
112
- with DDGS() as ddgs:
113
- return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}"
114
- for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))]
115
- except Exception:
116
- return []
117
-
118
- def format_conversation(history, system_prompt, tokenizer):
119
- if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
120
- messages = [{"role": "system", "content": system_prompt.strip()}] + history
121
- return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
122
- else:
123
- # Fallback for base LMs without chat template
124
- prompt = system_prompt.strip() + "\n"
125
- for msg in history:
126
- if msg['role'] == 'user':
127
- prompt += "User: " + msg['content'].strip() + "\n"
128
- elif msg['role'] == 'assistant':
129
- prompt += "Assistant: " + msg['content'].strip() + "\n"
130
- if not prompt.strip().endswith("Assistant:"):
131
- prompt += "Assistant: "
132
- return prompt
133
-
134
- def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
135
- # Get model size from the MODELS dict (more reliable than string parsing)
136
- model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
137
-
138
- # Only use AOT for models >= 2B parameters
139
- use_aot = model_size >= 2
140
-
141
- # Adjusted for H200 performance: faster inference, quicker compilation
142
- base_duration = 20 if not use_aot else 40 # Reduced base times
143
- token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
144
- search_duration = 10 if enable_search else 0 # Reduced search time
145
- aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
146
-
147
- return base_duration + token_duration + search_duration + aot_compilation_buffer
148
-
149
- @spaces.GPU(duration=get_duration)
150
- def chat_response(user_msg, chat_history, system_prompt,
151
- enable_search, max_results, max_chars,
152
- model_name, max_tokens, temperature,
153
- top_k, top_p, repeat_penalty, search_timeout):
154
- """
155
- Generates streaming chat responses, optionally with background web search.
156
- This version includes cancellation support.
157
- """
158
- # Clear the cancellation event at the start of a new generation
159
- cancel_event.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- history = list(chat_history or [])
162
- history.append({'role': 'user', 'content': user_msg})
163
-
164
- # Launch web search if enabled
165
- debug = ''
166
- search_results = []
167
- if enable_search:
168
- debug = 'Search task started.'
169
- thread_search = threading.Thread(
170
- target=lambda: search_results.extend(
171
- retrieve_context(user_msg, int(max_results), int(max_chars))
172
- )
173
- )
174
- thread_search.daemon = True
175
- thread_search.start()
176
- else:
177
- debug = 'Web search disabled.'
178
 
179
  try:
180
- cur_date = datetime.now().strftime('%Y-%m-%d')
181
- # merge any fetched search results into the system prompt
182
- if search_results:
183
-
184
- enriched = system_prompt.strip() + \
185
- f'''\n# The following contents are the search results related to the user's message:
186
- {search_results}
187
- In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
188
- When responding, please keep the following points in mind:
189
- - Today is {cur_date}.
190
- - Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
191
- - For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
192
- - For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
193
- - If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
194
- - For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
195
- - Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
196
- - Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
197
- - Unless the user requests otherwise, your response should be in the same language as the user's question.
198
- # The user's message is:
199
- '''
200
- else:
201
- enriched = system_prompt
202
-
203
- # wait up to 1s for snippets, then replace debug with them
204
- if enable_search:
205
- thread_search.join(timeout=float(search_timeout))
206
- if search_results:
207
- debug = "### Search results merged into prompt\n\n" + "\n".join(
208
- f"- {r}" for r in search_results
209
- )
210
- else:
211
- debug = "*No web search results found.*"
212
-
213
- # merge fetched snippets into the system prompt
214
- if search_results:
215
- enriched = system_prompt.strip() + \
216
- f'''\n# The following contents are the search results related to the user's message:
217
- {search_results}
218
- In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
219
- When responding, please keep the following points in mind:
220
- - Today is {cur_date}.
221
- - Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
222
- - For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
223
- - For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
224
- - If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
225
- - For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
226
- - Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
227
- - Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
228
- - Unless the user requests otherwise, your response should be in the same language as the user's question.
229
- # The user's message is:
230
- '''
231
- else:
232
- enriched = system_prompt
233
-
234
- pipe = load_pipeline(model_name)
235
-
236
- prompt = format_conversation(history, enriched, pipe.tokenizer)
237
- prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
238
- streamer = TextIteratorStreamer(pipe.tokenizer,
239
- skip_prompt=True,
240
- skip_special_tokens=True)
241
- gen_thread = threading.Thread(
242
- target=pipe,
243
- args=(prompt,),
244
- kwargs={
245
- 'max_new_tokens': max_tokens,
246
- 'temperature': temperature,
247
- 'top_k': top_k,
248
- 'top_p': top_p,
249
- 'repetition_penalty': repeat_penalty,
250
- 'streamer': streamer,
251
- 'return_full_text': False,
252
- }
253
  )
254
- gen_thread.start()
255
-
256
- # Buffers for thought vs answer
257
- thought_buf = ''
258
- answer_buf = ''
259
- in_thought = False
260
- assistant_message_started = False
261
-
262
- # First yield contains the user message
263
- yield history, debug
264
-
265
- # Stream tokens
266
- for chunk in streamer:
267
- # Check for cancellation signal
268
- if cancel_event.is_set():
269
- if assistant_message_started and history and history[-1]['role'] == 'assistant':
270
- history[-1]['content'] += " [Generation Canceled]"
271
- yield history, debug
272
- break
273
-
274
- text = chunk
275
-
276
- # Detect start of thinking
277
- if not in_thought and '<think>' in text:
278
- in_thought = True
279
- history.append({'role': 'assistant', 'content': '', 'metadata': {'title': '💭 Thought'}})
280
- assistant_message_started = True
281
- after = text.split('<think>', 1)[1]
282
- thought_buf += after
283
- if '</think>' in thought_buf:
284
- before, after2 = thought_buf.split('</think>', 1)
285
- history[-1]['content'] = before.strip()
286
- in_thought = False
287
- answer_buf = after2
288
- history.append({'role': 'assistant', 'content': answer_buf})
289
- else:
290
- history[-1]['content'] = thought_buf
291
- yield history, debug
292
- continue
293
-
294
- if in_thought:
295
- thought_buf += text
296
- if '</think>' in thought_buf:
297
- before, after2 = thought_buf.split('</think>', 1)
298
- history[-1]['content'] = before.strip()
299
- in_thought = False
300
- answer_buf = after2
301
- history.append({'role': 'assistant', 'content': answer_buf})
302
- else:
303
- history[-1]['content'] = thought_buf
304
- yield history, debug
305
- continue
306
-
307
- # Stream answer
308
- if not assistant_message_started:
309
- history.append({'role': 'assistant', 'content': ''})
310
- assistant_message_started = True
311
-
312
- answer_buf += text
313
- history[-1]['content'] = answer_buf.strip()
314
- yield history, debug
315
-
316
- gen_thread.join()
317
- yield history, debug + prompt_debug
318
- except GeneratorExit:
319
- # Handle cancellation gracefully
320
- print("Chat response cancelled.")
321
- # Don't yield anything - let the cancellation propagate
322
- return
323
- except Exception as e:
324
- history.append({'role': 'assistant', 'content': f"Error: {e}"})
325
- yield history, debug
326
- finally:
327
- gc.collect()
328
-
329
-
330
- def update_default_prompt(enable_search):
331
- return f"You are a helpful assistant."
332
-
333
- def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
334
- """Calculate and format the estimated GPU duration for current settings."""
335
- try:
336
- dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
337
- duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
338
- enable_search, max_results, max_chars, model_name,
339
- max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
340
- model_size = MODELS[model_name].get("params_b", 4.0)
341
- return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
342
- f"📊 **Model Size:** {model_size:.1f}B parameters\n"
343
- f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
344
- except Exception as e:
345
- return f"⚠️ Error calculating estimate: {e}"
346
-
347
- # ------------------------------
348
- # Gradio UI
349
- # ------------------------------
350
- with gr.Blocks(
351
- title="LLM Inference with ZeroGPU",
352
- theme=gr.themes.Soft(
353
- primary_hue="indigo",
354
- secondary_hue="purple",
355
- neutral_hue="slate",
356
- radius_size="lg",
357
- font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
358
- ),
359
- css="""
360
- .duration-estimate { background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%); border-left: 4px solid #667eea; padding: 12px; border-radius: 8px; margin: 16px 0; }
361
- .chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
362
- button.primary { font-weight: 600; }
363
- .gradio-accordion { margin-bottom: 12px; }
364
- """
365
- ) as demo:
366
- # Header
367
- gr.Markdown("""
368
- # 🧠 ZeroGPU LLM Inference
369
- ### Powered by Hugging Face ZeroGPU with Web Search Integration
370
- """)
371
-
372
- with gr.Row():
373
- # Left Panel - Configuration
374
- with gr.Column(scale=3):
375
- # Core Settings (Always Visible)
376
- with gr.Group():
377
- gr.Markdown("### ⚙️ Core Settings")
378
- model_dd = gr.Dropdown(
379
- label="🤖 Model",
380
- choices=list(MODELS.keys()),
381
- value="Qwen3-1.7B",
382
- info="Select the language model to use"
383
  )
384
- search_chk = gr.Checkbox(
385
- label="🔍 Enable Web Search",
386
- value=False,
387
- info="Augment responses with real-time web data"
388
  )
389
- sys_prompt = gr.Textbox(
390
- label="📝 System Prompt",
 
391
  lines=3,
392
- value=update_default_prompt(search_chk.value),
393
- placeholder="Define the assistant's behavior and personality..."
394
- )
395
-
396
- # Duration Estimate
397
- duration_display = gr.Markdown(
398
- value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0),
399
- elem_classes="duration-estimate"
400
- )
401
-
402
- # Advanced Settings (Collapsible)
403
- with gr.Accordion("🎛️ Advanced Generation Parameters", open=False):
404
- max_tok = gr.Slider(
405
- 64, 16384, value=1024, step=32,
406
- label="Max Tokens",
407
- info="Maximum length of generated response"
408
  )
409
- temp = gr.Slider(
410
- 0.1, 2.0, value=0.7, step=0.1,
411
- label="Temperature",
412
- info="Higher = more creative, Lower = more focused"
413
- )
414
- with gr.Row():
415
- k = gr.Slider(
416
- 1, 100, value=40, step=1,
417
- label="Top-K",
418
- info="Number of top tokens to consider"
419
- )
420
- p = gr.Slider(
421
- 0.1, 1.0, value=0.9, step=0.05,
422
- label="Top-P",
423
- info="Nucleus sampling threshold"
424
- )
425
- rp = gr.Slider(
426
- 1.0, 2.0, value=1.2, step=0.1,
427
- label="Repetition Penalty",
428
- info="Penalize repeated tokens"
429
- )
430
-
431
- # Web Search Settings (Collapsible)
432
- with gr.Accordion("🌐 Web Search Settings", open=False, visible=False) as search_settings:
433
- mr = gr.Number(
434
- value=4, precision=0,
435
- label="Max Results",
436
- info="Number of search results to retrieve"
437
  )
438
- mc = gr.Number(
439
- value=50, precision=0,
440
- label="Max Chars/Result",
441
- info="Character limit per search result"
 
 
442
  )
443
- st = gr.Slider(
444
- minimum=0.0, maximum=30.0, step=0.5, value=5.0,
445
- label="Search Timeout (s)",
446
- info="Maximum time to wait for search results"
 
447
  )
448
-
449
- # Actions
450
- with gr.Row():
451
- clr = gr.Button("🗑️ Clear Chat", variant="secondary", scale=1)
452
-
453
- # Right Panel - Chat Interface
454
- with gr.Column(scale=7):
455
- chat = gr.Chatbot(
456
- type="messages",
457
- height=600,
458
- label="💬 Conversation",
459
- show_copy_button=True,
460
- avatar_images=(None, "🤖"),
461
- bubble_full_width=False
462
- )
463
-
464
- # Input Area
465
- with gr.Row():
466
- txt = gr.Textbox(
467
- placeholder="💭 Type your message here... (Press Enter to send)",
468
- scale=9,
469
- container=False,
470
- show_label=False,
471
- lines=1,
472
- max_lines=5
473
  )
474
- with gr.Column(scale=1, min_width=120):
475
- submit_btn = gr.Button("📤 Send", variant="primary", size="lg")
476
- cancel_btn = gr.Button("⏹️ Stop", variant="stop", visible=False, size="lg")
477
-
478
- # Example Prompts
479
- gr.Examples(
480
- examples=[
481
- ["Explain quantum computing in simple terms"],
482
- ["Write a Python function to calculate fibonacci numbers"],
483
- ["What are the latest developments in AI? (Enable web search)"],
484
- ["Tell me a creative story about a time traveler"],
485
- ["Help me debug this code: def add(a,b): return a+b+1"]
486
- ],
487
- inputs=txt,
488
- label="💡 Example Prompts"
489
- )
490
-
491
- # Debug/Status Info (Collapsible)
492
- with gr.Accordion("🔍 Debug Info", open=False):
493
- dbg = gr.Markdown()
494
-
495
- # Footer
496
- gr.Markdown("""
497
- ---
498
- 💡 **Tips:**
499
- - Use **Advanced Parameters** to fine-tune creativity and response length
500
- - Enable **Web Search** for real-time, up-to-date information
501
- - Try different **models** for various tasks (reasoning, coding, general chat)
502
- - Click the **Copy** button on responses to save them to your clipboard
503
- """, elem_classes="footer")
504
-
505
- # --- Event Listeners ---
506
-
507
- # Group all inputs for cleaner event handling
508
- chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
509
- # Group all UI components that can be updated.
510
- ui_components = [chat, dbg, txt, submit_btn, cancel_btn]
511
-
512
- def submit_and_manage_ui(user_msg, chat_history, *args):
513
- """
514
- Orchestrator function that manages UI state and calls the backend chat function.
515
- It uses a try...finally block to ensure the UI is always reset.
516
- """
517
- if not user_msg.strip():
518
- # If the message is empty, do nothing.
519
- # We yield an empty dict to avoid any state changes.
520
- yield {}
521
- return
522
-
523
- # 1. Update UI to "generating" state.
524
- # Crucially, we do NOT update the `chat` component here, as the backend
525
- # will provide the correctly formatted history in the first response chunk.
526
- yield {
527
- txt: gr.update(value="", interactive=False),
528
- submit_btn: gr.update(interactive=False),
529
- cancel_btn: gr.update(visible=True),
530
- }
531
-
532
- cancelled = False
533
- try:
534
- # 2. Call the backend and stream updates
535
- backend_args = [user_msg, chat_history] + list(args)
536
- for response_chunk in chat_response(*backend_args):
537
- yield {
538
- chat: response_chunk[0],
539
- dbg: response_chunk[1],
540
- }
541
- except GeneratorExit:
542
- # Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
543
- cancelled = True
544
- print("Generation cancelled by user.")
545
- raise
546
- except Exception as e:
547
- print(f"An error occurred during generation: {e}")
548
- # If an error happens, add it to the chat history to inform the user.
549
- error_history = (chat_history or []) + [
550
- {'role': 'user', 'content': user_msg},
551
- {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
552
- ]
553
- yield {chat: error_history}
554
- finally:
555
- # Only reset UI if not cancelled (to avoid "generator ignored GeneratorExit")
556
- if not cancelled:
557
- print("Resetting UI state.")
558
- yield {
559
- txt: gr.update(interactive=True),
560
- submit_btn: gr.update(interactive=True),
561
- cancel_btn: gr.update(visible=False),
562
- }
563
-
564
- def set_cancel_flag():
565
- """Called by the cancel button, sets the global event."""
566
- cancel_event.set()
567
- print("Cancellation signal sent.")
568
-
569
- def reset_ui_after_cancel():
570
- """Reset UI components after cancellation."""
571
- cancel_event.clear() # Clear the flag for next generation
572
- print("UI reset after cancellation.")
573
- return {
574
- txt: gr.update(interactive=True),
575
- submit_btn: gr.update(interactive=True),
576
- cancel_btn: gr.update(visible=False),
577
- }
578
-
579
- # Event for submitting text via Enter key or Submit button
580
- submit_event = txt.submit(
581
- fn=submit_and_manage_ui,
582
- inputs=chat_inputs,
583
- outputs=ui_components,
584
- )
585
- submit_btn.click(
586
- fn=submit_and_manage_ui,
587
- inputs=chat_inputs,
588
- outputs=ui_components,
589
- )
590
 
591
- # Event for the "Cancel" button.
592
- # It sets the cancel flag, cancels the submit event, then resets the UI.
593
- cancel_btn.click(
594
- fn=set_cancel_flag,
595
- cancels=[submit_event]
596
- ).then(
597
- fn=reset_ui_after_cancel,
598
- outputs=ui_components
599
- )
600
 
601
- # Listeners for updating the duration estimate
602
- duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
603
- for component in duration_inputs:
604
- component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
605
 
606
- # Toggle web search settings visibility
607
- def toggle_search_settings(enabled):
608
- return gr.update(visible=enabled)
609
-
610
- search_chk.change(
611
- fn=lambda enabled: (update_default_prompt(enabled), gr.update(visible=enabled)),
612
- inputs=search_chk,
613
- outputs=[sys_prompt, search_settings]
614
- )
615
-
616
- # Clear chat action
617
- clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
618
-
619
- demo.launch()
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
  import os
5
+ from typing import Any, Dict, List, Tuple
6
+
 
 
 
 
 
7
  import gradio as gr
8
+ import spaces
9
  import torch
10
+ from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
 
 
 
 
 
 
 
11
 
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+ if not HF_TOKEN:
14
+ raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
15
 
16
+ ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit ONLY strict JSON with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nEach route_plan entry must be a tool call (e.g., /math(...), /code(...), /general-search(...)).\nBe concise but precise. Do not include prose outside of the JSON object."""
 
17
 
 
 
 
18
  MODELS = {
19
  "Router-Qwen3-32B-8bit": {
20
  "repo_id": "Alovestocode/router-qwen3-32b-merged",
21
+ "description": "Router checkpoint on Qwen3 32B merged and quantized for 8-bit ZeroGPU inference.",
22
  "params_b": 32.0,
23
  },
24
  "Router-Gemma3-27B-8bit": {
25
  "repo_id": "Alovestocode/router-gemma3-merged",
26
+ "description": "Router checkpoint on Gemma3 27B merged and quantized for 8-bit ZeroGPU inference.",
27
  "params_b": 27.0,
28
  },
29
  }
30
 
31
+ REQUIRED_KEYS = [
32
+ "route_plan",
33
+ "route_rationale",
34
+ "expected_artifacts",
35
+ "thinking_outline",
36
+ "handoff_plan",
37
+ "todo_list",
38
+ "difficulty",
39
+ "tags",
40
+ "acceptance_criteria",
41
+ "metrics",
42
+ ]
43
 
44
+ PIPELINES: Dict[str, Any] = {}
45
+
46
+
47
+ def load_pipeline(model_name: str):
 
 
 
48
  if model_name in PIPELINES:
49
  return PIPELINES[model_name]
50
 
51
  repo = MODELS[model_name]["repo_id"]
52
+ tokenizer = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
53
 
 
54
  try:
55
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
56
  pipe = pipeline(
57
  task="text-generation",
58
  model=repo,
59
  tokenizer=tokenizer,
60
  trust_remote_code=True,
61
  device_map="auto",
62
+ model_kwargs={"quantization_config": quantization_config},
63
  use_cache=True,
64
+ token=HF_TOKEN,
65
  )
66
  PIPELINES[model_name] = pipe
67
  return pipe
68
  except Exception as exc:
69
  print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
70
 
 
71
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
72
  try:
73
  pipe = pipeline(
 
75
  model=repo,
76
  tokenizer=tokenizer,
77
  trust_remote_code=True,
 
78
  device_map="auto",
79
+ dtype=dtype,
80
  use_cache=True,
81
+ token=HF_TOKEN,
82
  )
83
  PIPELINES[model_name] = pipe
84
  return pipe
85
  except Exception:
86
  continue
87
 
 
88
  pipe = pipeline(
89
  task="text-generation",
90
  model=repo,
 
92
  trust_remote_code=True,
93
  device_map="auto",
94
  use_cache=True,
95
+ token=HF_TOKEN,
96
  )
97
  PIPELINES[model_name] = pipe
98
  return pipe
99
 
100
 
101
+ def build_router_prompt(
102
+ user_task: str,
103
+ context: str,
104
+ acceptance: str,
105
+ extra_guidance: str,
106
+ difficulty: str,
107
+ tags: str,
108
+ ) -> str:
109
+ prompt_parts = [ROUTER_SYSTEM_PROMPT.strip(), "\n### Router Inputs\n"]
110
+ prompt_parts.append(f"Difficulty: {difficulty or 'intermediate'}")
111
+ prompt_parts.append(f"Tags: {tags or 'general'}")
112
+ if acceptance.strip():
113
+ prompt_parts.append(f"Acceptance criteria: {acceptance.strip()}")
114
+ if extra_guidance.strip():
115
+ prompt_parts.append(f"Additional guidance: {extra_guidance.strip()}")
116
+ if context.strip():
117
+ prompt_parts.append("\n### Supporting context\n" + context.strip())
118
+ prompt_parts.append("\n### User task\n" + user_task.strip())
119
+ prompt_parts.append("\nReturn only JSON.")
120
+ return "\n".join(prompt_parts)
121
+
122
+
123
+ def extract_json_from_text(text: str) -> str:
124
+ start = text.find("{")
125
+ if start == -1:
126
+ raise ValueError("Router output did not contain a JSON object.")
127
+ depth = 0
128
+ in_string = False
129
+ escape = False
130
+ for idx in range(start, len(text)):
131
+ ch = text[idx]
132
+ if in_string:
133
+ if escape:
134
+ escape = False
135
+ elif ch == "\\":
136
+ escape = True
137
+ elif ch == '"':
138
+ in_string = False
139
+ continue
140
+ if ch == '"':
141
+ in_string = True
142
+ continue
143
+ if ch == '{':
144
+ depth += 1
145
+ elif ch == '}':
146
+ depth -= 1
147
+ if depth == 0:
148
+ return text[start : idx + 1]
149
+ raise ValueError("Router output JSON appears truncated.")
150
+
151
+
152
+ def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
153
+ issues: List[str] = []
154
+ for key in REQUIRED_KEYS:
155
+ if key not in plan:
156
+ issues.append(f"Missing key: {key}")
157
+ route_plan = plan.get("route_plan")
158
+ if not isinstance(route_plan, list) or not route_plan:
159
+ issues.append("route_plan must be a non-empty list of tool calls")
160
+ metrics = plan.get("metrics")
161
+ if not isinstance(metrics, dict):
162
+ issues.append("metrics must be an object containing primary/secondary entries")
163
+ todo = plan.get("todo_list")
164
+ if not isinstance(todo, list) or not todo:
165
+ issues.append("todo_list must contain at least one checklist item")
166
+ return len(issues) == 0, issues
167
+
168
+
169
+ def format_validation_message(ok: bool, issues: List[str]) -> str:
170
+ if ok:
171
+ return "✅ Router plan includes all required fields."
172
+ bullets = "\n".join(f"- {issue}" for issue in issues)
173
+ return f"❌ Issues detected:\n{bullets}"
174
+
175
+
176
+ @spaces.GPU(duration=600)
177
+ def generate_router_plan(
178
+ user_task: str,
179
+ context: str,
180
+ acceptance: str,
181
+ extra_guidance: str,
182
+ difficulty: str,
183
+ tags: str,
184
+ model_choice: str,
185
+ max_new_tokens: int,
186
+ temperature: float,
187
+ top_p: float,
188
+ ) -> Tuple[str, Dict[str, Any], str, str]:
189
+ if not user_task.strip():
190
+ raise gr.Error("User task is required.")
191
 
192
+ if model_choice not in MODELS:
193
+ raise gr.Error(f"Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  try:
196
+ prompt = build_router_prompt(
197
+ user_task=user_task,
198
+ context=context,
199
+ acceptance=acceptance,
200
+ extra_guidance=extra_guidance,
201
+ difficulty=difficulty,
202
+ tags=tags,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
+
205
+ generator = load_pipeline(model_choice)
206
+ result = generator(
207
+ prompt,
208
+ max_new_tokens=max_new_tokens,
209
+ temperature=temperature,
210
+ top_p=top_p,
211
+ do_sample=True,
212
+ )[0]["generated_text"]
213
+
214
+ completion = result[len(prompt) :].strip() if result.startswith(prompt) else result.strip()
215
+
216
+ try:
217
+ json_block = extract_json_from_text(completion)
218
+ plan = json.loads(json_block)
219
+ ok, issues = validate_router_plan(plan)
220
+ validation_msg = format_validation_message(ok, issues)
221
+ except Exception as exc:
222
+ plan = {}
223
+ validation_msg = f"❌ JSON parsing failed: {exc}"
224
+
225
+ return completion, plan, validation_msg, prompt
226
+ except Exception as exc:
227
+ error_msg = f"❌ Generation failed: {str(exc)}"
228
+ return "", {}, error_msg, ""
229
+
230
+
231
+ def clear_outputs():
232
+ return "", {}, "Awaiting generation.", ""
233
+
234
+
235
+ def build_ui():
236
+ description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
237
+ with gr.Blocks(theme=gr.themes.Soft(), css="""
238
+ textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
239
+ .status-ok { color: #0d9488; font-weight: 600; }
240
+ .status-bad { color: #dc2626; font-weight: 600; }
241
+ """) as demo:
242
+ gr.Markdown("# 🛰️ Router Control Room — ZeroGPU" )
243
+ gr.Markdown(description)
244
+
245
+ with gr.Row():
246
+ with gr.Column(scale=3):
247
+ user_task = gr.Textbox(
248
+ label="User Task / Problem Statement",
249
+ placeholder="Describe the homework-style query that needs routing...",
250
+ lines=8,
251
+ value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  )
253
+ context = gr.Textbox(
254
+ label="Supporting Context (optional)",
255
+ placeholder="Paste any retrieved evidence, PDFs, or rubric notes.",
256
+ lines=4,
257
  )
258
+ acceptance = gr.Textbox(
259
+ label="Acceptance Criteria",
260
+ placeholder="Bullet list of 'definition of done' checks.",
261
  lines=3,
262
+ value="- Provide citations for every claim.\n- Ensure /math verifies /code output.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  )
264
+ extra_guidance = gr.Textbox(
265
+ label="Additional Guidance",
266
+ placeholder="Special constraints, tools to avoid, etc.",
267
+ lines=3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  )
269
+ with gr.Column(scale=2):
270
+ model_choice = gr.Dropdown(
271
+ label="Router Checkpoint",
272
+ choices=list(MODELS.keys()),
273
+ value=list(MODELS.keys())[0] if MODELS else None,
274
+ allow_custom_value=False,
275
  )
276
+ difficulty = gr.Radio(
277
+ label="Difficulty Tier",
278
+ choices=["introductory", "intermediate", "advanced"],
279
+ value="advanced",
280
+ interactive=True,
281
  )
282
+ tags = gr.Textbox(
283
+ label="Tags",
284
+ placeholder="Comma-separated e.g. calculus, optimization, python",
285
+ value="calculus, optimization, python",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  )
287
+ max_new_tokens = gr.Slider(256, 1024, value=640, step=32, label="Max New Tokens")
288
+ temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
289
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
290
+
291
+ generate_btn = gr.Button("Generate Router Plan", variant="primary")
292
+ clear_btn = gr.Button("Clear", variant="secondary")
293
+
294
+ with gr.Row():
295
+ raw_output = gr.Textbox(label="Raw Model Output", lines=12)
296
+ plan_json = gr.JSON(label="Parsed Router Plan")
297
+ validation_msg = gr.Markdown("Awaiting generation.")
298
+ prompt_view = gr.Textbox(label="Full Prompt", lines=10)
299
+
300
+ generate_btn.click(
301
+ generate_router_plan,
302
+ inputs=[
303
+ user_task,
304
+ context,
305
+ acceptance,
306
+ extra_guidance,
307
+ difficulty,
308
+ tags,
309
+ model_choice,
310
+ max_new_tokens,
311
+ temperature,
312
+ top_p,
313
+ ],
314
+ outputs=[raw_output, plan_json, validation_msg, prompt_view],
315
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ clear_btn.click(fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view])
 
 
 
 
 
 
 
 
318
 
319
+ return demo
 
 
 
320
 
321
+
322
+ demo = build_ui()
323
+
324
+ if __name__ == "__main__": # pragma: no cover
325
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))