Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
| import torch | |
| import onnxruntime as ort | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import hf_hub_download | |
| MODEL_ID = "yuhueng/qwen3-4b-singlish-base-v3" # replace with your model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| ) | |
| REPO_ID = "govtech/lionguard-v1" | |
| EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5" | |
| FILENAME = "models/lionguard-binary.onnx" | |
| embedder = SentenceTransformer(EMBEDDING_MODEL) | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) | |
| session = ort.InferenceSession(model_path) | |
| def check_safety(text): | |
| embedding = embedder.encode([text], normalize_embeddings=True) | |
| input_name = session.get_inputs()[0].name | |
| pred = session.run(None, {input_name: embedding.astype(np.float32)})[0] | |
| return "Unsafe" if pred[0] == 1 else "Safe" | |
| def inference(prompt: str, max_tokens: int = 256) -> str: | |
| model.to("cuda") # Move to GPU inside decorated function | |
| SYSTEM_PROMPT = """ | |
| You are having a casual conversation with a user in Singapore. | |
| Keep responses helpful and friendly. Avoid sensitive topics like politics, religion, or race. | |
| If asked about harmful activities, politely decline. | |
| """ | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize = False, | |
| add_generation_prompt = True, # Must add for generation | |
| ) | |
| inputs = tokenizer(text, return_tensors="pt").to("cuda") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=0.7, | |
| top_p=0.8, | |
| top_k=20, | |
| ) | |
| response = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| safety = check_safety(response) | |
| json = {"response": response, "safety": safety} | |
| return json | |
| # # Use TextIteratorStreamer instead of TextStreamer | |
| # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # generation_kwargs = dict( | |
| # **inputs, | |
| # max_new_tokens=max_tokens, | |
| # temperature=0.7, | |
| # top_p=0.8, | |
| # top_k=20, | |
| # streamer=streamer, | |
| # ) | |
| # # Run generation in separate thread | |
| # thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| # thread.start() | |
| # # Yield tokens as they come | |
| # generated_text = "" | |
| # for new_text in streamer: | |
| # generated_text += new_text | |
| # yield generated_text # yield cumulative text for Gradio | |
| demo = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Textbox(label="prompt"), | |
| gr.Number(value=256, label="max_tokens") | |
| ], | |
| outputs=gr.Textbox(label="response"), | |
| api_name="inference" # explicit endpoint name: /inference | |
| ) | |
| demo.launch() |