Alikestocode commited on
Commit
9a4d6d3
Β·
1 Parent(s): 597f1a9

Add user-configurable GPU duration slider (60-1800 seconds)

Browse files

- Add GPU Duration slider in UI (default: 600 seconds)
- Refactor to use spaces.GPU context manager with dynamic duration
- Allow users to set GPU time allocation per request
- Maintain backward compatibility with default 600s wrapper

Files changed (1) hide show
  1. app.py +122 -96
app.py CHANGED
@@ -269,8 +269,7 @@ def format_validation_message(ok: bool, issues: List[str]) -> str:
269
  return f"❌ Issues detected:\n{bullets}"
270
 
271
 
272
- @spaces.GPU(duration=600)
273
- def generate_router_plan_streaming(
274
  user_task: str,
275
  context: str,
276
  acceptance: str,
@@ -281,8 +280,9 @@ def generate_router_plan_streaming(
281
  max_new_tokens: int,
282
  temperature: float,
283
  top_p: float,
 
284
  ):
285
- """Generator function for streaming token output."""
286
  if not user_task.strip():
287
  yield "", {}, "❌ User task is required.", ""
288
  return
@@ -291,100 +291,124 @@ def generate_router_plan_streaming(
291
  yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
292
  return
293
 
294
- try:
295
- prompt = build_router_prompt(
296
- user_task=user_task,
297
- context=context,
298
- acceptance=acceptance,
299
- extra_guidance=extra_guidance,
300
- difficulty=difficulty,
301
- tags=tags,
302
- )
303
-
304
- generator = load_pipeline(model_choice)
305
-
306
- # Get the underlying model and tokenizer
307
- model = generator.model
308
- tokenizer = generator.tokenizer
309
-
310
- # Set up streaming
311
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
312
-
313
- # Prepare inputs
314
- inputs = tokenizer(prompt, return_tensors="pt")
315
- if hasattr(model, 'device'):
316
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
317
- elif torch.cuda.is_available():
318
- inputs = {k: v.cuda() for k, v in inputs.items()}
319
-
320
- # Start generation in a separate thread
321
- generation_kwargs = {
322
- **inputs,
323
- "max_new_tokens": max_new_tokens,
324
- "temperature": temperature,
325
- "top_p": top_p,
326
- "do_sample": True,
327
- "streamer": streamer,
328
- "eos_token_id": tokenizer.eos_token_id,
329
- "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
330
- }
331
-
332
- def _generate():
333
- with torch.inference_mode():
334
- model.generate(**generation_kwargs)
335
-
336
- thread = Thread(target=_generate)
337
- thread.start()
338
-
339
- # Stream tokens
340
- completion = ""
341
- parsed_plan: Dict[str, Any] | None = None
342
- validation_msg = "πŸ”„ Generating..."
343
-
344
- for new_text in streamer:
345
- completion += new_text
346
- chunk = completion
347
- finished = False
348
- display_plan = parsed_plan or {}
349
-
350
- chunk, finished = trim_at_stop_sequences(chunk)
351
 
352
- try:
353
- json_block = extract_json_from_text(chunk)
354
- candidate_plan = json.loads(json_block)
355
- ok, issues = validate_router_plan(candidate_plan)
356
- validation_msg = format_validation_message(ok, issues)
357
- parsed_plan = candidate_plan if ok else parsed_plan
358
- display_plan = candidate_plan
359
- except Exception:
360
- # Ignore until JSON is complete
361
- pass
362
-
363
- yield chunk, display_plan, validation_msg, prompt
364
-
365
- if finished:
366
- completion = chunk
367
- break
368
-
369
- # Final processing after streaming completes
370
- thread.join()
371
-
372
- completion = trim_at_stop_sequences(completion.strip())[0]
373
- if parsed_plan is None:
374
- try:
375
- json_block = extract_json_from_text(completion)
376
- parsed_plan = json.loads(json_block)
377
- ok, issues = validate_router_plan(parsed_plan)
378
- validation_msg = format_validation_message(ok, issues)
379
- except Exception as exc:
380
- parsed_plan = {}
381
- validation_msg = f"❌ JSON parsing failed: {exc}"
382
-
383
- yield completion, parsed_plan, validation_msg, prompt
384
-
385
- except Exception as exc:
386
- error_msg = f"❌ Generation failed: {str(exc)}"
387
- yield "", {}, error_msg, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
 
390
  def clear_outputs():
@@ -446,6 +470,7 @@ def build_ui():
446
  max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
447
  temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
448
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
 
449
 
450
  generate_btn = gr.Button("Generate Router Plan", variant="primary")
451
  clear_btn = gr.Button("Clear", variant="secondary")
@@ -469,6 +494,7 @@ def build_ui():
469
  max_new_tokens,
470
  temperature,
471
  top_p,
 
472
  ],
473
  outputs=[raw_output, plan_json, validation_msg, prompt_view],
474
  show_progress="full",
 
269
  return f"❌ Issues detected:\n{bullets}"
270
 
271
 
272
+ def _generate_router_plan_streaming_internal(
 
273
  user_task: str,
274
  context: str,
275
  acceptance: str,
 
280
  max_new_tokens: int,
281
  temperature: float,
282
  top_p: float,
283
+ gpu_duration: int,
284
  ):
285
+ """Internal generator function for streaming token output."""
286
  if not user_task.strip():
287
  yield "", {}, "❌ User task is required.", ""
288
  return
 
291
  yield "", {}, f"❌ Invalid model choice: {model_choice}. Available: {list(MODELS.keys())}", ""
292
  return
293
 
294
+ # Use GPU context manager with user-specified duration
295
+ with spaces.GPU(duration=gpu_duration):
296
+ try:
297
+ prompt = build_router_prompt(
298
+ user_task=user_task,
299
+ context=context,
300
+ acceptance=acceptance,
301
+ extra_guidance=extra_guidance,
302
+ difficulty=difficulty,
303
+ tags=tags,
304
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ generator = load_pipeline(model_choice)
307
+
308
+ # Get the underlying model and tokenizer
309
+ model = generator.model
310
+ tokenizer = generator.tokenizer
311
+
312
+ # Set up streaming
313
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
314
+
315
+ # Prepare inputs
316
+ inputs = tokenizer(prompt, return_tensors="pt")
317
+ if hasattr(model, 'device'):
318
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
319
+ elif torch.cuda.is_available():
320
+ inputs = {k: v.cuda() for k, v in inputs.items()}
321
+
322
+ # Start generation in a separate thread
323
+ generation_kwargs = {
324
+ **inputs,
325
+ "max_new_tokens": max_new_tokens,
326
+ "temperature": temperature,
327
+ "top_p": top_p,
328
+ "do_sample": True,
329
+ "streamer": streamer,
330
+ "eos_token_id": tokenizer.eos_token_id,
331
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
332
+ }
333
+
334
+ def _generate():
335
+ with torch.inference_mode():
336
+ model.generate(**generation_kwargs)
337
+
338
+ thread = Thread(target=_generate)
339
+ thread.start()
340
+
341
+ # Stream tokens
342
+ completion = ""
343
+ parsed_plan: Dict[str, Any] | None = None
344
+ validation_msg = "πŸ”„ Generating..."
345
+
346
+ for new_text in streamer:
347
+ completion += new_text
348
+ chunk = completion
349
+ finished = False
350
+ display_plan = parsed_plan or {}
351
+
352
+ chunk, finished = trim_at_stop_sequences(chunk)
353
+
354
+ try:
355
+ json_block = extract_json_from_text(chunk)
356
+ candidate_plan = json.loads(json_block)
357
+ ok, issues = validate_router_plan(candidate_plan)
358
+ validation_msg = format_validation_message(ok, issues)
359
+ parsed_plan = candidate_plan if ok else parsed_plan
360
+ display_plan = candidate_plan
361
+ except Exception:
362
+ # Ignore until JSON is complete
363
+ pass
364
+
365
+ yield chunk, display_plan, validation_msg, prompt
366
+
367
+ if finished:
368
+ completion = chunk
369
+ break
370
+
371
+ # Final processing after streaming completes
372
+ thread.join()
373
+
374
+ completion = trim_at_stop_sequences(completion.strip())[0]
375
+ if parsed_plan is None:
376
+ try:
377
+ json_block = extract_json_from_text(completion)
378
+ parsed_plan = json.loads(json_block)
379
+ ok, issues = validate_router_plan(parsed_plan)
380
+ validation_msg = format_validation_message(ok, issues)
381
+ except Exception as exc:
382
+ parsed_plan = {}
383
+ validation_msg = f"❌ JSON parsing failed: {exc}"
384
+
385
+ yield completion, parsed_plan, validation_msg, prompt
386
+
387
+ except Exception as exc:
388
+ error_msg = f"❌ Generation failed: {str(exc)}"
389
+ yield "", {}, error_msg, ""
390
+
391
+
392
+ @spaces.GPU(duration=600) # Default wrapper for backward compatibility
393
+ def generate_router_plan_streaming(
394
+ user_task: str,
395
+ context: str,
396
+ acceptance: str,
397
+ extra_guidance: str,
398
+ difficulty: str,
399
+ tags: str,
400
+ model_choice: str,
401
+ max_new_tokens: int,
402
+ temperature: float,
403
+ top_p: float,
404
+ gpu_duration: int = 600,
405
+ ):
406
+ """Wrapper function that calls internal generator with GPU duration."""
407
+ yield from _generate_router_plan_streaming_internal(
408
+ user_task, context, acceptance, extra_guidance,
409
+ difficulty, tags, model_choice, max_new_tokens,
410
+ temperature, top_p, gpu_duration
411
+ )
412
 
413
 
414
  def clear_outputs():
 
470
  max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
471
  temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
472
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
473
+ gpu_duration = gr.Slider(60, 1800, value=600, step=60, label="GPU Duration (seconds)", info="Maximum GPU time allocation for this request")
474
 
475
  generate_btn = gr.Button("Generate Router Plan", variant="primary")
476
  clear_btn = gr.Button("Clear", variant="secondary")
 
494
  max_new_tokens,
495
  temperature,
496
  top_p,
497
+ gpu_duration,
498
  ],
499
  outputs=[raw_output, plan_json, validation_msg, prompt_view],
500
  show_progress="full",