Spaces:
Sleeping
Sleeping
Commit
·
bf2fdae
1
Parent(s):
4c3d05b
Fix deprecation warnings and improve error handling
Browse files- Replace deprecated load_in_8bit with BitsAndBytesConfig
- Fix dropdown value to dynamically use first model
- Increase GPU duration to 600s for model loading
- Add better error handling for GPU task aborted errors
- Add model_choice validation
app.py
CHANGED
|
@@ -1,77 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
import sys
|
| 5 |
-
import threading
|
| 6 |
-
from itertools import islice
|
| 7 |
-
from datetime import datetime
|
| 8 |
-
import re # for parsing <think> blocks
|
| 9 |
import gradio as gr
|
|
|
|
| 10 |
import torch
|
| 11 |
-
from transformers import
|
| 12 |
-
from transformers import AutoTokenizer
|
| 13 |
-
from ddgs import DDGS
|
| 14 |
-
import spaces # Import spaces early to enable ZeroGPU support
|
| 15 |
-
from torch.utils._pytree import tree_map
|
| 16 |
-
|
| 17 |
-
# Global event to signal cancellation from the UI thread to the generation thread
|
| 18 |
-
cancel_event = threading.Event()
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 24 |
|
| 25 |
-
# ------------------------------
|
| 26 |
-
# Torch-Compatible Model Definitions with Adjusted Descriptions
|
| 27 |
-
# ------------------------------
|
| 28 |
MODELS = {
|
| 29 |
"Router-Qwen3-32B-8bit": {
|
| 30 |
"repo_id": "Alovestocode/router-qwen3-32b-merged",
|
| 31 |
-
"description": "
|
| 32 |
"params_b": 32.0,
|
| 33 |
},
|
| 34 |
"Router-Gemma3-27B-8bit": {
|
| 35 |
"repo_id": "Alovestocode/router-gemma3-merged",
|
| 36 |
-
"description": "
|
| 37 |
"params_b": 27.0,
|
| 38 |
},
|
| 39 |
}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
falling back to bf16/fp16/fp32 if quantized loading is unavailable.
|
| 49 |
-
"""
|
| 50 |
-
global PIPELINES
|
| 51 |
if model_name in PIPELINES:
|
| 52 |
return PIPELINES[model_name]
|
| 53 |
|
| 54 |
repo = MODELS[model_name]["repo_id"]
|
| 55 |
-
tokenizer = AutoTokenizer.from_pretrained(repo, token=
|
| 56 |
|
| 57 |
-
# First try to load in 8-bit to minimise VRAM usage.
|
| 58 |
try:
|
|
|
|
| 59 |
pipe = pipeline(
|
| 60 |
task="text-generation",
|
| 61 |
model=repo,
|
| 62 |
tokenizer=tokenizer,
|
| 63 |
trust_remote_code=True,
|
| 64 |
device_map="auto",
|
| 65 |
-
model_kwargs={"
|
| 66 |
use_cache=True,
|
| 67 |
-
token=
|
| 68 |
)
|
| 69 |
PIPELINES[model_name] = pipe
|
| 70 |
return pipe
|
| 71 |
except Exception as exc:
|
| 72 |
print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
|
| 73 |
|
| 74 |
-
# Fallback ladder when 8-bit is not available.
|
| 75 |
for dtype in (torch.bfloat16, torch.float16, torch.float32):
|
| 76 |
try:
|
| 77 |
pipe = pipeline(
|
|
@@ -79,17 +75,16 @@ def load_pipeline(model_name):
|
|
| 79 |
model=repo,
|
| 80 |
tokenizer=tokenizer,
|
| 81 |
trust_remote_code=True,
|
| 82 |
-
dtype=dtype,
|
| 83 |
device_map="auto",
|
|
|
|
| 84 |
use_cache=True,
|
| 85 |
-
token=
|
| 86 |
)
|
| 87 |
PIPELINES[model_name] = pipe
|
| 88 |
return pipe
|
| 89 |
except Exception:
|
| 90 |
continue
|
| 91 |
|
| 92 |
-
# Final fallback with framework defaults.
|
| 93 |
pipe = pipeline(
|
| 94 |
task="text-generation",
|
| 95 |
model=repo,
|
|
@@ -97,523 +92,234 @@ def load_pipeline(model_name):
|
|
| 97 |
trust_remote_code=True,
|
| 98 |
device_map="auto",
|
| 99 |
use_cache=True,
|
| 100 |
-
token=
|
| 101 |
)
|
| 102 |
PIPELINES[model_name] = pipe
|
| 103 |
return pipe
|
| 104 |
|
| 105 |
|
| 106 |
-
def
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
if
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
""
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
# Launch web search if enabled
|
| 165 |
-
debug = ''
|
| 166 |
-
search_results = []
|
| 167 |
-
if enable_search:
|
| 168 |
-
debug = 'Search task started.'
|
| 169 |
-
thread_search = threading.Thread(
|
| 170 |
-
target=lambda: search_results.extend(
|
| 171 |
-
retrieve_context(user_msg, int(max_results), int(max_chars))
|
| 172 |
-
)
|
| 173 |
-
)
|
| 174 |
-
thread_search.daemon = True
|
| 175 |
-
thread_search.start()
|
| 176 |
-
else:
|
| 177 |
-
debug = 'Web search disabled.'
|
| 178 |
|
| 179 |
try:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
|
| 188 |
-
When responding, please keep the following points in mind:
|
| 189 |
-
- Today is {cur_date}.
|
| 190 |
-
- Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
|
| 191 |
-
- For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
|
| 192 |
-
- For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
|
| 193 |
-
- If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
|
| 194 |
-
- For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
|
| 195 |
-
- Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
|
| 196 |
-
- Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
|
| 197 |
-
- Unless the user requests otherwise, your response should be in the same language as the user's question.
|
| 198 |
-
# The user's message is:
|
| 199 |
-
'''
|
| 200 |
-
else:
|
| 201 |
-
enriched = system_prompt
|
| 202 |
-
|
| 203 |
-
# wait up to 1s for snippets, then replace debug with them
|
| 204 |
-
if enable_search:
|
| 205 |
-
thread_search.join(timeout=float(search_timeout))
|
| 206 |
-
if search_results:
|
| 207 |
-
debug = "### Search results merged into prompt\n\n" + "\n".join(
|
| 208 |
-
f"- {r}" for r in search_results
|
| 209 |
-
)
|
| 210 |
-
else:
|
| 211 |
-
debug = "*No web search results found.*"
|
| 212 |
-
|
| 213 |
-
# merge fetched snippets into the system prompt
|
| 214 |
-
if search_results:
|
| 215 |
-
enriched = system_prompt.strip() + \
|
| 216 |
-
f'''\n# The following contents are the search results related to the user's message:
|
| 217 |
-
{search_results}
|
| 218 |
-
In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
|
| 219 |
-
When responding, please keep the following points in mind:
|
| 220 |
-
- Today is {cur_date}.
|
| 221 |
-
- Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
|
| 222 |
-
- For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
|
| 223 |
-
- For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
|
| 224 |
-
- If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
|
| 225 |
-
- For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
|
| 226 |
-
- Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
|
| 227 |
-
- Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
|
| 228 |
-
- Unless the user requests otherwise, your response should be in the same language as the user's question.
|
| 229 |
-
# The user's message is:
|
| 230 |
-
'''
|
| 231 |
-
else:
|
| 232 |
-
enriched = system_prompt
|
| 233 |
-
|
| 234 |
-
pipe = load_pipeline(model_name)
|
| 235 |
-
|
| 236 |
-
prompt = format_conversation(history, enriched, pipe.tokenizer)
|
| 237 |
-
prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
|
| 238 |
-
streamer = TextIteratorStreamer(pipe.tokenizer,
|
| 239 |
-
skip_prompt=True,
|
| 240 |
-
skip_special_tokens=True)
|
| 241 |
-
gen_thread = threading.Thread(
|
| 242 |
-
target=pipe,
|
| 243 |
-
args=(prompt,),
|
| 244 |
-
kwargs={
|
| 245 |
-
'max_new_tokens': max_tokens,
|
| 246 |
-
'temperature': temperature,
|
| 247 |
-
'top_k': top_k,
|
| 248 |
-
'top_p': top_p,
|
| 249 |
-
'repetition_penalty': repeat_penalty,
|
| 250 |
-
'streamer': streamer,
|
| 251 |
-
'return_full_text': False,
|
| 252 |
-
}
|
| 253 |
)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
else:
|
| 303 |
-
history[-1]['content'] = thought_buf
|
| 304 |
-
yield history, debug
|
| 305 |
-
continue
|
| 306 |
-
|
| 307 |
-
# Stream answer
|
| 308 |
-
if not assistant_message_started:
|
| 309 |
-
history.append({'role': 'assistant', 'content': ''})
|
| 310 |
-
assistant_message_started = True
|
| 311 |
-
|
| 312 |
-
answer_buf += text
|
| 313 |
-
history[-1]['content'] = answer_buf.strip()
|
| 314 |
-
yield history, debug
|
| 315 |
-
|
| 316 |
-
gen_thread.join()
|
| 317 |
-
yield history, debug + prompt_debug
|
| 318 |
-
except GeneratorExit:
|
| 319 |
-
# Handle cancellation gracefully
|
| 320 |
-
print("Chat response cancelled.")
|
| 321 |
-
# Don't yield anything - let the cancellation propagate
|
| 322 |
-
return
|
| 323 |
-
except Exception as e:
|
| 324 |
-
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 325 |
-
yield history, debug
|
| 326 |
-
finally:
|
| 327 |
-
gc.collect()
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
def update_default_prompt(enable_search):
|
| 331 |
-
return f"You are a helpful assistant."
|
| 332 |
-
|
| 333 |
-
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
|
| 334 |
-
"""Calculate and format the estimated GPU duration for current settings."""
|
| 335 |
-
try:
|
| 336 |
-
dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
|
| 337 |
-
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 338 |
-
enable_search, max_results, max_chars, model_name,
|
| 339 |
-
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
|
| 340 |
-
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 341 |
-
return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
|
| 342 |
-
f"📊 **Model Size:** {model_size:.1f}B parameters\n"
|
| 343 |
-
f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
|
| 344 |
-
except Exception as e:
|
| 345 |
-
return f"⚠️ Error calculating estimate: {e}"
|
| 346 |
-
|
| 347 |
-
# ------------------------------
|
| 348 |
-
# Gradio UI
|
| 349 |
-
# ------------------------------
|
| 350 |
-
with gr.Blocks(
|
| 351 |
-
title="LLM Inference with ZeroGPU",
|
| 352 |
-
theme=gr.themes.Soft(
|
| 353 |
-
primary_hue="indigo",
|
| 354 |
-
secondary_hue="purple",
|
| 355 |
-
neutral_hue="slate",
|
| 356 |
-
radius_size="lg",
|
| 357 |
-
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
| 358 |
-
),
|
| 359 |
-
css="""
|
| 360 |
-
.duration-estimate { background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%); border-left: 4px solid #667eea; padding: 12px; border-radius: 8px; margin: 16px 0; }
|
| 361 |
-
.chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
|
| 362 |
-
button.primary { font-weight: 600; }
|
| 363 |
-
.gradio-accordion { margin-bottom: 12px; }
|
| 364 |
-
"""
|
| 365 |
-
) as demo:
|
| 366 |
-
# Header
|
| 367 |
-
gr.Markdown("""
|
| 368 |
-
# 🧠 ZeroGPU LLM Inference
|
| 369 |
-
### Powered by Hugging Face ZeroGPU with Web Search Integration
|
| 370 |
-
""")
|
| 371 |
-
|
| 372 |
-
with gr.Row():
|
| 373 |
-
# Left Panel - Configuration
|
| 374 |
-
with gr.Column(scale=3):
|
| 375 |
-
# Core Settings (Always Visible)
|
| 376 |
-
with gr.Group():
|
| 377 |
-
gr.Markdown("### ⚙️ Core Settings")
|
| 378 |
-
model_dd = gr.Dropdown(
|
| 379 |
-
label="🤖 Model",
|
| 380 |
-
choices=list(MODELS.keys()),
|
| 381 |
-
value="Qwen3-1.7B",
|
| 382 |
-
info="Select the language model to use"
|
| 383 |
)
|
| 384 |
-
|
| 385 |
-
label="
|
| 386 |
-
|
| 387 |
-
|
| 388 |
)
|
| 389 |
-
|
| 390 |
-
label="
|
|
|
|
| 391 |
lines=3,
|
| 392 |
-
value=
|
| 393 |
-
placeholder="Define the assistant's behavior and personality..."
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
# Duration Estimate
|
| 397 |
-
duration_display = gr.Markdown(
|
| 398 |
-
value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0),
|
| 399 |
-
elem_classes="duration-estimate"
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
# Advanced Settings (Collapsible)
|
| 403 |
-
with gr.Accordion("🎛️ Advanced Generation Parameters", open=False):
|
| 404 |
-
max_tok = gr.Slider(
|
| 405 |
-
64, 16384, value=1024, step=32,
|
| 406 |
-
label="Max Tokens",
|
| 407 |
-
info="Maximum length of generated response"
|
| 408 |
)
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
)
|
| 414 |
-
with gr.Row():
|
| 415 |
-
k = gr.Slider(
|
| 416 |
-
1, 100, value=40, step=1,
|
| 417 |
-
label="Top-K",
|
| 418 |
-
info="Number of top tokens to consider"
|
| 419 |
-
)
|
| 420 |
-
p = gr.Slider(
|
| 421 |
-
0.1, 1.0, value=0.9, step=0.05,
|
| 422 |
-
label="Top-P",
|
| 423 |
-
info="Nucleus sampling threshold"
|
| 424 |
-
)
|
| 425 |
-
rp = gr.Slider(
|
| 426 |
-
1.0, 2.0, value=1.2, step=0.1,
|
| 427 |
-
label="Repetition Penalty",
|
| 428 |
-
info="Penalize repeated tokens"
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
# Web Search Settings (Collapsible)
|
| 432 |
-
with gr.Accordion("🌐 Web Search Settings", open=False, visible=False) as search_settings:
|
| 433 |
-
mr = gr.Number(
|
| 434 |
-
value=4, precision=0,
|
| 435 |
-
label="Max Results",
|
| 436 |
-
info="Number of search results to retrieve"
|
| 437 |
)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
label="
|
| 441 |
-
|
|
|
|
|
|
|
| 442 |
)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
| 447 |
)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
# Right Panel - Chat Interface
|
| 454 |
-
with gr.Column(scale=7):
|
| 455 |
-
chat = gr.Chatbot(
|
| 456 |
-
type="messages",
|
| 457 |
-
height=600,
|
| 458 |
-
label="💬 Conversation",
|
| 459 |
-
show_copy_button=True,
|
| 460 |
-
avatar_images=(None, "🤖"),
|
| 461 |
-
bubble_full_width=False
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
# Input Area
|
| 465 |
-
with gr.Row():
|
| 466 |
-
txt = gr.Textbox(
|
| 467 |
-
placeholder="💭 Type your message here... (Press Enter to send)",
|
| 468 |
-
scale=9,
|
| 469 |
-
container=False,
|
| 470 |
-
show_label=False,
|
| 471 |
-
lines=1,
|
| 472 |
-
max_lines=5
|
| 473 |
)
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
""", elem_classes="footer")
|
| 504 |
-
|
| 505 |
-
# --- Event Listeners ---
|
| 506 |
-
|
| 507 |
-
# Group all inputs for cleaner event handling
|
| 508 |
-
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
|
| 509 |
-
# Group all UI components that can be updated.
|
| 510 |
-
ui_components = [chat, dbg, txt, submit_btn, cancel_btn]
|
| 511 |
-
|
| 512 |
-
def submit_and_manage_ui(user_msg, chat_history, *args):
|
| 513 |
-
"""
|
| 514 |
-
Orchestrator function that manages UI state and calls the backend chat function.
|
| 515 |
-
It uses a try...finally block to ensure the UI is always reset.
|
| 516 |
-
"""
|
| 517 |
-
if not user_msg.strip():
|
| 518 |
-
# If the message is empty, do nothing.
|
| 519 |
-
# We yield an empty dict to avoid any state changes.
|
| 520 |
-
yield {}
|
| 521 |
-
return
|
| 522 |
-
|
| 523 |
-
# 1. Update UI to "generating" state.
|
| 524 |
-
# Crucially, we do NOT update the `chat` component here, as the backend
|
| 525 |
-
# will provide the correctly formatted history in the first response chunk.
|
| 526 |
-
yield {
|
| 527 |
-
txt: gr.update(value="", interactive=False),
|
| 528 |
-
submit_btn: gr.update(interactive=False),
|
| 529 |
-
cancel_btn: gr.update(visible=True),
|
| 530 |
-
}
|
| 531 |
-
|
| 532 |
-
cancelled = False
|
| 533 |
-
try:
|
| 534 |
-
# 2. Call the backend and stream updates
|
| 535 |
-
backend_args = [user_msg, chat_history] + list(args)
|
| 536 |
-
for response_chunk in chat_response(*backend_args):
|
| 537 |
-
yield {
|
| 538 |
-
chat: response_chunk[0],
|
| 539 |
-
dbg: response_chunk[1],
|
| 540 |
-
}
|
| 541 |
-
except GeneratorExit:
|
| 542 |
-
# Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
|
| 543 |
-
cancelled = True
|
| 544 |
-
print("Generation cancelled by user.")
|
| 545 |
-
raise
|
| 546 |
-
except Exception as e:
|
| 547 |
-
print(f"An error occurred during generation: {e}")
|
| 548 |
-
# If an error happens, add it to the chat history to inform the user.
|
| 549 |
-
error_history = (chat_history or []) + [
|
| 550 |
-
{'role': 'user', 'content': user_msg},
|
| 551 |
-
{'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
|
| 552 |
-
]
|
| 553 |
-
yield {chat: error_history}
|
| 554 |
-
finally:
|
| 555 |
-
# Only reset UI if not cancelled (to avoid "generator ignored GeneratorExit")
|
| 556 |
-
if not cancelled:
|
| 557 |
-
print("Resetting UI state.")
|
| 558 |
-
yield {
|
| 559 |
-
txt: gr.update(interactive=True),
|
| 560 |
-
submit_btn: gr.update(interactive=True),
|
| 561 |
-
cancel_btn: gr.update(visible=False),
|
| 562 |
-
}
|
| 563 |
-
|
| 564 |
-
def set_cancel_flag():
|
| 565 |
-
"""Called by the cancel button, sets the global event."""
|
| 566 |
-
cancel_event.set()
|
| 567 |
-
print("Cancellation signal sent.")
|
| 568 |
-
|
| 569 |
-
def reset_ui_after_cancel():
|
| 570 |
-
"""Reset UI components after cancellation."""
|
| 571 |
-
cancel_event.clear() # Clear the flag for next generation
|
| 572 |
-
print("UI reset after cancellation.")
|
| 573 |
-
return {
|
| 574 |
-
txt: gr.update(interactive=True),
|
| 575 |
-
submit_btn: gr.update(interactive=True),
|
| 576 |
-
cancel_btn: gr.update(visible=False),
|
| 577 |
-
}
|
| 578 |
-
|
| 579 |
-
# Event for submitting text via Enter key or Submit button
|
| 580 |
-
submit_event = txt.submit(
|
| 581 |
-
fn=submit_and_manage_ui,
|
| 582 |
-
inputs=chat_inputs,
|
| 583 |
-
outputs=ui_components,
|
| 584 |
-
)
|
| 585 |
-
submit_btn.click(
|
| 586 |
-
fn=submit_and_manage_ui,
|
| 587 |
-
inputs=chat_inputs,
|
| 588 |
-
outputs=ui_components,
|
| 589 |
-
)
|
| 590 |
|
| 591 |
-
|
| 592 |
-
# It sets the cancel flag, cancels the submit event, then resets the UI.
|
| 593 |
-
cancel_btn.click(
|
| 594 |
-
fn=set_cancel_flag,
|
| 595 |
-
cancels=[submit_event]
|
| 596 |
-
).then(
|
| 597 |
-
fn=reset_ui_after_cancel,
|
| 598 |
-
outputs=ui_components
|
| 599 |
-
)
|
| 600 |
|
| 601 |
-
|
| 602 |
-
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
|
| 603 |
-
for component in duration_inputs:
|
| 604 |
-
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 605 |
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
fn=lambda enabled: (update_default_prompt(enabled), gr.update(visible=enabled)),
|
| 612 |
-
inputs=search_chk,
|
| 613 |
-
outputs=[sys_prompt, search_settings]
|
| 614 |
-
)
|
| 615 |
-
|
| 616 |
-
# Clear chat action
|
| 617 |
-
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
|
| 618 |
-
|
| 619 |
-
demo.launch()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
import os
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
+
import spaces
|
| 9 |
import torch
|
| 10 |
+
from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 13 |
+
if not HF_TOKEN:
|
| 14 |
+
raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
|
| 15 |
|
| 16 |
+
ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit ONLY strict JSON with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nEach route_plan entry must be a tool call (e.g., /math(...), /code(...), /general-search(...)).\nBe concise but precise. Do not include prose outside of the JSON object."""
|
|
|
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
MODELS = {
|
| 19 |
"Router-Qwen3-32B-8bit": {
|
| 20 |
"repo_id": "Alovestocode/router-qwen3-32b-merged",
|
| 21 |
+
"description": "Router checkpoint on Qwen3 32B merged and quantized for 8-bit ZeroGPU inference.",
|
| 22 |
"params_b": 32.0,
|
| 23 |
},
|
| 24 |
"Router-Gemma3-27B-8bit": {
|
| 25 |
"repo_id": "Alovestocode/router-gemma3-merged",
|
| 26 |
+
"description": "Router checkpoint on Gemma3 27B merged and quantized for 8-bit ZeroGPU inference.",
|
| 27 |
"params_b": 27.0,
|
| 28 |
},
|
| 29 |
}
|
| 30 |
|
| 31 |
+
REQUIRED_KEYS = [
|
| 32 |
+
"route_plan",
|
| 33 |
+
"route_rationale",
|
| 34 |
+
"expected_artifacts",
|
| 35 |
+
"thinking_outline",
|
| 36 |
+
"handoff_plan",
|
| 37 |
+
"todo_list",
|
| 38 |
+
"difficulty",
|
| 39 |
+
"tags",
|
| 40 |
+
"acceptance_criteria",
|
| 41 |
+
"metrics",
|
| 42 |
+
]
|
| 43 |
|
| 44 |
+
PIPELINES: Dict[str, Any] = {}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_pipeline(model_name: str):
|
|
|
|
|
|
|
|
|
|
| 48 |
if model_name in PIPELINES:
|
| 49 |
return PIPELINES[model_name]
|
| 50 |
|
| 51 |
repo = MODELS[model_name]["repo_id"]
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
|
| 53 |
|
|
|
|
| 54 |
try:
|
| 55 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 56 |
pipe = pipeline(
|
| 57 |
task="text-generation",
|
| 58 |
model=repo,
|
| 59 |
tokenizer=tokenizer,
|
| 60 |
trust_remote_code=True,
|
| 61 |
device_map="auto",
|
| 62 |
+
model_kwargs={"quantization_config": quantization_config},
|
| 63 |
use_cache=True,
|
| 64 |
+
token=HF_TOKEN,
|
| 65 |
)
|
| 66 |
PIPELINES[model_name] = pipe
|
| 67 |
return pipe
|
| 68 |
except Exception as exc:
|
| 69 |
print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
|
| 70 |
|
|
|
|
| 71 |
for dtype in (torch.bfloat16, torch.float16, torch.float32):
|
| 72 |
try:
|
| 73 |
pipe = pipeline(
|
|
|
|
| 75 |
model=repo,
|
| 76 |
tokenizer=tokenizer,
|
| 77 |
trust_remote_code=True,
|
|
|
|
| 78 |
device_map="auto",
|
| 79 |
+
dtype=dtype,
|
| 80 |
use_cache=True,
|
| 81 |
+
token=HF_TOKEN,
|
| 82 |
)
|
| 83 |
PIPELINES[model_name] = pipe
|
| 84 |
return pipe
|
| 85 |
except Exception:
|
| 86 |
continue
|
| 87 |
|
|
|
|
| 88 |
pipe = pipeline(
|
| 89 |
task="text-generation",
|
| 90 |
model=repo,
|
|
|
|
| 92 |
trust_remote_code=True,
|
| 93 |
device_map="auto",
|
| 94 |
use_cache=True,
|
| 95 |
+
token=HF_TOKEN,
|
| 96 |
)
|
| 97 |
PIPELINES[model_name] = pipe
|
| 98 |
return pipe
|
| 99 |
|
| 100 |
|
| 101 |
+
def build_router_prompt(
|
| 102 |
+
user_task: str,
|
| 103 |
+
context: str,
|
| 104 |
+
acceptance: str,
|
| 105 |
+
extra_guidance: str,
|
| 106 |
+
difficulty: str,
|
| 107 |
+
tags: str,
|
| 108 |
+
) -> str:
|
| 109 |
+
prompt_parts = [ROUTER_SYSTEM_PROMPT.strip(), "\n### Router Inputs\n"]
|
| 110 |
+
prompt_parts.append(f"Difficulty: {difficulty or 'intermediate'}")
|
| 111 |
+
prompt_parts.append(f"Tags: {tags or 'general'}")
|
| 112 |
+
if acceptance.strip():
|
| 113 |
+
prompt_parts.append(f"Acceptance criteria: {acceptance.strip()}")
|
| 114 |
+
if extra_guidance.strip():
|
| 115 |
+
prompt_parts.append(f"Additional guidance: {extra_guidance.strip()}")
|
| 116 |
+
if context.strip():
|
| 117 |
+
prompt_parts.append("\n### Supporting context\n" + context.strip())
|
| 118 |
+
prompt_parts.append("\n### User task\n" + user_task.strip())
|
| 119 |
+
prompt_parts.append("\nReturn only JSON.")
|
| 120 |
+
return "\n".join(prompt_parts)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def extract_json_from_text(text: str) -> str:
|
| 124 |
+
start = text.find("{")
|
| 125 |
+
if start == -1:
|
| 126 |
+
raise ValueError("Router output did not contain a JSON object.")
|
| 127 |
+
depth = 0
|
| 128 |
+
in_string = False
|
| 129 |
+
escape = False
|
| 130 |
+
for idx in range(start, len(text)):
|
| 131 |
+
ch = text[idx]
|
| 132 |
+
if in_string:
|
| 133 |
+
if escape:
|
| 134 |
+
escape = False
|
| 135 |
+
elif ch == "\\":
|
| 136 |
+
escape = True
|
| 137 |
+
elif ch == '"':
|
| 138 |
+
in_string = False
|
| 139 |
+
continue
|
| 140 |
+
if ch == '"':
|
| 141 |
+
in_string = True
|
| 142 |
+
continue
|
| 143 |
+
if ch == '{':
|
| 144 |
+
depth += 1
|
| 145 |
+
elif ch == '}':
|
| 146 |
+
depth -= 1
|
| 147 |
+
if depth == 0:
|
| 148 |
+
return text[start : idx + 1]
|
| 149 |
+
raise ValueError("Router output JSON appears truncated.")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
| 153 |
+
issues: List[str] = []
|
| 154 |
+
for key in REQUIRED_KEYS:
|
| 155 |
+
if key not in plan:
|
| 156 |
+
issues.append(f"Missing key: {key}")
|
| 157 |
+
route_plan = plan.get("route_plan")
|
| 158 |
+
if not isinstance(route_plan, list) or not route_plan:
|
| 159 |
+
issues.append("route_plan must be a non-empty list of tool calls")
|
| 160 |
+
metrics = plan.get("metrics")
|
| 161 |
+
if not isinstance(metrics, dict):
|
| 162 |
+
issues.append("metrics must be an object containing primary/secondary entries")
|
| 163 |
+
todo = plan.get("todo_list")
|
| 164 |
+
if not isinstance(todo, list) or not todo:
|
| 165 |
+
issues.append("todo_list must contain at least one checklist item")
|
| 166 |
+
return len(issues) == 0, issues
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def format_validation_message(ok: bool, issues: List[str]) -> str:
|
| 170 |
+
if ok:
|
| 171 |
+
return "✅ Router plan includes all required fields."
|
| 172 |
+
bullets = "\n".join(f"- {issue}" for issue in issues)
|
| 173 |
+
return f"❌ Issues detected:\n{bullets}"
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@spaces.GPU(duration=600)
|
| 177 |
+
def generate_router_plan(
|
| 178 |
+
user_task: str,
|
| 179 |
+
context: str,
|
| 180 |
+
acceptance: str,
|
| 181 |
+
extra_guidance: str,
|
| 182 |
+
difficulty: str,
|
| 183 |
+
tags: str,
|
| 184 |
+
model_choice: str,
|
| 185 |
+
max_new_tokens: int,
|
| 186 |
+
temperature: float,
|
| 187 |
+
top_p: float,
|
| 188 |
+
) -> Tuple[str, Dict[str, Any], str, str]:
|
| 189 |
+
if not user_task.strip():
|
| 190 |
+
raise gr.Error("User task is required.")
|
| 191 |
|
| 192 |
+
if model_choice not in MODELS:
|
| 193 |
+
raise gr.Error(f"Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
try:
|
| 196 |
+
prompt = build_router_prompt(
|
| 197 |
+
user_task=user_task,
|
| 198 |
+
context=context,
|
| 199 |
+
acceptance=acceptance,
|
| 200 |
+
extra_guidance=extra_guidance,
|
| 201 |
+
difficulty=difficulty,
|
| 202 |
+
tags=tags,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
)
|
| 204 |
+
|
| 205 |
+
generator = load_pipeline(model_choice)
|
| 206 |
+
result = generator(
|
| 207 |
+
prompt,
|
| 208 |
+
max_new_tokens=max_new_tokens,
|
| 209 |
+
temperature=temperature,
|
| 210 |
+
top_p=top_p,
|
| 211 |
+
do_sample=True,
|
| 212 |
+
)[0]["generated_text"]
|
| 213 |
+
|
| 214 |
+
completion = result[len(prompt) :].strip() if result.startswith(prompt) else result.strip()
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
json_block = extract_json_from_text(completion)
|
| 218 |
+
plan = json.loads(json_block)
|
| 219 |
+
ok, issues = validate_router_plan(plan)
|
| 220 |
+
validation_msg = format_validation_message(ok, issues)
|
| 221 |
+
except Exception as exc:
|
| 222 |
+
plan = {}
|
| 223 |
+
validation_msg = f"❌ JSON parsing failed: {exc}"
|
| 224 |
+
|
| 225 |
+
return completion, plan, validation_msg, prompt
|
| 226 |
+
except Exception as exc:
|
| 227 |
+
error_msg = f"❌ Generation failed: {str(exc)}"
|
| 228 |
+
return "", {}, error_msg, ""
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def clear_outputs():
|
| 232 |
+
return "", {}, "Awaiting generation.", ""
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def build_ui():
|
| 236 |
+
description = "Use the CourseGPT-Pro router checkpoints (Gemma3/Qwen3) hosted on ZeroGPU to generate structured routing plans."
|
| 237 |
+
with gr.Blocks(theme=gr.themes.Soft(), css="""
|
| 238 |
+
textarea { font-family: 'JetBrains Mono', 'Fira Code', monospace; }
|
| 239 |
+
.status-ok { color: #0d9488; font-weight: 600; }
|
| 240 |
+
.status-bad { color: #dc2626; font-weight: 600; }
|
| 241 |
+
""") as demo:
|
| 242 |
+
gr.Markdown("# 🛰️ Router Control Room — ZeroGPU" )
|
| 243 |
+
gr.Markdown(description)
|
| 244 |
+
|
| 245 |
+
with gr.Row():
|
| 246 |
+
with gr.Column(scale=3):
|
| 247 |
+
user_task = gr.Textbox(
|
| 248 |
+
label="User Task / Problem Statement",
|
| 249 |
+
placeholder="Describe the homework-style query that needs routing...",
|
| 250 |
+
lines=8,
|
| 251 |
+
value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
)
|
| 253 |
+
context = gr.Textbox(
|
| 254 |
+
label="Supporting Context (optional)",
|
| 255 |
+
placeholder="Paste any retrieved evidence, PDFs, or rubric notes.",
|
| 256 |
+
lines=4,
|
| 257 |
)
|
| 258 |
+
acceptance = gr.Textbox(
|
| 259 |
+
label="Acceptance Criteria",
|
| 260 |
+
placeholder="Bullet list of 'definition of done' checks.",
|
| 261 |
lines=3,
|
| 262 |
+
value="- Provide citations for every claim.\n- Ensure /math verifies /code output.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
+
extra_guidance = gr.Textbox(
|
| 265 |
+
label="Additional Guidance",
|
| 266 |
+
placeholder="Special constraints, tools to avoid, etc.",
|
| 267 |
+
lines=3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
)
|
| 269 |
+
with gr.Column(scale=2):
|
| 270 |
+
model_choice = gr.Dropdown(
|
| 271 |
+
label="Router Checkpoint",
|
| 272 |
+
choices=list(MODELS.keys()),
|
| 273 |
+
value=list(MODELS.keys())[0] if MODELS else None,
|
| 274 |
+
allow_custom_value=False,
|
| 275 |
)
|
| 276 |
+
difficulty = gr.Radio(
|
| 277 |
+
label="Difficulty Tier",
|
| 278 |
+
choices=["introductory", "intermediate", "advanced"],
|
| 279 |
+
value="advanced",
|
| 280 |
+
interactive=True,
|
| 281 |
)
|
| 282 |
+
tags = gr.Textbox(
|
| 283 |
+
label="Tags",
|
| 284 |
+
placeholder="Comma-separated e.g. calculus, optimization, python",
|
| 285 |
+
value="calculus, optimization, python",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
)
|
| 287 |
+
max_new_tokens = gr.Slider(256, 1024, value=640, step=32, label="Max New Tokens")
|
| 288 |
+
temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
|
| 289 |
+
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
|
| 290 |
+
|
| 291 |
+
generate_btn = gr.Button("Generate Router Plan", variant="primary")
|
| 292 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
| 293 |
+
|
| 294 |
+
with gr.Row():
|
| 295 |
+
raw_output = gr.Textbox(label="Raw Model Output", lines=12)
|
| 296 |
+
plan_json = gr.JSON(label="Parsed Router Plan")
|
| 297 |
+
validation_msg = gr.Markdown("Awaiting generation.")
|
| 298 |
+
prompt_view = gr.Textbox(label="Full Prompt", lines=10)
|
| 299 |
+
|
| 300 |
+
generate_btn.click(
|
| 301 |
+
generate_router_plan,
|
| 302 |
+
inputs=[
|
| 303 |
+
user_task,
|
| 304 |
+
context,
|
| 305 |
+
acceptance,
|
| 306 |
+
extra_guidance,
|
| 307 |
+
difficulty,
|
| 308 |
+
tags,
|
| 309 |
+
model_choice,
|
| 310 |
+
max_new_tokens,
|
| 311 |
+
temperature,
|
| 312 |
+
top_p,
|
| 313 |
+
],
|
| 314 |
+
outputs=[raw_output, plan_json, validation_msg, prompt_view],
|
| 315 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
+
clear_btn.click(fn=clear_outputs, outputs=[raw_output, plan_json, validation_msg, prompt_view])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
+
return demo
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
+
|
| 322 |
+
demo = build_ui()
|
| 323 |
+
|
| 324 |
+
if __name__ == "__main__": # pragma: no cover
|
| 325 |
+
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|