arahrooh commited on
Commit
c9adae0
·
1 Parent(s): 2fed471

Initial deployment: CGT-LLM-Beta RAG Chatbot

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.log
4
+ results/
5
+ *.csv
6
+ .DS_Store
7
+ *.pyc
README.md CHANGED
@@ -1,12 +1,59 @@
1
  ---
2
- title: Cgt Llm Chatbot V2
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CGT-LLM-Beta RAG Chatbot
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # CGT-LLM-Beta: Genetic Counseling RAG Chatbot
14
+
15
+ A Retrieval-Augmented Generation (RAG) chatbot for genetic counseling and cascade genetic testing questions.
16
+
17
+ ## Features
18
+
19
+ - **Evidence-based answers** from medical literature
20
+ - **Multiple education levels**: Middle School, High School, College, and Doctoral
21
+ - **Source document citations** with full chunk text
22
+ - **Similarity scoring** for transparency
23
+ - **Flesch-Kincaid readability scores** for all answers
24
+ - **Multiple LLM models** to choose from
25
+ - **100+ example questions** for testing
26
+
27
+ ## How to Use
28
+
29
+ 1. **Select a model** from the dropdown (default: Llama-3.2-3B-Instruct)
30
+ 2. **Choose your education level** for personalized answers
31
+ 3. **Enter your question** or select from example questions
32
+ 4. **View the answer** with readability score, sources, and similarity scores
33
+
34
+ ## Education Levels
35
+
36
+ - **Middle School**: Simplified version for ages 12-14
37
+ - **High School**: Simplified version for ages 15-18
38
+ - **College**: Professional version for undergraduate level
39
+ - **Doctoral**: Advanced version for medical professionals
40
+
41
+ ## Models Available
42
+
43
+ - Llama-3.2-3B-Instruct
44
+ - Mistral-7B-Instruct-v0.2
45
+ - Llama-4-Scout-17B-16E-Instruct
46
+ - MediPhi-Instruct
47
+ - MediPhi
48
+ - Phi-4-reasoning
49
+
50
+ ## Important Notes
51
+
52
+ ⚠️ **This chatbot provides informational answers based on medical literature. It is not a substitute for professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare providers for medical decisions.**
53
+
54
+ ## Technical Details
55
+
56
+ - **Vector Database**: ChromaDB with sentence-transformers embeddings
57
+ - **RAG System**: Retrieval-Augmented Generation with semantic search
58
+ - **Source Attribution**: Full document tracking with chunk-level citations
59
+
app.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Chatbot Interface for CGT-LLM-Beta RAG System
3
+
4
+ This application provides a web interface for the RAG chatbot, allowing users to:
5
+ - Select different LLM models from a dropdown
6
+ - Choose education level for personalized answers (Middle School, High School, Professional, Improved)
7
+ - View answers with Flesch-Kincaid grade level scores
8
+ - See source documents and similarity scores for every answer
9
+
10
+ Usage:
11
+ python app.py
12
+
13
+ IMPORTANT: Before using, update the MODEL_MAP dictionary with correct HuggingFace paths
14
+ for models that currently have placeholder paths (Llama-4-Scout, MediPhi, Phi-4-reasoning).
15
+
16
+ For Hugging Face Spaces:
17
+ - Ensure vector database is built (run bot.py with indexing first)
18
+ - Model will be loaded on startup
19
+ - Access via the Gradio interface
20
+ """
21
+
22
+ import gradio as gr
23
+ import argparse
24
+ import sys
25
+ import os
26
+ from typing import Tuple, Optional, List
27
+ import logging
28
+ import textstat
29
+ import torch
30
+
31
+ # Import from bot.py
32
+ from bot import RAGBot, parse_args, Chunk
33
+
34
+ # Set up logging first (before any logger usage)
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # For Hugging Face Inference API
39
+ try:
40
+ from huggingface_hub import InferenceClient
41
+ HF_INFERENCE_AVAILABLE = True
42
+ except ImportError:
43
+ HF_INFERENCE_AVAILABLE = False
44
+ logger.warning("huggingface_hub not available, InferenceClient will not work")
45
+
46
+ # Model mapping: short name -> full HuggingFace path
47
+ MODEL_MAP = {
48
+ "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
49
+ "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
50
+ "Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
51
+ "MediPhi-Instruct": "microsoft/MediPhi-Instruct",
52
+ "MediPhi": "microsoft/MediPhi",
53
+ "Phi-4-reasoning": "microsoft/Phi-4-reasoning",
54
+ }
55
+
56
+ # Education level mapping
57
+ EDUCATION_LEVELS = {
58
+ "Middle School": "middle_school",
59
+ "High School": "high_school",
60
+ "College": "college",
61
+ "Doctoral": "doctoral"
62
+ }
63
+
64
+ # Example questions from the results CSV (hardcoded for easy access)
65
+ EXAMPLE_QUESTIONS = [
66
+ "Can a BRCA2 variant skip a generation?",
67
+ "Can a PMS2 variant skip a generation?",
68
+ "Can an EPCAM/MSH2 variant skip a generation?",
69
+ "Can an MLH1 variant skip a generation?",
70
+ "Can an MSH2 variant skip a generation?",
71
+ "Can an MSH6 variant skip a generation?",
72
+ "Can I pass this MSH2 variant to my kids?",
73
+ "Can only women carry a BRCA inherited mutation?",
74
+ "Does GINA cover life or disability insurance?",
75
+ "Does having a BRCA1 mutation mean I will definitely have cancer?",
76
+ "Does having a BRCA2 mutation mean I will definitely have cancer?",
77
+ "Does having a PMS2 mutation mean I will definitely have cancer?",
78
+ "Does having an EPCAM/MSH2 mutation mean I will definitely have cancer?",
79
+ "Does having an MLH1 mutation mean I will definitely have cancer?",
80
+ "Does having an MSH2 mutation mean I will definitely have cancer?",
81
+ "Does having an MSH6 mutation mean I will definitely have cancer?",
82
+ "Does this BRCA1 genetic variant affect my cancer treatment?",
83
+ "Does this BRCA2 genetic variant affect my cancer treatment?",
84
+ "Does this EPCAM/MSH2 genetic variant affect my cancer treatment?",
85
+ "Does this MLH1 genetic variant affect my cancer treatment?",
86
+ "Does this MSH2 genetic variant affect my cancer treatment?",
87
+ "Does this MSH6 genetic variant affect my cancer treatment?",
88
+ "Does this PMS2 genetic variant affect my cancer treatment?",
89
+ "How can I cope with this diagnosis?",
90
+ "How can I get my kids tested?",
91
+ "How can I help others with my condition?",
92
+ "How might my genetic test results change over time?",
93
+ "I don't talk to my family/parents/sister/brother. How can I share this with them?",
94
+ "I have a BRCA pathogenic variant and I want to have children, what are my options?",
95
+ "Is genetic testing for my family members covered by insurance?",
96
+ "Is new research being done on my condition?",
97
+ "Is this BRCA1 variant something I inherited?",
98
+ "Is this BRCA2 variant something I inherited?",
99
+ "Is this EPCAM/MSH2 variant something I inherited?",
100
+ "Is this MLH1 variant something I inherited?",
101
+ "Is this MSH2 variant something I inherited?",
102
+ "Is this MSH6 variant something I inherited?",
103
+ "Is this PMS2 variant something I inherited?",
104
+ "My relative doesn't have insurance. What should they do?",
105
+ "People who test positive for a genetic mutation are they at risk of losing their health insurance?",
106
+ "Should I contact my male and female relatives?",
107
+ "Should my family members get tested?",
108
+ "What are the Risks and Benefits of Risk-Reducing Surgeries for Lynch Syndrome?",
109
+ "What are the recommendations for my family members if I have a BRCA1 mutation?",
110
+ "What are the recommendations for my family members if I have a BRCA2 mutation?",
111
+ "What are the recommendations for my family members if I have a PMS2 mutation?",
112
+ "What are the recommendations for my family members if I have an EPCAM/MSH2 mutation?",
113
+ "What are the recommendations for my family members if I have an MLH1 mutation?",
114
+ "What are the recommendations for my family members if I have an MSH2 mutation?",
115
+ "What are the recommendations for my family members if I have an MSH6 mutation?",
116
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have a BRCA mutation?",
117
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an EPCAM/MSH2 mutation?",
118
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an MSH2 mutation?",
119
+ "What does a BRCA1 genetic variant mean for me?",
120
+ "What does a BRCA2 genetic variant mean for me?",
121
+ "What does a PMS2 genetic variant mean for me?",
122
+ "What does an EPCAM/MSH2 genetic variant mean for me?",
123
+ "What does an MLH1 genetic variant mean for me?",
124
+ "What does an MSH2 genetic variant mean for me?",
125
+ "What does an MSH6 genetic variant mean for me?",
126
+ "What if I feel overwhelmed?",
127
+ "What if I want to have children and have a hereditary cancer gene? What are my reproductive options?",
128
+ "What if a family member doesn't want to get tested?",
129
+ "What is Lynch Syndrome?",
130
+ "What is my cancer risk if I have BRCA1 Hereditary Breast and Ovarian Cancer syndrome?",
131
+ "What is my cancer risk if I have BRCA2 Hereditary Breast and Ovarian Cancer syndrome?",
132
+ "What is my cancer risk if I have MLH1 Lynch syndrome?",
133
+ "What is my cancer risk if I have MSH2 or EPCAM-associated Lynch syndrome?",
134
+ "What is my cancer risk if I have MSH6 Lynch syndrome?",
135
+ "What is my cancer risk if I have PMS2 Lynch syndrome?",
136
+ "What other resources are available to help me?",
137
+ "What screening tests do you recommend for BRCA1 carriers?",
138
+ "What screening tests do you recommend for BRCA2 carriers?",
139
+ "What screening tests do you recommend for EPCAM/MSH2 carriers?",
140
+ "What screening tests do you recommend for MLH1 carriers?",
141
+ "What screening tests do you recommend for MSH2 carriers?",
142
+ "What screening tests do you recommend for MSH6 carriers?",
143
+ "What screening tests do you recommend for PMS2 carriers?",
144
+ "What steps can I take to manage my cancer risk if I have Lynch syndrome?",
145
+ "What types of cancers am I at risk for with a BRCA1 mutation?",
146
+ "What types of cancers am I at risk for with a BRCA2 mutation?",
147
+ "What types of cancers am I at risk for with a PMS2 mutation?",
148
+ "What types of cancers am I at risk for with an EPCAM/MSH2 mutation?",
149
+ "What types of cancers am I at risk for with an MLH1 mutation?",
150
+ "What types of cancers am I at risk for with an MSH2 mutation?",
151
+ "What types of cancers am I at risk for with an MSH6 mutation?",
152
+ "Where can I find a genetic counselor?",
153
+ "Which of my relatives are at risk?",
154
+ "Who are my first-degree relatives?",
155
+ "Who do my family members call to have genetic testing?",
156
+ "Why do some families with Lynch syndrome have more cases of cancer than others?",
157
+ "Why should I share my BRCA1 genetic results with family?",
158
+ "Why should I share my BRCA2 genetic results with family?",
159
+ "Why should I share my EPCAM/MSH2 genetic results with family?",
160
+ "Why should I share my MLH1 genetic results with family?",
161
+ "Why should I share my MSH2 genetic results with family?",
162
+ "Why should I share my MSH6 genetic results with family?",
163
+ "Why should I share my PMS2 genetic results with family?",
164
+ "Why would my relatives want to know if they have this? What can they do about it?",
165
+ "Will my insurance cover testing for my parents/brother/sister?",
166
+ "Will this affect my health insurance?",
167
+ ]
168
+
169
+
170
+ class InferenceAPIBot:
171
+ """Wrapper that uses Hugging Face Inference API instead of loading models locally"""
172
+
173
+ def __init__(self, bot: RAGBot, hf_token: str):
174
+ """Initialize with a RAGBot (for vector DB) and HF token for Inference API"""
175
+ self.bot = bot # Use bot for vector DB and formatting
176
+ self.client = InferenceClient(api_key=hf_token)
177
+ self.current_model = bot.args.model
178
+ # Don't set args as attribute - access via bot.args instead
179
+ logger.info(f"InferenceAPIBot initialized with model: {self.current_model}")
180
+
181
+ @property
182
+ def args(self):
183
+ """Access args from the wrapped bot"""
184
+ return self.bot.args
185
+
186
+ def generate_answer(self, prompt: str, **kwargs) -> str:
187
+ """Generate answer using Inference API"""
188
+ try:
189
+ # Convert prompt to chat format
190
+ messages = [{"role": "user", "content": prompt}]
191
+
192
+ # Call Inference API
193
+ completion = self.client.chat.completions.create(
194
+ model=self.current_model,
195
+ messages=messages,
196
+ max_tokens=kwargs.get('max_new_tokens', 512),
197
+ temperature=kwargs.get('temperature', 0.2),
198
+ top_p=kwargs.get('top_p', 0.9),
199
+ )
200
+
201
+ answer = completion.choices[0].message.content
202
+ return answer
203
+ except Exception as e:
204
+ logger.error(f"Error calling Inference API: {e}", exc_info=True)
205
+ return f"Error generating answer: {str(e)}"
206
+
207
+ def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
208
+ """Enhance readability using Inference API"""
209
+ try:
210
+ # Define prompts for different reading levels (same as bot.py)
211
+ if target_level == "middle_school":
212
+ level_description = "middle school reading level (ages 12-14, 6th-8th grade)"
213
+ instructions = """
214
+ - Use simpler medical terms or explain them
215
+ - Medium-length sentences
216
+ - Clear, structured explanations
217
+ - Keep important medical information accessible"""
218
+ elif target_level == "high_school":
219
+ level_description = "high school reading level (ages 15-18, 9th-12th grade)"
220
+ instructions = """
221
+ - Use appropriate medical terminology with context
222
+ - Varied sentence length
223
+ - Comprehensive yet accessible explanations
224
+ - Maintain technical accuracy while ensuring clarity"""
225
+ elif target_level == "college":
226
+ level_description = "college reading level (undergraduate level, ages 18-22)"
227
+ instructions = """
228
+ - Use standard medical terminology with brief explanations
229
+ - Professional and clear writing style
230
+ - Include relevant clinical context
231
+ - Maintain scientific accuracy and precision
232
+ - Appropriate for undergraduate students in health sciences"""
233
+ elif target_level == "doctoral":
234
+ level_description = "doctoral/professional reading level (graduate level, medical professionals)"
235
+ instructions = """
236
+ - Use advanced medical and scientific terminology
237
+ - Include detailed clinical and research context
238
+ - Reference specific mechanisms, pathways, and evidence
239
+ - Provide comprehensive technical explanations
240
+ - Appropriate for medical professionals, researchers, and graduate students
241
+ - Include nuanced discussions of clinical implications and research findings"""
242
+ else:
243
+ raise ValueError(f"Unknown target_level: {target_level}")
244
+
245
+ # Create messages for chat API
246
+ system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}:
247
+ {instructions}
248
+ - Keep the same important information but adapt the complexity
249
+ - Provide context for technical terms
250
+ - Ensure the answer is informative yet understandable"""
251
+
252
+ user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}"
253
+
254
+ messages = [
255
+ {"role": "system", "content": system_message},
256
+ {"role": "user", "content": user_message}
257
+ ]
258
+
259
+ # Call Inference API
260
+ completion = self.client.chat.completions.create(
261
+ model=self.current_model,
262
+ messages=messages,
263
+ max_tokens=512 if target_level in ["college", "doctoral"] else 384,
264
+ temperature=0.4 if target_level in ["college", "doctoral"] else 0.3,
265
+ )
266
+
267
+ enhanced_answer = completion.choices[0].message.content
268
+ # Clean the answer (same as bot.py)
269
+ cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
270
+
271
+ # Calculate Flesch score
272
+ try:
273
+ flesch_score = textstat.flesch_kincaid_grade(cleaned)
274
+ except:
275
+ flesch_score = 0.0
276
+
277
+ return cleaned, flesch_score
278
+ except Exception as e:
279
+ logger.error(f"Error enhancing readability: {e}", exc_info=True)
280
+ return answer, 0.0
281
+
282
+ # Delegate other methods to bot
283
+ def format_prompt(self, context_chunks: List[Chunk], question: str) -> str:
284
+ return self.bot.format_prompt(context_chunks, question)
285
+
286
+ def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]:
287
+ return self.bot.retrieve_with_scores(query, k)
288
+
289
+ def _categorize_question(self, question: str) -> str:
290
+ return self.bot._categorize_question(question)
291
+
292
+ @property
293
+ def args(self):
294
+ return self.bot.args
295
+
296
+ @property
297
+ def vector_retriever(self):
298
+ return self.bot.vector_retriever
299
+
300
+
301
+ class GradioRAGInterface:
302
+ """Wrapper class to integrate RAGBot with Gradio"""
303
+
304
+ def __init__(self, initial_bot: RAGBot, use_inference_api: bool = False):
305
+ # Check if we should use Inference API (on Spaces)
306
+ if use_inference_api and HF_INFERENCE_AVAILABLE:
307
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
308
+ if hf_token:
309
+ self.bot = InferenceAPIBot(initial_bot, hf_token)
310
+ self.use_inference_api = True
311
+ logger.info("Using Hugging Face Inference API")
312
+ else:
313
+ logger.warning("HF_TOKEN not found, falling back to local model")
314
+ self.bot = initial_bot
315
+ self.use_inference_api = False
316
+ else:
317
+ self.bot = initial_bot
318
+ self.use_inference_api = False
319
+
320
+ # Get current model from bot args (not a direct attribute)
321
+ self.current_model = self.bot.args.model if hasattr(self.bot, 'args') else getattr(self.bot, 'current_model', None)
322
+ if self.current_model is None and hasattr(self.bot, 'bot'):
323
+ # If using InferenceAPIBot, get from the wrapped bot
324
+ self.current_model = self.bot.bot.args.model
325
+ self.data_dir = initial_bot.args.data_dir
326
+ logger.info("GradioRAGInterface initialized")
327
+
328
+ def _find_file_path(self, filename: str) -> str:
329
+ """Find the full file path for a given filename"""
330
+ from pathlib import Path
331
+ data_path = Path(self.data_dir)
332
+
333
+ if not data_path.exists():
334
+ return ""
335
+
336
+ # Search for the file recursively
337
+ for file_path in data_path.rglob(filename):
338
+ return str(file_path)
339
+
340
+ return ""
341
+
342
+ def reload_model(self, model_short_name: str) -> str:
343
+ """Reload the model when user selects a different one"""
344
+ if model_short_name not in MODEL_MAP:
345
+ return f"Error: Unknown model '{model_short_name}'"
346
+
347
+ new_model_path = MODEL_MAP[model_short_name]
348
+
349
+ # If same model, no need to reload
350
+ if new_model_path == self.current_model:
351
+ return f"Model already loaded: {model_short_name}"
352
+
353
+ try:
354
+ logger.info(f"Switching model from {self.current_model} to {new_model_path}")
355
+
356
+ if self.use_inference_api:
357
+ # For Inference API, just update the model name
358
+ self.bot.current_model = new_model_path
359
+ self.current_model = new_model_path
360
+ return f"✓ Model switched to: {model_short_name} (using Inference API)"
361
+ else:
362
+ # For local model, reload it
363
+ # Update args
364
+ self.bot.args.model = new_model_path
365
+
366
+ # Clear old model from memory
367
+ if hasattr(self.bot, 'model') and self.bot.model is not None:
368
+ del self.bot.model
369
+ del self.bot.tokenizer
370
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
371
+
372
+ # Load new model
373
+ self.bot._load_model()
374
+ self.current_model = new_model_path
375
+
376
+ return f"✓ Model loaded: {model_short_name}"
377
+ except Exception as e:
378
+ logger.error(f"Error reloading model: {e}", exc_info=True)
379
+ return f"✗ Error loading model: {str(e)}"
380
+
381
+ def process_question(
382
+ self,
383
+ question: str,
384
+ model_name: str,
385
+ education_level: str,
386
+ k: int,
387
+ temperature: float,
388
+ max_tokens: int
389
+ ) -> Tuple[str, str, str, str, str]:
390
+ """
391
+ Process a single question and return formatted results
392
+
393
+ Returns:
394
+ Tuple of (answer, flesch_score, sources, similarity_scores, question_category)
395
+ """
396
+ import time
397
+
398
+ if not question or not question.strip():
399
+ return "Please enter a question.", "N/A", "", "", ""
400
+
401
+ try:
402
+ start_time = time.time()
403
+ logger.info(f"Processing question: {question[:50]}...")
404
+
405
+ # Reload model if changed (this can take 1-3 minutes)
406
+ if model_name in MODEL_MAP:
407
+ model_path = MODEL_MAP[model_name]
408
+ if model_path != self.current_model:
409
+ logger.info(f"Model changed, reloading from {self.current_model} to {model_path}")
410
+ reload_status = self.reload_model(model_name)
411
+ if reload_status.startswith("✗"):
412
+ return f"Error: {reload_status}", "N/A", "", "", ""
413
+ logger.info(f"Model reloaded in {time.time() - start_time:.1f}s")
414
+
415
+ # Update bot args for this query
416
+ self.bot.args.k = k
417
+ self.bot.args.temperature = temperature
418
+ # Limit max_tokens for faster generation in Gradio
419
+ self.bot.args.max_new_tokens = min(max_tokens, 512) # Cap at 512 for faster responses
420
+
421
+ # Categorize question
422
+ logger.info("Categorizing question...")
423
+ question_group = self.bot._categorize_question(question)
424
+
425
+ # Retrieve relevant chunks with similarity scores
426
+ logger.info("Retrieving relevant documents...")
427
+ retrieve_start = time.time()
428
+ context_chunks, similarity_scores = self.bot.retrieve_with_scores(question, k)
429
+ logger.info(f"Retrieved {len(context_chunks)} chunks in {time.time() - retrieve_start:.2f}s")
430
+
431
+ if not context_chunks:
432
+ return (
433
+ "I don't have enough information to answer this question. Please try rephrasing or asking about a different topic.",
434
+ "N/A",
435
+ "No sources found",
436
+ "No matches found",
437
+ question_group
438
+ )
439
+
440
+ # Format similarity scores
441
+ similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores])
442
+
443
+ # Format sources with chunk text and file paths
444
+ sources_list = []
445
+ for i, (chunk, score) in enumerate(zip(context_chunks, similarity_scores)):
446
+ # Try to find the file path
447
+ file_path = self._find_file_path(chunk.filename)
448
+
449
+ source_info = f"""
450
+ {'='*80}
451
+ SOURCE {i+1} | Similarity: {score:.3f}
452
+ {'='*80}
453
+ 📄 File: {chunk.filename}
454
+ 📍 Path: {file_path if file_path else 'File path not found (search in Data Resources directory)'}
455
+ 📊 Chunk: {chunk.chunk_id + 1}/{chunk.total_chunks} (Position: {chunk.start_pos}-{chunk.end_pos})
456
+
457
+ 📝 Full Chunk Text:
458
+ {chunk.text}
459
+
460
+ """
461
+ sources_list.append(source_info)
462
+
463
+ sources = "\n".join(sources_list)
464
+
465
+ # Generation kwargs
466
+ gen_kwargs = {
467
+ 'max_new_tokens': min(max_tokens, 512), # Cap for faster responses
468
+ 'temperature': temperature,
469
+ 'top_p': self.bot.args.top_p,
470
+ 'repetition_penalty': self.bot.args.repetition_penalty
471
+ }
472
+
473
+ # Generate answer based on education level
474
+ answer = ""
475
+ flesch_score = 0.0
476
+
477
+ # Generate original answer first (needed for all enhancement levels)
478
+ logger.info("Generating original answer...")
479
+ gen_start = time.time()
480
+ prompt = self.bot.format_prompt(context_chunks, question)
481
+ original_answer = self.bot.generate_answer(prompt, **gen_kwargs)
482
+ logger.info(f"Original answer generated in {time.time() - gen_start:.1f}s")
483
+
484
+ # Enhance based on education level
485
+ logger.info(f"Enhancing answer for {education_level} level...")
486
+ enhance_start = time.time()
487
+ if education_level == "middle_school":
488
+ # Simplify to middle school level
489
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="middle_school")
490
+
491
+ elif education_level == "high_school":
492
+ # Simplify to high school level
493
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="high_school")
494
+
495
+ elif education_level == "college":
496
+ # Enhance to college level
497
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="college")
498
+
499
+ elif education_level == "doctoral":
500
+ # Enhance to doctoral/professional level
501
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="doctoral")
502
+ else:
503
+ answer = "Invalid education level selected."
504
+ flesch_score = 0.0
505
+
506
+ logger.info(f"Answer enhanced in {time.time() - enhance_start:.1f}s")
507
+ total_time = time.time() - start_time
508
+ logger.info(f"Total processing time: {total_time:.1f}s")
509
+
510
+ # Clean the answer - remove special tokens and formatting
511
+ import re
512
+ cleaned_answer = answer
513
+
514
+ # Remove special tokens (case-insensitive)
515
+ special_tokens = [
516
+ "<|end|>",
517
+ "<|endoftext|>",
518
+ "<|end_of_text|>",
519
+ "<|eot_id|>",
520
+ "<|start_header_id|>",
521
+ "<|end_header_id|>",
522
+ "<|assistant|>",
523
+ "<|endoftext|>",
524
+ "<|end_of_text|>",
525
+ ]
526
+ for token in special_tokens:
527
+ # Remove case-insensitive
528
+ cleaned_answer = re.sub(re.escape(token), '', cleaned_answer, flags=re.IGNORECASE)
529
+
530
+ # Remove any remaining special token patterns like <|...|>
531
+ cleaned_answer = re.sub(r'<\|[^|]+\|>', '', cleaned_answer)
532
+
533
+ # Remove any markdown-style headers that might have been added
534
+ cleaned_answer = re.sub(r'^\*\*.*?\*\*.*?\n', '', cleaned_answer, flags=re.MULTILINE)
535
+
536
+ # Clean up extra whitespace and newlines
537
+ cleaned_answer = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned_answer) # Multiple newlines to double
538
+ cleaned_answer = re.sub(r'^\s+|\s+$', '', cleaned_answer, flags=re.MULTILINE) # Trim lines
539
+ cleaned_answer = cleaned_answer.strip()
540
+
541
+ # Return just the clean answer (no headers or metadata)
542
+ return (
543
+ cleaned_answer,
544
+ f"{flesch_score:.1f}",
545
+ sources,
546
+ similarity_scores_str,
547
+ question_group # Add question category as 5th return value
548
+ )
549
+
550
+ except Exception as e:
551
+ logger.error(f"Error processing question: {e}", exc_info=True)
552
+ return (
553
+ f"An error occurred while processing your question: {str(e)}",
554
+ "N/A",
555
+ "",
556
+ "",
557
+ "Error"
558
+ )
559
+
560
+
561
+ def create_interface(initial_bot: RAGBot, use_inference_api: bool = False) -> gr.Blocks:
562
+ """Create and configure the Gradio interface"""
563
+
564
+ # Use Inference API on Spaces, local model otherwise
565
+ if use_inference_api is None:
566
+ use_inference_api = os.getenv("SPACE_ID") is not None or os.getenv("SYSTEM") == "spaces"
567
+
568
+ interface = GradioRAGInterface(initial_bot, use_inference_api=use_inference_api)
569
+
570
+ # Get initial model name from bot
571
+ initial_model_short = None
572
+ for short_name, full_path in MODEL_MAP.items():
573
+ if full_path == initial_bot.args.model:
574
+ initial_model_short = short_name
575
+ break
576
+ if initial_model_short is None:
577
+ initial_model_short = list(MODEL_MAP.keys())[0]
578
+
579
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
580
+ gr.Markdown("""
581
+ # 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot
582
+
583
+ Ask questions about genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics.
584
+
585
+ The chatbot uses a Retrieval-Augmented Generation (RAG) system to provide evidence-based answers from medical literature.
586
+ """)
587
+
588
+ with gr.Row():
589
+ with gr.Column(scale=2):
590
+ question_input = gr.Textbox(
591
+ label="Your Question",
592
+ placeholder="e.g., What is Lynch Syndrome? What screening is recommended for BRCA1 carriers?",
593
+ lines=3
594
+ )
595
+
596
+ with gr.Row():
597
+ model_dropdown = gr.Dropdown(
598
+ choices=list(MODEL_MAP.keys()),
599
+ value=initial_model_short,
600
+ label="Select Model",
601
+ info="Choose which LLM model to use for generating answers"
602
+ )
603
+
604
+ education_dropdown = gr.Dropdown(
605
+ choices=list(EDUCATION_LEVELS.keys()),
606
+ value=list(EDUCATION_LEVELS.keys())[0],
607
+ label="Education Level",
608
+ info="Select your education level for personalized answers"
609
+ )
610
+
611
+ with gr.Accordion("Advanced Settings", open=False):
612
+ k_slider = gr.Slider(
613
+ minimum=1,
614
+ maximum=10,
615
+ value=5,
616
+ step=1,
617
+ label="Number of document chunks to retrieve (k)"
618
+ )
619
+ temperature_slider = gr.Slider(
620
+ minimum=0.1,
621
+ maximum=1.0,
622
+ value=0.2,
623
+ step=0.1,
624
+ label="Temperature (lower = more focused)"
625
+ )
626
+ max_tokens_slider = gr.Slider(
627
+ minimum=128,
628
+ maximum=1024,
629
+ value=512,
630
+ step=128,
631
+ label="Max Tokens (lower = faster responses)"
632
+ )
633
+
634
+ submit_btn = gr.Button("Ask Question", variant="primary", size="lg")
635
+
636
+ with gr.Column(scale=3):
637
+ answer_output = gr.Textbox(
638
+ label="Answer",
639
+ lines=20,
640
+ interactive=False,
641
+ elem_classes=["answer-box"]
642
+ )
643
+
644
+ with gr.Row():
645
+ flesch_output = gr.Textbox(
646
+ label="Flesch-Kincaid Grade Level",
647
+ value="N/A",
648
+ interactive=False,
649
+ scale=1
650
+ )
651
+
652
+ similarity_output = gr.Textbox(
653
+ label="Similarity Scores",
654
+ value="",
655
+ interactive=False,
656
+ scale=1
657
+ )
658
+
659
+ category_output = gr.Textbox(
660
+ label="Question Category",
661
+ value="",
662
+ interactive=False,
663
+ scale=1
664
+ )
665
+
666
+ sources_output = gr.Textbox(
667
+ label="Source Documents (with Chunk Text)",
668
+ lines=15,
669
+ interactive=False,
670
+ info="Shows the retrieved document chunks with full text. File paths are shown for easy access."
671
+ )
672
+
673
+ # Example questions - all questions from the results CSV (scrollable)
674
+ gr.Markdown("### 💡 Example Questions")
675
+ gr.Markdown(f"Select a question below to use it in the chatbot ({len(EXAMPLE_QUESTIONS)} questions - scrollable dropdown):")
676
+
677
+ # Use Dropdown which is naturally scrollable with many options
678
+ example_questions_dropdown = gr.Dropdown(
679
+ choices=EXAMPLE_QUESTIONS,
680
+ label="Example Questions",
681
+ value=None,
682
+ info="Open the dropdown and scroll through all questions. Select one to use it.",
683
+ interactive=True,
684
+ container=True,
685
+ scale=1
686
+ )
687
+
688
+ # Update question input when dropdown selection changes
689
+ def update_question_from_dropdown(selected_question):
690
+ return selected_question if selected_question else ""
691
+
692
+ example_questions_dropdown.change(
693
+ fn=update_question_from_dropdown,
694
+ inputs=example_questions_dropdown,
695
+ outputs=question_input
696
+ )
697
+
698
+ # Footer
699
+ gr.Markdown("""
700
+ ---
701
+ **Note:** This chatbot provides informational answers based on medical literature.
702
+ It is not a substitute for professional medical advice, diagnosis, or treatment.
703
+ Always consult with qualified healthcare providers for medical decisions.
704
+ """)
705
+
706
+ # Connect the submit button
707
+ def process_with_education_level(question, model, education, k, temp, max_tok):
708
+ education_key = EDUCATION_LEVELS[education]
709
+ return interface.process_question(question, model, education_key, k, temp, max_tok)
710
+
711
+ submit_btn.click(
712
+ fn=process_with_education_level,
713
+ inputs=[
714
+ question_input,
715
+ model_dropdown,
716
+ education_dropdown,
717
+ k_slider,
718
+ temperature_slider,
719
+ max_tokens_slider
720
+ ],
721
+ outputs=[
722
+ answer_output,
723
+ flesch_output,
724
+ sources_output,
725
+ similarity_output,
726
+ category_output
727
+ ]
728
+ )
729
+
730
+ # Also allow Enter key to submit
731
+ question_input.submit(
732
+ fn=process_with_education_level,
733
+ inputs=[
734
+ question_input,
735
+ model_dropdown,
736
+ education_dropdown,
737
+ k_slider,
738
+ temperature_slider,
739
+ max_tokens_slider
740
+ ],
741
+ outputs=[
742
+ answer_output,
743
+ flesch_output,
744
+ sources_output,
745
+ similarity_output,
746
+ category_output
747
+ ]
748
+ )
749
+
750
+ return demo
751
+
752
+
753
+ def main():
754
+ """Main function to launch the Gradio app"""
755
+ # Parse arguments with defaults suitable for Gradio
756
+ parser = argparse.ArgumentParser(description="Gradio Interface for CGT-LLM-Beta RAG Chatbot")
757
+
758
+ # Model and database settings
759
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct',
760
+ help='HuggingFace model name')
761
+ parser.add_argument('--vector-db-dir', default='./chroma_db',
762
+ help='Directory for ChromaDB persistence')
763
+ parser.add_argument('--data-dir', default='./Data Resources',
764
+ help='Directory containing documents (for indexing if needed)')
765
+
766
+ # Generation parameters
767
+ parser.add_argument('--max-new-tokens', type=int, default=1024,
768
+ help='Maximum new tokens to generate')
769
+ parser.add_argument('--temperature', type=float, default=0.2,
770
+ help='Generation temperature')
771
+ parser.add_argument('--top-p', type=float, default=0.9,
772
+ help='Top-p sampling parameter')
773
+ parser.add_argument('--repetition-penalty', type=float, default=1.1,
774
+ help='Repetition penalty')
775
+
776
+ # Retrieval parameters
777
+ parser.add_argument('--k', type=int, default=5,
778
+ help='Number of chunks to retrieve per question')
779
+
780
+ # Other settings
781
+ parser.add_argument('--skip-indexing', action='store_true',
782
+ help='Skip document indexing (use existing vector DB)')
783
+ parser.add_argument('--verbose', action='store_true',
784
+ help='Enable verbose logging')
785
+ parser.add_argument('--share', action='store_true',
786
+ help='Create a public Gradio share link')
787
+ parser.add_argument('--server-name', type=str, default='127.0.0.1',
788
+ help='Server name (0.0.0.0 for public access)')
789
+ parser.add_argument('--server-port', type=int, default=7860,
790
+ help='Server port')
791
+
792
+ args = parser.parse_args()
793
+
794
+ # Set logging level
795
+ if args.verbose:
796
+ logging.getLogger().setLevel(logging.DEBUG)
797
+
798
+ logger.info("Initializing RAGBot for Gradio interface...")
799
+ logger.info(f"Model: {args.model}")
800
+ logger.info(f"Vector DB: {args.vector_db_dir}")
801
+
802
+ try:
803
+ # Initialize bot
804
+ bot = RAGBot(args)
805
+
806
+ # Check if vector database exists and has documents
807
+ collection_stats = bot.vector_retriever.get_collection_stats()
808
+ if collection_stats.get('total_chunks', 0) == 0:
809
+ logger.warning("Vector database is empty. You may need to run indexing first:")
810
+ logger.warning(" python bot.py --data-dir './Data Resources' --vector-db-dir './chroma_db'")
811
+ logger.warning("Continuing anyway - the chatbot will work but may not find relevant documents.")
812
+
813
+ # Create and launch Gradio interface
814
+ demo = create_interface(bot)
815
+
816
+ # For local use, launch it
817
+ # (On Spaces, the demo is already created at module level)
818
+ logger.info(f"Launching Gradio interface on http://{args.server_name}:{args.server_port}")
819
+ demo.launch(
820
+ server_name=args.server_name,
821
+ server_port=args.server_port,
822
+ share=args.share
823
+ )
824
+
825
+ except KeyboardInterrupt:
826
+ logger.info("Interrupted by user")
827
+ sys.exit(0)
828
+ except Exception as e:
829
+ logger.error(f"Error launching Gradio app: {e}", exc_info=True)
830
+ sys.exit(1)
831
+
832
+
833
+ # For Hugging Face Spaces: create demo at module level
834
+ # Following the HF Spaces pattern: create the Gradio app directly at module level
835
+ # Spaces will import this module and look for a Gradio Blocks/Interface object
836
+ # Pattern: demo = gr.Interface(...) or demo = gr.Blocks(...)
837
+ # DO NOT call demo.launch() - Spaces handles that automatically
838
+
839
+ # Check if we're on Spaces (be more permissive - check multiple env vars)
840
+ IS_SPACES = (
841
+ os.getenv("SPACE_ID") is not None or
842
+ os.getenv("SYSTEM") == "spaces" or
843
+ os.getenv("HF_SPACE_ID") is not None
844
+ )
845
+
846
+ # Create demo at module level (like HF docs example)
847
+ # Initialize demo variable to None first (safety measure)
848
+ demo = None
849
+
850
+ # Create demo at module level (like HF docs example)
851
+ # This ensures Spaces can always find it when importing the module
852
+ try:
853
+ if IS_SPACES:
854
+ logger.info("Initializing for Hugging Face Spaces...")
855
+ else:
856
+ logger.info("Initializing for local execution...")
857
+
858
+ # Initialize with default args
859
+ parser = argparse.ArgumentParser()
860
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct')
861
+ parser.add_argument('--vector-db-dir', default='./chroma_db')
862
+ parser.add_argument('--data-dir', default='./Data Resources')
863
+ parser.add_argument('--max-new-tokens', type=int, default=1024)
864
+ parser.add_argument('--temperature', type=float, default=0.2)
865
+ parser.add_argument('--top-p', type=float, default=0.9)
866
+ parser.add_argument('--repetition-penalty', type=float, default=1.1)
867
+ parser.add_argument('--k', type=int, default=5)
868
+ parser.add_argument('--skip-indexing', action='store_true', default=True)
869
+ parser.add_argument('--verbose', action='store_true', default=False)
870
+ parser.add_argument('--share', action='store_true', default=False)
871
+ parser.add_argument('--server-name', type=str, default='0.0.0.0')
872
+ parser.add_argument('--server-port', type=int, default=7860)
873
+ parser.add_argument('--seed', type=int, default=42)
874
+
875
+ args = parser.parse_args([]) # Empty args
876
+ args.skip_model_loading = IS_SPACES # Skip model loading on Spaces, use Inference API
877
+
878
+ # Create bot - handle initialization errors gracefully
879
+ try:
880
+ bot = RAGBot(args)
881
+
882
+ if bot.vector_retriever is None:
883
+ raise Exception("Vector database not available")
884
+
885
+ # Check if vector database has documents
886
+ collection_stats = bot.vector_retriever.get_collection_stats()
887
+ if collection_stats.get('total_chunks', 0) == 0:
888
+ logger.warning("Vector database is empty. The chatbot may not find relevant documents.")
889
+ logger.warning("This is OK for initial deployment - documents can be indexed later.")
890
+
891
+ # Create the demo interface directly at module level (like HF docs example)
892
+ demo = create_interface(bot, use_inference_api=IS_SPACES)
893
+ except Exception as bot_error:
894
+ logger.error(f"Error initializing RAGBot: {bot_error}", exc_info=True)
895
+ # Create a demo that shows the error but still allows the interface to load
896
+ with gr.Blocks() as demo:
897
+ gr.Markdown(f"""
898
+ # ⚠️ Initialization Error
899
+
900
+ The chatbot encountered an error during initialization:
901
+
902
+ **Error:** {str(bot_error)}
903
+
904
+ This might be due to:
905
+ - Missing vector database (chroma_db directory)
906
+ - Missing dependencies
907
+ - Configuration issues
908
+
909
+ Please check the logs for more details.
910
+ """)
911
+ raise # Re-raise to be caught by outer try/except
912
+ logger.info(f"Demo created successfully: {type(demo)}")
913
+ # Explicitly verify it's a valid Gradio object
914
+ if not isinstance(demo, (gr.Blocks, gr.Interface)):
915
+ raise TypeError(f"Demo is not a valid Gradio object: {type(demo)}")
916
+ logger.info("Demo validation passed - ready for Spaces")
917
+ except Exception as e:
918
+ logger.error(f"Error creating demo: {e}", exc_info=True)
919
+ import traceback
920
+ logger.error(f"Traceback: {traceback.format_exc()}")
921
+ # Create a fallback error demo so Spaces doesn't show blank
922
+ with gr.Blocks() as demo:
923
+ gr.Markdown(f"# Error Initializing Chatbot\n\nAn error occurred while initializing the chatbot.\n\nError: {str(e)}\n\nPlease check the logs for details.")
924
+ logger.info(f"Error demo created: {type(demo)}")
925
+
926
+ # Final verification - ensure demo exists and is valid
927
+ if demo is None:
928
+ logger.error("CRITICAL: Demo variable is None!")
929
+ with gr.Blocks() as demo:
930
+ gr.Markdown("# Error: Demo was not created properly\n\nPlease check the logs for details.")
931
+ elif not isinstance(demo, (gr.Blocks, gr.Interface)):
932
+ logger.error(f"CRITICAL: Demo is not a valid Gradio object: {type(demo)}")
933
+ with gr.Blocks() as demo:
934
+ gr.Markdown(f"# Error: Invalid demo type\n\nDemo type: {type(demo)}\n\nPlease check the logs for details.")
935
+ else:
936
+ logger.info(f"✅ Final demo check passed: demo type={type(demo)}")
937
+ # Explicit print to ensure demo is accessible (Spaces might check this)
938
+ print(f"DEMO_VARIABLE_SET: {type(demo)}")
939
+
940
+ # For local execution only (not on Spaces)
941
+ if __name__ == "__main__":
942
+ if not IS_SPACES:
943
+ main()
bot.py ADDED
@@ -0,0 +1,1777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RAG Chatbot Implementation for CGT-LLM-Beta with Vector Database
4
+ Production-ready local RAG system with ChromaDB and MPS acceleration for Apple Silicon
5
+ """
6
+
7
+ import argparse
8
+ import csv
9
+ import json
10
+ import logging
11
+ import os
12
+ import re
13
+ import sys
14
+ import time
15
+ import hashlib
16
+ from pathlib import Path
17
+ from typing import List, Tuple, Dict, Any, Optional, Union
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
+
21
+ import textstat
22
+
23
+ import torch
24
+ import numpy as np
25
+ import pandas as pd
26
+ from tqdm import tqdm
27
+
28
+ # Optional imports with graceful fallbacks
29
+ try:
30
+ import chromadb
31
+ from chromadb.config import Settings
32
+ CHROMADB_AVAILABLE = True
33
+ except ImportError:
34
+ CHROMADB_AVAILABLE = False
35
+ print("Warning: chromadb not available. Install with: pip install chromadb")
36
+
37
+ try:
38
+ from sentence_transformers import SentenceTransformer
39
+ SENTENCE_TRANSFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SENTENCE_TRANSFORMERS_AVAILABLE = False
42
+ print("Warning: sentence-transformers not available. Install with: pip install sentence-transformers")
43
+
44
+ try:
45
+ import pypdf
46
+ PDF_AVAILABLE = True
47
+ except ImportError:
48
+ PDF_AVAILABLE = False
49
+ print("Warning: pypdf not available. PDF files will be skipped.")
50
+
51
+ try:
52
+ from docx import Document
53
+ DOCX_AVAILABLE = True
54
+ except ImportError:
55
+ DOCX_AVAILABLE = False
56
+ print("Warning: python-docx not available. DOCX files will be skipped.")
57
+
58
+ try:
59
+ from rank_bm25 import BM25Okapi
60
+ BM25_AVAILABLE = True
61
+ except ImportError:
62
+ BM25_AVAILABLE = False
63
+ print("Warning: rank-bm25 not available. BM25 retrieval disabled.")
64
+
65
+ # Configure logging
66
+ logging.basicConfig(
67
+ level=logging.INFO,
68
+ format='%(asctime)s - %(levelname)s - %(message)s',
69
+ handlers=[
70
+ logging.StreamHandler(),
71
+ logging.FileHandler('rag_bot.log')
72
+ ]
73
+ )
74
+ logger = logging.getLogger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class Document:
79
+ """Represents a document with metadata"""
80
+ filename: str
81
+ content: str
82
+ filepath: str
83
+ file_type: str
84
+ chunk_count: int = 0
85
+ file_hash: str = ""
86
+
87
+
88
+ @dataclass
89
+ class Chunk:
90
+ """Represents a text chunk with metadata"""
91
+ text: str
92
+ filename: str
93
+ chunk_id: int
94
+ total_chunks: int
95
+ start_pos: int
96
+ end_pos: int
97
+ metadata: Dict[str, Any]
98
+ chunk_hash: str = ""
99
+
100
+
101
+ class VectorRetriever:
102
+ """ChromaDB-based vector retrieval"""
103
+
104
+ def __init__(self, collection_name: str = "cgt_documents", persist_directory: str = "./chroma_db"):
105
+ if not CHROMADB_AVAILABLE:
106
+ raise ImportError("ChromaDB is required for vector retrieval")
107
+
108
+ self.collection_name = collection_name
109
+ self.persist_directory = persist_directory
110
+
111
+ # Initialize ChromaDB client
112
+ self.client = chromadb.PersistentClient(path=persist_directory)
113
+
114
+ # Get or create collection
115
+ try:
116
+ self.collection = self.client.get_collection(name=collection_name)
117
+ logger.info(f"Loaded existing collection '{collection_name}' with {self.collection.count()} documents")
118
+ except:
119
+ self.collection = self.client.create_collection(
120
+ name=collection_name,
121
+ metadata={"description": "CGT-LLM-Beta document collection"}
122
+ )
123
+ logger.info(f"Created new collection '{collection_name}'")
124
+
125
+ # Initialize embedding model
126
+ if SENTENCE_TRANSFORMERS_AVAILABLE:
127
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
128
+ logger.info("Loaded sentence-transformers embedding model")
129
+ else:
130
+ self.embedding_model = None
131
+ logger.warning("Sentence-transformers not available, using ChromaDB default embeddings")
132
+
133
+ def add_documents(self, chunks: List[Chunk]) -> None:
134
+ """Add document chunks to the vector database"""
135
+ if not chunks:
136
+ return
137
+
138
+ logger.info(f"Adding {len(chunks)} chunks to vector database...")
139
+
140
+ # Prepare data for ChromaDB
141
+ documents = []
142
+ metadatas = []
143
+ ids = []
144
+
145
+ for chunk in chunks:
146
+ chunk_id = f"{chunk.filename}_{chunk.chunk_id}"
147
+ documents.append(chunk.text)
148
+
149
+ metadata = {
150
+ "filename": chunk.filename,
151
+ "chunk_id": chunk.chunk_id,
152
+ "total_chunks": chunk.total_chunks,
153
+ "start_pos": chunk.start_pos,
154
+ "end_pos": chunk.end_pos,
155
+ "chunk_hash": chunk.chunk_hash,
156
+ **chunk.metadata
157
+ }
158
+ metadatas.append(metadata)
159
+ ids.append(chunk_id)
160
+
161
+ # Add to collection
162
+ try:
163
+ self.collection.add(
164
+ documents=documents,
165
+ metadatas=metadatas,
166
+ ids=ids
167
+ )
168
+ logger.info(f"Successfully added {len(chunks)} chunks to vector database")
169
+ except Exception as e:
170
+ logger.error(f"Error adding documents to vector database: {e}")
171
+
172
+ def search(self, query: str, k: int = 5) -> List[Tuple[Chunk, float]]:
173
+ """Search for similar chunks using vector similarity"""
174
+ try:
175
+ # Perform vector search
176
+ results = self.collection.query(
177
+ query_texts=[query],
178
+ n_results=k
179
+ )
180
+
181
+ chunks_with_scores = []
182
+ if results['documents'] and results['documents'][0]:
183
+ for i, (doc, metadata, distance) in enumerate(zip(
184
+ results['documents'][0],
185
+ results['metadatas'][0],
186
+ results['distances'][0]
187
+ )):
188
+ # Convert distance to similarity score (ChromaDB uses cosine distance)
189
+ similarity_score = 1 - distance
190
+
191
+ chunk = Chunk(
192
+ text=doc,
193
+ filename=metadata['filename'],
194
+ chunk_id=metadata['chunk_id'],
195
+ total_chunks=metadata['total_chunks'],
196
+ start_pos=metadata['start_pos'],
197
+ end_pos=metadata['end_pos'],
198
+ metadata={k: v for k, v in metadata.items()
199
+ if k not in ['filename', 'chunk_id', 'total_chunks', 'start_pos', 'end_pos', 'chunk_hash']},
200
+ chunk_hash=metadata.get('chunk_hash', '')
201
+ )
202
+ chunks_with_scores.append((chunk, similarity_score))
203
+
204
+ return chunks_with_scores
205
+
206
+ except Exception as e:
207
+ logger.error(f"Error searching vector database: {e}")
208
+ return []
209
+
210
+ def get_collection_stats(self) -> Dict[str, Any]:
211
+ """Get statistics about the collection"""
212
+ try:
213
+ count = self.collection.count()
214
+ return {
215
+ "total_chunks": count,
216
+ "collection_name": self.collection_name,
217
+ "persist_directory": self.persist_directory
218
+ }
219
+ except Exception as e:
220
+ logger.error(f"Error getting collection stats: {e}")
221
+ return {}
222
+
223
+
224
+ class RAGBot:
225
+ """Main RAG chatbot class with vector database"""
226
+
227
+ def __init__(self, args):
228
+ self.args = args
229
+ self.device = self._setup_device()
230
+ self.model = None
231
+ self.tokenizer = None
232
+ self.vector_retriever = None
233
+
234
+ # Load model (unless skipping for Inference API)
235
+ if not hasattr(args, 'skip_model_loading') or not args.skip_model_loading:
236
+ self._load_model()
237
+
238
+ # Initialize vector retriever
239
+ self._setup_vector_retriever()
240
+
241
+ def _setup_device(self) -> str:
242
+ """Setup device with MPS support for Apple Silicon"""
243
+ if torch.backends.mps.is_available():
244
+ device = "mps"
245
+ logger.info("Using device: mps (Apple Silicon)")
246
+ elif torch.cuda.is_available():
247
+ device = "cuda"
248
+ logger.info("Using device: cuda")
249
+ else:
250
+ device = "cpu"
251
+ logger.info("Using device: cpu")
252
+
253
+ return device
254
+
255
+ def _load_model(self):
256
+ """Load the specified LLM model and tokenizer"""
257
+ try:
258
+ model_name = self.args.model
259
+ logger.info(f"Loading model: {model_name}...")
260
+ from transformers import AutoTokenizer, AutoModelForCausalLM
261
+
262
+ # Get Hugging Face token from environment (for gated models)
263
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
264
+
265
+ # Load tokenizer
266
+ tokenizer_kwargs = {
267
+ "trust_remote_code": True
268
+ }
269
+ if hf_token:
270
+ tokenizer_kwargs["token"] = hf_token
271
+ logger.info("Using HF_TOKEN for authentication")
272
+
273
+ self.tokenizer = AutoTokenizer.from_pretrained(
274
+ model_name,
275
+ **tokenizer_kwargs
276
+ )
277
+
278
+ # Determine appropriate torch dtype based on device and model
279
+ # Use float16 for MPS/CUDA, float32 for CPU
280
+ # Some models work better with bfloat16
281
+ if self.device == "mps":
282
+ torch_dtype = torch.float16
283
+ elif self.device == "cuda":
284
+ torch_dtype = torch.float16
285
+ else:
286
+ torch_dtype = torch.float32
287
+
288
+ # Load model with appropriate settings
289
+ model_kwargs = {
290
+ "torch_dtype": torch_dtype,
291
+ "trust_remote_code": True,
292
+ }
293
+
294
+ # Add token if available (for gated models)
295
+ if hf_token:
296
+ model_kwargs["token"] = hf_token
297
+
298
+ # Use 8-bit quantization on CPU to reduce memory usage
299
+ # This reduces memory by ~50% with minimal quality loss
300
+ if self.device == "cpu":
301
+ try:
302
+ from transformers import BitsAndBytesConfig
303
+ # Use 8-bit quantization for CPU (reduces memory significantly)
304
+ model_kwargs["load_in_8bit"] = False # 8-bit not available on CPU
305
+ # Instead, use float16 even on CPU to save memory
306
+ model_kwargs["torch_dtype"] = torch.float16
307
+ logger.info("Using float16 on CPU to reduce memory usage")
308
+ except ImportError:
309
+ # Fallback: use float16 anyway
310
+ model_kwargs["torch_dtype"] = torch.float16
311
+ logger.info("Using float16 on CPU to reduce memory usage (fallback)")
312
+
313
+ # For MPS, use device_map; for CUDA, let it auto-detect
314
+ if self.device == "mps":
315
+ model_kwargs["device_map"] = self.device
316
+ elif self.device == "cuda":
317
+ model_kwargs["device_map"] = "auto"
318
+ # For CPU, don't specify device_map
319
+
320
+ self.model = AutoModelForCausalLM.from_pretrained(
321
+ model_name,
322
+ **model_kwargs
323
+ )
324
+
325
+ # Move to device if not using device_map
326
+ if self.device == "cpu":
327
+ self.model = self.model.to(self.device)
328
+
329
+ # Set pad token if not already set
330
+ if self.tokenizer.pad_token is None:
331
+ if self.tokenizer.eos_token is not None:
332
+ self.tokenizer.pad_token = self.tokenizer.eos_token
333
+ else:
334
+ # Some models might need a different approach
335
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
336
+
337
+ logger.info(f"Model {model_name} loaded successfully on {self.device}")
338
+
339
+ except Exception as e:
340
+ logger.error(f"Failed to load model {self.args.model}: {e}")
341
+ logger.error("Make sure the model name is correct and you have access to it on HuggingFace")
342
+ logger.error("For gated models (like Llama), you need to:")
343
+ logger.error(" 1. Request access at: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct")
344
+ logger.error(" 2. Add HF_TOKEN as a secret in your Hugging Face Space settings")
345
+ logger.error(" 3. Get your token from: https://huggingface.co/settings/tokens")
346
+ logger.error("For local use, ensure you're logged in: huggingface-cli login")
347
+ sys.exit(2)
348
+
349
+ def _setup_vector_retriever(self):
350
+ """Setup the vector retriever"""
351
+ try:
352
+ self.vector_retriever = VectorRetriever(
353
+ collection_name="cgt_documents",
354
+ persist_directory=self.args.vector_db_dir
355
+ )
356
+ logger.info("Vector retriever initialized successfully")
357
+ except Exception as e:
358
+ logger.error(f"Failed to setup vector retriever: {e}")
359
+ sys.exit(2)
360
+
361
+ def _calculate_file_hash(self, filepath: str) -> str:
362
+ """Calculate hash of file for change detection"""
363
+ try:
364
+ with open(filepath, 'rb') as f:
365
+ return hashlib.md5(f.read()).hexdigest()
366
+ except:
367
+ return ""
368
+
369
+ def _calculate_chunk_hash(self, text: str) -> str:
370
+ """Calculate hash of chunk text"""
371
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
372
+
373
+ def load_corpus(self, data_dir: str) -> List[Document]:
374
+ """Load all documents from the data directory"""
375
+ logger.info(f"Loading corpus from {data_dir}")
376
+ documents = []
377
+ data_path = Path(data_dir)
378
+
379
+ if not data_path.exists():
380
+ logger.error(f"Data directory {data_dir} does not exist")
381
+ sys.exit(1)
382
+
383
+ # Supported file extensions
384
+ supported_extensions = {'.txt', '.md', '.json', '.csv'}
385
+ if PDF_AVAILABLE:
386
+ supported_extensions.add('.pdf')
387
+ if DOCX_AVAILABLE:
388
+ supported_extensions.add('.docx')
389
+ supported_extensions.add('.doc')
390
+
391
+ # Find all files recursively
392
+ files = []
393
+ for ext in supported_extensions:
394
+ files.extend(data_path.rglob(f"*{ext}"))
395
+
396
+ logger.info(f"Found {len(files)} files to process")
397
+
398
+ # Process files with progress bar
399
+ for file_path in tqdm(files, desc="Loading documents"):
400
+ try:
401
+ content = self._read_file(file_path)
402
+ if content.strip(): # Only add non-empty documents
403
+ file_hash = self._calculate_file_hash(file_path)
404
+ doc = Document(
405
+ filename=file_path.name,
406
+ content=content,
407
+ filepath=str(file_path),
408
+ file_type=file_path.suffix.lower(),
409
+ file_hash=file_hash
410
+ )
411
+ documents.append(doc)
412
+ logger.debug(f"Loaded {file_path.name} ({len(content)} chars)")
413
+ else:
414
+ logger.warning(f"Skipping empty file: {file_path.name}")
415
+
416
+ except Exception as e:
417
+ logger.error(f"Failed to load {file_path.name}: {e}")
418
+ continue
419
+
420
+ logger.info(f"Successfully loaded {len(documents)} documents")
421
+ return documents
422
+
423
+ def _read_file(self, file_path: Path) -> str:
424
+ """Read content from various file types"""
425
+ suffix = file_path.suffix.lower()
426
+
427
+ try:
428
+ if suffix == '.txt':
429
+ return file_path.read_text(encoding='utf-8')
430
+
431
+ elif suffix == '.md':
432
+ return file_path.read_text(encoding='utf-8')
433
+
434
+ elif suffix == '.json':
435
+ with open(file_path, 'r', encoding='utf-8') as f:
436
+ data = json.load(f)
437
+ if isinstance(data, dict):
438
+ return json.dumps(data, indent=2)
439
+ else:
440
+ return str(data)
441
+
442
+ elif suffix == '.csv':
443
+ df = pd.read_csv(file_path)
444
+ return df.to_string()
445
+
446
+ elif suffix == '.pdf' and PDF_AVAILABLE:
447
+ text = ""
448
+ with open(file_path, 'rb') as f:
449
+ pdf_reader = pypdf.PdfReader(f)
450
+ for page in pdf_reader.pages:
451
+ text += page.extract_text() + "\n"
452
+ return text
453
+
454
+ elif suffix in ['.docx', '.doc'] and DOCX_AVAILABLE:
455
+ doc = Document(file_path)
456
+ text = ""
457
+ for paragraph in doc.paragraphs:
458
+ text += paragraph.text + "\n"
459
+ return text
460
+
461
+ else:
462
+ logger.warning(f"Unsupported file type: {suffix}")
463
+ return ""
464
+
465
+ except Exception as e:
466
+ logger.error(f"Error reading {file_path}: {e}")
467
+ return ""
468
+
469
+ def chunk_documents(self, docs: List[Document], chunk_size: int, overlap: int) -> List[Chunk]:
470
+ """Chunk documents into smaller pieces"""
471
+ logger.info(f"Chunking {len(docs)} documents (size={chunk_size}, overlap={overlap})")
472
+ chunks = []
473
+
474
+ for doc in docs:
475
+ doc_chunks = self._chunk_text(
476
+ doc.content,
477
+ doc.filename,
478
+ chunk_size,
479
+ overlap
480
+ )
481
+ chunks.extend(doc_chunks)
482
+
483
+ # Update document metadata
484
+ doc.chunk_count = len(doc_chunks)
485
+
486
+ logger.info(f"Created {len(chunks)} chunks from {len(docs)} documents")
487
+ return chunks
488
+
489
+ def _chunk_text(self, text: str, filename: str, chunk_size: int, overlap: int) -> List[Chunk]:
490
+ """Split text into overlapping chunks"""
491
+ # Clean text
492
+ text = re.sub(r'\s+', ' ', text.strip())
493
+
494
+ # Simple token-based chunking (approximate)
495
+ words = text.split()
496
+ chunks = []
497
+
498
+ for i in range(0, len(words), chunk_size - overlap):
499
+ chunk_words = words[i:i + chunk_size]
500
+ chunk_text = ' '.join(chunk_words)
501
+
502
+ if chunk_text.strip():
503
+ chunk_hash = self._calculate_chunk_hash(chunk_text)
504
+ chunk = Chunk(
505
+ text=chunk_text,
506
+ filename=filename,
507
+ chunk_id=len(chunks),
508
+ total_chunks=0, # Will be updated later
509
+ start_pos=i,
510
+ end_pos=i + len(chunk_words),
511
+ metadata={
512
+ 'word_count': len(chunk_words),
513
+ 'char_count': len(chunk_text)
514
+ },
515
+ chunk_hash=chunk_hash
516
+ )
517
+ chunks.append(chunk)
518
+
519
+ # Update total_chunks for each chunk
520
+ for chunk in chunks:
521
+ chunk.total_chunks = len(chunks)
522
+
523
+ return chunks
524
+
525
+ def build_or_update_index(self, chunks: List[Chunk], force_rebuild: bool = False) -> None:
526
+ """Build or update the vector index"""
527
+ if not chunks:
528
+ logger.warning("No chunks provided for indexing")
529
+ return
530
+
531
+ # Check if we need to rebuild
532
+ collection_stats = self.vector_retriever.get_collection_stats()
533
+ existing_count = collection_stats.get('total_chunks', 0)
534
+
535
+ if existing_count > 0 and not force_rebuild:
536
+ logger.info(f"Vector database already contains {existing_count} chunks. Use --force-rebuild to rebuild.")
537
+ return
538
+
539
+ if force_rebuild and existing_count > 0:
540
+ logger.info("Force rebuild requested. Clearing existing collection...")
541
+ try:
542
+ self.client.delete_collection(self.vector_retriever.collection_name)
543
+ self.vector_retriever.collection = self.client.create_collection(
544
+ name=self.vector_retriever.collection_name,
545
+ metadata={"description": "CGT-LLM-Beta document collection"}
546
+ )
547
+ except Exception as e:
548
+ logger.error(f"Error clearing collection: {e}")
549
+
550
+ # Add chunks to vector database
551
+ self.vector_retriever.add_documents(chunks)
552
+
553
+ logger.info("Vector index built successfully")
554
+
555
+ def retrieve(self, query: str, k: int) -> List[Chunk]:
556
+ """Retrieve relevant chunks for a query using vector search"""
557
+ results = self.vector_retriever.search(query, k)
558
+ chunks = [chunk for chunk, score in results]
559
+
560
+ if self.args.verbose:
561
+ logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
562
+ for i, (chunk, score) in enumerate(results):
563
+ logger.info(f" {i+1}. {chunk.filename} (score: {score:.3f})")
564
+
565
+ return chunks
566
+
567
+ def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]:
568
+ """Retrieve relevant chunks with similarity scores
569
+
570
+ Returns:
571
+ Tuple of (chunks, scores) where scores are similarity scores for each chunk
572
+ """
573
+ results = self.vector_retriever.search(query, k)
574
+ chunks = [chunk for chunk, score in results]
575
+ scores = [score for chunk, score in results]
576
+
577
+ if self.args.verbose:
578
+ logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
579
+ for i, (chunk, score) in enumerate(results):
580
+ logger.info(f" {i+1}. {chunk.filename} (score: {score:.3f})")
581
+
582
+ return chunks, scores
583
+
584
+ def format_prompt(self, context_chunks: List[Chunk], question: str) -> str:
585
+ """Format the prompt with context and question, ensuring it fits within token limits"""
586
+ context_parts = []
587
+ for chunk in context_chunks:
588
+ context_parts.append(f"{chunk.text}")
589
+
590
+ context = "\n".join(context_parts)
591
+
592
+ # Try to use the tokenizer's chat template if available
593
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
594
+ try:
595
+ messages = [
596
+ {"role": "system", "content": "You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative."},
597
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
598
+ ]
599
+ base_prompt = self.tokenizer.apply_chat_template(
600
+ messages,
601
+ tokenize=False,
602
+ add_generation_prompt=True
603
+ )
604
+ except Exception as e:
605
+ logger.warning(f"Failed to use chat template, falling back to manual format: {e}")
606
+ base_prompt = self._format_prompt_manual(context, question)
607
+ else:
608
+ # Fall back to manual formatting (for Llama models)
609
+ base_prompt = self._format_prompt_manual(context, question)
610
+
611
+ # Check if prompt is too long and truncate context if needed
612
+ max_context_tokens = 1200 # Leave room for generation
613
+ try:
614
+ tokenized = self.tokenizer(base_prompt, return_tensors="pt")
615
+ current_tokens = tokenized['input_ids'].shape[1]
616
+ except Exception as e:
617
+ logger.warning(f"Tokenization error, using base prompt as-is: {e}")
618
+ return base_prompt
619
+
620
+ if current_tokens > max_context_tokens:
621
+ # Truncate context to fit within limits
622
+ try:
623
+ context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
624
+ available_tokens = max_context_tokens - (current_tokens - context_tokens)
625
+
626
+ if available_tokens > 0:
627
+ # Truncate context to fit
628
+ truncated_context = self.tokenizer.decode(
629
+ self.tokenizer(context, return_tensors="pt", truncation=True, max_length=available_tokens)['input_ids'][0],
630
+ skip_special_tokens=True
631
+ )
632
+
633
+ # Reformat with truncated context
634
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
635
+ try:
636
+ messages = [
637
+ {"role": "system", "content": "You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative."},
638
+ {"role": "user", "content": f"Context: {truncated_context}\n\nQuestion: {question}"}
639
+ ]
640
+ prompt = self.tokenizer.apply_chat_template(
641
+ messages,
642
+ tokenize=False,
643
+ add_generation_prompt=True
644
+ )
645
+ except:
646
+ prompt = self._format_prompt_manual(truncated_context, question)
647
+ else:
648
+ prompt = self._format_prompt_manual(truncated_context, question)
649
+ else:
650
+ # If even basic prompt is too long, use minimal format
651
+ prompt = self._format_prompt_manual(context[:500] + "...", question)
652
+ except Exception as e:
653
+ logger.warning(f"Error truncating context: {e}, using base prompt")
654
+ prompt = base_prompt
655
+ else:
656
+ prompt = base_prompt
657
+
658
+ return prompt
659
+
660
+ def _format_prompt_manual(self, context: str, question: str) -> str:
661
+ """Manual prompt formatting for models without chat templates (e.g., Llama)"""
662
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
663
+
664
+ You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative.<|eot_id|><|start_header_id|>user<|end_header_id|>
665
+
666
+ Context: {context}
667
+
668
+ Question: {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
669
+
670
+ """
671
+
672
+ def format_improved_prompt(self, context_chunks: List[Chunk], question: str) -> Tuple[str, str]:
673
+ """Format an improved prompt with better tone, structure, and medical appropriateness
674
+
675
+ Returns:
676
+ Tuple of (prompt, prompt_text) where prompt_text is the system prompt instructions
677
+ """
678
+ context_parts = []
679
+ for chunk in context_chunks:
680
+ context_parts.append(f"{chunk.text}")
681
+
682
+ context = "\n".join(context_parts)
683
+
684
+ # Improved prompt with all the feedback incorporated
685
+ improved_prompt_text = """Provide a concise, neutral, and informative answer based on the provided medical context.
686
+
687
+ CRITICAL GUIDELINES:
688
+ - Format your response as clear, well-structured sentences and paragraphs
689
+ - Be concise and direct - focus on answering the specific question asked
690
+ - Use neutral, factual language - do NOT tell the questioner how to feel (avoid phrases like 'don't worry', 'the good news is', etc.)
691
+ - Do NOT use leading or coercive language - present information neutrally to preserve patient autonomy
692
+ - Do NOT make specific medical recommendations - instead state that management decisions should be made with a healthcare provider
693
+ - Use third-person voice only - never claim to be a medical professional or assistant
694
+ - Use consistent terminology: use 'children' (not 'offspring') consistently
695
+ - Do NOT include hypothetical examples with specific names (e.g., avoid 'Aunt Jenna' or similar)
696
+ - Include important distinctions when relevant (e.g., somatic vs. germline variants, reproductive risks)
697
+ - When citing sources, be consistent - always specify which guidelines or sources when mentioned
698
+ - Remove any formatting markers like asterisks (*) or bold markers
699
+ - Do NOT include phrases like 'Here's a rewritten version' - just provide the answer directly
700
+
701
+ If the question asks about medical management, screening, or interventions, conclude with: 'Management recommendations are individualized and should be discussed with a healthcare provider or genetic counselor.'"""
702
+
703
+ # Try to use the tokenizer's chat template if available
704
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
705
+ try:
706
+ messages = [
707
+ {"role": "system", "content": improved_prompt_text},
708
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
709
+ ]
710
+ base_prompt = self.tokenizer.apply_chat_template(
711
+ messages,
712
+ tokenize=False,
713
+ add_generation_prompt=True
714
+ )
715
+ except Exception as e:
716
+ logger.warning(f"Failed to use chat template for improved prompt, falling back to manual format: {e}")
717
+ base_prompt = self._format_improved_prompt_manual(context, question, improved_prompt_text)
718
+ else:
719
+ # Fall back to manual formatting (for Llama models)
720
+ base_prompt = self._format_improved_prompt_manual(context, question, improved_prompt_text)
721
+
722
+ # Check if prompt is too long and truncate context if needed
723
+ max_context_tokens = 1200 # Leave room for generation
724
+ try:
725
+ tokenized = self.tokenizer(base_prompt, return_tensors="pt")
726
+ current_tokens = tokenized['input_ids'].shape[1]
727
+ except Exception as e:
728
+ logger.warning(f"Tokenization error for improved prompt, using base prompt as-is: {e}")
729
+ return base_prompt, improved_prompt_text
730
+
731
+ if current_tokens > max_context_tokens:
732
+ # Truncate context to fit within limits
733
+ try:
734
+ context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
735
+ available_tokens = max_context_tokens - (current_tokens - context_tokens)
736
+
737
+ if available_tokens > 0:
738
+ # Truncate context to fit
739
+ truncated_context = self.tokenizer.decode(
740
+ self.tokenizer(context, return_tensors="pt", truncation=True, max_length=available_tokens)['input_ids'][0],
741
+ skip_special_tokens=True
742
+ )
743
+
744
+ # Reformat with truncated context
745
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
746
+ try:
747
+ messages = [
748
+ {"role": "system", "content": improved_prompt_text},
749
+ {"role": "user", "content": f"Context: {truncated_context}\n\nQuestion: {question}"}
750
+ ]
751
+ prompt = self.tokenizer.apply_chat_template(
752
+ messages,
753
+ tokenize=False,
754
+ add_generation_prompt=True
755
+ )
756
+ except:
757
+ prompt = self._format_improved_prompt_manual(truncated_context, question, improved_prompt_text)
758
+ else:
759
+ prompt = self._format_improved_prompt_manual(truncated_context, question, improved_prompt_text)
760
+ else:
761
+ # If even basic prompt is too long, use minimal format
762
+ prompt = self._format_improved_prompt_manual(context[:500] + "...", question, improved_prompt_text)
763
+ except Exception as e:
764
+ logger.warning(f"Error truncating context for improved prompt: {e}, using base prompt")
765
+ prompt = base_prompt
766
+ else:
767
+ prompt = base_prompt
768
+
769
+ return prompt, improved_prompt_text
770
+
771
+ def _format_improved_prompt_manual(self, context: str, question: str, improved_prompt_text: str) -> str:
772
+ """Manual prompt formatting for improved prompts (for models without chat templates)"""
773
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
774
+
775
+ {improved_prompt_text}<|eot_id|><|start_header_id|>user<|end_header_id|>
776
+
777
+ Context: {context}
778
+
779
+ Question: {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
780
+
781
+ """
782
+
783
+ def generate_answer(self, prompt: str, **gen_kwargs) -> str:
784
+ """Generate answer using the language model"""
785
+ try:
786
+ if self.args.verbose:
787
+ logger.info(f"Full prompt (first 500 chars): {prompt[:500]}...")
788
+
789
+ # Tokenize input with more conservative limit to leave room for generation
790
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1500)
791
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
792
+
793
+ if self.args.verbose:
794
+ logger.info(f"Input tokens: {inputs['input_ids'].shape}")
795
+
796
+ # Generate
797
+ with torch.no_grad():
798
+ outputs = self.model.generate(
799
+ **inputs,
800
+ max_new_tokens=gen_kwargs.get('max_new_tokens', 512),
801
+ temperature=gen_kwargs.get('temperature', 0.7),
802
+ top_p=gen_kwargs.get('top_p', 0.95),
803
+ repetition_penalty=gen_kwargs.get('repetition_penalty', 1.05),
804
+ do_sample=True,
805
+ pad_token_id=self.tokenizer.eos_token_id,
806
+ eos_token_id=self.tokenizer.eos_token_id,
807
+ use_cache=True,
808
+ num_beams=1
809
+ )
810
+
811
+ # Decode response without skipping special tokens to preserve full length
812
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
813
+
814
+ if self.args.verbose:
815
+ logger.info(f"Full response (first 1000 chars): {response[:1000]}...")
816
+ logger.info(f"Looking for 'Answer:' in response: {'Answer:' in response}")
817
+ if "Answer:" in response:
818
+ answer_part = response.split("Answer:")[-1]
819
+ logger.info(f"Answer part (first 200 chars): {answer_part[:200]}...")
820
+
821
+ # Debug: Show the full response to understand the structure
822
+ logger.info(f"Full response length: {len(response)}")
823
+ logger.info(f"Prompt length: {len(prompt)}")
824
+ logger.info(f"Response after prompt (first 500 chars): {response[len(prompt):][:500]}...")
825
+
826
+ # Extract the answer more robustly by looking for the end of the prompt
827
+ # Find the actual end of the prompt in the response
828
+ prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>\n\n"
829
+ if prompt_end_marker in response:
830
+ answer = response.split(prompt_end_marker)[-1].strip()
831
+ else:
832
+ # Fallback to character-based extraction
833
+ answer = response[len(prompt):].strip()
834
+
835
+ if self.args.verbose:
836
+ logger.info(f"Full LLM output (first 200 chars): {answer[:200]}...")
837
+ logger.info(f"Full LLM output length: {len(answer)} characters")
838
+ logger.info(f"Full LLM output (last 200 chars): ...{answer[-200:]}")
839
+
840
+ # Only do minimal cleanup to preserve the full response
841
+ # Remove special tokens that might interfere with display, but preserve content
842
+ if "<|start_header_id|>" in answer:
843
+ # Only remove if it's at the very end
844
+ if answer.endswith("<|start_header_id|>"):
845
+ answer = answer[:-len("<|start_header_id|>")].strip()
846
+ if "<|eot_id|>" in answer:
847
+ # Only remove if it's at the very end
848
+ if answer.endswith("<|eot_id|>"):
849
+ answer = answer[:-len("<|eot_id|>")].strip()
850
+ if "<|end_of_text|>" in answer:
851
+ # Only remove if it's at the very end
852
+ if answer.endswith("<|end_of_text|>"):
853
+ answer = answer[:-len("<|end_of_text|>")].strip()
854
+
855
+ # Final validation - only reject if completely empty
856
+ if not answer or len(answer) < 3:
857
+ answer = "I don't know."
858
+
859
+ if self.args.verbose:
860
+ logger.info(f"Final answer: '{answer}'")
861
+
862
+ return answer
863
+
864
+ except Exception as e:
865
+ logger.error(f"Generation error: {e}")
866
+ return "I encountered an error while generating the answer."
867
+
868
+ def process_questions(self, questions_path: str, **kwargs) -> List[Tuple[str, str, str, str, float, str, float, str, float, str, str]]:
869
+ """Process all questions and generate answers with multiple readability levels
870
+
871
+ Returns:
872
+ List of tuples: (question, answer, sources, question_group, original_flesch,
873
+ middle_school_answer, middle_school_flesch,
874
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores)
875
+ """
876
+ logger.info(f"Processing questions from {questions_path}")
877
+
878
+ # Load questions
879
+ try:
880
+ with open(questions_path, 'r', encoding='utf-8') as f:
881
+ questions = [line.strip() for line in f if line.strip()]
882
+ except Exception as e:
883
+ logger.error(f"Failed to load questions: {e}")
884
+ sys.exit(1)
885
+
886
+ logger.info(f"Found {len(questions)} questions to process")
887
+
888
+ qa_pairs = []
889
+
890
+ # Get the improved prompt text for CSV header by calling format_improved_prompt with empty chunks
891
+ # This will give us the prompt text without actually generating
892
+ _, improved_prompt_text = self.format_improved_prompt([], "")
893
+
894
+ # Initialize CSV file with headers
895
+ self.write_csv([], kwargs.get('output_file', 'results.csv'), append=False, improved_prompt_text=improved_prompt_text)
896
+
897
+ # Process each question
898
+ for i, question in enumerate(tqdm(questions, desc="Processing questions")):
899
+ logger.info(f"Question {i+1}/{len(questions)}: {question[:50]}...")
900
+
901
+ try:
902
+ # Categorize question
903
+ question_group = self._categorize_question(question)
904
+
905
+ # Retrieve relevant chunks with similarity scores
906
+ context_chunks, similarity_scores = self.retrieve_with_scores(question, self.args.k)
907
+
908
+ # Format similarity scores as a string (comma-separated, 3 decimal places)
909
+ similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores]) if similarity_scores else "0.000"
910
+
911
+ if not context_chunks:
912
+ answer = "I don't know."
913
+ sources = "No sources found"
914
+ middle_school_answer = "I don't know."
915
+ high_school_answer = "I don't know."
916
+ improved_answer = "I don't know."
917
+ original_flesch = 0.0
918
+ middle_school_flesch = 0.0
919
+ high_school_flesch = 0.0
920
+ similarity_scores_str = "0.000"
921
+ else:
922
+ # Format original prompt
923
+ prompt = self.format_prompt(context_chunks, question)
924
+
925
+ # Generate original answer
926
+ start_time = time.time()
927
+ answer = self.generate_answer(prompt, **kwargs)
928
+ gen_time = time.time() - start_time
929
+
930
+ # Generate improved answer
931
+ improved_prompt, _ = self.format_improved_prompt(context_chunks, question)
932
+ improved_start = time.time()
933
+ improved_answer = self.generate_answer(improved_prompt, **kwargs)
934
+ improved_time = time.time() - improved_start
935
+
936
+ # Clean up improved answer - remove unwanted phrases and formatting
937
+ improved_answer = self._clean_improved_answer(improved_answer)
938
+ logger.info(f"Improved answer generated in {improved_time:.2f}s")
939
+
940
+ # Extract source documents
941
+ sources = self._extract_sources(context_chunks)
942
+
943
+ # Calculate original answer Flesch score
944
+ try:
945
+ original_flesch = textstat.flesch_kincaid_grade(answer)
946
+ except:
947
+ original_flesch = 0.0
948
+
949
+ # Generate middle school version
950
+ readability_start = time.time()
951
+ middle_school_answer, middle_school_flesch = self.enhance_readability(answer, "middle_school")
952
+ readability_time = time.time() - readability_start
953
+ logger.info(f"Middle school readability in {readability_time:.2f}s")
954
+
955
+ # Generate high school version
956
+ readability_start = time.time()
957
+ high_school_answer, high_school_flesch = self.enhance_readability(answer, "high_school")
958
+ readability_time = time.time() - readability_start
959
+ logger.info(f"High school readability in {readability_time:.2f}s")
960
+
961
+ logger.info(f"Generated answer in {gen_time:.2f}s")
962
+ logger.info(f"Sources: {sources}")
963
+ logger.info(f"Similarity scores: {similarity_scores_str}")
964
+ logger.info(f"Original Flesch: {original_flesch:.1f}, Middle School: {middle_school_flesch:.1f}, High School: {high_school_flesch:.1f}")
965
+
966
+ qa_pairs.append((question, answer, sources, question_group, original_flesch,
967
+ middle_school_answer, middle_school_flesch,
968
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str))
969
+
970
+ # Write incrementally to CSV after each question
971
+ self.write_csv([(question, answer, sources, question_group, original_flesch,
972
+ middle_school_answer, middle_school_flesch,
973
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str)],
974
+ kwargs.get('output_file', 'results.csv'), append=True, improved_prompt_text=improved_prompt_text)
975
+ logger.info(f"Progress saved: {i+1}/{len(questions)} questions completed")
976
+
977
+ except Exception as e:
978
+ logger.error(f"Error processing question {i+1}: {e}")
979
+ error_answer = "I encountered an error processing this question."
980
+ sources = "Error retrieving sources"
981
+ question_group = self._categorize_question(question)
982
+ original_flesch = 0.0
983
+ middle_school_answer = "I encountered an error processing this question."
984
+ high_school_answer = "I encountered an error processing this question."
985
+ improved_answer = "I encountered an error processing this question."
986
+ middle_school_flesch = 0.0
987
+ high_school_flesch = 0.0
988
+ similarity_scores_str = "0.000"
989
+ qa_pairs.append((question, error_answer, sources, question_group, original_flesch,
990
+ middle_school_answer, middle_school_flesch,
991
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str))
992
+
993
+ # Still write the error to CSV
994
+ self.write_csv([(question, error_answer, sources, question_group, original_flesch,
995
+ middle_school_answer, middle_school_flesch,
996
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str)],
997
+ kwargs.get('output_file', 'results.csv'), append=True, improved_prompt_text=improved_prompt_text)
998
+ logger.info(f"Error saved: {i+1}/{len(questions)} questions completed")
999
+
1000
+ return qa_pairs
1001
+
1002
+ def _clean_readability_answer(self, answer: str, target_level: str) -> str:
1003
+ """Clean up readability-enhanced answers to remove unwanted phrases and formatting
1004
+
1005
+ Args:
1006
+ answer: The readability-enhanced answer
1007
+ target_level: Either "middle_school" or "high_school"
1008
+ """
1009
+ cleaned = answer
1010
+
1011
+ # Remove the "Here's a rewritten version" phrases
1012
+ if target_level == "middle_school":
1013
+ unwanted_phrases = [
1014
+ "Here's a rewritten version of the text at a middle school reading level:",
1015
+ "Here's a rewritten version of the text at a middle school reading level",
1016
+ "Here is a rewritten version of the text at a middle school reading level:",
1017
+ "Here is a rewritten version of the text at a middle school reading level",
1018
+ "Here's a rewritten version at a middle school reading level:",
1019
+ "Here's a rewritten version at a middle school reading level",
1020
+ ]
1021
+ elif target_level == "high_school":
1022
+ unwanted_phrases = [
1023
+ "Here's a rewritten version of the text at a high school reading level",
1024
+ "Here's a rewritten version of the text at a high school reading level:",
1025
+ "Here is a rewritten version of the text at a high school reading level",
1026
+ "Here is a rewritten version of the text at a high school reading level:",
1027
+ "Here's a rewritten version at a high school reading level",
1028
+ "Here's a rewritten version at a high school reading level:",
1029
+ ]
1030
+ else:
1031
+ unwanted_phrases = []
1032
+
1033
+ for phrase in unwanted_phrases:
1034
+ if phrase.lower() in cleaned.lower():
1035
+ # Find and remove the phrase (case-insensitive)
1036
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
1037
+ cleaned = pattern.sub("", cleaned).strip()
1038
+ # Remove leading colons, semicolons, or dashes
1039
+ cleaned = re.sub(r'^[:;\-]\s*', '', cleaned).strip()
1040
+
1041
+ # Remove asterisks (but preserve bullet points if they use •)
1042
+ cleaned = re.sub(r'\*\*', '', cleaned) # Remove bold markers
1043
+ cleaned = re.sub(r'\(\*\)', '', cleaned) # Remove (*)
1044
+ cleaned = re.sub(r'\*', '', cleaned) # Remove remaining asterisks
1045
+
1046
+ # Clean up extra whitespace
1047
+ cleaned = ' '.join(cleaned.split())
1048
+
1049
+ return cleaned
1050
+
1051
+ def _clean_improved_answer(self, answer: str) -> str:
1052
+ """Clean up improved answer to remove unwanted phrases and formatting"""
1053
+ # Remove phrases like "Here's a rewritten version" or similar
1054
+ unwanted_phrases = [
1055
+ "Here's a rewritten version",
1056
+ "Here's a version",
1057
+ "Here is a rewritten version",
1058
+ "Here is a version",
1059
+ "Here's the answer",
1060
+ "Here is the answer"
1061
+ ]
1062
+
1063
+ cleaned = answer
1064
+ for phrase in unwanted_phrases:
1065
+ if phrase.lower() in cleaned.lower():
1066
+ # Find and remove the phrase and any following colon/semicolon
1067
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
1068
+ cleaned = pattern.sub("", cleaned).strip()
1069
+ # Remove leading colons, semicolons, or dashes
1070
+ cleaned = re.sub(r'^[:;\-]\s*', '', cleaned).strip()
1071
+
1072
+ # Remove formatting markers like (*) or ** but preserve bullet points
1073
+ cleaned = re.sub(r'\*\*', '', cleaned) # Remove bold markers
1074
+ cleaned = re.sub(r'\(\*\)', '', cleaned) # Remove (*)
1075
+ # Note: Single asterisks are left alone as they might be used for formatting
1076
+ # The prompt specifies using • for bullet points, so this should be fine
1077
+
1078
+ # Remove "Don't worry" and similar emotional management phrases
1079
+ emotional_phrases = [
1080
+ r"don't worry[^.]*\.\s*",
1081
+ r"Don't worry[^.]*\.\s*",
1082
+ r"the good news is[^.]*\.\s*",
1083
+ r"The good news is[^.]*\.\s*",
1084
+ ]
1085
+ for pattern in emotional_phrases:
1086
+ cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE)
1087
+
1088
+ # Clean up extra whitespace
1089
+ cleaned = ' '.join(cleaned.split())
1090
+
1091
+ return cleaned
1092
+
1093
+ def diagnose_system(self, sample_questions: List[str] = None) -> Dict[str, Any]:
1094
+ """Diagnose the document loading, chunking, and retrieval system
1095
+
1096
+ Args:
1097
+ sample_questions: Optional list of questions to test retrieval
1098
+
1099
+ Returns:
1100
+ Dictionary with diagnostic information
1101
+ """
1102
+ diagnostics = {
1103
+ 'vector_db_stats': {},
1104
+ 'document_stats': {},
1105
+ 'chunk_stats': {},
1106
+ 'retrieval_tests': []
1107
+ }
1108
+
1109
+ # Check vector database
1110
+ try:
1111
+ stats = self.vector_retriever.get_collection_stats()
1112
+ diagnostics['vector_db_stats'] = {
1113
+ 'total_chunks': stats.get('total_chunks', 0),
1114
+ 'collection_name': stats.get('collection_name', 'unknown'),
1115
+ 'status': 'OK' if stats.get('total_chunks', 0) > 0 else 'EMPTY'
1116
+ }
1117
+ except Exception as e:
1118
+ diagnostics['vector_db_stats'] = {
1119
+ 'status': 'ERROR',
1120
+ 'error': str(e)
1121
+ }
1122
+
1123
+ # Test document loading (without actually loading)
1124
+ try:
1125
+ data_path = Path(self.args.data_dir)
1126
+ if data_path.exists():
1127
+ supported_extensions = {'.txt', '.md', '.json', '.csv'}
1128
+ if PDF_AVAILABLE:
1129
+ supported_extensions.add('.pdf')
1130
+ if DOCX_AVAILABLE:
1131
+ supported_extensions.add('.docx')
1132
+ supported_extensions.add('.doc')
1133
+
1134
+ files = []
1135
+ for ext in supported_extensions:
1136
+ files.extend(data_path.rglob(f"*{ext}"))
1137
+
1138
+ # Sample a few files to check content
1139
+ sample_files = files[:5] if len(files) > 5 else files
1140
+ file_samples = []
1141
+ for file_path in sample_files:
1142
+ try:
1143
+ content = self._read_file(file_path)
1144
+ file_samples.append({
1145
+ 'filename': file_path.name,
1146
+ 'size_chars': len(content),
1147
+ 'size_words': len(content.split()),
1148
+ 'readable': True
1149
+ })
1150
+ except Exception as e:
1151
+ file_samples.append({
1152
+ 'filename': file_path.name,
1153
+ 'readable': False,
1154
+ 'error': str(e)
1155
+ })
1156
+
1157
+ diagnostics['document_stats'] = {
1158
+ 'total_files_found': len(files),
1159
+ 'sample_files': file_samples,
1160
+ 'status': 'OK'
1161
+ }
1162
+ else:
1163
+ diagnostics['document_stats'] = {
1164
+ 'status': 'ERROR',
1165
+ 'error': f'Data directory {self.args.data_dir} does not exist'
1166
+ }
1167
+ except Exception as e:
1168
+ diagnostics['document_stats'] = {
1169
+ 'status': 'ERROR',
1170
+ 'error': str(e)
1171
+ }
1172
+
1173
+ # Test chunking on a sample document
1174
+ try:
1175
+ if diagnostics['document_stats'].get('status') == 'OK':
1176
+ sample_file = None
1177
+ for file_info in diagnostics['document_stats'].get('sample_files', []):
1178
+ if file_info.get('readable', False):
1179
+ # Find the actual file
1180
+ data_path = Path(self.args.data_dir)
1181
+ for ext in ['.txt', '.md', '.pdf', '.docx']:
1182
+ files = list(data_path.rglob(f"*{file_info['filename']}"))
1183
+ if files:
1184
+ sample_file = files[0]
1185
+ break
1186
+ if sample_file:
1187
+ break
1188
+
1189
+ if sample_file:
1190
+ content = self._read_file(sample_file)
1191
+ # Create a dummy document (Document is already imported at top)
1192
+ sample_doc = Document(
1193
+ filename=sample_file.name,
1194
+ content=content,
1195
+ filepath=str(sample_file),
1196
+ file_type=sample_file.suffix.lower(),
1197
+ file_hash=""
1198
+ )
1199
+
1200
+ # Test chunking
1201
+ sample_chunks = self._chunk_text(
1202
+ content,
1203
+ sample_file.name,
1204
+ self.args.chunk_size,
1205
+ self.args.chunk_overlap
1206
+ )
1207
+
1208
+ chunk_lengths = [len(chunk.text.split()) for chunk in sample_chunks]
1209
+
1210
+ diagnostics['chunk_stats'] = {
1211
+ 'sample_document': sample_file.name,
1212
+ 'total_chunks': len(sample_chunks),
1213
+ 'avg_chunk_size_words': sum(chunk_lengths) / len(chunk_lengths) if chunk_lengths else 0,
1214
+ 'min_chunk_size_words': min(chunk_lengths) if chunk_lengths else 0,
1215
+ 'max_chunk_size_words': max(chunk_lengths) if chunk_lengths else 0,
1216
+ 'chunk_size_setting': self.args.chunk_size,
1217
+ 'chunk_overlap_setting': self.args.chunk_overlap,
1218
+ 'status': 'OK'
1219
+ }
1220
+ except Exception as e:
1221
+ diagnostics['chunk_stats'] = {
1222
+ 'status': 'ERROR',
1223
+ 'error': str(e)
1224
+ }
1225
+
1226
+ # Test retrieval with sample questions
1227
+ if sample_questions and diagnostics['vector_db_stats'].get('status') == 'OK':
1228
+ for question in sample_questions:
1229
+ try:
1230
+ context_chunks = self.retrieve(question, self.args.k)
1231
+ sources = self._extract_sources(context_chunks)
1232
+
1233
+ # Get similarity scores
1234
+ results = self.vector_retriever.search(question, self.args.k)
1235
+
1236
+ # Get sample chunk text (first 200 chars of first chunk)
1237
+ sample_chunk_text = context_chunks[0].text[:200] + "..." if context_chunks else "N/A"
1238
+
1239
+ diagnostics['retrieval_tests'].append({
1240
+ 'question': question,
1241
+ 'chunks_retrieved': len(context_chunks),
1242
+ 'sources': sources,
1243
+ 'similarity_scores': [f"{score:.3f}" for _, score in results],
1244
+ 'sample_chunk_preview': sample_chunk_text,
1245
+ 'status': 'OK' if context_chunks else 'NO_RESULTS'
1246
+ })
1247
+ except Exception as e:
1248
+ diagnostics['retrieval_tests'].append({
1249
+ 'question': question,
1250
+ 'status': 'ERROR',
1251
+ 'error': str(e)
1252
+ })
1253
+
1254
+ return diagnostics
1255
+
1256
+ def print_diagnostics(self, diagnostics: Dict[str, Any]) -> None:
1257
+ """Print diagnostic information in a readable format"""
1258
+ print("\n" + "="*80)
1259
+ print("SYSTEM DIAGNOSTICS")
1260
+ print("="*80)
1261
+
1262
+ # Vector DB Stats
1263
+ print("\n📊 VECTOR DATABASE:")
1264
+ vdb = diagnostics.get('vector_db_stats', {})
1265
+ print(f" Status: {vdb.get('status', 'UNKNOWN')}")
1266
+ print(f" Total chunks: {vdb.get('total_chunks', 0)}")
1267
+ print(f" Collection: {vdb.get('collection_name', 'unknown')}")
1268
+ if 'error' in vdb:
1269
+ print(f" Error: {vdb['error']}")
1270
+
1271
+ # Document Stats
1272
+ print("\n📄 DOCUMENT LOADING:")
1273
+ doc_stats = diagnostics.get('document_stats', {})
1274
+ print(f" Status: {doc_stats.get('status', 'UNKNOWN')}")
1275
+ print(f" Total files found: {doc_stats.get('total_files_found', 0)}")
1276
+ if 'sample_files' in doc_stats:
1277
+ print(f" Sample files:")
1278
+ for file_info in doc_stats['sample_files']:
1279
+ if file_info.get('readable', False):
1280
+ print(f" ✓ {file_info['filename']}: {file_info.get('size_chars', 0):,} chars, {file_info.get('size_words', 0):,} words")
1281
+ else:
1282
+ print(f" ✗ {file_info['filename']}: {file_info.get('error', 'unreadable')}")
1283
+ if 'error' in doc_stats:
1284
+ print(f" Error: {doc_stats['error']}")
1285
+
1286
+ # Chunk Stats
1287
+ print("\n✂️ CHUNKING:")
1288
+ chunk_stats = diagnostics.get('chunk_stats', {})
1289
+ print(f" Status: {chunk_stats.get('status', 'UNKNOWN')}")
1290
+ if chunk_stats.get('status') == 'OK':
1291
+ print(f" Sample document: {chunk_stats.get('sample_document', 'N/A')}")
1292
+ print(f" Total chunks from sample: {chunk_stats.get('total_chunks', 0)}")
1293
+ print(f" Average chunk size: {chunk_stats.get('avg_chunk_size_words', 0):.1f} words")
1294
+ print(f" Chunk size range: {chunk_stats.get('min_chunk_size_words', 0)} - {chunk_stats.get('max_chunk_size_words', 0)} words")
1295
+ print(f" Settings: size={chunk_stats.get('chunk_size_setting', 0)}, overlap={chunk_stats.get('chunk_overlap_setting', 0)}")
1296
+ if 'error' in chunk_stats:
1297
+ print(f" Error: {chunk_stats['error']}")
1298
+
1299
+ # Retrieval Tests
1300
+ if diagnostics.get('retrieval_tests'):
1301
+ print("\n🔍 RETRIEVAL TESTS:")
1302
+ for test in diagnostics['retrieval_tests']:
1303
+ print(f"\n Question: {test.get('question', 'N/A')}")
1304
+ print(f" Status: {test.get('status', 'UNKNOWN')}")
1305
+ if test.get('status') == 'OK':
1306
+ print(f" Chunks retrieved: {test.get('chunks_retrieved', 0)}")
1307
+ print(f" Sources: {test.get('sources', 'N/A')}")
1308
+ scores = test.get('similarity_scores', [])
1309
+ if scores:
1310
+ print(f" Similarity scores: {', '.join(scores)}")
1311
+ # Warn if scores are low
1312
+ try:
1313
+ score_values = [float(s) for s in scores]
1314
+ if max(score_values) < 0.3:
1315
+ print(f" ⚠️ WARNING: Low similarity scores - retrieved chunks may not be very relevant")
1316
+ elif max(score_values) < 0.5:
1317
+ print(f" ⚠️ NOTE: Moderate similarity - consider increasing --k or checking chunk quality")
1318
+ except:
1319
+ pass
1320
+ if 'sample_chunk_preview' in test:
1321
+ print(f" Sample chunk preview: {test['sample_chunk_preview']}")
1322
+ elif 'error' in test:
1323
+ print(f" Error: {test['error']}")
1324
+
1325
+ print("\n" + "="*80 + "\n")
1326
+
1327
+ def _extract_sources(self, context_chunks: List[Chunk]) -> str:
1328
+ """Extract source document names from context chunks"""
1329
+ sources = []
1330
+ for chunk in context_chunks:
1331
+ # Debug: Print chunk filename if verbose
1332
+ if self.args.verbose:
1333
+ logger.info(f"Chunk filename: {chunk.filename}")
1334
+
1335
+ # Extract filename from chunk attribute (not metadata)
1336
+ source = chunk.filename if hasattr(chunk, 'filename') and chunk.filename else 'Unknown source'
1337
+ # Clean up the source name
1338
+ if source.endswith('.pdf'):
1339
+ source = source[:-4] # Remove .pdf extension
1340
+ elif source.endswith('.txt'):
1341
+ source = source[:-4] # Remove .txt extension
1342
+ elif source.endswith('.md'):
1343
+ source = source[:-3] # Remove .md extension
1344
+
1345
+ sources.append(source)
1346
+
1347
+ # Remove duplicates while preserving order
1348
+ unique_sources = []
1349
+ for source in sources:
1350
+ if source not in unique_sources:
1351
+ unique_sources.append(source)
1352
+
1353
+ return "; ".join(unique_sources)
1354
+
1355
+ def _categorize_question(self, question: str) -> str:
1356
+ """Categorize a question into one of 5 categories"""
1357
+ question_lower = question.lower()
1358
+
1359
+ # Gene-Specific Recommendations
1360
+ if any(gene in question_lower for gene in ['msh2', 'mlh1', 'msh6', 'pms2', 'epcam', 'brca1', 'brca2']):
1361
+ if any(kw in question_lower for kw in ['screening', 'surveillance', 'prevention', 'recommendation', 'risk', 'cancer risk', 'steps', 'management']):
1362
+ return "Gene-Specific Recommendations"
1363
+
1364
+ # Inheritance Patterns
1365
+ if any(kw in question_lower for kw in ['inherit', 'inherited', 'pass', 'skip a generation', 'generation', 'can i pass']):
1366
+ return "Inheritance Patterns"
1367
+
1368
+ # Family Risk Assessment
1369
+ if any(kw in question_lower for kw in ['family member', 'relative', 'first-degree', 'family risk', 'which relative', 'should my family']):
1370
+ return "Family Risk Assessment"
1371
+
1372
+ # Genetic Variant Interpretation
1373
+ if any(kw in question_lower for kw in ['what does', 'genetic variant mean', 'variant mean', 'mutation mean', 'genetic result']):
1374
+ return "Genetic Variant Interpretation"
1375
+
1376
+ # Support and Resources
1377
+ if any(kw in question_lower for kw in ['cope', 'overwhelmed', 'resource', 'genetic counselor', 'support', 'research', 'help', 'insurance', 'gina']):
1378
+ return "Support and Resources"
1379
+
1380
+ # Default to Genetic Variant Interpretation if unclear
1381
+ return "Genetic Variant Interpretation"
1382
+
1383
+ def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
1384
+ """Enhance answer readability to different levels and calculate Flesch-Kincaid Grade Level
1385
+
1386
+ Args:
1387
+ answer: The original answer to simplify or enhance
1388
+ target_level: One of "middle_school", "high_school", "college", or "doctoral"
1389
+
1390
+ Returns:
1391
+ Tuple of (enhanced_answer, grade_level)
1392
+ """
1393
+ try:
1394
+ # Define prompts for different reading levels
1395
+ if target_level == "middle_school":
1396
+ level_description = "middle school reading level (ages 12-14, 6th-8th grade)"
1397
+ instructions = """
1398
+ - Use simpler medical terms or explain them
1399
+ - Medium-length sentences
1400
+ - Clear, structured explanations
1401
+ - Keep important medical information accessible"""
1402
+ elif target_level == "high_school":
1403
+ level_description = "high school reading level (ages 15-18, 9th-12th grade)"
1404
+ instructions = """
1405
+ - Use appropriate medical terminology with context
1406
+ - Varied sentence length
1407
+ - Comprehensive yet accessible explanations
1408
+ - Maintain technical accuracy while ensuring clarity"""
1409
+ elif target_level == "college":
1410
+ level_description = "college reading level (undergraduate level, ages 18-22)"
1411
+ instructions = """
1412
+ - Use standard medical terminology with brief explanations
1413
+ - Professional and clear writing style
1414
+ - Include relevant clinical context
1415
+ - Maintain scientific accuracy and precision
1416
+ - Appropriate for undergraduate students in health sciences"""
1417
+ elif target_level == "doctoral":
1418
+ level_description = "doctoral/professional reading level (graduate level, medical professionals)"
1419
+ instructions = """
1420
+ - Use advanced medical and scientific terminology
1421
+ - Include detailed clinical and research context
1422
+ - Reference specific mechanisms, pathways, and evidence
1423
+ - Provide comprehensive technical explanations
1424
+ - Appropriate for medical professionals, researchers, and graduate students
1425
+ - Include nuanced discussions of clinical implications and research findings"""
1426
+ else:
1427
+ raise ValueError(f"Unknown target_level: {target_level}. Must be one of: middle_school, high_school, college, doctoral")
1428
+
1429
+ # Create a prompt to enhance the medical answer for the target level
1430
+ # Try to use chat template if available, otherwise use manual format
1431
+ system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}:
1432
+ {instructions}
1433
+ - Keep the same important information but adapt the complexity
1434
+ - Provide context for technical terms
1435
+ - Ensure the answer is informative yet understandable"""
1436
+
1437
+ user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}"
1438
+
1439
+ # Try to use chat template if available
1440
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
1441
+ try:
1442
+ messages = [
1443
+ {"role": "system", "content": system_message},
1444
+ {"role": "user", "content": user_message}
1445
+ ]
1446
+ readability_prompt = self.tokenizer.apply_chat_template(
1447
+ messages,
1448
+ tokenize=False,
1449
+ add_generation_prompt=True
1450
+ )
1451
+ except Exception as e:
1452
+ logger.warning(f"Failed to use chat template for readability, falling back to manual format: {e}")
1453
+ readability_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1454
+
1455
+ {system_message}
1456
+
1457
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
1458
+
1459
+ {user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1460
+
1461
+ """
1462
+ else:
1463
+ # Fall back to manual formatting (for Llama models)
1464
+ readability_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1465
+
1466
+ {system_message}
1467
+
1468
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
1469
+
1470
+ {user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1471
+
1472
+ """
1473
+
1474
+ # Generate simplified answer
1475
+ inputs = self.tokenizer(readability_prompt, return_tensors="pt", truncation=True, max_length=2048)
1476
+ if self.device == "mps":
1477
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
1478
+
1479
+ # Adjust generation parameters based on target level
1480
+ if target_level in ["college", "doctoral"]:
1481
+ max_tokens = 512 # Reduced from 1024 for faster responses
1482
+ temp = 0.4 # Slightly higher temperature for more natural flow
1483
+ else:
1484
+ max_tokens = 384 # Reduced from 512 for faster responses
1485
+ temp = 0.3 # Lower temperature for more consistent simplification
1486
+
1487
+ with torch.no_grad():
1488
+ outputs = self.model.generate(
1489
+ **inputs,
1490
+ max_new_tokens=max_tokens,
1491
+ temperature=temp,
1492
+ top_p=0.9,
1493
+ repetition_penalty=1.05,
1494
+ do_sample=True,
1495
+ pad_token_id=self.tokenizer.eos_token_id,
1496
+ eos_token_id=self.tokenizer.eos_token_id,
1497
+ use_cache=True,
1498
+ num_beams=1
1499
+ )
1500
+
1501
+ # Decode response
1502
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
1503
+
1504
+ # Extract enhanced answer
1505
+ # Try to find the assistant response marker
1506
+ prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>\n\n"
1507
+ if prompt_end_marker in response:
1508
+ simplified_answer = response.split(prompt_end_marker)[-1].strip()
1509
+ elif "<|assistant|>" in response:
1510
+ # Some chat templates use <|assistant|>
1511
+ simplified_answer = response.split("<|assistant|>")[-1].strip()
1512
+ else:
1513
+ # Fallback: extract everything after the prompt
1514
+ simplified_answer = response[len(readability_prompt):].strip()
1515
+
1516
+ # Clean up special tokens
1517
+ if "<|eot_id|>" in simplified_answer:
1518
+ if simplified_answer.endswith("<|eot_id|>"):
1519
+ simplified_answer = simplified_answer[:-len("<|eot_id|>")].strip()
1520
+ if "<|end_of_text|>" in simplified_answer:
1521
+ if simplified_answer.endswith("<|end_of_text|>"):
1522
+ simplified_answer = simplified_answer[:-len("<|end_of_text|>")].strip()
1523
+
1524
+ # Clean up unwanted phrases and formatting
1525
+ simplified_answer = self._clean_readability_answer(simplified_answer, target_level)
1526
+
1527
+ # Calculate Flesch-Kincaid Grade Level
1528
+ try:
1529
+ grade_level = textstat.flesch_kincaid_grade(simplified_answer)
1530
+ except:
1531
+ grade_level = 0.0
1532
+
1533
+ if self.args.verbose:
1534
+ logger.info(f"Simplified answer length: {len(simplified_answer)} characters")
1535
+ logger.info(f"Flesch-Kincaid Grade Level: {grade_level:.1f}")
1536
+
1537
+ return simplified_answer, grade_level
1538
+
1539
+ except Exception as e:
1540
+ logger.error(f"Error enhancing readability: {e}")
1541
+ # Fallback: return original answer with estimated grade level
1542
+ try:
1543
+ grade_level = textstat.flesch_kincaid_grade(answer)
1544
+ except:
1545
+ grade_level = 12.0 # Default to high school level
1546
+ return answer, grade_level
1547
+
1548
+ def write_csv(self, qa_pairs: List[Tuple[str, str, str, str, float, str, float, str, float, str, str]], output_path: str, append: bool = False, improved_prompt_text: str = "") -> None:
1549
+ """Write Q&A pairs to CSV file in results folder
1550
+
1551
+ Expected tuple format: (question, answer, sources, question_group, original_flesch,
1552
+ middle_school_answer, middle_school_flesch,
1553
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores)
1554
+ """
1555
+ # Ensure results directory exists
1556
+ os.makedirs('results', exist_ok=True)
1557
+
1558
+ # If output_path doesn't already have results/ prefix, add it
1559
+ if not output_path.startswith('results/'):
1560
+ output_path = f'results/{output_path}'
1561
+
1562
+ if append:
1563
+ logger.info(f"Appending results to {output_path}")
1564
+ else:
1565
+ logger.info(f"Writing results to {output_path}")
1566
+
1567
+ # Create output directory if needed
1568
+ output_path = Path(output_path)
1569
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1570
+
1571
+ try:
1572
+ # Check if file exists and if we're appending
1573
+ file_exists = output_path.exists()
1574
+ write_mode = 'a' if append and file_exists else 'w'
1575
+
1576
+ with open(output_path, write_mode, newline='', encoding='utf-8') as f:
1577
+ writer = csv.writer(f)
1578
+
1579
+ # Write header only if creating new file or first append
1580
+ if not append or not file_exists:
1581
+ # Create improved answer header with prompt text
1582
+ improved_header = f'improved_answer (PROMPT: {improved_prompt_text})'
1583
+ writer.writerow(['question', 'question_group', 'answer', 'original_flesch', 'sources',
1584
+ 'similarity_scores', 'middle_school_answer', 'middle_school_flesch',
1585
+ 'high_school_answer', 'high_school_flesch', improved_header])
1586
+
1587
+ for data in qa_pairs:
1588
+ # Unpack the data tuple
1589
+ (question, answer, sources, question_group, original_flesch,
1590
+ middle_school_answer, middle_school_flesch,
1591
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores) = data
1592
+
1593
+ # Clean and escape the answers for CSV
1594
+ def clean_text(text):
1595
+ # Replace newlines with spaces and clean up formatting
1596
+ cleaned = text.replace('\n', ' ').replace('\r', ' ')
1597
+ # Remove extra whitespace but preserve the full content
1598
+ cleaned = ' '.join(cleaned.split())
1599
+ # Escape quotes properly for CSV
1600
+ cleaned = cleaned.replace('"', '""')
1601
+ return cleaned
1602
+
1603
+ clean_question = clean_text(question)
1604
+ clean_answer = clean_text(answer)
1605
+ clean_sources = clean_text(sources)
1606
+ clean_middle_school = clean_text(middle_school_answer)
1607
+ clean_high_school = clean_text(high_school_answer)
1608
+ clean_improved = clean_text(improved_answer)
1609
+
1610
+ # Log the full answer length for debugging
1611
+ if self.args.verbose:
1612
+ logger.info(f"Writing answer length: {len(clean_answer)} characters")
1613
+ logger.info(f"Middle school answer length: {len(clean_middle_school)} characters")
1614
+ logger.info(f"High school answer length: {len(clean_high_school)} characters")
1615
+ logger.info(f"Improved answer length: {len(clean_improved)} characters")
1616
+ logger.info(f"Question group: {question_group}")
1617
+
1618
+ # Use proper CSV quoting - let csv.writer handle the quoting
1619
+ writer.writerow([
1620
+ clean_question,
1621
+ question_group,
1622
+ clean_answer,
1623
+ f"{original_flesch:.1f}",
1624
+ clean_sources,
1625
+ similarity_scores, # Similarity scores (comma-separated)
1626
+ clean_middle_school,
1627
+ f"{middle_school_flesch:.1f}",
1628
+ clean_high_school,
1629
+ f"{high_school_flesch:.1f}",
1630
+ clean_improved
1631
+ ])
1632
+
1633
+ if append:
1634
+ logger.info(f"Appended {len(qa_pairs)} Q&A pairs to {output_path}")
1635
+ else:
1636
+ logger.info(f"Successfully wrote {len(qa_pairs)} Q&A pairs to {output_path}")
1637
+
1638
+ except Exception as e:
1639
+ logger.error(f"Failed to write CSV: {e}")
1640
+ sys.exit(4)
1641
+
1642
+
1643
+ def parse_args():
1644
+ """Parse command line arguments"""
1645
+ parser = argparse.ArgumentParser(description="RAG Chatbot for CGT-LLM-Beta with Vector Database")
1646
+
1647
+ # File paths
1648
+ parser.add_argument('--data-dir', default='./Data Resources',
1649
+ help='Directory containing documents to index')
1650
+ parser.add_argument('--questions', default='./questions.txt',
1651
+ help='File containing questions (one per line)')
1652
+ parser.add_argument('--out', default='./answers.csv',
1653
+ help='Output CSV file for answers')
1654
+ parser.add_argument('--vector-db-dir', default='./chroma_db',
1655
+ help='Directory for ChromaDB persistence')
1656
+
1657
+ # Retrieval parameters
1658
+ parser.add_argument('--k', type=int, default=5,
1659
+ help='Number of chunks to retrieve per question')
1660
+
1661
+ # Chunking parameters
1662
+ parser.add_argument('--chunk-size', type=int, default=500,
1663
+ help='Size of text chunks in tokens')
1664
+ parser.add_argument('--chunk-overlap', type=int, default=200,
1665
+ help='Overlap between chunks in tokens')
1666
+
1667
+ # Model selection
1668
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct',
1669
+ help='HuggingFace model name to use (e.g., meta-llama/Llama-3.2-3B-Instruct, mistralai/Mistral-7B-Instruct-v0.2)')
1670
+
1671
+ # Generation parameters
1672
+ parser.add_argument('--max-new-tokens', type=int, default=1024,
1673
+ help='Maximum new tokens to generate')
1674
+ parser.add_argument('--temperature', type=float, default=0.2,
1675
+ help='Generation temperature')
1676
+ parser.add_argument('--top-p', type=float, default=0.9,
1677
+ help='Top-p sampling parameter')
1678
+ parser.add_argument('--repetition-penalty', type=float, default=1.1,
1679
+ help='Repetition penalty')
1680
+
1681
+ # Database options
1682
+ parser.add_argument('--force-rebuild', action='store_true',
1683
+ help='Force rebuild of vector database')
1684
+ parser.add_argument('--skip-indexing', action='store_true',
1685
+ help='Skip document indexing, use existing database')
1686
+
1687
+ # Other options
1688
+ parser.add_argument('--seed', type=int, default=42,
1689
+ help='Random seed for reproducibility')
1690
+ parser.add_argument('--verbose', action='store_true',
1691
+ help='Enable verbose logging')
1692
+ parser.add_argument('--dry-run', action='store_true',
1693
+ help='Build index and test retrieval without generation')
1694
+ parser.add_argument('--diagnose', action='store_true',
1695
+ help='Run system diagnostics and exit')
1696
+
1697
+ return parser.parse_args()
1698
+
1699
+
1700
+ def main():
1701
+ """Main function"""
1702
+ args = parse_args()
1703
+
1704
+ # Set random seed
1705
+ torch.manual_seed(args.seed)
1706
+ np.random.seed(args.seed)
1707
+
1708
+ # Set logging level
1709
+ if args.verbose:
1710
+ logging.getLogger().setLevel(logging.DEBUG)
1711
+
1712
+ logger.info("Starting RAG Chatbot with Vector Database")
1713
+ logger.info(f"Arguments: {vars(args)}")
1714
+
1715
+ try:
1716
+ # Initialize bot
1717
+ bot = RAGBot(args)
1718
+
1719
+ # Check if we should skip indexing
1720
+ if not args.skip_indexing:
1721
+ # Load and process documents
1722
+ documents = bot.load_corpus(args.data_dir)
1723
+ if not documents:
1724
+ logger.error("No documents found to process")
1725
+ sys.exit(3)
1726
+
1727
+ # Chunk documents
1728
+ chunks = bot.chunk_documents(documents, args.chunk_size, args.chunk_overlap)
1729
+ if not chunks:
1730
+ logger.error("No chunks created from documents")
1731
+ sys.exit(3)
1732
+
1733
+ # Build or update index
1734
+ bot.build_or_update_index(chunks, args.force_rebuild)
1735
+ else:
1736
+ logger.info("Skipping document indexing, using existing vector database")
1737
+
1738
+ # Run diagnostics if requested
1739
+ if args.diagnose:
1740
+ sample_questions = [
1741
+ "What is Lynch Syndrome?",
1742
+ "What does a BRCA1 genetic variant mean?",
1743
+ "What screening tests are recommended for MSH2 carriers?"
1744
+ ]
1745
+ diagnostics = bot.diagnose_system(sample_questions=sample_questions)
1746
+ bot.print_diagnostics(diagnostics)
1747
+ return
1748
+
1749
+ if args.dry_run:
1750
+ logger.info("Dry run completed successfully")
1751
+ return
1752
+
1753
+ # Process questions
1754
+ generation_kwargs = {
1755
+ 'max_new_tokens': args.max_new_tokens,
1756
+ 'temperature': args.temperature,
1757
+ 'top_p': args.top_p,
1758
+ 'repetition_penalty': args.repetition_penalty
1759
+ }
1760
+
1761
+ qa_pairs = bot.process_questions(args.questions, output_file=args.out, **generation_kwargs)
1762
+
1763
+ logger.info("RAG Chatbot completed successfully")
1764
+
1765
+ except KeyboardInterrupt:
1766
+ logger.info("Interrupted by user")
1767
+ sys.exit(0)
1768
+ except Exception as e:
1769
+ logger.error(f"Unexpected error: {e}")
1770
+ if args.verbose:
1771
+ import traceback
1772
+ traceback.print_exc()
1773
+ sys.exit(1)
1774
+
1775
+
1776
+ if __name__ == "__main__":
1777
+ main()
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80fe29380be0f587de8c3d0df3bbd891219ebe35d3ab4e007721d322ca704b9f
3
+ size 18888520
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56091853c1c20a1ec97ba4a7935cb7ab95f58b91d1ca56b990bf768f7bd2df88
3
+ size 100
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:754f12ddf66368443039e44c7d3625dbfa54c42604f231054e5c8ab8df162ebb
3
+ size 548379
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e72c9f5fb80c8fa3f488f68172cf32cdaf226d94cb6cff09ff68990b34fbb04c
3
+ size 45080
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0046b8333ff42649a27896a5da1f0fd89ee54954221fde9172dfe284d94262b
3
+ size 99820
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70340ab0d0dddb6b5bcf29c0e09f316b0f695f6645be0231302346d5af463700
3
+ size 294584320
requirements.txt ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # RAG Chatbot with Vector Database - Requirements
3
+ # =============================================================================
4
+ # Production-ready dependencies for medical document analysis and Q&A
5
+
6
+ # Core ML/AI Framework
7
+ torch>=2.0.0 # PyTorch for model inference
8
+ transformers>=4.30.0 # Hugging Face transformers
9
+ huggingface_hub>=0.20.0 # Hugging Face Hub API (for Inference API)
10
+ accelerate>=0.20.0 # Model loading optimization
11
+ safetensors>=0.3.0 # Safe model loading
12
+
13
+ # Vector Database & Embeddings
14
+ chromadb>=0.4.0 # Vector database for fast retrieval
15
+ sentence-transformers>=2.2.0 # Semantic embeddings (all-MiniLM-L6-v2)
16
+
17
+ # Data Processing
18
+ pandas>=1.3.0 # Data manipulation and CSV handling
19
+ numpy>=1.21.0 # Numerical computing
20
+ scikit-learn>=1.0.0 # ML utilities and TF-IDF
21
+
22
+ # Text Analysis & Readability
23
+ textstat>=0.7.0 # Flesch-Kincaid Grade Level calculation
24
+ nltk>=3.8.0 # Natural language processing utilities
25
+
26
+ # Document Processing (Core)
27
+ pypdf>=3.0.0 # PDF document parsing
28
+ python-docx>=0.8.11 # DOCX document parsing
29
+
30
+ # Optional Document Processing
31
+ rank-bm25>=0.2.2 # BM25 retrieval algorithm (alternative to TF-IDF)
32
+
33
+ # Utilities & Progress
34
+ tqdm>=4.65.0 # Progress bars
35
+ pathlib2>=2.3.0 # Enhanced path handling (if needed)
36
+
37
+ # Web Interface
38
+ gradio==4.44.1 # Gradio web interface for chatbot (updated for Spaces compatibility)
39
+
40
+ # Development & Testing (Optional)
41
+ pytest>=7.0.0 # Testing framework
42
+ black>=22.0.0 # Code formatting
43
+ flake8>=4.0.0 # Code linting
44
+
45
+ # Performance Monitoring (Optional)
46
+ psutil>=5.8.0 # System resource monitoring
47
+ memory-profiler>=0.60.0 # Memory usage profiling
48
+
49
+ # =============================================================================
50
+ # Installation Notes:
51
+ # =============================================================================
52
+ # 1. Install with: pip install -r requirements.txt
53
+ # 2. For Apple Silicon: PyTorch will automatically use MPS acceleration
54
+ # 3. Optional packages can be installed separately if needed
55
+ # 4. Model files (~6GB) will be downloaded on first run
56
+ # =============================================================================