SinglishTest / app.py
yuhueng's picture
Update app.py
e91666b verified
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
import onnxruntime as ort
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
MODEL_ID = "yuhueng/qwen3-4b-singlish-base-v3" # replace with your model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
)
REPO_ID = "govtech/lionguard-v1"
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
FILENAME = "models/lionguard-binary.onnx"
embedder = SentenceTransformer(EMBEDDING_MODEL)
model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
session = ort.InferenceSession(model_path)
def check_safety(text):
embedding = embedder.encode([text], normalize_embeddings=True)
input_name = session.get_inputs()[0].name
pred = session.run(None, {input_name: embedding.astype(np.float32)})[0]
return "Unsafe" if pred[0] == 1 else "Safe"
@spaces.GPU(duration=60)
def inference(prompt: str, max_tokens: int = 256) -> str:
model.to("cuda") # Move to GPU inside decorated function
SYSTEM_PROMPT = """
You are having a casual conversation with a user in Singapore.
Keep responses helpful and friendly. Avoid sensitive topics like politics, religion, or race.
If asked about harmful activities, politely decline.
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize = False,
add_generation_prompt = True, # Must add for generation
)
inputs = tokenizer(text, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
top_p=0.8,
top_k=20,
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
safety = check_safety(response)
json = {"response": response, "safety": safety}
return json
# # Use TextIteratorStreamer instead of TextStreamer
# streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# generation_kwargs = dict(
# **inputs,
# max_new_tokens=max_tokens,
# temperature=0.7,
# top_p=0.8,
# top_k=20,
# streamer=streamer,
# )
# # Run generation in separate thread
# thread = Thread(target=model.generate, kwargs=generation_kwargs)
# thread.start()
# # Yield tokens as they come
# generated_text = ""
# for new_text in streamer:
# generated_text += new_text
# yield generated_text # yield cumulative text for Gradio
demo = gr.Interface(
fn=inference,
inputs=[
gr.Textbox(label="prompt"),
gr.Number(value=256, label="max_tokens")
],
outputs=gr.Textbox(label="response"),
api_name="inference" # explicit endpoint name: /inference
)
demo.launch()