gabrielchua commited on
Commit
0ff9a77
Β·
1 Parent(s): 6291db1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -134
app.py CHANGED
@@ -1,59 +1,178 @@
1
  import gradio as gr
2
- import openai
3
  import os
4
- import sys
5
  import torch
 
 
 
6
 
7
- # # Add the parent directory to the path to import from final_model
8
- # sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'final_model'))
 
9
 
10
  from safetensors.torch import load_file
11
- from lionguard2 import LionGuard2
12
  from utils import get_embeddings
13
 
14
- # Set up OpenAI client
15
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
16
 
17
- # Load LionGuard2 model
18
- model = LionGuard2()
19
- model.eval()
 
 
 
 
20
 
21
- # Load model weights
22
- model_path = 'LionGuard2.safetensors'
23
- state_dict = load_file(model_path)
24
- model.load_state_dict(state_dict)
25
 
26
- def lionguard_2(message, threshold=0.5):
27
- """
28
- LionGuard 2 function that uses the actual model to determine if content is unsafe.
29
-
30
- Args:
31
- message: The text message to check
32
- threshold: Probability threshold for flagging content as unsafe (default: 0.5)
33
-
34
- Returns:
35
- bool: True if content is flagged as unsafe, False otherwise
36
- """
37
  try:
38
- # Get embeddings for the message
39
- embeddings = get_embeddings([message])
40
-
41
- # Get predictions from the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  results = model.predict(embeddings)
43
-
44
- # Check the binary classification result (overall safety)
45
- binary_prob = results['binary'][0] # First (and only) message's binary probability
46
-
47
- # Flag as unsafe if probability exceeds threshold
48
- return binary_prob > threshold
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
- print(f"Error in LionGuard 2: {e}")
52
- # In case of error, default to not flagging to avoid blocking legitimate content
53
- return False
 
 
 
 
 
54
 
 
 
 
 
 
 
 
55
  def get_openai_response(message, system_prompt="You are a helpful assistant."):
56
- """Get response from OpenAI API"""
57
  try:
58
  response = client.chat.completions.create(
59
  model="gpt-4.1-nano",
@@ -70,134 +189,146 @@ def get_openai_response(message, system_prompt="You are a helpful assistant."):
70
  return f"Error: {str(e)}. Please check your OpenAI API key."
71
 
72
  def openai_moderation(message):
73
- """
74
- OpenAI moderation function that uses OpenAI's built-in moderation API.
75
-
76
- Args:
77
- message: The text message to check
78
-
79
- Returns:
80
- bool: True if content is flagged as unsafe, False otherwise
81
- """
82
  try:
83
  response = client.moderations.create(input=message)
84
  return response.results[0].flagged
85
  except Exception as e:
86
  print(f"Error in OpenAI moderation: {e}")
87
- # In case of error, default to not flagging
 
 
 
 
 
 
 
 
 
88
  return False
89
 
90
  def process_message(message, history_no_mod, history_openai, history_lg):
91
- """Process message for all three chatbots"""
92
  if not message.strip():
93
  return history_no_mod, history_openai, history_lg, ""
94
-
95
- # Process for gpt-4.1-nano (no moderation)
96
  no_mod_response = get_openai_response(message)
97
  history_no_mod.append({"role": "user", "content": message})
98
  history_no_mod.append({"role": "assistant", "content": no_mod_response})
99
-
100
- # Process for gpt-4.1-nano with OpenAI moderation
101
  openai_flagged = openai_moderation(message)
102
  history_openai.append({"role": "user", "content": message})
103
-
104
  if openai_flagged:
105
  openai_response = "🚫 This message has been flagged by OpenAI moderation"
106
  history_openai.append({"role": "assistant", "content": openai_response})
107
  else:
108
- openai_response = get_openai_response(
109
- message,
110
- )
111
  history_openai.append({"role": "assistant", "content": openai_response})
112
-
113
- # Process for gpt-4.1-nano with LionGuard 2
114
  lg_flagged = lionguard_2(message)
115
  history_lg.append({"role": "user", "content": message})
116
-
117
  if lg_flagged:
118
  lg_response = "🚫 This message has been flagged by LionGuard 2"
119
  history_lg.append({"role": "assistant", "content": lg_response})
120
  else:
121
- lg_response = get_openai_response(
122
- message,
123
- )
124
  history_lg.append({"role": "assistant", "content": lg_response})
125
-
126
  return history_no_mod, history_openai, history_lg, ""
127
 
128
  def clear_all_chats():
129
- """Clear all chat histories"""
130
  return [], [], []
131
 
132
- # Create the Gradio interface
133
- with gr.Blocks(title="LionGuard 2", theme=gr.themes.Soft()) as demo:
134
- gr.Markdown("# EMNLP 2025 System Demonstration: LionGuard 2 🦁")
135
- gr.Markdown("**LionGuard 2 is a content moderator localised to Singapore - use it to detect unsafe LLM inputs and outputs**")
136
-
137
- with gr.Row():
138
- with gr.Column(scale=1):
139
- gr.Markdown("## πŸ”΅ No Moderation")
140
- chatbot_no_mod = gr.Chatbot(
141
- height=800,
142
- label="No Moderation",
143
- show_label=False,
144
- bubble_full_width=False,
145
- type='messages'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
-
148
- with gr.Column(scale=1):
149
- gr.Markdown("## 🟠 OpenAI Moderation")
150
- chatbot_openai = gr.Chatbot(
151
- height=800,
152
- label="OpenAI Moderation",
153
- show_label=False,
154
- bubble_full_width=False,
155
- type='messages'
156
  )
157
-
158
- with gr.Column(scale=1):
159
- gr.Markdown("## πŸ›‘οΈ LionGuard 2")
160
- chatbot_lg = gr.Chatbot(
161
- height=800,
162
- label="LionGuard 2",
163
- show_label=False,
164
- bubble_full_width=False,
165
- type='messages'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
-
168
- # Single input for all chatbots
169
- gr.Markdown("### πŸ’¬ Send Message to All Models")
170
- with gr.Row():
171
- message_input = gr.Textbox(
172
- placeholder="Type your message to compare responses...",
173
- show_label=False,
174
- scale=4
175
- )
176
- send_btn = gr.Button("Send", variant="primary", scale=1)
177
-
178
- # Control buttons
179
- with gr.Row():
180
- clear_btn = gr.Button("Clear All Chats", variant="stop")
181
-
182
- # Event handlers
183
- send_btn.click(
184
- process_message,
185
- inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
186
- outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
187
- )
188
-
189
- message_input.submit(
190
- process_message,
191
- inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
192
- outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
193
- )
194
-
195
- # Clear button
196
- clear_btn.click(
197
- clear_all_chats,
198
- outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg]
199
- )
200
-
201
- # Launch the app
202
  if __name__ == "__main__":
203
- demo.launch(share=True, debug=True)
 
1
  import gradio as gr
 
2
  import os
3
+ import openai
4
  import torch
5
+ import sys
6
+ import uuid
7
+ from datetime import datetime
8
 
9
+ import json
10
+ import gspread
11
+ from google.oauth2 import service_account
12
 
13
  from safetensors.torch import load_file
14
+ from lionguard2 import LionGuard2, CATEGORIES
15
  from utils import get_embeddings
16
 
17
+ # -- OpenAI Setup --
18
  client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
19
 
20
+ # -- Model Loading --
21
+ def load_lionguard2():
22
+ model = LionGuard2()
23
+ model.eval()
24
+ state_dict = load_file('LionGuard2.safetensors')
25
+ model.load_state_dict(state_dict)
26
+ return model
27
 
28
+ model = load_lionguard2()
 
 
 
29
 
30
+ # -- Google Sheets Config --
31
+ GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
32
+ GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
33
+ RESULTS_SHEET_NAME = "results"
34
+ VOTES_SHEET_NAME = "votes"
35
+
36
+ def save_results_data(row):
 
 
 
 
37
  try:
38
+ credentials = service_account.Credentials.from_service_account_info(
39
+ json.loads(GOOGLE_CREDENTIALS),
40
+ scopes=[
41
+ "https://www.googleapis.com/auth/spreadsheets",
42
+ "https://www.googleapis.com/auth/drive",
43
+ ],
44
+ )
45
+ gc = gspread.authorize(credentials)
46
+ sheet = gc.open_by_url(GOOGLE_SHEET_URL)
47
+ ws = sheet.worksheet(RESULTS_SHEET_NAME)
48
+ ws.append_row(list(row.values()))
49
+ except Exception as e:
50
+ print(f"Error saving results data: {e}")
51
+
52
+ def save_vote_data(text_id, agree):
53
+ try:
54
+ credentials = service_account.Credentials.from_service_account_info(
55
+ json.loads(GOOGLE_CREDENTIALS),
56
+ scopes=[
57
+ "https://www.googleapis.com/auth/spreadsheets",
58
+ "https://www.googleapis.com/auth/drive",
59
+ ],
60
+ )
61
+ gc = gspread.authorize(credentials)
62
+ sheet = gc.open_by_url(GOOGLE_SHEET_URL)
63
+ ws = sheet.worksheet(VOTES_SHEET_NAME)
64
+ vote_row = {
65
+ "datetime": datetime.now().isoformat(),
66
+ "text_id": text_id,
67
+ "agree": agree
68
+ }
69
+ ws.append_row(list(vote_row.values()))
70
+ except Exception as e:
71
+ print(f"Error saving vote data: {e}")
72
+
73
+ # --- Classifier logic ---
74
+
75
+ def format_score_with_style(score_str):
76
+ if score_str == "-":
77
+ return '<span style="color: #9ca3af;">-</span>'
78
+ try:
79
+ score = float(score_str)
80
+ percentage = int(score * 100)
81
+ if score < 0.4:
82
+ return f'<span style="color: #34d399; font-weight:600;">πŸ‘Œ {percentage}%</span>'
83
+ elif 0.4 <= score < 0.7:
84
+ return f'<span style="color: #fbbf24; font-weight:600;">⚠️ {percentage}%</span>'
85
+ else:
86
+ return f'<span style="color: #fca5a5; font-weight:600;">🚨 {percentage}%</span>'
87
+ except:
88
+ return score_str
89
+
90
+ def format_binary_score(score):
91
+ percentage = int(score * 100)
92
+ if score < 0.4:
93
+ return f'<div style="color: #34d399; font-weight:700;">βœ… Pass ({percentage}/100)</div>'
94
+ elif 0.4 <= score < 0.7:
95
+ return f'<div style="color: #fbbf24; font-weight:700;">⚠️ Warning ({percentage}/100)</div>'
96
+ else:
97
+ return f'<div style="color: #fca5a5; font-weight:700;">🚨 Fail ({percentage}/100)</div>'
98
+
99
+ def analyze_text(text):
100
+ if not text.strip():
101
+ empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
102
+ return empty_html, empty_html, "", ""
103
+ try:
104
+ text_id = str(uuid.uuid4())
105
+ embeddings = get_embeddings([text])
106
  results = model.predict(embeddings)
107
+ binary_score = results.get('binary', [0.0])[0]
108
+
109
+ main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
110
+ categories_html = []
111
+ for category in main_categories:
112
+ subcategories = CATEGORIES[category]
113
+ category_name = category.replace('_', ' ').title()
114
+ category_emojis = {
115
+ 'Hateful': '🀬',
116
+ 'Insults': 'πŸ’’',
117
+ 'Sexual': 'πŸ”ž',
118
+ 'Physical Violence': 'βš”οΈ',
119
+ 'Self Harm': '☹️',
120
+ 'All Other Misconduct': 'πŸ™…β€β™€οΈ'
121
+ }
122
+ category_display = f"{category_emojis.get(category_name, 'πŸ“')} {category_name}"
123
+ level_scores = [results.get(subcategory_key, [0.0])[0] for subcategory_key in subcategories]
124
+ max_score = max(level_scores) if level_scores else 0.0
125
+ categories_html.append(f'''
126
+ <tr>
127
+ <td>{category_display}</td>
128
+ <td style="text-align: center;">{format_score_with_style(f"{max_score:.4f}")}</td>
129
+ </tr>
130
+ ''')
131
+
132
+ html_table = f'''
133
+ <table style="width:100%">
134
+ <thead>
135
+ <tr><th>Category</th><th>Score</th></tr>
136
+ </thead>
137
+ <tbody>
138
+ {''.join(categories_html)}
139
+ </tbody>
140
+ </table>
141
+ '''
142
+
143
+ # Save to Google Sheets if enabled
144
+ if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
145
+ results_row = {
146
+ "datetime": datetime.now().isoformat(),
147
+ "text_id": text_id,
148
+ "text": text,
149
+ "binary_score": binary_score,
150
+ # Add all category scores as before...
151
+ }
152
+ save_results_data(results_row)
153
+
154
+ voting_html = '<div>Help improve LionGuard2! Rate the analysis below.</div>'
155
+
156
+ return format_binary_score(binary_score), html_table, text_id, voting_html
157
+
158
  except Exception as e:
159
+ error_msg = f"Error analyzing text: {str(e)}"
160
+ return f'<div style="color: #fca5a5;">❌ {error_msg}</div>', '', '', ''
161
+
162
+ def vote_thumbs_up(text_id):
163
+ if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
164
+ save_vote_data(text_id, True)
165
+ return '<div style="color: #34d399; font-weight:700;">πŸŽ‰ Thank you!</div>'
166
+ return '<div>Voting not available</div>'
167
 
168
+ def vote_thumbs_down(text_id):
169
+ if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
170
+ save_vote_data(text_id, False)
171
+ return '<div style="color: #fca5a5; font-weight:700;">πŸ“ Thanks for the feedback!</div>'
172
+ return '<div>Voting not available</div>'
173
+
174
+ # --- Chatbot guardrail logic ---
175
  def get_openai_response(message, system_prompt="You are a helpful assistant."):
 
176
  try:
177
  response = client.chat.completions.create(
178
  model="gpt-4.1-nano",
 
189
  return f"Error: {str(e)}. Please check your OpenAI API key."
190
 
191
  def openai_moderation(message):
 
 
 
 
 
 
 
 
 
192
  try:
193
  response = client.moderations.create(input=message)
194
  return response.results[0].flagged
195
  except Exception as e:
196
  print(f"Error in OpenAI moderation: {e}")
197
+ return False
198
+
199
+ def lionguard_2(message, threshold=0.5):
200
+ try:
201
+ embeddings = get_embeddings([message])
202
+ results = model.predict(embeddings)
203
+ binary_prob = results['binary'][0]
204
+ return binary_prob > threshold
205
+ except Exception as e:
206
+ print(f"Error in LionGuard 2: {e}")
207
  return False
208
 
209
  def process_message(message, history_no_mod, history_openai, history_lg):
 
210
  if not message.strip():
211
  return history_no_mod, history_openai, history_lg, ""
 
 
212
  no_mod_response = get_openai_response(message)
213
  history_no_mod.append({"role": "user", "content": message})
214
  history_no_mod.append({"role": "assistant", "content": no_mod_response})
215
+
 
216
  openai_flagged = openai_moderation(message)
217
  history_openai.append({"role": "user", "content": message})
 
218
  if openai_flagged:
219
  openai_response = "🚫 This message has been flagged by OpenAI moderation"
220
  history_openai.append({"role": "assistant", "content": openai_response})
221
  else:
222
+ openai_response = get_openai_response(message)
 
 
223
  history_openai.append({"role": "assistant", "content": openai_response})
224
+
 
225
  lg_flagged = lionguard_2(message)
226
  history_lg.append({"role": "user", "content": message})
 
227
  if lg_flagged:
228
  lg_response = "🚫 This message has been flagged by LionGuard 2"
229
  history_lg.append({"role": "assistant", "content": lg_response})
230
  else:
231
+ lg_response = get_openai_response(message)
 
 
232
  history_lg.append({"role": "assistant", "content": lg_response})
233
+
234
  return history_no_mod, history_openai, history_lg, ""
235
 
236
  def clear_all_chats():
 
237
  return [], [], []
238
 
239
+ # ---- MAIN GRADIO UI ----
240
+
241
+ DISCLAIMER = """
242
+ <div style='background: #fbbf24; color: #1e293b; border-radius: 8px; padding: 14px; margin-bottom: 12px; font-size: 15px; font-weight:500;'>
243
+ ⚠️ LionGuard 2 is an experimental ML model and may make mistakes. All entries are logged (anonymised) to improve the model.
244
+ </div>
245
+ """
246
+
247
+ with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo:
248
+ gr.HTML("<h1 style='text-align:center'>LionGuard 2 Demo</h1>")
249
+
250
+ with gr.Tabs():
251
+ with gr.Tab("Classifier"):
252
+ gr.HTML(DISCLAIMER)
253
+ with gr.Row():
254
+ with gr.Column(scale=1, min_width=400):
255
+ text_input = gr.Textbox(
256
+ label="Enter text to analyze:",
257
+ placeholder="Type your text here...",
258
+ lines=8,
259
+ max_lines=16,
260
+ container=True
261
+ )
262
+ analyze_btn = gr.Button("Analyze", variant="primary")
263
+ with gr.Column(scale=1, min_width=400):
264
+ binary_output = gr.HTML(
265
+ value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
266
+ )
267
+ category_table = gr.HTML(
268
+ value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
269
+ )
270
+ voting_feedback = gr.HTML(value="")
271
+ current_text_id = gr.Textbox(value="", visible=False)
272
+
273
+ with gr.Row(visible=False) as voting_buttons_row:
274
+ thumbs_up_btn = gr.Button("πŸ‘ Looks Accurate", variant="primary")
275
+ thumbs_down_btn = gr.Button("πŸ‘Ž Looks Wrong", variant="secondary")
276
+
277
+ def analyze_and_show_voting(text):
278
+ binary_score, category_table_val, text_id, voting_html = analyze_text(text)
279
+ show_vote = gr.update(visible=True) if text_id else gr.update(visible=False)
280
+ return binary_score, category_table_val, text_id, show_vote, "", ""
281
+
282
+ analyze_btn.click(
283
+ analyze_and_show_voting,
284
+ inputs=[text_input],
285
+ outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback]
286
  )
287
+ text_input.submit(
288
+ analyze_and_show_voting,
289
+ inputs=[text_input],
290
+ outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback]
 
 
 
 
 
291
  )
292
+ thumbs_up_btn.click(vote_thumbs_up, inputs=[current_text_id], outputs=[voting_feedback])
293
+ thumbs_down_btn.click(vote_thumbs_down, inputs=[current_text_id], outputs=[voting_feedback])
294
+
295
+ with gr.Tab("Chatbot Guardrail"):
296
+ gr.HTML(DISCLAIMER)
297
+ with gr.Row():
298
+ with gr.Column(scale=1):
299
+ gr.Markdown("#### πŸ”΅ No Moderation")
300
+ chatbot_no_mod = gr.Chatbot(height=400, label="No Moderation", show_label=False, bubble_full_width=False, type='messages')
301
+ with gr.Column(scale=1):
302
+ gr.Markdown("#### 🟠 OpenAI Moderation")
303
+ chatbot_openai = gr.Chatbot(height=400, label="OpenAI Moderation", show_label=False, bubble_full_width=False, type='messages')
304
+ with gr.Column(scale=1):
305
+ gr.Markdown("#### πŸ›‘οΈ LionGuard 2")
306
+ chatbot_lg = gr.Chatbot(height=400, label="LionGuard 2", show_label=False, bubble_full_width=False, type='messages')
307
+ gr.Markdown("##### πŸ’¬ Send Message to All Models")
308
+ with gr.Row():
309
+ message_input = gr.Textbox(
310
+ placeholder="Type your message to compare responses...",
311
+ show_label=False,
312
+ scale=4
313
+ )
314
+ send_btn = gr.Button("Send", variant="primary", scale=1)
315
+ with gr.Row():
316
+ clear_btn = gr.Button("Clear All Chats", variant="stop")
317
+
318
+ send_btn.click(
319
+ process_message,
320
+ inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
321
+ outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
322
  )
323
+ message_input.submit(
324
+ process_message,
325
+ inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
326
+ outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
327
+ )
328
+ clear_btn.click(
329
+ clear_all_chats,
330
+ outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg]
331
+ )
332
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  if __name__ == "__main__":
334
+ demo.launch()