import spaces import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer import torch import onnxruntime as ort import numpy as np from sentence_transformers import SentenceTransformer from huggingface_hub import hf_hub_download MODEL_ID = "yuhueng/qwen3-4b-singlish-base-v3" # replace with your model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, ) REPO_ID = "govtech/lionguard-v1" EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5" FILENAME = "models/lionguard-binary.onnx" embedder = SentenceTransformer(EMBEDDING_MODEL) model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) session = ort.InferenceSession(model_path) def check_safety(text): embedding = embedder.encode([text], normalize_embeddings=True) input_name = session.get_inputs()[0].name pred = session.run(None, {input_name: embedding.astype(np.float32)})[0] return "Unsafe" if pred[0] == 1 else "Safe" @spaces.GPU(duration=60) def inference(prompt: str, max_tokens: int = 256) -> str: model.to("cuda") # Move to GPU inside decorated function SYSTEM_PROMPT = """ You are having a casual conversation with a user in Singapore. Keep responses helpful and friendly. Avoid sensitive topics like politics, religion, or race. If asked about harmful activities, politely decline. """ messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt} ] text = tokenizer.apply_chat_template( messages, tokenize = False, add_generation_prompt = True, # Must add for generation ) inputs = tokenizer(text, return_tensors="pt").to("cuda") outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=0.7, top_p=0.8, top_k=20, ) response = tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) safety = check_safety(response) json = {"response": response, "safety": safety} return json # # Use TextIteratorStreamer instead of TextStreamer # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # generation_kwargs = dict( # **inputs, # max_new_tokens=max_tokens, # temperature=0.7, # top_p=0.8, # top_k=20, # streamer=streamer, # ) # # Run generation in separate thread # thread = Thread(target=model.generate, kwargs=generation_kwargs) # thread.start() # # Yield tokens as they come # generated_text = "" # for new_text in streamer: # generated_text += new_text # yield generated_text # yield cumulative text for Gradio demo = gr.Interface( fn=inference, inputs=[ gr.Textbox(label="prompt"), gr.Number(value=256, label="max_tokens") ], outputs=gr.Textbox(label="response"), api_name="inference" # explicit endpoint name: /inference ) demo.launch()