DSDUDEd commited on
Commit
a0cf557
·
verified ·
1 Parent(s): 825139d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ # Load model and tokenizer
7
+ model_name = "DSDUDEd/firebase"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+
11
+ # Ensure model uses GPU if available
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model = model.to(device)
14
+
15
+ # Function to generate responses
16
+ def chat_with_model(user_input, chat_history=[]):
17
+ # Append user input to history
18
+ chat_history.append({"role": "user", "content": user_input})
19
+
20
+ # Prepare prompt
21
+ prompt = ""
22
+ for turn in chat_history:
23
+ if turn["role"] == "user":
24
+ prompt += f"User: {turn['content']}\n"
25
+ else:
26
+ prompt += f"AI: {turn['content']}\n"
27
+
28
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
29
+ outputs = model.generate(
30
+ **inputs,
31
+ max_new_tokens=150,
32
+ do_sample=True,
33
+ top_p=0.9,
34
+ temperature=0.7,
35
+ )
36
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+
38
+ # Extract AI response (assume last part after "AI: ")
39
+ response_text = response.split("AI:")[-1].strip()
40
+
41
+ chat_history.append({"role": "ai", "content": response_text})
42
+
43
+ # Prepare chat history for Gradio
44
+ chat_for_gradio = [(turn["content"], "") if turn["role"]=="user" else ("", turn["content"]) for turn in chat_history]
45
+
46
+ return chat_for_gradio, chat_history
47
+
48
+ # Build Gradio interface
49
+ with gr.Blocks() as demo:
50
+ chat_history_state = gr.State([])
51
+ chatbot = gr.Chatbot()
52
+ msg = gr.Textbox(label="Enter your message")
53
+ submit = gr.Button("Send")
54
+
55
+ submit.click(chat_with_model, inputs=[msg, chat_history_state], outputs=[chatbot, chat_history_state])
56
+
57
+ demo.launch()