| import gradio as gr |
| from transformers import pipeline |
| import torch |
| import spaces |
|
|
| |
| model_id = "facebook/MobileLLM-R1-950M" |
| pipe = pipeline( |
| "text-generation", |
| model=model_id, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| ) |
|
|
| @spaces.GPU(duration=120) |
| def respond(message, history): |
| prompt = "" |
| for user_msg, assistant_msg in history: |
| if user_msg: |
| prompt += f"User: {user_msg}\n" |
| if assistant_msg: |
| prompt += f"Assistant: {assistant_msg}\n" |
| |
| |
| prompt += f"User: {message}\nAssistant: " |
| |
| |
| streamer = pipe.tokenizer.decode |
| |
| |
| inputs = pipe.tokenizer(prompt, return_tensors="pt").to(pipe.model.device) |
| |
| with torch.no_grad(): |
| outputs = pipe.model.generate( |
| **inputs, |
| max_new_tokens=10000, |
| temperature=0.7, |
| do_sample=True, |
| pad_token_id=pipe.tokenizer.eos_token_id, |
| ) |
| |
| |
| generated_tokens = outputs[0][inputs['input_ids'].shape[-1]:] |
| |
| |
| response_text = "" |
| for i in range(len(generated_tokens)): |
| token = generated_tokens[i:i+1] |
| token_text = pipe.tokenizer.decode(token, skip_special_tokens=True) |
| response_text += token_text |
| yield response_text |
|
|
| |
| demo = gr.ChatInterface( |
| fn=respond, |
| title="MobileLLM Chat", |
| description="Chat with Meta MobileLLM-R1-950M", |
| examples=[ |
| "Write a Python function that returns the square of a number.", |
| "Compute: 1-2+3-4+5- ... +99-100.", |
| "Write a C++ program that prints 'Hello, World!'.", |
| "Explain how recursion works in programming.", |
| "What is the difference between a list and a tuple in Python?", |
| ], |
| theme=gr.themes.Soft(), |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |