Spaces:
Sleeping
Sleeping
Commit
·
1b16b00
1
Parent(s):
9592189
Enable API in Gradio launch configuration
Browse files
app.py
CHANGED
|
@@ -2,12 +2,13 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import os
|
|
|
|
| 5 |
from typing import Any, Dict, List, Tuple
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
-
from transformers import AutoTokenizer,
|
| 11 |
from threading import Thread
|
| 12 |
|
| 13 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
@@ -15,6 +16,7 @@ if not HF_TOKEN:
|
|
| 15 |
raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
|
| 16 |
|
| 17 |
PLAN_END_TOKEN = "<|end_of_plan|>"
|
|
|
|
| 18 |
|
| 19 |
ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else."""
|
| 20 |
|
|
@@ -45,6 +47,22 @@ REQUIRED_KEYS = [
|
|
| 45 |
]
|
| 46 |
|
| 47 |
PIPELINES: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def load_pipeline(model_name: str):
|
|
@@ -52,21 +70,23 @@ def load_pipeline(model_name: str):
|
|
| 52 |
return PIPELINES[model_name]
|
| 53 |
|
| 54 |
repo = MODELS[model_name]["repo_id"]
|
| 55 |
-
tokenizer =
|
| 56 |
|
| 57 |
try:
|
| 58 |
-
|
| 59 |
pipe = pipeline(
|
| 60 |
task="text-generation",
|
| 61 |
model=repo,
|
| 62 |
tokenizer=tokenizer,
|
| 63 |
trust_remote_code=True,
|
| 64 |
device_map="auto",
|
| 65 |
-
model_kwargs={"quantization_config":
|
| 66 |
use_cache=True,
|
| 67 |
token=HF_TOKEN,
|
| 68 |
)
|
|
|
|
| 69 |
PIPELINES[model_name] = pipe
|
|
|
|
| 70 |
return pipe
|
| 71 |
except Exception as exc:
|
| 72 |
print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
|
|
@@ -83,7 +103,9 @@ def load_pipeline(model_name: str):
|
|
| 83 |
use_cache=True,
|
| 84 |
token=HF_TOKEN,
|
| 85 |
)
|
|
|
|
| 86 |
PIPELINES[model_name] = pipe
|
|
|
|
| 87 |
return pipe
|
| 88 |
except Exception:
|
| 89 |
continue
|
|
@@ -97,10 +119,37 @@ def load_pipeline(model_name: str):
|
|
| 97 |
use_cache=True,
|
| 98 |
token=HF_TOKEN,
|
| 99 |
)
|
|
|
|
| 100 |
PIPELINES[model_name] = pipe
|
|
|
|
| 101 |
return pipe
|
| 102 |
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def build_router_prompt(
|
| 105 |
user_task: str,
|
| 106 |
context: str,
|
|
@@ -152,20 +201,52 @@ def extract_json_from_text(text: str) -> str:
|
|
| 152 |
raise ValueError("Router output JSON appears truncated.")
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
| 156 |
issues: List[str] = []
|
| 157 |
for key in REQUIRED_KEYS:
|
| 158 |
if key not in plan:
|
| 159 |
issues.append(f"Missing key: {key}")
|
|
|
|
| 160 |
route_plan = plan.get("route_plan")
|
|
|
|
|
|
|
|
|
|
| 161 |
if not isinstance(route_plan, list) or not route_plan:
|
| 162 |
issues.append("route_plan must be a non-empty list of tool calls")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
metrics = plan.get("metrics")
|
| 164 |
if not isinstance(metrics, dict):
|
| 165 |
issues.append("metrics must be an object containing primary/secondary entries")
|
| 166 |
todo = plan.get("todo_list")
|
| 167 |
if not isinstance(todo, list) or not todo:
|
| 168 |
issues.append("todo_list must contain at least one checklist item")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
return len(issues) == 0, issues
|
| 170 |
|
| 171 |
|
|
@@ -232,9 +313,15 @@ def generate_router_plan_streaming(
|
|
| 232 |
"top_p": top_p,
|
| 233 |
"do_sample": True,
|
| 234 |
"streamer": streamer,
|
|
|
|
|
|
|
| 235 |
}
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
thread.start()
|
| 239 |
|
| 240 |
# Stream tokens
|
|
@@ -246,21 +333,22 @@ def generate_router_plan_streaming(
|
|
| 246 |
completion += new_text
|
| 247 |
chunk = completion
|
| 248 |
finished = False
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
|
| 253 |
try:
|
| 254 |
json_block = extract_json_from_text(chunk)
|
| 255 |
candidate_plan = json.loads(json_block)
|
| 256 |
ok, issues = validate_router_plan(candidate_plan)
|
| 257 |
validation_msg = format_validation_message(ok, issues)
|
| 258 |
-
parsed_plan = candidate_plan if ok else
|
|
|
|
| 259 |
except Exception:
|
| 260 |
# Ignore until JSON is complete
|
| 261 |
pass
|
| 262 |
|
| 263 |
-
yield chunk,
|
| 264 |
|
| 265 |
if finished:
|
| 266 |
completion = chunk
|
|
@@ -269,7 +357,7 @@ def generate_router_plan_streaming(
|
|
| 269 |
# Final processing after streaming completes
|
| 270 |
thread.join()
|
| 271 |
|
| 272 |
-
completion = completion.strip()
|
| 273 |
if parsed_plan is None:
|
| 274 |
try:
|
| 275 |
json_block = extract_json_from_text(completion)
|
|
@@ -384,7 +472,47 @@ def build_ui():
|
|
| 384 |
return demo
|
| 385 |
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
demo = build_ui()
|
| 388 |
|
| 389 |
if __name__ == "__main__": # pragma: no cover
|
| 390 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
+
import re
|
| 6 |
from typing import Any, Dict, List, Tuple
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import spaces
|
| 10 |
import torch
|
| 11 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, pipeline
|
| 12 |
from threading import Thread
|
| 13 |
|
| 14 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
| 16 |
raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
|
| 17 |
|
| 18 |
PLAN_END_TOKEN = "<|end_of_plan|>"
|
| 19 |
+
STOP_SEQUENCES = [PLAN_END_TOKEN, "</json>", "</JSON>"]
|
| 20 |
|
| 21 |
ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else."""
|
| 22 |
|
|
|
|
| 47 |
]
|
| 48 |
|
| 49 |
PIPELINES: Dict[str, Any] = {}
|
| 50 |
+
TOKENIZER_CACHE: Dict[str, Any] = {}
|
| 51 |
+
WARMED_REMAINING = False
|
| 52 |
+
TOOL_PATTERN = re.compile(r"^/[a-z0-9_-]+\(.*\)$", re.IGNORECASE)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_tokenizer(repo: str):
|
| 56 |
+
tok = TOKENIZER_CACHE.get(repo)
|
| 57 |
+
if tok is not None:
|
| 58 |
+
return tok
|
| 59 |
+
tok = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
|
| 60 |
+
tok.padding_side = "left"
|
| 61 |
+
tok.truncation_side = "left"
|
| 62 |
+
if tok.pad_token_id is None and tok.eos_token_id is not None:
|
| 63 |
+
tok.pad_token_id = tok.eos_token_id
|
| 64 |
+
TOKENIZER_CACHE[repo] = tok
|
| 65 |
+
return tok
|
| 66 |
|
| 67 |
|
| 68 |
def load_pipeline(model_name: str):
|
|
|
|
| 70 |
return PIPELINES[model_name]
|
| 71 |
|
| 72 |
repo = MODELS[model_name]["repo_id"]
|
| 73 |
+
tokenizer = get_tokenizer(repo)
|
| 74 |
|
| 75 |
try:
|
| 76 |
+
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 77 |
pipe = pipeline(
|
| 78 |
task="text-generation",
|
| 79 |
model=repo,
|
| 80 |
tokenizer=tokenizer,
|
| 81 |
trust_remote_code=True,
|
| 82 |
device_map="auto",
|
| 83 |
+
model_kwargs={"quantization_config": quant_config},
|
| 84 |
use_cache=True,
|
| 85 |
token=HF_TOKEN,
|
| 86 |
)
|
| 87 |
+
pipe.model.eval()
|
| 88 |
PIPELINES[model_name] = pipe
|
| 89 |
+
_schedule_background_warm(model_name)
|
| 90 |
return pipe
|
| 91 |
except Exception as exc:
|
| 92 |
print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
|
|
|
|
| 103 |
use_cache=True,
|
| 104 |
token=HF_TOKEN,
|
| 105 |
)
|
| 106 |
+
pipe.model.eval()
|
| 107 |
PIPELINES[model_name] = pipe
|
| 108 |
+
_schedule_background_warm(model_name)
|
| 109 |
return pipe
|
| 110 |
except Exception:
|
| 111 |
continue
|
|
|
|
| 119 |
use_cache=True,
|
| 120 |
token=HF_TOKEN,
|
| 121 |
)
|
| 122 |
+
pipe.model.eval()
|
| 123 |
PIPELINES[model_name] = pipe
|
| 124 |
+
_schedule_background_warm(model_name)
|
| 125 |
return pipe
|
| 126 |
|
| 127 |
|
| 128 |
+
def _schedule_background_warm(loaded_model: str) -> None:
|
| 129 |
+
global WARMED_REMAINING
|
| 130 |
+
if WARMED_REMAINING:
|
| 131 |
+
return
|
| 132 |
+
warm_remaining = os.environ.get("ROUTER_WARM_REMAINING", "1")
|
| 133 |
+
if warm_remaining not in {"1", "true", "True"}:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
remaining = [name for name in MODELS if name not in PIPELINES]
|
| 137 |
+
if not remaining:
|
| 138 |
+
WARMED_REMAINING = True
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
def _warm_all():
|
| 142 |
+
for name in remaining:
|
| 143 |
+
try:
|
| 144 |
+
print(f"Background warm start for {name}")
|
| 145 |
+
load_pipeline(name)
|
| 146 |
+
except Exception as exc: # pragma: no cover
|
| 147 |
+
print(f"Warm start failed for {name}: {exc}")
|
| 148 |
+
WARMED_REMAINING = True
|
| 149 |
+
|
| 150 |
+
Thread(target=_warm_all, daemon=True).start()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
def build_router_prompt(
|
| 154 |
user_task: str,
|
| 155 |
context: str,
|
|
|
|
| 201 |
raise ValueError("Router output JSON appears truncated.")
|
| 202 |
|
| 203 |
|
| 204 |
+
def is_function_call(text: str) -> bool:
|
| 205 |
+
return bool(TOOL_PATTERN.match(text.strip()))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
| 209 |
issues: List[str] = []
|
| 210 |
for key in REQUIRED_KEYS:
|
| 211 |
if key not in plan:
|
| 212 |
issues.append(f"Missing key: {key}")
|
| 213 |
+
|
| 214 |
route_plan = plan.get("route_plan")
|
| 215 |
+
if isinstance(route_plan, str) and is_function_call(route_plan):
|
| 216 |
+
plan["route_plan"] = [route_plan]
|
| 217 |
+
route_plan = plan["route_plan"]
|
| 218 |
if not isinstance(route_plan, list) or not route_plan:
|
| 219 |
issues.append("route_plan must be a non-empty list of tool calls")
|
| 220 |
+
else:
|
| 221 |
+
cleaned: List[str] = []
|
| 222 |
+
for entry in route_plan:
|
| 223 |
+
if isinstance(entry, str) and is_function_call(entry.strip().strip("'\"")):
|
| 224 |
+
cleaned.append(entry.strip().strip("'\""))
|
| 225 |
+
else:
|
| 226 |
+
issues.append(f"route_plan entry is not a tool call: {entry}")
|
| 227 |
+
if cleaned:
|
| 228 |
+
plan["route_plan"] = cleaned
|
| 229 |
+
|
| 230 |
metrics = plan.get("metrics")
|
| 231 |
if not isinstance(metrics, dict):
|
| 232 |
issues.append("metrics must be an object containing primary/secondary entries")
|
| 233 |
todo = plan.get("todo_list")
|
| 234 |
if not isinstance(todo, list) or not todo:
|
| 235 |
issues.append("todo_list must contain at least one checklist item")
|
| 236 |
+
else:
|
| 237 |
+
cleaned_todo: List[str] = []
|
| 238 |
+
for entry in todo:
|
| 239 |
+
if isinstance(entry, str):
|
| 240 |
+
text = entry.strip()
|
| 241 |
+
if not text.startswith("- ["):
|
| 242 |
+
text = text.lstrip("- ")
|
| 243 |
+
text = f"- [ ] {text}"
|
| 244 |
+
cleaned_todo.append(text)
|
| 245 |
+
else:
|
| 246 |
+
issues.append("todo_list entry must be a string")
|
| 247 |
+
if cleaned_todo:
|
| 248 |
+
plan["todo_list"] = cleaned_todo
|
| 249 |
+
|
| 250 |
return len(issues) == 0, issues
|
| 251 |
|
| 252 |
|
|
|
|
| 313 |
"top_p": top_p,
|
| 314 |
"do_sample": True,
|
| 315 |
"streamer": streamer,
|
| 316 |
+
"eos_token_id": tokenizer.eos_token_id,
|
| 317 |
+
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 318 |
}
|
| 319 |
+
|
| 320 |
+
def _generate():
|
| 321 |
+
with torch.inference_mode():
|
| 322 |
+
model.generate(**generation_kwargs)
|
| 323 |
+
|
| 324 |
+
thread = Thread(target=_generate)
|
| 325 |
thread.start()
|
| 326 |
|
| 327 |
# Stream tokens
|
|
|
|
| 333 |
completion += new_text
|
| 334 |
chunk = completion
|
| 335 |
finished = False
|
| 336 |
+
display_plan = parsed_plan or {}
|
| 337 |
+
|
| 338 |
+
chunk, finished = trim_at_stop_sequences(chunk)
|
| 339 |
|
| 340 |
try:
|
| 341 |
json_block = extract_json_from_text(chunk)
|
| 342 |
candidate_plan = json.loads(json_block)
|
| 343 |
ok, issues = validate_router_plan(candidate_plan)
|
| 344 |
validation_msg = format_validation_message(ok, issues)
|
| 345 |
+
parsed_plan = candidate_plan if ok else parsed_plan
|
| 346 |
+
display_plan = candidate_plan
|
| 347 |
except Exception:
|
| 348 |
# Ignore until JSON is complete
|
| 349 |
pass
|
| 350 |
|
| 351 |
+
yield chunk, display_plan, validation_msg, prompt
|
| 352 |
|
| 353 |
if finished:
|
| 354 |
completion = chunk
|
|
|
|
| 357 |
# Final processing after streaming completes
|
| 358 |
thread.join()
|
| 359 |
|
| 360 |
+
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 361 |
if parsed_plan is None:
|
| 362 |
try:
|
| 363 |
json_block = extract_json_from_text(completion)
|
|
|
|
| 472 |
return demo
|
| 473 |
|
| 474 |
|
| 475 |
+
|
| 476 |
+
def _prefetch_from_env() -> None:
|
| 477 |
+
entries = os.environ.get("ROUTER_PREFETCH_MODELS")
|
| 478 |
+
if entries:
|
| 479 |
+
names = [item.strip() for item in entries.split(",") if item.strip()]
|
| 480 |
+
else:
|
| 481 |
+
single = os.environ.get("ROUTER_PREFETCH_MODEL")
|
| 482 |
+
names = [single] if single else []
|
| 483 |
+
|
| 484 |
+
if names == ["ALL"] or names == ["all"]:
|
| 485 |
+
names = list(MODELS.keys())
|
| 486 |
+
|
| 487 |
+
for name in names:
|
| 488 |
+
if name not in MODELS:
|
| 489 |
+
print(f"Prefetch skipped, unknown model: {name}")
|
| 490 |
+
continue
|
| 491 |
+
try:
|
| 492 |
+
load_pipeline(name)
|
| 493 |
+
print(f"Prefetched router model: {name}")
|
| 494 |
+
except Exception as exc: # pragma: no cover
|
| 495 |
+
print(f"Prefetch failed for {name}: {exc}")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
_prefetch_from_env()
|
| 499 |
+
|
| 500 |
demo = build_ui()
|
| 501 |
|
| 502 |
if __name__ == "__main__": # pragma: no cover
|
| 503 |
+
demo.launch(
|
| 504 |
+
server_name="0.0.0.0",
|
| 505 |
+
server_port=int(os.environ.get("PORT", 7860)),
|
| 506 |
+
show_api=True,
|
| 507 |
+
api_name="/generate_router_plan_streaming"
|
| 508 |
+
)
|
| 509 |
+
def trim_at_stop_sequences(text: str) -> Tuple[str, bool]:
|
| 510 |
+
earliest = None
|
| 511 |
+
for stop in STOP_SEQUENCES:
|
| 512 |
+
idx = text.find(stop)
|
| 513 |
+
if idx != -1 and (earliest is None or idx < earliest):
|
| 514 |
+
earliest = idx
|
| 515 |
+
if earliest is not None:
|
| 516 |
+
return text[:earliest], True
|
| 517 |
+
return text, False
|
| 518 |
+
|