Alikestocode commited on
Commit
1b16b00
·
1 Parent(s): 9592189

Enable API in Gradio launch configuration

Browse files
Files changed (1) hide show
  1. app.py +141 -13
app.py CHANGED
@@ -2,12 +2,13 @@ from __future__ import annotations
2
 
3
  import json
4
  import os
 
5
  from typing import Any, Dict, List, Tuple
6
 
7
  import gradio as gr
8
  import spaces
9
  import torch
10
- from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig, TextIteratorStreamer
11
  from threading import Thread
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -15,6 +16,7 @@ if not HF_TOKEN:
15
  raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
16
 
17
  PLAN_END_TOKEN = "<|end_of_plan|>"
 
18
 
19
  ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else."""
20
 
@@ -45,6 +47,22 @@ REQUIRED_KEYS = [
45
  ]
46
 
47
  PIPELINES: Dict[str, Any] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def load_pipeline(model_name: str):
@@ -52,21 +70,23 @@ def load_pipeline(model_name: str):
52
  return PIPELINES[model_name]
53
 
54
  repo = MODELS[model_name]["repo_id"]
55
- tokenizer = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
56
 
57
  try:
58
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
59
  pipe = pipeline(
60
  task="text-generation",
61
  model=repo,
62
  tokenizer=tokenizer,
63
  trust_remote_code=True,
64
  device_map="auto",
65
- model_kwargs={"quantization_config": quantization_config},
66
  use_cache=True,
67
  token=HF_TOKEN,
68
  )
 
69
  PIPELINES[model_name] = pipe
 
70
  return pipe
71
  except Exception as exc:
72
  print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
@@ -83,7 +103,9 @@ def load_pipeline(model_name: str):
83
  use_cache=True,
84
  token=HF_TOKEN,
85
  )
 
86
  PIPELINES[model_name] = pipe
 
87
  return pipe
88
  except Exception:
89
  continue
@@ -97,10 +119,37 @@ def load_pipeline(model_name: str):
97
  use_cache=True,
98
  token=HF_TOKEN,
99
  )
 
100
  PIPELINES[model_name] = pipe
 
101
  return pipe
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def build_router_prompt(
105
  user_task: str,
106
  context: str,
@@ -152,20 +201,52 @@ def extract_json_from_text(text: str) -> str:
152
  raise ValueError("Router output JSON appears truncated.")
153
 
154
 
 
 
 
 
155
  def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
156
  issues: List[str] = []
157
  for key in REQUIRED_KEYS:
158
  if key not in plan:
159
  issues.append(f"Missing key: {key}")
 
160
  route_plan = plan.get("route_plan")
 
 
 
161
  if not isinstance(route_plan, list) or not route_plan:
162
  issues.append("route_plan must be a non-empty list of tool calls")
 
 
 
 
 
 
 
 
 
 
163
  metrics = plan.get("metrics")
164
  if not isinstance(metrics, dict):
165
  issues.append("metrics must be an object containing primary/secondary entries")
166
  todo = plan.get("todo_list")
167
  if not isinstance(todo, list) or not todo:
168
  issues.append("todo_list must contain at least one checklist item")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  return len(issues) == 0, issues
170
 
171
 
@@ -232,9 +313,15 @@ def generate_router_plan_streaming(
232
  "top_p": top_p,
233
  "do_sample": True,
234
  "streamer": streamer,
 
 
235
  }
236
-
237
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
238
  thread.start()
239
 
240
  # Stream tokens
@@ -246,21 +333,22 @@ def generate_router_plan_streaming(
246
  completion += new_text
247
  chunk = completion
248
  finished = False
249
- if PLAN_END_TOKEN in chunk:
250
- chunk = chunk.split(PLAN_END_TOKEN, 1)[0]
251
- finished = True
252
 
253
  try:
254
  json_block = extract_json_from_text(chunk)
255
  candidate_plan = json.loads(json_block)
256
  ok, issues = validate_router_plan(candidate_plan)
257
  validation_msg = format_validation_message(ok, issues)
258
- parsed_plan = candidate_plan if ok else candidate_plan
 
259
  except Exception:
260
  # Ignore until JSON is complete
261
  pass
262
 
263
- yield chunk, parsed_plan or {}, validation_msg, prompt
264
 
265
  if finished:
266
  completion = chunk
@@ -269,7 +357,7 @@ def generate_router_plan_streaming(
269
  # Final processing after streaming completes
270
  thread.join()
271
 
272
- completion = completion.strip()
273
  if parsed_plan is None:
274
  try:
275
  json_block = extract_json_from_text(completion)
@@ -384,7 +472,47 @@ def build_ui():
384
  return demo
385
 
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  demo = build_ui()
388
 
389
  if __name__ == "__main__": # pragma: no cover
390
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import json
4
  import os
5
+ import re
6
  from typing import Any, Dict, List, Tuple
7
 
8
  import gradio as gr
9
  import spaces
10
  import torch
11
+ from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, pipeline
12
  from threading import Thread
13
 
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
16
  raise RuntimeError("HF_TOKEN environment variable must be set for private router checkpoints.")
17
 
18
  PLAN_END_TOKEN = "<|end_of_plan|>"
19
+ STOP_SEQUENCES = [PLAN_END_TOKEN, "</json>", "</JSON>"]
20
 
21
  ROUTER_SYSTEM_PROMPT = """You are the Router Agent coordinating Math, Code, and General-Search specialists.\nEmit EXACTLY ONE strict JSON object with keys route_plan, route_rationale, expected_artifacts,\nthinking_outline, handoff_plan, todo_list, difficulty, tags, acceptance_criteria, metrics.\nRules:\n- No markdown/code fences, no natural-language prologues or epilogues.\n- route_plan must be an ordered list of tool invocations such as /math(...), /code(...), /general-search(...).\n- todo_list must map each checklist item to the responsible tool.\n- metrics must include primary and secondary arrays (add optional *_guidance fields when they exist).\n- After the closing brace of the JSON object, immediately append the sentinel <|end_of_plan|>.\nExample output:\n{\n "route_plan": ["/general-search(...)"],\n "route_rationale": "...",\n ...\n}<|end_of_plan|>\nReturn nothing else."""
22
 
 
47
  ]
48
 
49
  PIPELINES: Dict[str, Any] = {}
50
+ TOKENIZER_CACHE: Dict[str, Any] = {}
51
+ WARMED_REMAINING = False
52
+ TOOL_PATTERN = re.compile(r"^/[a-z0-9_-]+\(.*\)$", re.IGNORECASE)
53
+
54
+
55
+ def get_tokenizer(repo: str):
56
+ tok = TOKENIZER_CACHE.get(repo)
57
+ if tok is not None:
58
+ return tok
59
+ tok = AutoTokenizer.from_pretrained(repo, token=HF_TOKEN)
60
+ tok.padding_side = "left"
61
+ tok.truncation_side = "left"
62
+ if tok.pad_token_id is None and tok.eos_token_id is not None:
63
+ tok.pad_token_id = tok.eos_token_id
64
+ TOKENIZER_CACHE[repo] = tok
65
+ return tok
66
 
67
 
68
  def load_pipeline(model_name: str):
 
70
  return PIPELINES[model_name]
71
 
72
  repo = MODELS[model_name]["repo_id"]
73
+ tokenizer = get_tokenizer(repo)
74
 
75
  try:
76
+ quant_config = BitsAndBytesConfig(load_in_8bit=True)
77
  pipe = pipeline(
78
  task="text-generation",
79
  model=repo,
80
  tokenizer=tokenizer,
81
  trust_remote_code=True,
82
  device_map="auto",
83
+ model_kwargs={"quantization_config": quant_config},
84
  use_cache=True,
85
  token=HF_TOKEN,
86
  )
87
+ pipe.model.eval()
88
  PIPELINES[model_name] = pipe
89
+ _schedule_background_warm(model_name)
90
  return pipe
91
  except Exception as exc:
92
  print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
 
103
  use_cache=True,
104
  token=HF_TOKEN,
105
  )
106
+ pipe.model.eval()
107
  PIPELINES[model_name] = pipe
108
+ _schedule_background_warm(model_name)
109
  return pipe
110
  except Exception:
111
  continue
 
119
  use_cache=True,
120
  token=HF_TOKEN,
121
  )
122
+ pipe.model.eval()
123
  PIPELINES[model_name] = pipe
124
+ _schedule_background_warm(model_name)
125
  return pipe
126
 
127
 
128
+ def _schedule_background_warm(loaded_model: str) -> None:
129
+ global WARMED_REMAINING
130
+ if WARMED_REMAINING:
131
+ return
132
+ warm_remaining = os.environ.get("ROUTER_WARM_REMAINING", "1")
133
+ if warm_remaining not in {"1", "true", "True"}:
134
+ return
135
+
136
+ remaining = [name for name in MODELS if name not in PIPELINES]
137
+ if not remaining:
138
+ WARMED_REMAINING = True
139
+ return
140
+
141
+ def _warm_all():
142
+ for name in remaining:
143
+ try:
144
+ print(f"Background warm start for {name}")
145
+ load_pipeline(name)
146
+ except Exception as exc: # pragma: no cover
147
+ print(f"Warm start failed for {name}: {exc}")
148
+ WARMED_REMAINING = True
149
+
150
+ Thread(target=_warm_all, daemon=True).start()
151
+
152
+
153
  def build_router_prompt(
154
  user_task: str,
155
  context: str,
 
201
  raise ValueError("Router output JSON appears truncated.")
202
 
203
 
204
+ def is_function_call(text: str) -> bool:
205
+ return bool(TOOL_PATTERN.match(text.strip()))
206
+
207
+
208
  def validate_router_plan(plan: Dict[str, Any]) -> Tuple[bool, List[str]]:
209
  issues: List[str] = []
210
  for key in REQUIRED_KEYS:
211
  if key not in plan:
212
  issues.append(f"Missing key: {key}")
213
+
214
  route_plan = plan.get("route_plan")
215
+ if isinstance(route_plan, str) and is_function_call(route_plan):
216
+ plan["route_plan"] = [route_plan]
217
+ route_plan = plan["route_plan"]
218
  if not isinstance(route_plan, list) or not route_plan:
219
  issues.append("route_plan must be a non-empty list of tool calls")
220
+ else:
221
+ cleaned: List[str] = []
222
+ for entry in route_plan:
223
+ if isinstance(entry, str) and is_function_call(entry.strip().strip("'\"")):
224
+ cleaned.append(entry.strip().strip("'\""))
225
+ else:
226
+ issues.append(f"route_plan entry is not a tool call: {entry}")
227
+ if cleaned:
228
+ plan["route_plan"] = cleaned
229
+
230
  metrics = plan.get("metrics")
231
  if not isinstance(metrics, dict):
232
  issues.append("metrics must be an object containing primary/secondary entries")
233
  todo = plan.get("todo_list")
234
  if not isinstance(todo, list) or not todo:
235
  issues.append("todo_list must contain at least one checklist item")
236
+ else:
237
+ cleaned_todo: List[str] = []
238
+ for entry in todo:
239
+ if isinstance(entry, str):
240
+ text = entry.strip()
241
+ if not text.startswith("- ["):
242
+ text = text.lstrip("- ")
243
+ text = f"- [ ] {text}"
244
+ cleaned_todo.append(text)
245
+ else:
246
+ issues.append("todo_list entry must be a string")
247
+ if cleaned_todo:
248
+ plan["todo_list"] = cleaned_todo
249
+
250
  return len(issues) == 0, issues
251
 
252
 
 
313
  "top_p": top_p,
314
  "do_sample": True,
315
  "streamer": streamer,
316
+ "eos_token_id": tokenizer.eos_token_id,
317
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
318
  }
319
+
320
+ def _generate():
321
+ with torch.inference_mode():
322
+ model.generate(**generation_kwargs)
323
+
324
+ thread = Thread(target=_generate)
325
  thread.start()
326
 
327
  # Stream tokens
 
333
  completion += new_text
334
  chunk = completion
335
  finished = False
336
+ display_plan = parsed_plan or {}
337
+
338
+ chunk, finished = trim_at_stop_sequences(chunk)
339
 
340
  try:
341
  json_block = extract_json_from_text(chunk)
342
  candidate_plan = json.loads(json_block)
343
  ok, issues = validate_router_plan(candidate_plan)
344
  validation_msg = format_validation_message(ok, issues)
345
+ parsed_plan = candidate_plan if ok else parsed_plan
346
+ display_plan = candidate_plan
347
  except Exception:
348
  # Ignore until JSON is complete
349
  pass
350
 
351
+ yield chunk, display_plan, validation_msg, prompt
352
 
353
  if finished:
354
  completion = chunk
 
357
  # Final processing after streaming completes
358
  thread.join()
359
 
360
+ completion = trim_at_stop_sequences(completion.strip())[0]
361
  if parsed_plan is None:
362
  try:
363
  json_block = extract_json_from_text(completion)
 
472
  return demo
473
 
474
 
475
+
476
+ def _prefetch_from_env() -> None:
477
+ entries = os.environ.get("ROUTER_PREFETCH_MODELS")
478
+ if entries:
479
+ names = [item.strip() for item in entries.split(",") if item.strip()]
480
+ else:
481
+ single = os.environ.get("ROUTER_PREFETCH_MODEL")
482
+ names = [single] if single else []
483
+
484
+ if names == ["ALL"] or names == ["all"]:
485
+ names = list(MODELS.keys())
486
+
487
+ for name in names:
488
+ if name not in MODELS:
489
+ print(f"Prefetch skipped, unknown model: {name}")
490
+ continue
491
+ try:
492
+ load_pipeline(name)
493
+ print(f"Prefetched router model: {name}")
494
+ except Exception as exc: # pragma: no cover
495
+ print(f"Prefetch failed for {name}: {exc}")
496
+
497
+
498
+ _prefetch_from_env()
499
+
500
  demo = build_ui()
501
 
502
  if __name__ == "__main__": # pragma: no cover
503
+ demo.launch(
504
+ server_name="0.0.0.0",
505
+ server_port=int(os.environ.get("PORT", 7860)),
506
+ show_api=True,
507
+ api_name="/generate_router_plan_streaming"
508
+ )
509
+ def trim_at_stop_sequences(text: str) -> Tuple[str, bool]:
510
+ earliest = None
511
+ for stop in STOP_SEQUENCES:
512
+ idx = text.find(stop)
513
+ if idx != -1 and (earliest is None or idx < earliest):
514
+ earliest = idx
515
+ if earliest is not None:
516
+ return text[:earliest], True
517
+ return text, False
518
+