import gradio as gr
import os
import openai
import torch
import sys
import uuid
from datetime import datetime
import json
import gspread
from google.oauth2 import service_account
from safetensors.torch import load_file
from lionguard2 import LionGuard2, CATEGORIES
from utils import get_embeddings
# -- OpenAI Setup --
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# -- Model Loading --
def load_lionguard2():
model = LionGuard2()
model.eval()
state_dict = load_file('LionGuard2.safetensors')
model.load_state_dict(state_dict)
return model
model = load_lionguard2()
# -- Google Sheets Config --
GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
RESULTS_SHEET_NAME = "results"
VOTES_SHEET_NAME = "votes"
def save_results_data(row):
try:
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
gc = gspread.authorize(credentials)
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(RESULTS_SHEET_NAME)
ws.append_row(list(row.values()))
except Exception as e:
print(f"Error saving results data: {e}")
def save_vote_data(text_id, agree):
try:
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
gc = gspread.authorize(credentials)
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(VOTES_SHEET_NAME)
vote_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"agree": agree
}
ws.append_row(list(vote_row.values()))
except Exception as e:
print(f"Error saving vote data: {e}")
# --- Classifier logic ---
def format_score_with_style(score_str):
if score_str == "-":
return '-'
try:
score = float(score_str)
percentage = int(score * 100)
if score < 0.4:
return f'đ {percentage}%'
elif 0.4 <= score < 0.7:
return f'â ī¸ {percentage}%'
else:
return f'đ¨ {percentage}%'
except:
return score_str
def format_binary_score(score):
percentage = int(score * 100)
if score < 0.4:
return f'
â
Pass ({percentage}/100)
'
elif 0.4 <= score < 0.7:
return f'â ī¸ Warning ({percentage}/100)
'
else:
return f'đ¨ Fail ({percentage}/100)
'
def analyze_text(text):
if not text.strip():
empty_html = 'Enter text to analyze
'
return empty_html, empty_html, "", ""
try:
text_id = str(uuid.uuid4())
embeddings = get_embeddings([text])
results = model.predict(embeddings)
binary_score = results.get('binary', [0.0])[0]
main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
categories_html = []
for category in main_categories:
subcategories = CATEGORIES[category]
category_name = category.replace('_', ' ').title()
category_emojis = {
'Hateful': 'đ¤Ŧ',
'Insults': 'đĸ',
'Sexual': 'đ',
'Physical Violence': 'âī¸',
'Self Harm': 'âšī¸',
'All Other Misconduct': 'đ
ââī¸'
}
category_display = f"{category_emojis.get(category_name, 'đ')} {category_name}"
level_scores = [results.get(subcategory_key, [0.0])[0] for subcategory_key in subcategories]
max_score = max(level_scores) if level_scores else 0.0
categories_html.append(f'''
| {category_display} |
{format_score_with_style(f"{max_score:.4f}")} |
''')
html_table = f'''
| Category | Score |
{''.join(categories_html)}
'''
# Save to Google Sheets if enabled
if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
results_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"text": text,
"binary_score": binary_score,
# Add all category scores as before...
}
save_results_data(results_row)
voting_html = 'Help improve LionGuard2! Rate the analysis below.
'
return format_binary_score(binary_score), html_table, text_id, voting_html
except Exception as e:
error_msg = f"Error analyzing text: {str(e)}"
return f'â {error_msg}
', '', '', ''
def vote_thumbs_up(text_id):
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, True)
return 'đ Thank you!
'
return 'Voting not available
'
def vote_thumbs_down(text_id):
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, False)
return 'đ Thanks for the feedback!
'
return 'Voting not available
'
# --- Chatbot guardrail logic ---
def get_openai_response(message, system_prompt="You are a helpful assistant."):
try:
response = client.chat.completions.create(
model="gpt-4.1-nano",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
],
max_tokens=500,
temperature=0,
seed=42,
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {str(e)}. Please check your OpenAI API key."
def openai_moderation(message):
try:
response = client.moderations.create(input=message)
return response.results[0].flagged
except Exception as e:
print(f"Error in OpenAI moderation: {e}")
return False
def lionguard_2(message, threshold=0.5):
try:
embeddings = get_embeddings([message])
results = model.predict(embeddings)
binary_prob = results['binary'][0]
return binary_prob > threshold
except Exception as e:
print(f"Error in LionGuard 2: {e}")
return False
def process_message(message, history_no_mod, history_openai, history_lg):
if not message.strip():
return history_no_mod, history_openai, history_lg, ""
no_mod_response = get_openai_response(message)
history_no_mod.append({"role": "user", "content": message})
history_no_mod.append({"role": "assistant", "content": no_mod_response})
openai_flagged = openai_moderation(message)
history_openai.append({"role": "user", "content": message})
if openai_flagged:
openai_response = "đĢ This message has been flagged by OpenAI moderation"
history_openai.append({"role": "assistant", "content": openai_response})
else:
openai_response = get_openai_response(message)
history_openai.append({"role": "assistant", "content": openai_response})
lg_flagged = lionguard_2(message)
history_lg.append({"role": "user", "content": message})
if lg_flagged:
lg_response = "đĢ This message has been flagged by LionGuard 2"
history_lg.append({"role": "assistant", "content": lg_response})
else:
lg_response = get_openai_response(message)
history_lg.append({"role": "assistant", "content": lg_response})
return history_no_mod, history_openai, history_lg, ""
def clear_all_chats():
return [], [], []
# ---- MAIN GRADIO UI ----
DISCLAIMER = """
â ī¸ LionGuard 2 is an experimental ML model and may make mistakes. All entries are logged (anonymised) to improve the model.
"""
with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo:
gr.HTML("LionGuard 2 Demo
")
with gr.Tabs():
with gr.Tab("Classifier"):
gr.HTML(DISCLAIMER)
with gr.Row():
with gr.Column(scale=1, min_width=400):
text_input = gr.Textbox(
label="Enter text to analyze:",
placeholder="Type your text here...",
lines=8,
max_lines=16,
container=True
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1, min_width=400):
binary_output = gr.HTML(
value='Enter text to analyze
'
)
category_table = gr.HTML(
value='Category scores will appear here after analysis
'
)
voting_feedback = gr.HTML(value="")
current_text_id = gr.Textbox(value="", visible=False)
with gr.Row(visible=False) as voting_buttons_row:
thumbs_up_btn = gr.Button("đ Looks Accurate", variant="primary")
thumbs_down_btn = gr.Button("đ Looks Wrong", variant="secondary")
def analyze_and_show_voting(text):
binary_score, category_table_val, text_id, voting_html = analyze_text(text)
show_vote = gr.update(visible=True) if text_id else gr.update(visible=False)
return binary_score, category_table_val, text_id, show_vote, "", ""
analyze_btn.click(
analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback]
)
text_input.submit(
analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback]
)
thumbs_up_btn.click(vote_thumbs_up, inputs=[current_text_id], outputs=[voting_feedback])
thumbs_down_btn.click(vote_thumbs_down, inputs=[current_text_id], outputs=[voting_feedback])
with gr.Tab("Chatbot Guardrail"):
gr.HTML(DISCLAIMER)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### đĩ No Moderation")
chatbot_no_mod = gr.Chatbot(height=400, label="No Moderation", show_label=False, bubble_full_width=False, type='messages')
with gr.Column(scale=1):
gr.Markdown("#### đ OpenAI Moderation")
chatbot_openai = gr.Chatbot(height=400, label="OpenAI Moderation", show_label=False, bubble_full_width=False, type='messages')
with gr.Column(scale=1):
gr.Markdown("#### đĄī¸ LionGuard 2")
chatbot_lg = gr.Chatbot(height=400, label="LionGuard 2", show_label=False, bubble_full_width=False, type='messages')
gr.Markdown("##### đŦ Send Message to All Models")
with gr.Row():
message_input = gr.Textbox(
placeholder="Type your message to compare responses...",
show_label=False,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear All Chats", variant="stop")
send_btn.click(
process_message,
inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
)
message_input.submit(
process_message,
inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
)
clear_btn.click(
clear_all_chats,
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg]
)
if __name__ == "__main__":
demo.launch()