import streamlit as st from Rag import launch_depression_assistant, depression_assistant from openai import OpenAI from together import Together import time import os from dotenv import load_dotenv from feedback_utils import FeedbackManager load_dotenv() @st.cache_resource def load_embedding_model_cached(embedder_name): print(f"🔄 Loading cached embedding model: {embedder_name}") launch_depression_assistant(embedder_name=embedder_name, designated_client=None) print(f"✅ Cached embedding model loaded successfully: {embedder_name}") return { "embedder_name": embedder_name, "status": "loaded", "timestamp": time.time() } def get_llm_client(client_type, api_key): if client_type == "together": return Together(api_key=api_key) elif client_type == "nvidia": return OpenAI( base_url="https://integrate.api.nvidia.com/v1", api_key=api_key, ) return None def get_current_model_info(): if "cached_model_info" in st.session_state and st.session_state.cached_model_info: return st.session_state.cached_model_info return None def force_reload_model(): st.cache_resource.clear() if "cached_model_info" in st.session_state: del st.session_state.cached_model_info if "user_llm_client" in st.session_state: del st.session_state.user_llm_client # Initialize feedback manager if "feedback_manager" not in st.session_state: st.session_state.feedback_manager = FeedbackManager() st.set_page_config( page_title="Depression Assistant Chatbot", page_icon=":robot_face:", layout="wide", initial_sidebar_state="expanded" ) model_options = [ "Qwen/Qwen3-Embedding-0.6B", "jinaai/jina-embeddings-v3", # "BAAI/bge-large-en-v1.5", "BAAI/bge-small-en-v1.5", # "BAAI/bge-base-en-v1.5", "sentence-transformers/all-mpnet-base-v2", # "Other" ] # --- Sidebar --- st.sidebar.title("Settings") with st.sidebar: st.subheader("Model Selection") embedder_name = st.sidebar.selectbox( "Select embedder model", model_options, index=0 ) if embedder_name == "Other": embedder_name = st.sidebar.text_input('Enter the embedder model name') current_info = get_current_model_info() if current_info and current_info["embedder_name"] == embedder_name: st.success(f"✅ Current model: {embedder_name}") st.caption(f"Loaded at: {time.strftime('%H:%M:%S', time.localtime(current_info['timestamp']))}") else: if embedder_name: with st.spinner(f"Loading embedding model: {embedder_name}..."): try: model_info = load_embedding_model_cached(embedder_name=embedder_name) st.session_state.cached_model_info = model_info st.success(f"✅ Model {embedder_name} loaded successfully!") st.rerun() except Exception as e: st.error(f"❌ Failed to load model: {str(e)}") st.session_state.cached_model_info = None if st.button("🔄 Force Reload Model", help="Clear cache and reload the model"): force_reload_model() st.rerun() selected_model = st.sidebar.selectbox( 'Choose a model for generation', [ "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", "deepseek-ai/deepseek-r1", "meta/llama-3.3-70b-instruct" ], key='selected_model' ) if selected_model in ["deepseek-ai/deepseek-r1", "meta/llama-3.3-70b-instruct"]: max_length_default = 1000 client_type = "nvidia" api_key = os.getenv("NVIDIA_API_KEY") else: max_length_default = 500 client_type = "together" api_key = os.getenv("TOGETHER_API_KEY") client_key = f"{client_type}_{selected_model}" if "user_llm_client" not in st.session_state or st.session_state.get("client_key") != client_key: if api_key: st.session_state.user_llm_client = get_llm_client(client_type, api_key) st.session_state.client_key = client_key st.sidebar.success(f"✅ LLM Client: {client_type.upper()}") else: st.session_state.user_llm_client = None st.sidebar.error(f"❌ Missing API key for {client_type.upper()}") else: st.sidebar.info(f"📱 LLM Client: {client_type.upper()} (Ready)") temperature = st.sidebar.slider('Temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01) top_p = st.sidebar.slider('Top P', min_value=0.01, max_value=1.0, value=0.9, step=0.01) max_length = st.sidebar.slider('Max Length', min_value=100, max_value=1000, value=max_length_default, step=10) st.sidebar.markdown("---") st.sidebar.markdown("**Current Configuration:**") st.sidebar.caption(f"Embedder: {embedder_name}") st.sidebar.caption(f"LLM: {selected_model}") st.sidebar.caption(f"Client: {client_type.upper()}") st.sidebar.caption(f"Session ID: {st.session_state.get('client_key', 'None')[:20]}...") # Google Sheets status if st.session_state.feedback_manager.is_connected(): st.sidebar.success("📊 Google Sheets: Connected") # Check if conversation logging is available if hasattr(st.session_state.feedback_manager, 'conversation_worksheet') and st.session_state.feedback_manager.conversation_worksheet: st.sidebar.success("📝 Conversation Logging: Active") else: st.sidebar.warning("📝 Conversation Logging: Not Available") else: st.sidebar.error("📊 Google Sheets: Not Connected") # Show title and description st.title("💬 Depression Assistant Chatbot") if not get_current_model_info(): st.warning("⚠️ Please select and load an embedding model from the sidebar first.") st.stop() # Initialize chat history if "messages" not in st.session_state.keys(): st.session_state.messages = [{ "role": "assistant", "content": "Welcome to a prototype of the open-source and open-weight CANMAT/MDD 2023 depression Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at [johnjose.nunez@ubc.ca](johnjose.nunez@ubc.ca). Thanks!" }] # Initialize sources tracking if "message_sources" not in st.session_state: st.session_state.message_sources = {} # Initialize feedback tracking if "feedback_submitted" not in st.session_state: st.session_state.feedback_submitted = set() # Display chat messages from history on app rerun for idx, message in enumerate(st.session_state.messages): with st.chat_message(message["role"]): st.markdown(message["content"]) # Add feedback section for assistant messages (except the first welcome message) if message["role"] == "assistant" and idx > 0: # Add sources expander for this message if idx in st.session_state.message_sources: sources_expander = st.expander("📚 See Sources") with sources_expander: results = st.session_state.message_sources[idx] if results: for i, result in enumerate(results): st.markdown(f"**Source {i+1}:** **Similarity:** {result.get('similarity', 'N/A')}") st.write(f"**TEXT:** {result['text']}") st.markdown(f"**Section:** {result['section']}") st.markdown("---") else: st.markdown("No relevant sources found.") # Check if feedback was already submitted for this message feedback_key = f"feedback_submitted_{idx}" if feedback_key not in st.session_state.feedback_submitted: # Put feedback in an expander feedback_expander = st.expander("📝 Provide Feedback") with feedback_expander: col1, col2 = st.columns(2) with col1: st.markdown("**⭐ Rating Questions:**") source_rating = st.selectbox( "Please rate the answer provided. Higher ratings indicate better quality in the answer.", options=[None, 1, 2, 3, 4, 5], format_func=lambda x: "Select rating..." if x is None else f"{x} {'⭐' * x}", key=f"source_rating_{idx}" ) answer_rating = st.selectbox( "Please rate the quality of the data source provided - is it sufficient to answer the question? Higher ratings indicate better quality in the source.", options=[None, 1, 2, 3, 4, 5], format_func=lambda x: "Select rating..." if x is None else f"{x} {'⭐' * x}", key=f"answer_rating_{idx}" ) with col2: # Text feedback questions st.markdown("**📝 Detailed Feedback Questions:**") feedback_q1 = st.text_area( "Why is the answer wrong? Does it miss any key information?", placeholder="Please describe any mistakes or missing information in the response...", key=f"feedback_q1_{idx}", height=80 ) # Submit feedback button if st.button("Submit Feedback", key=f"submit_{idx}"): if source_rating is not None or answer_rating is not None: # Get current model parameters current_embedder = get_current_model_info()["embedder_name"] if get_current_model_info() else "Unknown" # Get the corresponding user query (should be the previous message) user_query = st.session_state.messages[idx-1]["content"] if idx > 0 else "Unknown" # Save feedback success = st.session_state.feedback_manager.save_feedback( user_query=user_query, ai_response=message["content"], source_rating=source_rating, answer_rating=answer_rating, feedback_q1=feedback_q1 or "", embedder_model=current_embedder, llm_model=getattr(st.session_state, 'last_model_used', 'Unknown'), temperature=getattr(st.session_state, 'last_temperature_used', 0), top_p=getattr(st.session_state, 'last_top_p_used', 0), max_length=getattr(st.session_state, 'last_max_length_used', 0) ) if success: st.success("✅ Thank you for your feedback!") st.session_state.feedback_submitted.add(feedback_key) st.rerun() else: st.error("❌ Failed to save feedback. Please check your Google Sheets configuration.") else: st.warning("⚠️ Please select at least one rating before submitting feedback.") else: st.success("✅ Feedback already submitted for this response") # User input if user_input := st.chat_input("Ask me questions about the CANMAT depression guideline!"): if not get_current_model_info(): st.error("❌ Please load an embedding model first from the sidebar.") st.stop() if not st.session_state.get("user_llm_client"): st.error("❌ LLM client not available. Please check your API keys.") st.stop() st.chat_message("user").markdown(user_input) st.session_state.messages.append({"role": "user", "content": user_input}) # Store current model parameters for feedback st.session_state.last_model_used = selected_model st.session_state.last_temperature_used = temperature st.session_state.last_top_p_used = top_p st.session_state.last_max_length_used = max_length # ===== latest 10 histories(5 round) ===== history = st.session_state.messages[:-1][-10:] placeholder = st.chat_message("assistant").empty() collected = "" try: t0 = time.perf_counter() import Rag original_client = Rag.llm_client Rag.llm_client = st.session_state.user_llm_client try: results, response = depression_assistant( user_input, model_name=selected_model, max_tokens=max_length, temperature=temperature, top_p=top_p, stream_flag=True, chat_history=history ) for chunk in response: collected += chunk placeholder.markdown(collected) finally: Rag.llm_client = original_client t1 = time.perf_counter() response_time = t1 - t0 print(f"[Time] Retriever + Generator takes: {response_time:.2f} seconds in total.") print(f"============== Finish R-A-Generation for Current Query {user_input} ==============") # Save the response and sources st.session_state.messages.append({"role": "assistant", "content": collected}) # Save sources for this message (will be at index len(messages)-1) message_idx = len(st.session_state.messages) - 1 st.session_state.message_sources[message_idx] = results # Log conversation to Google Sheets try: current_embedder = get_current_model_info()["embedder_name"] if get_current_model_info() else "Unknown" session_id = st.session_state.get('client_key', 'Unknown') st.session_state.feedback_manager.log_conversation( session_id=session_id, user_query=user_input, ai_response=collected, embedder_model=current_embedder, llm_model=selected_model, temperature=temperature, top_p=top_p, max_length=max_length, response_time=response_time ) except Exception as log_error: print(f"Warning: Failed to log conversation to Google Sheets: {log_error}") st.rerun() except Exception as e: st.error(f"❌ Error generating response: {str(e)}") print(f"Error in main loop: {e}") import traceback traceback.print_exc()