Javedalam commited on
Commit
9a4d34a
·
verified ·
1 Parent(s): ae0e964

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, threading
2
+ import gradio as gr
3
+ import torch, spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+
6
+ MODEL_ID = "WeiboAI/VibeThinker-1.5B"
7
+ SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer."
8
+ MAX_INPUT_TOKENS = 384
9
+ MAX_NEW_TOKENS = 128 # keep short so the slice finishes
10
+ TEMPERATURE = 0.4
11
+ TOP_P = 0.9
12
+ NO_TOKEN_TIMEOUT = 8 # seconds with no new token -> stop the stream
13
+
14
+ print(f"⏳ Loading {MODEL_ID} …", flush=True)
15
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_ID,
18
+ trust_remote_code=True,
19
+ low_cpu_mem_usage=True,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map="auto",
22
+ ).eval()
23
+ print("✅ Model ready.", flush=True)
24
+
25
+ def _apply_template(messages):
26
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
27
+
28
+ def _clip_inputs(prompt_text, max_tokens):
29
+ ids = tok([prompt_text], return_tensors="pt")
30
+ if ids["input_ids"].shape[-1] > max_tokens:
31
+ ids = {k: v[:, -max_tokens:] for k, v in ids.items()}
32
+ return {k: v.to(model.device) for k, v in ids.items()}
33
+
34
+ @spaces.GPU(duration=90) # shorter than max to reduce pre-emption
35
+ def respond(message, history):
36
+ history = history or []
37
+ msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history,
38
+ {"role": "user", "content": str(message)}]
39
+
40
+ prompt = _apply_template(msgs)
41
+ inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS)
42
+ streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
43
+
44
+ gen_kwargs = dict(
45
+ **inputs,
46
+ streamer=streamer,
47
+ do_sample=True,
48
+ temperature=TEMPERATURE,
49
+ top_p=TOP_P,
50
+ repetition_penalty=1.18,
51
+ max_new_tokens=MAX_NEW_TOKENS,
52
+ pad_token_id=tok.eos_token_id,
53
+ use_cache=True,
54
+ )
55
+
56
+ # run generate() in a daemon thread so it never blocks future calls
57
+ th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
58
+ th.start()
59
+
60
+ # start streaming
61
+ assistant = {"role": "assistant", "content": ""}
62
+ out = list(history) + [assistant]
63
+
64
+ last_token_time = time.time()
65
+ last_heartbeat = 0
66
+
67
+ for chunk in streamer:
68
+ assistant["content"] += chunk
69
+ last_token_time = time.time()
70
+ # heartbeat every ~4s so frontend never stalls
71
+ if int(time.time() - last_heartbeat) >= 4:
72
+ yield out
73
+ last_heartbeat = time.time()
74
+
75
+ # if the streamer hangs (no more chunks), enforce timeout
76
+ while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT:
77
+ # keep UI alive while waiting for last tokens
78
+ time.sleep(0.5)
79
+ yield out
80
+
81
+ # if still alive after timeout, we abort the turn gracefully
82
+ if th.is_alive():
83
+ assistant["content"] += "\n\n(Stopped: no tokens for {}s)".format(NO_TOKEN_TIMEOUT)
84
+
85
+ yield out # final
86
+
87
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
+ gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)")
89
+ chat = gr.Chatbot(type="messages", height=520)
90
+ box = gr.Textbox(placeholder="Ask a question…")
91
+ send = gr.Button("Send", variant="primary")
92
+
93
+ def pipeline(msg, hist):
94
+ for hist in respond(msg, hist):
95
+ yield "", hist
96
+
97
+ box.submit(pipeline, [box, chat], [box, chat])
98
+ send.click(pipeline, [box, chat], [box, chat])
99
+
100
+ if __name__ == "__main__":
101
+ demo.queue(concurrency_count=1, max_size=16).launch()