Nam Fam commited on
Commit
472e1d4
·
1 Parent(s): 0d52457
.dockerignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Exclude Python cache & virtual envs
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ .env
7
+
8
+ # Exclude Git files
9
+ .git/
10
+ .gitignore
11
+
12
+ # IDE/editor files
13
+ .vscode/
14
+ .idea/
15
+
16
+ # OS files
17
+ .DS_Store
18
+ Thumbs.db
19
+
20
+ # Project-specific
21
+ notebooks/
22
+ scripts/
23
+ tests/
24
+ eval/
25
+ mcp_server/
26
+ mcp_client.py
27
+ utils/google_api_manager.py
28
+ output/
.gitignore ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.egg-info/
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ venv/
19
+ ENV/
20
+ env/
21
+ env.bak/
22
+ venv.bak/
23
+ pip-log.txt
24
+ pip-delete-this-directory.txt
25
+
26
+ # IDEs
27
+ .vscode/
28
+ .idea/
29
+ *.swp
30
+ *.swo
31
+ *.sublime-workspace
32
+ *.sublime-project
33
+
34
+ # Environment files
35
+ .env
36
+ .env.local
37
+ .env.development
38
+ .env.test
39
+ .env.production
40
+
41
+ # Local development
42
+ # *.db
43
+ *.sqlite
44
+ data/
45
+ logs/
46
+ *.log
47
+ # *.csv
48
+ *.parquet
49
+ # output/
50
+
51
+ # Streamlit
52
+ .streamlit/credentials.toml
53
+
54
+ # /notebooks
55
+ # /tests
56
+
57
+
58
+ *conftest.py
59
+ *test_plotting.py
60
+ plots/
61
+ utils/google_api_manager.py
62
+ mcp_client.py
63
+ mcp_server/
64
+ tests/
65
+ scripts/
66
+ notebooks/
67
+ output/
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1
2
+
3
+ FROM python:3.10-slim
4
+
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements and install Python dependencies
13
+ COPY requirements.txt ./
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy application code
17
+ COPY . .
18
+
19
+ # Expose Streamlit default port
20
+ EXPOSE 8501
21
+
22
+ # # Streamlit configuration
23
+ # ENV STREAMLIT_SERVER_PORT=8501 \
24
+ # STREAMLIT_SERVER_ADDRESS=0.0.0.0
25
+
26
+ # # Launch the app
27
+ # CMD ["streamlit", "run", "app.py"]
28
+
29
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
agents/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .sql_agent.agent import SQLAgent
2
+ from .sql_agent.states import SQLAgentState
3
+ from .sql_agent.nodes import get_db_info, generate_sql, execute_sql, optional_plot, format_response , generate_answer
4
+ from .tools import PlotSQLTool
5
+ from .llms import LLM
6
+
7
+ __all__ = ['SQLAgent', 'SQLAgentState', 'get_db_info', 'generate_sql', 'execute_sql', 'optional_plot', 'format_response', 'generate_answer', 'PlotSQLTool', 'LLM']
agents/dataframe_agent.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
3
+ from agents.llms import LLM
4
+ from langchain.agents.agent_types import AgentType
5
+
6
+
7
+ def get_dataframe_agent(
8
+ df: pd.DataFrame,
9
+ verbose: bool = True,
10
+ allow_dangerous_code: bool = True,
11
+ agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION
12
+ ):
13
+ """
14
+ Create a pandas DataFrame agent using the custom LLM.
15
+ Args:
16
+ df (pd.DataFrame): The pandas DataFrame to use.
17
+ verbose (bool): Whether to enable verbose output. Default is True.
18
+ allow_dangerous_code (bool): Whether to allow dangerous code execution. Default is True.
19
+ agent_type: The agent type to use. Default is ZERO_SHOT_REACT_DESCRIPTION.
20
+ Returns:
21
+ agent: The created DataFrame agent.
22
+ """
23
+ llm = LLM().chat_model
24
+ agent = create_pandas_dataframe_agent(
25
+ llm,
26
+ df,
27
+ agent_type=agent_type,
28
+ verbose=verbose,
29
+ allow_dangerous_code=allow_dangerous_code
30
+ )
31
+ return agent
32
+
33
+ # Usage example:
34
+ # import pandas as pd
35
+ # from agents.dataframe_agent import get_dataframe_agent
36
+ # df = pd.read_csv('your_file.csv')
37
+ # agent = get_dataframe_agent(df)
38
+ # response = agent.invoke('Your question here')
39
+ # print(response)
agents/llms.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import init_chat_model
2
+ from langchain_core.messages import HumanMessage
3
+ from dotenv import load_dotenv
4
+ from typing import List
5
+ from langchain.tools import BaseTool
6
+ from langchain.agents import initialize_agent, AgentType
7
+
8
+ _ = load_dotenv()
9
+
10
+ class LLM:
11
+ def __init__(
12
+ self,
13
+ model: str = "gemini-2.0-flash",
14
+ model_provider: str = "google_genai",
15
+ temperature: float = 0.0,
16
+ max_tokens: int = 1000
17
+ ):
18
+ self.chat_model = init_chat_model(
19
+ model=model,
20
+ model_provider=model_provider,
21
+ temperature=temperature,
22
+ max_tokens=max_tokens,
23
+ )
24
+
25
+ def generate(self, prompt: str) -> str:
26
+ message = HumanMessage(content=prompt)
27
+ response = self.chat_model.invoke([message])
28
+ return response.content
29
+
30
+ def bind_tools(self, tools: List[BaseTool], agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION):
31
+ """
32
+ Bind LangChain tools to this model and return an AgentExecutor.
33
+ """
34
+ return initialize_agent(
35
+ tools,
36
+ self.chat_model,
37
+ agent=agent_type,
38
+ verbose=False
39
+ )
40
+
41
+ def set_temperature(self, temperature: float):
42
+ """
43
+ Set the temperature for the chat model.
44
+ """
45
+ self.chat_model.temperature = temperature
46
+
47
+ def set_max_tokens(self, max_tokens: int):
48
+ """
49
+ Set the maximum number of tokens for the chat model.
50
+ """
51
+ self.chat_model.max_tokens = max_tokens
52
+
53
+
54
+
agents/memories.py ADDED
File without changes
agents/safe_guardrails.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from guardrails.validators import Validator, register_validator
3
+ import sys
4
+ import os
5
+ import logging
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+ from agents.llms import LLM
8
+
9
+ def setup_logger(name):
10
+ logger = logging.getLogger(name)
11
+ logger.setLevel(logging.INFO)
12
+ handler = logging.StreamHandler()
13
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
14
+ logger.addHandler(handler)
15
+ return logger
16
+
17
+
18
+ # @register_validator(name="medical_topic", data_type="string")
19
+ # class MedicalTopicValidator(Validator):
20
+ # """Validates medical topics using project's LLM"""
21
+
22
+ # def __init__(self, threshold: float = 0.7, on_fail=None):
23
+ # super().__init__(on_fail=on_fail)
24
+ # self.threshold = threshold
25
+ # self.llm = LLM()
26
+ # self.logger = setup_logger('MedicalValidator') # Use project's logging
27
+
28
+ # def validate(self, value: str, metadata: Dict = {}) -> str:
29
+ # prompt = f"""Medical relevance analysis (0-1 score then YES/NO):
30
+ # Question: {value}
31
+ # Medical context: {metadata.get('context', 'general health')}
32
+ # Score then Answer:"""
33
+
34
+ # try:
35
+ # response = self.llm.generate(prompt)
36
+ # last_line = response.strip().split('\n')[-1]
37
+ # parts = last_line.upper().strip().split()
38
+
39
+ # if len(parts) != 2:
40
+ # raise ValueError(f"Malformed LLM response: '{last_line}'")
41
+
42
+ # score_str, answer = parts
43
+
44
+ # # Kiểm tra nếu phần điểm không phải là số
45
+ # if not score_str.replace('.', '', 1).isdigit() and '%' not in score_str:
46
+ # raise ValueError(f"Invalid score format: '{score_str}'")
47
+
48
+ # # Convert score
49
+ # score = float(score_str.strip('%')) / 100 if '%' in score_str else float(score_str)
50
+
51
+ # if score < self.threshold or "YES" not in answer:
52
+ # self.logger.warning(f"Rejected question: {value[:60]}")
53
+ # raise ValueError("Non-medical question detected")
54
+
55
+ # return value
56
+ # except Exception as e:
57
+ # self.logger.error(f"Validation failed: {str(e)}")
58
+ # raise
59
+
60
+
61
+ from typing import Callable, Dict, Optional
62
+ from guardrails.validators import (
63
+ FailResult,
64
+ PassResult,
65
+ register_validator,
66
+ ValidationResult,
67
+ Validator,
68
+ )
69
+ from agents.llms import LLM
70
+
71
+ TOXICITY_CHECK_PROMPT_TEMPLATE = """
72
+ You are a helpful assistant.
73
+ Respond with a number between 0 and 100 indicating the level of toxicity for the following text:
74
+ {text}
75
+ 0 means no toxic language.
76
+ 100 means highly toxic language.
77
+ """
78
+
79
+ @register_validator(name="toxic-language", data_type="string")
80
+ class ToxicLanguage(Validator):
81
+ def __init__(self, threshold: int = 70, on_fail: Optional[Callable] = None):
82
+ super().__init__(on_fail=on_fail, threshold=threshold)
83
+ self._threshold = threshold
84
+ self.llm = LLM()
85
+ self.fix_value = "Sorry, I can't assist you with that request."
86
+
87
+ def _validate(self, value: str, metadata: Dict) -> ValidationResult:
88
+ prompt = TOXICITY_CHECK_PROMPT_TEMPLATE.format(text=value)
89
+ score = int(self.llm.generate(prompt).strip())
90
+ if score > self._threshold:
91
+ return FailResult(
92
+ error_message=f"Validation failed. Score {score} exceeds threshold of {self._threshold}.",
93
+ fix_value=self.fix_value,
94
+ )
95
+ else:
96
+ return PassResult()
97
+ OFF_TOPIC_CHECK_PROMPT_TEMPLATE = """
98
+ You are a helpful assistant.
99
+ Respond with a number between 0 and 100 indicating how off-topic the following text is. Consider the context provided:
100
+ Topic: '{topic}'
101
+ Additional Context: '{additional_context}'
102
+ Text: {text}
103
+ Do not output prose.
104
+ 0 means very relevant to the topic.
105
+ 100 means completely off-topic.
106
+ Please note that common greetings should not be considered off-topic.
107
+ """
108
+
109
+ @register_validator(name="off-topic", data_type="string")
110
+ class OffTopicValidator(Validator):
111
+ def __init__(self, threshold: int = 70, on_fail: Optional[Callable] = None):
112
+ super().__init__(on_fail=on_fail, threshold=threshold)
113
+ self._threshold = threshold
114
+ self.llm = LLM()
115
+
116
+
117
+
118
+ def _validate(self, value: str, metadata: Dict) -> ValidationResult:
119
+ topic = metadata.get('topic', 'general')
120
+ additional_context = metadata.get('additional_context', '')
121
+
122
+ if topic == 'general':
123
+ return PassResult()
124
+
125
+ # self.fix_value = f"Sorry, i can only assist you with questions related to the topic '{topic}'."
126
+ self.fix_value = "OFF_TOPIC"
127
+
128
+ prompt = OFF_TOPIC_CHECK_PROMPT_TEMPLATE.format(
129
+ text=value,
130
+ topic=topic,
131
+ additional_context=additional_context
132
+ )
133
+
134
+ score = int(self.llm.generate(prompt).strip())
135
+
136
+ print(f"Off-topic score: {score}")
137
+ if score > self._threshold:
138
+ return FailResult(
139
+ error_message=f"Validation failed. Score {score} exceeds threshold of {self._threshold}.",
140
+ fix_value=self.fix_value,
141
+ )
142
+ else:
143
+ return PassResult()
144
+
145
+
146
+
147
+
148
+ if __name__ == "__main__":
149
+ # validator = OffTopicValidator()
150
+
151
+ # print("Validating:")
152
+ # result = validator.validate("What is the capital of France?", metadata={"topic": "Medical"})
153
+ # print("Validation result:", result)
154
+
155
+
156
+ from guardrails import Guard
157
+ guard = Guard().use(
158
+ # ToxicLanguage,
159
+ OffTopicValidator,
160
+ # on_fail=lambda value, fail_result: f"Sorry, I can't assist you with that request.",
161
+ # on_fail="exception"
162
+
163
+ on_fail="fix"
164
+ )
165
+
166
+ texts = [
167
+ "What is the capital of France?",
168
+ "I want to kill you.",
169
+ "You are a stupid dog",
170
+ "Triệu chứng của bệnh viêm dạ dày",
171
+ ]
172
+
173
+ metadata = {'topic': 'Medical'}
174
+ for text in texts:
175
+ print(f"Validating: {text}")
176
+ try:
177
+ validation_result = guard.validate(text, metadata=metadata)
178
+
179
+ print("Validation passed")
180
+ print("Validation result:", validation_result)
181
+
182
+ # response = guard.to_runnable().invoke(text)
183
+ # print("Response:", response)
184
+
185
+ except Exception as e:
186
+ print(f"Validation failed: {e}")
187
+
188
+ print('-' * 20)
189
+
190
+
191
+
192
+
193
+
194
+ # Example usage
195
+ # python agents/safe_guardrails.py
agents/sql_agent/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import SQLAgent
2
+ from .states import SQLAgentState
3
+ from .nodes import get_db_info, generate_sql, execute_sql, optional_plot, format_response, generate_answer
4
+
5
+ __all__ = [
6
+ 'SQLAgent',
7
+ 'SQLAgentState',
8
+ 'get_db_info',
9
+ 'generate_sql',
10
+ 'execute_sql',
11
+ 'optional_plot',
12
+ 'format_response',
13
+ 'generate_answer'
14
+ ]
agents/sql_agent/agent.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
4
+ from agents.llms import LLM
5
+ from dotenv import load_dotenv
6
+ from langchain_community.utilities import SQLDatabase
7
+ from utils.consts import DB_PATH
8
+ from agents.sql_agent.states import SQLAgentState
9
+
10
+ # Load environment vars
11
+ load_dotenv()
12
+
13
+ # def get_sql_agent():
14
+ # """
15
+ # Initializes a LangChain SQLDatabaseChain for SQLite.
16
+ # """
17
+ # # Load SQLite DB
18
+ # db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
19
+ # # Patch run to strip Markdown fences and log
20
+ # orig_run = db.run
21
+ # def clean_run(query: str, **kwargs) -> str:
22
+ # lines = query.splitlines()
23
+ # if lines and lines[0].strip().startswith("```"):
24
+ # lines = lines[1:]
25
+ # if lines and lines[-1].strip().startswith("```"):
26
+ # lines = lines[:-1]
27
+ # cleaned = "\n".join(lines).strip()
28
+ # print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
29
+
30
+ # def get_sql_agent():
31
+ # """
32
+ # Initializes a LangChain SQLDatabaseChain for SQLite.
33
+ # """
34
+ # # Load SQLite DB
35
+ # db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
36
+ # # Patch run to strip Markdown fences and log
37
+ # orig_run = db.run
38
+ # def clean_run(query: str, **kwargs) -> str:
39
+ # lines = query.splitlines()
40
+ # if lines and lines[0].strip().startswith("```"):
41
+ # lines = lines[1:]
42
+ # if lines and lines[-1].strip().startswith("```"):
43
+ # lines = lines[:-1]
44
+ # cleaned = "\n".join(lines).strip()
45
+ # print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
46
+ # return orig_run(cleaned, **kwargs)
47
+ # db.run = clean_run
48
+ # # Initialize LLM
49
+ # llm_wrapper = LLM()
50
+ # # Create SQLDatabaseChain
51
+ # chain = SQLDatabaseChain.from_llm(llm_wrapper.chat_model, db, verbose=True)
52
+ # return chain
53
+
54
+ class SQLAgent:
55
+ def __init__(self):
56
+ self.db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
57
+ self.llm = LLM()
58
+ self.graph = self.build_graph()
59
+
60
+
61
+ def build_graph(self):
62
+ from agents.sql_agent.graph import build_graph
63
+ return build_graph().compile()
64
+
65
+ def run(self, state: SQLAgentState) -> SQLAgentState:
66
+ """
67
+ Run the SQL agent with the given query.
68
+ """
69
+ return self.graph.invoke(state)
70
+
71
+ if __name__ == "__main__":
72
+ agent = SQLAgent()
73
+ state = {
74
+ "question": None,
75
+ "db_info": {
76
+ "tables": [],
77
+ "columns": {},
78
+ "schema": None
79
+ },
80
+ "sql_query": None,
81
+ "sql_result": None,
82
+ "error": None
83
+ }
84
+ while True:
85
+ question = input("Enter your query (or 'exit' to quit): ")
86
+ state['question'] = question
87
+ if not question or question.lower() in ('exit', 'quit'):
88
+ print("Goodbye!")
89
+ break
90
+ result = agent.run(state)
91
+ # print(result)
92
+
93
+ # answer = result['answer']
94
+ # print(answer)
95
+
96
+ for step in agent.graph.stream(state, stream_mode="updates"):
97
+ print(step)
98
+
agents/sql_agent/graph.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
4
+ from agents.sql_agent.states import SQLAgentState
5
+ from langgraph.graph import StateGraph, START, END
6
+ from agents.sql_agent.nodes import (
7
+ get_db_info,
8
+ generate_sql,
9
+ execute_sql,
10
+ generate_answer,
11
+ detect_off_topic,
12
+ choose_visualization,
13
+ format_data_for_visualization,
14
+ render_visualization,
15
+ finalize_output
16
+ )
17
+
18
+ def build_graph(visualize: bool = True) -> StateGraph:
19
+ graph = StateGraph(SQLAgentState)
20
+
21
+ # Add nodes
22
+ graph.add_node("detect_off_topic", detect_off_topic)
23
+ graph.add_node("generate_sql", generate_sql)
24
+ graph.add_node("get_db_info", get_db_info)
25
+ graph.add_node("execute_sql", execute_sql)
26
+ graph.add_node("generate_answer", generate_answer)
27
+ graph.add_node("choose_visualization", choose_visualization)
28
+ graph.add_node("format_data_for_visualization", format_data_for_visualization)
29
+ graph.add_node("render_visualization", render_visualization)
30
+ graph.add_node("finalize_output", finalize_output)
31
+
32
+
33
+ # Add edges
34
+ graph.add_edge(START, "detect_off_topic")
35
+
36
+ graph.add_conditional_edges(
37
+ "detect_off_topic",
38
+ lambda state: state['error'],
39
+ path_map={
40
+ # True: "generate_answer",
41
+ True: "get_db_info",
42
+ False: "get_db_info"
43
+ }
44
+ )
45
+
46
+ graph.add_edge("get_db_info", "generate_sql")
47
+ graph.add_edge("generate_sql", "execute_sql")
48
+ graph.add_edge("execute_sql", "choose_visualization")
49
+ graph.add_edge("choose_visualization", "format_data_for_visualization")
50
+ graph.add_edge("format_data_for_visualization", "render_visualization")
51
+ graph.add_edge("render_visualization", "generate_answer")
52
+ graph.add_edge("generate_answer", "finalize_output")
53
+ graph.add_edge("finalize_output", END)
54
+ # graph.add_edge("execute_sql", "generate_answer")
55
+ # graph.add_edge("generate_answer", "choose_visualization")
56
+ # graph.add_edge("choose_visualization", END)
57
+
58
+ if visualize:
59
+ # TODO: Implement visualization
60
+ pass
61
+ return graph
62
+
63
+ def visualize_graph(graph) -> None:
64
+ graph.visualize()
65
+
66
+ if __name__ == "__main__":
67
+ state = {
68
+ "question": "top 3 sản phẩm có giá thấp nhất",
69
+ "db_info": {
70
+ "tables": [],
71
+ "columns": {},
72
+ "schema": ""
73
+ },
74
+ "sql_query": "",
75
+ "sql_result": None,
76
+ "error": None,
77
+ "step": None,
78
+ "answer": None,
79
+ "plot_path": None,
80
+ "response_md": None,
81
+ "visualization": None,
82
+ "visualization_reason": None,
83
+ "formatted_data_for_visualization": None,
84
+ "visualization_output": None,
85
+ "off_topic": None
86
+ }
87
+
88
+ graph = build_graph().compile()
89
+ # visualize_graph(graph)
90
+
91
+ result = graph.invoke(state)
92
+ # print(result)
93
+
94
+ answer = result['answer']
95
+ print(answer)
96
+
97
+ for step in graph.stream(
98
+ state, stream_mode="updates"
99
+ ):
100
+ print("-" * 80)
101
+ # print(step['step'])
102
+ print(step)
103
+
agents/sql_agent/nodes.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ import re
4
+ from agents.llms import LLM
5
+ from agents.tools import PlotSQLTool
6
+ from .states import SQLAgentState
7
+ from utils.consts import DB_PATH
8
+
9
+
10
+ def choose_visualization(state: SQLAgentState) -> SQLAgentState:
11
+ """Use LLM to suggest a suitable chart type for the SQL result."""
12
+ question = state['question']
13
+ sql_query = state['sql_query']
14
+ sql_result = state['sql_result']
15
+ # Convert sql_result DataFrame to markdown or string preview (or sample rows)
16
+ if sql_result is not None:
17
+ if hasattr(sql_result, 'head'):
18
+ preview = sql_result.head(5).to_markdown(index=False)
19
+ else:
20
+ preview = str(sql_result)
21
+ else:
22
+ preview = "No results"
23
+
24
+ prompt = f'''
25
+ You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.
26
+
27
+ Available chart types and their use cases:
28
+ - Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2.
29
+ - Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large.
30
+ - Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous.
31
+ - Pie Charts: Ideal for showing proportions or percentages within a whole.
32
+ - Line Graphs: Best for showing trends and distributions over time. Best used when both x axis and y axis are continuous or time-based.
33
+
34
+ Provide your response in the following format:
35
+ Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
36
+ Reason: [Brief explanation for your recommendation]
37
+
38
+ User question: {question}
39
+ SQL query: {sql_query}
40
+ Query results: {preview}
41
+
42
+ Recommend a visualization:
43
+ '''
44
+ llm = LLM()
45
+ response = llm.generate(prompt)
46
+ lines = response.split('\n')
47
+ visualization = 'none'
48
+ reason = ''
49
+ for line in lines:
50
+ if line.lower().startswith('recommended visualization:'):
51
+ visualization = line.split(':', 1)[1].strip()
52
+ elif line.lower().startswith('reason:'):
53
+ reason = line.split(':', 1)[1].strip()
54
+ state['visualization'] = visualization
55
+ state['visualization_reason'] = reason
56
+ state['step'] = 'choose_visualization'
57
+ return state
58
+
59
+
60
+ def format_data_for_visualization(state: SQLAgentState) -> SQLAgentState:
61
+ """
62
+ Format the data for the chosen visualization type.
63
+ Hỗ trợ line, bar, scatter, grouped bar, fallback LLM cho các visualization khác.
64
+ """
65
+ import json
66
+ import pandas as pd
67
+ llm = LLM()
68
+
69
+ visualization = state.get('visualization', 'none')
70
+ sql_result = state.get('sql_result')
71
+ question = state.get('question')
72
+ sql_query = state.get('sql_query')
73
+
74
+ # Convert DataFrame to list of lists for processing
75
+ if sql_result is not None and hasattr(sql_result, 'values'):
76
+ data = sql_result.values.tolist()
77
+ columns = list(sql_result.columns)
78
+ elif isinstance(sql_result, list):
79
+ data = sql_result
80
+ columns = []
81
+ else:
82
+ state['formatted_data_for_visualization'] = None
83
+ return state
84
+
85
+ def _format_line_data(data, question):
86
+ if len(data[0]) == 2:
87
+ x_values = [str(row[0]) for row in data]
88
+ y_values = [float(row[1]) for row in data]
89
+ prompt = f"""
90
+ You are a data labeling expert. Given a question and some data, provide a concise and relevant label for the data series.
91
+ Question: {question}
92
+ Data (first few rows): {data[:2]}
93
+ Provide a concise label for this y axis.
94
+ """
95
+ label = llm.generate(prompt).strip()
96
+ formatted_data = {
97
+ "xValues": x_values,
98
+ "yValues": [
99
+ {
100
+ "data": y_values,
101
+ "label": label
102
+ }
103
+ ]
104
+ }
105
+ return formatted_data
106
+ elif len(data[0]) == 3:
107
+ data_by_label = {}
108
+ x_values = []
109
+ labels = list(set(item2 for item1, item2, item3 in data if isinstance(item2, str) and not item2.replace(".", "").isdigit() and "/" not in item2))
110
+ if not labels:
111
+ labels = list(set(item1 for item1, item2, item3 in data if isinstance(item1, str) and not item1.replace(".", "").isdigit() and "/" not in item1))
112
+ for item1, item2, item3 in data:
113
+ if isinstance(item1, str) and not item1.replace(".", "").isdigit() and "/" not in item1:
114
+ label, x, y = item1, item2, item3
115
+ else:
116
+ x, label, y = item1, item2, item3
117
+ if str(x) not in x_values:
118
+ x_values.append(str(x))
119
+ if label not in data_by_label:
120
+ data_by_label[label] = []
121
+ data_by_label[label].append(float(y))
122
+ for other_label in labels:
123
+ if other_label != label:
124
+ if other_label not in data_by_label:
125
+ data_by_label[other_label] = []
126
+ data_by_label[other_label].append(None)
127
+ y_values = [
128
+ {
129
+ "data": data,
130
+ "label": label
131
+ }
132
+ for label, data in data_by_label.items()
133
+ ]
134
+ formatted_data = {
135
+ "xValues": x_values,
136
+ "yValues": y_values,
137
+ "yAxisLabel": ""
138
+ }
139
+ prompt = f"""
140
+ You are a data labeling expert. Given a question and some data, provide a concise and relevant label for the y-axis.
141
+ Question: {question}
142
+ Data (first few rows): {data[:2]}
143
+ Provide a concise label for the y-axis.
144
+ """
145
+ y_axis_label = llm.generate(prompt).strip()
146
+ formatted_data["yAxisLabel"] = y_axis_label
147
+ return formatted_data
148
+ return None
149
+
150
+ def _format_scatter_data(data):
151
+ formatted_data = {"series": []}
152
+ if len(data[0]) == 2:
153
+ formatted_data["series"].append({
154
+ "data": [
155
+ {"x": float(x), "y": float(y), "id": i+1}
156
+ for i, (x, y) in enumerate(data)
157
+ ],
158
+ "label": "Data Points"
159
+ })
160
+ elif len(data[0]) == 3:
161
+ entities = {}
162
+ for item1, item2, item3 in data:
163
+ if isinstance(item1, str) and not item1.replace(".", "").isdigit() and "/" not in item1:
164
+ label, x, y = item1, item2, item3
165
+ else:
166
+ x, label, y = item1, item2, item3
167
+ if label not in entities:
168
+ entities[label] = []
169
+ entities[label].append({"x": float(x), "y": float(y), "id": len(entities[label])+1})
170
+ for label, d in entities.items():
171
+ formatted_data["series"].append({
172
+ "data": d,
173
+ "label": label
174
+ })
175
+ else:
176
+ raise ValueError("Unexpected data format in results")
177
+ return formatted_data
178
+
179
+ def _format_bar_data(data, question):
180
+ if len(data[0]) == 2:
181
+ labels = [str(row[0]) for row in data]
182
+ values = [float(row[1]) for row in data]
183
+ prompt = f"""
184
+ You are a data labeling expert. Given a question and some data, provide a concise and relevant label for the data series.
185
+ Question: {question}
186
+ Data (first few rows): {data[:2]}
187
+ Provide a concise label for this y axis.
188
+ """
189
+ label = llm.generate(prompt).strip()
190
+ y_values = [{"data": values, "label": label}]
191
+ elif len(data[0]) == 3:
192
+ categories = set(row[1] for row in data)
193
+ labels = list(categories)
194
+ entities = set(row[0] for row in data)
195
+ y_values = []
196
+ for entity in entities:
197
+ entity_data = [float(row[2]) for row in data if row[0] == entity]
198
+ y_values.append({"data": entity_data, "label": str(entity)})
199
+ else:
200
+ raise ValueError("Unexpected data format in results")
201
+ formatted_data = {
202
+ "labels": labels,
203
+ "values": y_values
204
+ }
205
+ return formatted_data
206
+
207
+ def _format_other_visualizations(visualization, question, sql_query, data):
208
+ # Fallback: use LLM to format data
209
+ prompt = f"""
210
+ You are a Data expert who formats data according to the required needs. You are given the question asked by the user, its sql query, the result of the query and the format you need to format it in.
211
+ For the given question: {question}\n\nSQL query: {sql_query}\n\nResult: {data}\n\nFormat this data for visualization type: {visualization}. Just give the json string. Do not format it.
212
+ """
213
+ response = llm.generate(prompt)
214
+ try:
215
+ formatted_data_for_visualization = json.loads(response)
216
+ return formatted_data_for_visualization
217
+ except json.JSONDecodeError:
218
+ return {"error": "Failed to format data for visualization", "raw_response": response}
219
+
220
+ visualization_map = {
221
+ "none": lambda data: None,
222
+ "scatter": lambda data: _format_scatter_data(data),
223
+ "bar": lambda data, question: _format_bar_data(data, question),
224
+ "horizontal_bar": lambda data, question: _format_bar_data(data, question),
225
+ "line": lambda data, question: _format_line_data(data, question)
226
+ }
227
+ try:
228
+ state["formatted_data_for_visualization"] = visualization_map[visualization](data, question)
229
+ except (KeyError, Exception):
230
+ state["formatted_data_for_visualization"] = _format_other_visualizations(visualization, question, sql_query, data)
231
+ state['step'] = 'format_data_for_visualization'
232
+ return state
233
+
234
+
235
+
236
+
237
+ def render_visualization(state: SQLAgentState) -> SQLAgentState:
238
+ """
239
+ Render the visualization from formatted data.
240
+ Output: path to saved image file.
241
+ """
242
+ import matplotlib.pyplot as plt
243
+ import os
244
+ from io import BytesIO
245
+ import uuid
246
+
247
+ data = state.get("formatted_data_for_visualization")
248
+ visualization = state.get("visualization", "none")
249
+
250
+ if not data:
251
+ state["visualization_output"] = None
252
+ return state
253
+
254
+ output_dir = "output/plots"
255
+ os.makedirs(output_dir, exist_ok=True)
256
+
257
+ def save_fig(fig):
258
+ file_path = os.path.join(output_dir, f"visualization_{uuid.uuid4().hex[:8]}.png")
259
+ fig.savefig(file_path, format="png", bbox_inches="tight")
260
+ plt.close(fig)
261
+ return file_path
262
+
263
+ def render_line(data):
264
+ fig, ax = plt.subplots()
265
+ x = data["xValues"]
266
+ for series in data["yValues"]:
267
+ ax.plot(x, series["data"], label=series["label"])
268
+ ax.set_xlabel("X")
269
+ ax.set_ylabel(data.get("yAxisLabel", "Y"))
270
+ ax.legend()
271
+ return save_fig(fig)
272
+
273
+ def render_bar(data, horizontal=False):
274
+ fig, ax = plt.subplots()
275
+ labels = data["labels"]
276
+ n_series = len(data["values"])
277
+ width = 0.8 / n_series
278
+ x_indexes = list(range(len(labels)))
279
+ for i, series in enumerate(data["values"]):
280
+ offset = (i - n_series / 2) * width + width / 2
281
+ if horizontal:
282
+ ax.barh(
283
+ [x + offset for x in x_indexes],
284
+ series["data"],
285
+ height=width,
286
+ label=series["label"]
287
+ )
288
+ ax.set_yticks(x_indexes)
289
+ ax.set_yticklabels(labels)
290
+ else:
291
+ ax.bar(
292
+ [x + offset for x in x_indexes],
293
+ series["data"],
294
+ width=width,
295
+ label=series["label"]
296
+ )
297
+ ax.set_xticks(x_indexes)
298
+ ax.set_xticklabels(labels, rotation=45, ha='right')
299
+ ax.legend()
300
+ return save_fig(fig)
301
+
302
+ def render_scatter(data):
303
+ fig, ax = plt.subplots()
304
+ for series in data["series"]:
305
+ xs = [point["x"] for point in series["data"]]
306
+ ys = [point["y"] for point in series["data"]]
307
+ ax.scatter(xs, ys, label=series["label"])
308
+ ax.set_xlabel("X")
309
+ ax.set_ylabel("Y")
310
+ ax.legend()
311
+ return save_fig(fig)
312
+
313
+ try:
314
+ if visualization == "line":
315
+ image_path = render_line(data)
316
+ elif visualization == "bar":
317
+ image_path = render_bar(data, horizontal=False)
318
+ elif visualization == "horizontal_bar":
319
+ image_path = render_bar(data, horizontal=True)
320
+ elif visualization == "scatter":
321
+ image_path = render_scatter(data)
322
+ else:
323
+ state["visualization_output"] = None
324
+ return state
325
+
326
+ state["visualization_output"] = image_path
327
+ except Exception as e:
328
+ state["visualization_output"] = None
329
+ state["error"] = f"Failed to render visualization: {str(e)}"
330
+
331
+ state["step"] = "render_visualization"
332
+ return state
333
+
334
+
335
+ def finalize_output(state: SQLAgentState) -> SQLAgentState:
336
+ """
337
+ Node hợp nhất kết quả cuối cùng (answer, visualization_output, error, ...).
338
+ Hiện tại chỉ trả về state, có thể mở rộng xử lý sau.
339
+ """
340
+ state['step'] = 'finalize_output'
341
+ return state
342
+
343
+ # def ingest(state: SQLAgentState) -> SQLAgentState:
344
+ # """Populate state.tables with list of tables in the DB."""
345
+ # db_info = state['db_info']
346
+ # conn = sqlite3.connect(DB_PATH)
347
+ # try:
348
+ # db_info['tables'] = [row[0] for row in conn.execute(
349
+ # "SELECT name FROM sqlite_master WHERE type='table';"
350
+ # )]
351
+ # # Populate columns for each table
352
+ # columns = {}
353
+ # for table in db_info['tables']:
354
+ # col_rows = conn.execute(f'PRAGMA table_info("{table}")').fetchall()
355
+ # columns[table] = [r[1] for r in col_rows]
356
+ # db_info['columns'] = columns
357
+ # state.db_info = db_info
358
+ # finally:
359
+ # conn.close()
360
+ # return state
361
+
362
+
363
+ from agents.safe_guardrails import OffTopicValidator
364
+ from guardrails import Guard
365
+
366
+ def detect_off_topic(state: SQLAgentState) -> SQLAgentState:
367
+ """Check if the input question is off-topic."""
368
+ question = state['question']
369
+ validator = Guard().use(
370
+ OffTopicValidator,
371
+ on_fail="fix"
372
+ )
373
+ metadata = {
374
+ "topic": "Database Queries",
375
+ "additional_context": "The database is about ecommerce products with tables: products, laptops, phones, tablets, promotions, category"
376
+ }
377
+
378
+ validation_result = validator.validate(question, metadata=metadata)
379
+ if validation_result.validated_output == "OFF_TOPIC":
380
+ state['error'] = True
381
+ else:
382
+ state['error'] = False
383
+ state['step'] = 'detect_off_topic'
384
+ state['off_topic'] = validation_result.validated_output
385
+
386
+ print(state)
387
+ return state
388
+
389
+
390
+ def get_db_info(state: SQLAgentState) -> SQLAgentState:
391
+ """Get database information."""
392
+ db_info = state['db_info']
393
+ conn = sqlite3.connect(DB_PATH)
394
+ try:
395
+ db_info['tables'] = [row[0] for row in conn.execute(
396
+ "SELECT name FROM sqlite_master WHERE type='table';"
397
+ )]
398
+ # Populate columns for each table
399
+ columns = {}
400
+ for table in db_info['tables']:
401
+ col_rows = conn.execute(f'PRAGMA table_info("{table}")').fetchall()
402
+ columns[table] = [r[1] for r in col_rows]
403
+ db_info['columns'] = columns
404
+ schema = "; ".join(f"{t}({', '.join(db_info['columns'][t])})" for t in db_info['tables'])
405
+ db_info['schema'] = schema
406
+ finally:
407
+ conn.close()
408
+ state['step'] = 'get_db_info'
409
+ return state
410
+
411
+
412
+ def generate_sql(state: SQLAgentState) -> SQLAgentState:
413
+ """Use LLM to translate user_query into SQL."""
414
+ llm = LLM()
415
+ # Include detailed schema with columns
416
+ schema = state['db_info']['schema']
417
+ prompt = (
418
+ f"Given this database schema: {schema}, "
419
+ f"write an SQL query to: {state['question']}. "
420
+ "Respond with only the SQL enclosed in triple backticks."
421
+ )
422
+ raw = llm.generate(prompt)
423
+ # print('raw', raw)
424
+ lines = raw.splitlines()
425
+ if lines and lines[0].strip().startswith("```"):
426
+ lines = lines[1:]
427
+ if lines and lines[-1].strip().startswith("```"):
428
+ lines = lines[:-1]
429
+ state['sql_query'] = "\n".join(lines).strip()
430
+ state['step'] = 'generate_sql'
431
+ return state
432
+
433
+
434
+
435
+ def execute_sql(state: SQLAgentState) -> SQLAgentState:
436
+ """Run the SQL in state.sql and store result DataFrame."""
437
+ sql_query = state['sql_query']
438
+ conn = sqlite3.connect(DB_PATH)
439
+ try:
440
+ state['sql_result'] = pd.read_sql_query(sql_query, conn)
441
+ except Exception as e:
442
+ state['error'] = str(e)
443
+ finally:
444
+ conn.close()
445
+ state['step'] = 'execute_sql'
446
+ return state
447
+
448
+ def generate_answer(state: SQLAgentState) -> SQLAgentState:
449
+ """Generate answer using LLM based on SQL result."""
450
+ llm = LLM()
451
+ if state['sql_result'] is not None and not state['sql_result'].empty:
452
+ result_str = state['sql_result'].to_string(index=False)
453
+ prompt = (
454
+ f"Given the question: {state['question']},\n"
455
+ f"SQL Query: {state['sql_query']},\n"
456
+ f"and the following SQL query result: {result_str},\n"
457
+ "provide a concise answer:"
458
+ )
459
+ state['answer'] = llm.generate(prompt)
460
+ else:
461
+ state['error'] = state['error'] or "No results found."
462
+ if state["off_topic"] == "OFF_TOPIC":
463
+ state['error'] = "The question is off-topic."
464
+ state["answer"] = "Sorry, I can't assist you with that request."
465
+ state['step'] = 'generate_answer'
466
+ return state
467
+
468
+
469
+ def optional_plot(state: SQLAgentState) -> SQLAgentState:
470
+ """If user_query requests plotting, generate plot and set state.plot_path."""
471
+ if any(k in state['question'].lower() for k in ['plot', 'vẽ', 'biểu đồ']):
472
+ tool = PlotSQLTool()
473
+ md = tool._run(state['sql_query'])
474
+ m = re.search(r'!\[.*\]\((.*?)\)', md)
475
+ if m:
476
+ state['plot_path'] = m.group(1)
477
+ else:
478
+ state['error'] = state['error'] or 'Plot generation failed'
479
+ return state
480
+
481
+
482
+ def format_response(state: SQLAgentState) -> SQLAgentState:
483
+ """Build markdown response including SQL, table preview, and plot."""
484
+ parts = []
485
+ if state['sql_query']:
486
+ parts.append(f"```sql\n{state['sql_query']}\n```")
487
+ if state['sql_result'] is not None:
488
+ parts.append(state['sql_result'].to_markdown(index=False))
489
+ if state['plot_path']:
490
+ parts.append(f"![Plot]({state['plot_path']})")
491
+ if state['error']:
492
+ parts.append(f"**Error**: {state['error']}")
493
+ state['response_md'] = "\n\n".join(parts)
494
+ return state
495
+
496
+
agents/sql_agent/prompts.py ADDED
File without changes
agents/sql_agent/states.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict
2
+ import pandas as pd
3
+ from typing import TypedDict
4
+
5
+
6
+ class SQLAgentState(TypedDict):
7
+ """
8
+ Carries context through the text→SQL and plotting pipeline:
9
+ - question: original NL input
10
+ - sql_query: generated SQL
11
+ - sql_result: raw query result DataFrame
12
+ - answer: final answer
13
+ - error: any error messages
14
+ """
15
+ question: str
16
+ db_info: dict
17
+ sql_query: str
18
+ sql_result: Optional[pd.DataFrame] = None
19
+ answer: str = ""
20
+ error: Optional[str] = None
21
+ plot_path: Optional[str] = None
22
+ response_md: str = ""
23
+ step: Optional[str] = None
24
+ visualization: Optional[str] = None
25
+ visualization_reason: Optional[str] = None
26
+ formatted_data_for_visualization: Optional[dict] = None
27
+ visualization_output: Optional[str] = None
28
+ off_topic: Optional[str] = None
29
+
30
+
agents/tools.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ from langchain.tools import BaseTool
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ from utils.consts import DB_PATH, PLOTS_DIR
7
+
8
+ # Fetch table list
9
+ _conn = sqlite3.connect(DB_PATH)
10
+ _TABLES = [row[0] for row in _conn.execute("SELECT name FROM sqlite_master WHERE type='table';")]
11
+ _conn.close()
12
+ _TABLES_LIST = ", ".join(_TABLES)
13
+
14
+ class SQLiteQueryTool(BaseTool):
15
+ name: str = "sqlite_query"
16
+ description: str = f"Executes a SQL query against the ecommerce SQLite database and returns results as CSV. Available tables: {_TABLES_LIST}."
17
+
18
+ def _run(self, query: str) -> str:
19
+ print(f"[SQLiteQueryTool] Executing query: {query}")
20
+ conn = sqlite3.connect(DB_PATH)
21
+ try:
22
+ df = pd.read_sql_query(query, conn)
23
+ return df.to_csv(index=False)
24
+ except Exception as e:
25
+ return f"SQL Error: {e}"
26
+ finally:
27
+ conn.close()
28
+
29
+ async def _arun(self, query: str) -> str:
30
+ raise NotImplementedError("Async not supported for SQLiteQueryTool")
31
+
32
+
33
+ class PlotSQLTool(BaseTool):
34
+ name: str = "plot_sql"
35
+ description: str = f"Executes a SQL query and generates a plot saved as a PNG; returns markdown image link. Available tables: {_TABLES_LIST}."
36
+
37
+ def _run(self, query: str) -> str:
38
+ print(f"[PlotSQLTool] Executing query: {query}")
39
+ conn = sqlite3.connect(DB_PATH)
40
+ try:
41
+ df = pd.read_sql_query(query, conn)
42
+ plt.figure()
43
+ df.plot(kind='bar' if df.shape[1] > 1 else 'line', legend=False)
44
+ timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
45
+ filename = f"plot_{timestamp}.png"
46
+ # Save plot to configured output directory
47
+ output_dir = PLOTS_DIR
48
+ os.makedirs(output_dir, exist_ok=True)
49
+ filepath = os.path.join(output_dir, filename)
50
+ plt.tight_layout()
51
+ plt.savefig(filepath)
52
+ plt.close()
53
+ return f"![Plot]({filepath})"
54
+ except Exception as e:
55
+ return f"Plot Error: {e}"
56
+ finally:
57
+ conn.close()
58
+
59
+ async def _arun(self, query: str) -> str:
60
+ raise NotImplementedError("Async not supported for PlotSQLTool")
61
+
62
+
63
+
app.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from utils.consts import DB_PATH
4
+ import sqlite3
5
+ import re
6
+ import os
7
+ from agents.sql_agent.agent import SQLAgent
8
+ import time
9
+ from agents.tools import PlotSQLTool
10
+ from agents.dataframe_agent import get_dataframe_agent
11
+ from datetime import datetime
12
+
13
+ db_name = os.path.basename(DB_PATH)
14
+
15
+ st.set_page_config(page_title="🔍 TalkToData", layout="wide", initial_sidebar_state="collapsed")
16
+
17
+ # Loại bỏ title markdown để tránh hiển thị lặp lại
18
+ # Sidebar for settings
19
+ with st.sidebar:
20
+ st.header("ℹ️ About", anchor=None)
21
+ st.markdown("""
22
+ **TalkToData** v0.1.0
23
+ Your personal AI Data Analyst.
24
+ """, unsafe_allow_html=True)
25
+
26
+ # Initialize chat history
27
+ if 'chat_history' not in st.session_state:
28
+ st.session_state.chat_history = []
29
+
30
+ # Initialize SQL agent
31
+ # agent = get_sql_agent()
32
+
33
+ agent = SQLAgent()
34
+ state = {
35
+ "question": None,
36
+ "db_info": {
37
+ "tables": [],
38
+ "columns": {},
39
+ "schema": None
40
+ },
41
+ "sql_query": None,
42
+ "sql_result": None,
43
+ "error": None,
44
+ "step": None,
45
+ "answer": None
46
+ }
47
+ # --- Upload Screen State ---
48
+ if 'files_uploaded' not in st.session_state:
49
+ st.session_state['files_uploaded'] = False
50
+
51
+ # TEMP: Bypass landing page
52
+ st.session_state['files_uploaded'] = True
53
+
54
+ if not st.session_state['files_uploaded']:
55
+ # CSS to center and enlarge only the welcome start button
56
+ st.markdown("""
57
+ <style>
58
+ .welcome .stButton { display: flex; justify-content: center; }
59
+ .welcome .stButton button { font-size:2.5rem !important; padding:1.25rem 2rem !important; }
60
+ </style>
61
+ """, unsafe_allow_html=True)
62
+ # Wrap welcome content to scope styling
63
+ st.markdown("<div class='welcome' style='max-width:600px;margin:auto;text-align:center;'>", unsafe_allow_html=True)
64
+ # Title and subtitle
65
+ st.markdown("""
66
+ <h1 style='text-align:center; margin-bottom:0;'>🔍 TalkToData</h1>
67
+ <h3 style='text-align:center; color:gray;'>Your Personal AI Data Analyst that instantly answers your data questions with clear insights and elegant visualizations.</h3>
68
+ """, unsafe_allow_html=True)
69
+ # Standalone welcome start button
70
+ if st.button("🚀 Explore now", key="start"):
71
+ st.session_state['files_uploaded'] = True
72
+ st.experimental_rerun()
73
+ # Close welcome wrapper
74
+ st.markdown("</div>", unsafe_allow_html=True)
75
+ st.divider()
76
+ # SaaS-style Features section
77
+ st.markdown("## Features")
78
+ feat_cols = st.columns(3)
79
+ feat_cols[0].markdown("### 🗣 Natural-Language Queries\nAsk your data without SQL knowledge.")
80
+ feat_cols[1].markdown("### 📊 Instant Visualizations\nGet charts from one command.")
81
+ feat_cols[2].markdown("### 🔒 Secure & Local\nYour data stays on your machine.")
82
+ st.divider()
83
+ # How It Works section
84
+ st.markdown("## How It Works")
85
+ step_cols = st.columns(3)
86
+ step_cols[0].markdown("#### 1️⃣ Upload\nUpload .db or CSV files.")
87
+ step_cols[1].markdown("#### 2️⃣ Chat\nInteract in natural language.")
88
+ step_cols[2].markdown("#### 3️⃣ Visualize\nSee results as tables or charts.")
89
+ st.divider()
90
+ # Use Cases
91
+ st.markdown("## Use Cases")
92
+ st.markdown("- \"Show me top 5 products by sales\" → Chart")
93
+ st.markdown("- \"List customers from 2020\" → Table")
94
+ st.divider()
95
+ # Testimonials
96
+ st.markdown("## Testimonials")
97
+ testi_cols = st.columns(2)
98
+ testi_cols[0].markdown("> \"TalkToData transformed our data workflow!\" \n— Jane Doe, Data Analyst")
99
+ testi_cols[1].markdown("> \"The AI assistant is incredibly smart and fast.\" \n— John Smith, Product Manager")
100
+ st.divider()
101
+ # Footer
102
+ st.markdown("2025 TalkToData. All rights reserved.")
103
+
104
+ st.markdown("<p style='text-align: center; color: gray;'>TalkToData v0.1.0 - Copyright 2025 by <a href='https://github.com/phamdinhkhanh'>Khanh Pham</a></p>", unsafe_allow_html=True)
105
+ st.html(
106
+ "<p><span style='text-decoration: line-through double red;'>Oops</span>!</p>"
107
+ )
108
+
109
+ st.divider()
110
+
111
+ else:
112
+ # App title and return button
113
+ # st.title("🔍 TalkToData")
114
+ st.markdown("### TalkToData")
115
+ # TEMP: Commented out back-to-home
116
+ # if st.button('⬅️ Back to Home', key='back_to_upload'):
117
+ # st.session_state['files_uploaded'] = False
118
+ # # Xóa dữ liệu cũ
119
+ # if 'uploaded_csvs' in st.session_state:
120
+ # st.session_state['uploaded_csvs'] = []
121
+ # st.experimental_rerun()
122
+ # Layout: Data source selector, main content, and chat
123
+ data_col, left_col, right_col = st.columns([1.5, 3, 2])
124
+ # Data source selection
125
+ with data_col:
126
+ # st.subheader("Data Sources")
127
+ # Upload data
128
+ with st.expander("**Upload Data**", expanded=True):
129
+ st.file_uploader('Select SQLite (.db), CSV or Excel (.xlsx) files',
130
+ type=['db', 'csv', 'xlsx'],
131
+ accept_multiple_files=True,
132
+ key='upload_any_col',
133
+ label_visibility="collapsed")
134
+ gsheet_url = st.text_input('Enter Google Sheets URL (optional)', '', key='gsheet_url')
135
+ upload_status = []
136
+ has_db = False
137
+ has_csv = False
138
+
139
+ # Retrieve uploaded files list safely
140
+ uploaded_files = st.session_state.get('upload_any_col', [])
141
+ # Process Google Sheets if URL provided
142
+ url = st.session_state.get('gsheet_url', '').strip()
143
+ if url:
144
+ try:
145
+ csv_url = url.replace('/edit#gid=', '/export?format=csv&gid=')
146
+ df_gs = pd.read_csv(csv_url)
147
+ if 'uploaded_csvs' not in st.session_state:
148
+ st.session_state['uploaded_csvs'] = []
149
+ st.session_state['uploaded_csvs'].append({'name': 'GoogleSheets', 'df': df_gs})
150
+ upload_status.append('✅ Google Sheets loaded')
151
+ has_csv = True
152
+ except Exception as e:
153
+ upload_status.append(f'❌ Google Sheets error: {e}')
154
+
155
+ # Process files
156
+ for f in uploaded_files:
157
+ if f.name.lower().endswith('.db'):
158
+ try:
159
+ with open(DB_PATH, "wb") as dbf:
160
+ dbf.write(f.read())
161
+ upload_status.append(f"✅ Database: {f.name}")
162
+ has_db = True
163
+ except Exception as e:
164
+ upload_status.append(f"❌ Database error: {e}")
165
+
166
+ # Process CSV and Excel
167
+ name = f.name.lower()
168
+ if name.endswith('.csv') or name.endswith('.xlsx'):
169
+ try:
170
+ if name.endswith('.xlsx'):
171
+ # Process each sheet in Excel
172
+ f.seek(0)
173
+ xls = pd.ExcelFile(f)
174
+ sheets = st.multiselect(f"Select sheet(s) from {f.name}", xls.sheet_names, default=xls.sheet_names)
175
+ for sheet in sheets:
176
+ # Read raw to detect header rows
177
+ raw = xls.parse(sheet, header=None)
178
+ nn = raw.notnull().sum(axis=1)
179
+ hdr = [i for i, cnt in enumerate(nn) if cnt > 1]
180
+ if len(hdr) >= 2:
181
+ header = hdr[:2]
182
+ elif len(hdr) == 1:
183
+ header = [hdr[0]]
184
+ else:
185
+ header = [0]
186
+ df_sheet = xls.parse(sheet, header=header)
187
+ # Flatten MultiIndex if needed
188
+ if isinstance(df_sheet.columns, pd.MultiIndex):
189
+ df_sheet.columns = [" ".join([str(x) for x in col if pd.notna(x)]).strip() for col in df_sheet.columns]
190
+ # Store with sheet label
191
+ sheet_key = f"{f.name}:{sheet}"
192
+ if 'uploaded_csvs' not in st.session_state:
193
+ st.session_state['uploaded_csvs'] = []
194
+ st.session_state['uploaded_csvs'].append({'name': sheet_key, 'df': df_sheet})
195
+ upload_status.append(f"✅ Excel: {sheet_key}")
196
+ else:
197
+ temp_df = pd.read_csv(f)
198
+
199
+ if 'uploaded_csvs' not in st.session_state:
200
+ st.session_state['uploaded_csvs'] = []
201
+
202
+ # Check existing and update
203
+ csv_exists = False
204
+ for i, csv in enumerate(st.session_state['uploaded_csvs']):
205
+ if csv['name'] == f.name:
206
+ st.session_state['uploaded_csvs'][i]['df'] = temp_df
207
+ csv_exists = True
208
+ break
209
+ if not csv_exists:
210
+ st.session_state['uploaded_csvs'].append({'name': f.name, 'df': temp_df})
211
+ upload_status.append(f"✅ CSV/Excel: {f.name}")
212
+ has_csv = True
213
+ except Exception as e:
214
+ upload_status.append(f"❌ CSV/Excel error: {e}")
215
+
216
+ # Hiển thị trạng thái upload
217
+ if upload_status:
218
+ for status in upload_status:
219
+ st.write(status)
220
+ # After upload, select data sources
221
+ ds = []
222
+ if os.path.exists(DB_PATH) and os.path.getsize(DB_PATH) > 0:
223
+ ds.append(db_name)
224
+ if 'uploaded_csvs' in st.session_state:
225
+ ds += [csv['name'] for csv in st.session_state['uploaded_csvs']]
226
+ if ds:
227
+ # Initialize selected_sources session state to default to db_name
228
+ if 'selected_sources' not in st.session_state:
229
+ st.session_state['selected_sources'] = [db_name] if db_name in ds else []
230
+ selected_sources = st.multiselect(
231
+ "**Select sources**", options=ds,
232
+ key='selected_sources'
233
+ )
234
+ else:
235
+ st.info("Upload a database or CSV/Excel file to select a data source.")
236
+
237
+ with left_col:
238
+ # Data Preview: filter sources by user selection
239
+ selected = st.session_state.get('selected_sources', [])
240
+ preview_db = os.path.exists(DB_PATH) and db_name in selected
241
+ # Filter CSV/Excel previews
242
+ preview_csvs = [csv for csv in st.session_state.get('uploaded_csvs', []) if csv['name'] in selected]
243
+ if preview_db or preview_csvs:
244
+ # Display previews
245
+ with st.container(height=415):
246
+ st.markdown("**Data Preview**")
247
+ # Build tab labels
248
+ tab_labels = []
249
+ if preview_db:
250
+ tab_labels.append(db_name)
251
+ for c in preview_csvs:
252
+ tab_labels.append(c['name'])
253
+ tabs = st.tabs(tab_labels)
254
+ idx = 0
255
+ # Database preview
256
+ if preview_db:
257
+ with tabs[idx]:
258
+ conn = sqlite3.connect(DB_PATH)
259
+ tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
260
+ if tables:
261
+ t_tabs = st.tabs([t[0] for t in tables])
262
+ for t, tab in zip(tables, t_tabs):
263
+ with tab:
264
+ st.table(pd.read_sql_query(f"SELECT * FROM {t[0]}", conn))
265
+ else:
266
+ st.info("No tables found.")
267
+ conn.close()
268
+ idx += 1
269
+ # CSV/Excel previews
270
+ for c in preview_csvs:
271
+ with tabs[idx]:
272
+ st.table(c['df'])
273
+ idx += 1
274
+
275
+ # --- Data Exploration Section (Always Visible) ---
276
+ with st.container(height=225):
277
+ # Data Exploration: only support Database source
278
+ selected = st.session_state.get('selected_sources', [])
279
+ if db_name not in selected:
280
+ st.warning(f"⚠️ Data Exploration only supports SQL queries on database .db files. Please select at least a database to continue.")
281
+ else:
282
+ # st.subheader("Data Exploration")
283
+ sql_explore = st.text_area(
284
+ "Enter SQL query to explore:",
285
+ value=st.session_state.get('explore_sql', ''),
286
+ height=100,
287
+ key='explore_sql'
288
+ )
289
+ if st.button("Run Query", key="explore_run"):
290
+ try:
291
+ df_explore = pd.read_sql_query(sql_explore, sqlite3.connect(DB_PATH))
292
+ st.session_state['explore_result'] = df_explore
293
+ # Record exploration history
294
+ if 'explore_history' not in st.session_state:
295
+ st.session_state['explore_history'] = []
296
+ # User query
297
+ st.session_state['explore_history'].append({
298
+ 'source': 'explore', 'role': 'user', 'content': sql_explore, 'timestamp': datetime.now()
299
+ })
300
+ # Assistant result as CSV
301
+ res_str = df_explore.to_csv(index=False)
302
+ st.session_state['explore_history'].append({
303
+ 'source': 'explore', 'role': 'assistant', 'content': res_str, 'timestamp': datetime.now()
304
+ })
305
+ except Exception as e:
306
+ st.error(f"Error: {e}")
307
+ # Wrap tabs in scrollable container
308
+ with st.container(height=300):
309
+ # st.markdown("<div style='height:300px; overflow:auto'>", unsafe_allow_html=True)
310
+ tabs = st.tabs(["Results", "History"])
311
+ # Results tab: show explore_result only
312
+ with tabs[0]:
313
+ if 'explore_result' in st.session_state:
314
+ # st.subheader("Results")
315
+ st.table(st.session_state['explore_result'])
316
+ else:
317
+ st.write("No results yet.")
318
+ # History tab: Query history
319
+ with tabs[1]:
320
+ # st.subheader("History")
321
+ # Build paired history entries
322
+ combined = []
323
+ # Exploration history pairs
324
+ explore_hist = st.session_state.get('explore_history', [])
325
+ for i in range(0, len(explore_hist), 2):
326
+ u = explore_hist[i] if i < len(explore_hist) else {}
327
+ a = explore_hist[i+1] if i+1 < len(explore_hist) else {}
328
+ combined.append({
329
+ 'source': db_name,
330
+ 'query_type': 'sql',
331
+ 'query': u.get('content'),
332
+ 'result': a.get('content'),
333
+ 'timestamp': u.get('timestamp')
334
+ })
335
+ # Chat history pairs for all sources
336
+ for source, chat_hist in st.session_state.get('chat_histories', {}).items():
337
+ for idx in range(len(chat_hist)):
338
+ if chat_hist[idx].get('role') == 'user':
339
+ q = chat_hist[idx].get('content')
340
+ r = chat_hist[idx+1].get('content') if idx+1 < len(chat_hist) else None
341
+ combined.append({
342
+ 'source': source,
343
+ 'query_type': 'chat',
344
+ 'query': q,
345
+ 'result': r,
346
+ 'timestamp': chat_hist[idx].get('timestamp')
347
+ })
348
+ if combined:
349
+ df_history = pd.DataFrame(combined)
350
+ # ensure timestamp column is datetime
351
+ if not pd.api.types.is_datetime64_any_dtype(df_history['timestamp']):
352
+ df_history['timestamp'] = pd.to_datetime(df_history['timestamp'])
353
+ # sort latest first
354
+ df_history = df_history.sort_values('timestamp', ascending=False)
355
+ st.table(df_history)
356
+ else:
357
+ st.write("No history yet.")
358
+ st.markdown("</div>", unsafe_allow_html=True)
359
+
360
+ with right_col:
361
+
362
+ # Use selected_sources from left data selector
363
+ data_sources = st.session_state.get('selected_sources', [])
364
+ csv_files = st.session_state.get('uploaded_csvs', [])
365
+ selected_source = data_sources[0] if data_sources else None
366
+
367
+ # Chat history per source (only if a source is selected)
368
+ if 'chat_histories' not in st.session_state:
369
+ st.session_state['chat_histories'] = {}
370
+ # Initialize past conversations container
371
+ if 'all_conversations' not in st.session_state:
372
+ st.session_state['all_conversations'] = {}
373
+
374
+ # Only proceed with chat if a data source is selected
375
+ if selected_source is not None:
376
+ if selected_source not in st.session_state['chat_histories']:
377
+ st.session_state['chat_histories'][selected_source] = []
378
+ if selected_source not in st.session_state['all_conversations']:
379
+ st.session_state['all_conversations'][selected_source] = []
380
+ chat_history = st.session_state['chat_histories'][selected_source]
381
+
382
+ # Only show chat interface if a data source is selected
383
+ if selected_source is not None:
384
+ container = st.container(height=700, border=True)
385
+ # Align New Conversation button top-right
386
+ with container:
387
+ cols = st.columns([2, 1])
388
+ with cols[0]:
389
+ st.markdown("**Ask TalkToData**")
390
+ if cols[1].button("New Chat", key=f"new_conv_{selected_source}"):
391
+ if chat_history:
392
+ conv = chat_history.copy()
393
+ ts = conv[0].get('timestamp', datetime.now())
394
+ st.session_state['all_conversations'][selected_source].append({'messages':conv, 'timestamp':ts})
395
+ st.session_state['chat_histories'][selected_source] = []
396
+ st.experimental_rerun()
397
+
398
+ # Display chat messages
399
+ chat_history = st.session_state['chat_histories'][selected_source]
400
+ # Welcome message for new chat
401
+ if not chat_history:
402
+ container.chat_message("assistant").write("👋 Hello! Welcome to TalkToData. Ask any question about your data to get started.")
403
+ for turn in chat_history:
404
+ role = turn.get('role', '')
405
+ content = turn.get('content', '')
406
+ if role == 'user':
407
+ container.chat_message("user").write(content)
408
+ else:
409
+ container.chat_message("assistant").write(content)
410
+
411
+ # Chat input
412
+ user_input = st.chat_input(f"Ask a question about {selected_source}...")
413
+ else:
414
+ # Placeholder to maintain layout
415
+ st.container(height=700, border=True)
416
+ user_input = None
417
+ if user_input:
418
+ chat_history.append({"role": "user", "content": user_input, "timestamp": datetime.now()})
419
+ with container.chat_message("user"):
420
+ st.write(user_input)
421
+ # Answer logic
422
+ with container.chat_message("assistant"):
423
+ with st.spinner("Thinking..."):
424
+ if selected_source == db_name:
425
+ # Handle /sql and /plot commands
426
+ if user_input.strip().lower().startswith('/sql'):
427
+ sql = user_input[len('/sql'):].strip()
428
+ try:
429
+ df = pd.read_sql_query(sql, sqlite3.connect(DB_PATH))
430
+ st.write(f"```sql\n{sql}\n```")
431
+ st.table(df)
432
+ chat_history.append({"role": "assistant", "content": f"```sql\n{sql}\n```", "timestamp": datetime.now()})
433
+ except Exception as e:
434
+ err = f"SQL Error: {e}"
435
+ st.error(err)
436
+ chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
437
+ elif user_input.strip().lower().startswith('/plot'):
438
+ sql = user_input[len('/plot'):].strip()
439
+ try:
440
+ tool = PlotSQLTool()
441
+ md = tool._run(sql)
442
+ st.markdown(md)
443
+ m = re.search(r'!\[.*\]\((.*?)\)', md)
444
+ if m:
445
+ st.image(m.group(1))
446
+ chat_history.append({"role": "assistant", "content": md, "timestamp": datetime.now()})
447
+ except Exception as e:
448
+ err = f"Plot Error: {e}"
449
+ st.error(err)
450
+ chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
451
+ else:
452
+ # Use SQL agent as before
453
+ state['question'] = user_input
454
+ try:
455
+ for step in agent.graph.stream(state, stream_mode="updates"):
456
+ step_name, step_details = next(iter(step.items()))
457
+ if step_name == 'generate_sql':
458
+ with st.expander("SQL Generated", expanded=False):
459
+ st.markdown(f"```sql\n{step_details.get('sql_query', '')}\n```")
460
+ elif step_name == 'execute_sql':
461
+ with st.expander("SQL Result", expanded=False):
462
+ st.table(step_details.get('sql_result', pd.DataFrame()))
463
+ elif step_name == 'generate_answer':
464
+ st.write(step_details.get('answer', ''))
465
+ chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()})
466
+ elif step_name == 'render_visualization':
467
+ # with st.expander("Chart", expanded=False):
468
+ st.image(step_details.get('visualization_output', ''))
469
+ except Exception as e:
470
+ err = f"SQL Agent Error: {e}"
471
+ st.error(err)
472
+ chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
473
+ else:
474
+ # Use DataFrame agent for selected CSV
475
+ csv_file = next((csv for csv in csv_files if csv['name'] == selected_source), None)
476
+ if csv_file:
477
+ if 'csv_agents' not in st.session_state:
478
+ st.session_state['csv_agents'] = {}
479
+ if selected_source not in st.session_state['csv_agents']:
480
+ st.session_state['csv_agents'][selected_source] = get_dataframe_agent(csv_file['df'])
481
+ agent = st.session_state['csv_agents'][selected_source]
482
+ try:
483
+ response = agent.invoke(user_input)
484
+ answer = response["output"] if isinstance(response, dict) and "output" in response else str(response)
485
+ except Exception as e:
486
+ answer = f"CSV Agent Error: {e}"
487
+ st.write(answer)
488
+ chat_history.append({"role": "assistant", "content": answer, "timestamp": datetime.now()})
489
+ # Refresh to update History immediately
490
+ # st.experimental_rerun()
491
+
492
+ # Past Conversations Panel
493
+ with st.container(height=200):
494
+ st.markdown("**Recent Conversations**")
495
+ # Flatten and sort conversations by most recent first
496
+ entries = []
497
+ for source, convs in st.session_state.get('all_conversations', {}).items():
498
+ for conv in convs:
499
+ entries.append((source, conv))
500
+ entries = sorted(entries, key=lambda x: x[1]['timestamp'], reverse=True)
501
+ for source, conv in entries:
502
+ label = conv['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
503
+ with st.expander(f"{source} - {label}", expanded=False):
504
+ for msg in conv['messages']:
505
+ if msg.get('role') == 'user':
506
+ st.chat_message('user').write(msg.get('content'))
507
+ else:
508
+ st.chat_message('assistant').write(msg.get('content'))
db/csv/category.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ id,name
2
+ 1,Laptop
3
+ 2,Tablet
4
+ 3,Smartphone
5
+ 4,Accessory
6
+ 5,Wearable
db/csv/laptop.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ id,product_id,ram,storage,processor
2
+ 1,1,8,256,Intel i5
3
+ 2,2,16,512,Intel i7
4
+ 3,7,32,1024,Intel i9
5
+ 4,1,4,128,Intel i3
6
+ 5,7,64,2048,AMD Ryzen 9
db/csv/product.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,name,price,promotion_id,category_id
2
+ 1,Basic Laptop,799.99,1,1
3
+ 2,High-end Laptop,1999.99,2,1
4
+ 3,Standard Tablet,499.99,1,2
5
+ 4,Pro Tablet,999.99,3,2
6
+ 5,Smartphone Model A,699.99,2,3
7
+ 6,Smartphone Model B,899.99,3,3
8
+ 7,Ultra Laptop,2499.99,4,1
9
+ 8,Mini Tablet,299.99,4,2
10
+ 9,Smartphone Model C,799.99,5,3
11
+ 10,Smartphone Model D,999.99,6,3
db/csv/promotion.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ id,description,discount
2
+ 1,Spring Sale,0.1
3
+ 2,Black Friday,0.25
4
+ 3,Clearance,0.5
5
+ 4,Summer Sale,0.15
6
+ 5,Cyber Monday,0.20
7
+ 6,Holiday Sale,0.30
db/csv/smartphone.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id,product_id,camera_megapixels,os,battery
2
+ 1,5,12.0,Android,3000
3
+ 2,6,48.0,iOS,3500
4
+ 3,5,16.0,Android,3200
5
+ 4,6,20.0,iOS,3100
6
+ 5,5,64.0,Android,3400
7
+ 6,6,108.0,iOS,3800
8
+ 7,5,12.0,Android,2900
9
+ 8,6,48.0,Android,3300
10
+ 9,5,32.0,Android,3500
11
+ 10,6,24.0,iOS,3000
db/csv/tablet.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ id,product_id,screen_size,battery,support_sim
2
+ 1,3,10.1,6000,0
3
+ 2,4,12.9,8000,1
4
+ 3,8,8.0,4200,1
5
+ 4,4,13.3,10000,0
db/sample_ecommerce.db ADDED
Binary file (32.8 kB). View file
 
pytest.ini ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ # Only run tests in the pytest_tests directory
3
+ testpaths = tests/pytest_tests
4
+ python_files = test_*.py
5
+ log_cli = true
6
+ log_cli_level = INFO
7
+ log_cli_format = %(asctime)s [%(levelname)s] %(message)s
8
+ log_cli_date_format = %H:%M:%S
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit== 1.36.0
2
+ pandas==2.2.0
3
+ numpy==1.26.4
4
+ langchain==0.3.25
5
+ langchain_core==0.3.58
6
+ langchain-google-genai==2.0.4
7
+ langgraph==0.3.31
8
+ python-dotenv==1.0.1
9
+ sqlalchemy==2.0.2
10
+ guardrails-ai==0.6.6
11
+ openpyxl==3.1.5
12
+ pydantic==2.9.2
utils/consts.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import abspath, dirname, join
3
+
4
+ # Root of the project
5
+ ROOT_DIR = abspath(join(dirname(__file__), '..'))
6
+ # Database directory and path
7
+ DB_DIR = join(ROOT_DIR, 'db')
8
+ os.makedirs(DB_DIR, exist_ok=True)
9
+ DB_PATH = join(DB_DIR, 'sample_ecommerce.db')
10
+ # Output directory for plots
11
+ PLOTS_DIR = join(ROOT_DIR, 'output', 'plots')
12
+ os.makedirs(PLOTS_DIR, exist_ok=True)