Spaces:
Sleeping
Sleeping
Nam Fam
commited on
Commit
·
472e1d4
1
Parent(s):
0d52457
add files
Browse files- .dockerignore +28 -0
- .gitignore +67 -0
- Dockerfile +29 -0
- agents/__init__.py +7 -0
- agents/dataframe_agent.py +39 -0
- agents/llms.py +54 -0
- agents/memories.py +0 -0
- agents/safe_guardrails.py +195 -0
- agents/sql_agent/__init__.py +14 -0
- agents/sql_agent/agent.py +98 -0
- agents/sql_agent/graph.py +103 -0
- agents/sql_agent/nodes.py +496 -0
- agents/sql_agent/prompts.py +0 -0
- agents/sql_agent/states.py +30 -0
- agents/tools.py +63 -0
- app.py +508 -0
- db/csv/category.csv +6 -0
- db/csv/laptop.csv +6 -0
- db/csv/product.csv +11 -0
- db/csv/promotion.csv +7 -0
- db/csv/smartphone.csv +11 -0
- db/csv/tablet.csv +5 -0
- db/sample_ecommerce.db +0 -0
- pytest.ini +8 -0
- requirements.txt +12 -0
- utils/consts.py +12 -0
.dockerignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exclude Python cache & virtual envs
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.env
|
| 7 |
+
|
| 8 |
+
# Exclude Git files
|
| 9 |
+
.git/
|
| 10 |
+
.gitignore
|
| 11 |
+
|
| 12 |
+
# IDE/editor files
|
| 13 |
+
.vscode/
|
| 14 |
+
.idea/
|
| 15 |
+
|
| 16 |
+
# OS files
|
| 17 |
+
.DS_Store
|
| 18 |
+
Thumbs.db
|
| 19 |
+
|
| 20 |
+
# Project-specific
|
| 21 |
+
notebooks/
|
| 22 |
+
scripts/
|
| 23 |
+
tests/
|
| 24 |
+
eval/
|
| 25 |
+
mcp_server/
|
| 26 |
+
mcp_client.py
|
| 27 |
+
utils/google_api_manager.py
|
| 28 |
+
output/
|
.gitignore
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.egg-info/
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
venv/
|
| 19 |
+
ENV/
|
| 20 |
+
env/
|
| 21 |
+
env.bak/
|
| 22 |
+
venv.bak/
|
| 23 |
+
pip-log.txt
|
| 24 |
+
pip-delete-this-directory.txt
|
| 25 |
+
|
| 26 |
+
# IDEs
|
| 27 |
+
.vscode/
|
| 28 |
+
.idea/
|
| 29 |
+
*.swp
|
| 30 |
+
*.swo
|
| 31 |
+
*.sublime-workspace
|
| 32 |
+
*.sublime-project
|
| 33 |
+
|
| 34 |
+
# Environment files
|
| 35 |
+
.env
|
| 36 |
+
.env.local
|
| 37 |
+
.env.development
|
| 38 |
+
.env.test
|
| 39 |
+
.env.production
|
| 40 |
+
|
| 41 |
+
# Local development
|
| 42 |
+
# *.db
|
| 43 |
+
*.sqlite
|
| 44 |
+
data/
|
| 45 |
+
logs/
|
| 46 |
+
*.log
|
| 47 |
+
# *.csv
|
| 48 |
+
*.parquet
|
| 49 |
+
# output/
|
| 50 |
+
|
| 51 |
+
# Streamlit
|
| 52 |
+
.streamlit/credentials.toml
|
| 53 |
+
|
| 54 |
+
# /notebooks
|
| 55 |
+
# /tests
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
*conftest.py
|
| 59 |
+
*test_plotting.py
|
| 60 |
+
plots/
|
| 61 |
+
utils/google_api_manager.py
|
| 62 |
+
mcp_client.py
|
| 63 |
+
mcp_server/
|
| 64 |
+
tests/
|
| 65 |
+
scripts/
|
| 66 |
+
notebooks/
|
| 67 |
+
output/
|
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# syntax=docker/dockerfile:1
|
| 2 |
+
|
| 3 |
+
FROM python:3.10-slim
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements and install Python dependencies
|
| 13 |
+
COPY requirements.txt ./
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Copy application code
|
| 17 |
+
COPY . .
|
| 18 |
+
|
| 19 |
+
# Expose Streamlit default port
|
| 20 |
+
EXPOSE 8501
|
| 21 |
+
|
| 22 |
+
# # Streamlit configuration
|
| 23 |
+
# ENV STREAMLIT_SERVER_PORT=8501 \
|
| 24 |
+
# STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 25 |
+
|
| 26 |
+
# # Launch the app
|
| 27 |
+
# CMD ["streamlit", "run", "app.py"]
|
| 28 |
+
|
| 29 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sql_agent.agent import SQLAgent
|
| 2 |
+
from .sql_agent.states import SQLAgentState
|
| 3 |
+
from .sql_agent.nodes import get_db_info, generate_sql, execute_sql, optional_plot, format_response , generate_answer
|
| 4 |
+
from .tools import PlotSQLTool
|
| 5 |
+
from .llms import LLM
|
| 6 |
+
|
| 7 |
+
__all__ = ['SQLAgent', 'SQLAgentState', 'get_db_info', 'generate_sql', 'execute_sql', 'optional_plot', 'format_response', 'generate_answer', 'PlotSQLTool', 'LLM']
|
agents/dataframe_agent.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
|
| 3 |
+
from agents.llms import LLM
|
| 4 |
+
from langchain.agents.agent_types import AgentType
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_dataframe_agent(
|
| 8 |
+
df: pd.DataFrame,
|
| 9 |
+
verbose: bool = True,
|
| 10 |
+
allow_dangerous_code: bool = True,
|
| 11 |
+
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
| 12 |
+
):
|
| 13 |
+
"""
|
| 14 |
+
Create a pandas DataFrame agent using the custom LLM.
|
| 15 |
+
Args:
|
| 16 |
+
df (pd.DataFrame): The pandas DataFrame to use.
|
| 17 |
+
verbose (bool): Whether to enable verbose output. Default is True.
|
| 18 |
+
allow_dangerous_code (bool): Whether to allow dangerous code execution. Default is True.
|
| 19 |
+
agent_type: The agent type to use. Default is ZERO_SHOT_REACT_DESCRIPTION.
|
| 20 |
+
Returns:
|
| 21 |
+
agent: The created DataFrame agent.
|
| 22 |
+
"""
|
| 23 |
+
llm = LLM().chat_model
|
| 24 |
+
agent = create_pandas_dataframe_agent(
|
| 25 |
+
llm,
|
| 26 |
+
df,
|
| 27 |
+
agent_type=agent_type,
|
| 28 |
+
verbose=verbose,
|
| 29 |
+
allow_dangerous_code=allow_dangerous_code
|
| 30 |
+
)
|
| 31 |
+
return agent
|
| 32 |
+
|
| 33 |
+
# Usage example:
|
| 34 |
+
# import pandas as pd
|
| 35 |
+
# from agents.dataframe_agent import get_dataframe_agent
|
| 36 |
+
# df = pd.read_csv('your_file.csv')
|
| 37 |
+
# agent = get_dataframe_agent(df)
|
| 38 |
+
# response = agent.invoke('Your question here')
|
| 39 |
+
# print(response)
|
agents/llms.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.chat_models import init_chat_model
|
| 2 |
+
from langchain_core.messages import HumanMessage
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from typing import List
|
| 5 |
+
from langchain.tools import BaseTool
|
| 6 |
+
from langchain.agents import initialize_agent, AgentType
|
| 7 |
+
|
| 8 |
+
_ = load_dotenv()
|
| 9 |
+
|
| 10 |
+
class LLM:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model: str = "gemini-2.0-flash",
|
| 14 |
+
model_provider: str = "google_genai",
|
| 15 |
+
temperature: float = 0.0,
|
| 16 |
+
max_tokens: int = 1000
|
| 17 |
+
):
|
| 18 |
+
self.chat_model = init_chat_model(
|
| 19 |
+
model=model,
|
| 20 |
+
model_provider=model_provider,
|
| 21 |
+
temperature=temperature,
|
| 22 |
+
max_tokens=max_tokens,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def generate(self, prompt: str) -> str:
|
| 26 |
+
message = HumanMessage(content=prompt)
|
| 27 |
+
response = self.chat_model.invoke([message])
|
| 28 |
+
return response.content
|
| 29 |
+
|
| 30 |
+
def bind_tools(self, tools: List[BaseTool], agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION):
|
| 31 |
+
"""
|
| 32 |
+
Bind LangChain tools to this model and return an AgentExecutor.
|
| 33 |
+
"""
|
| 34 |
+
return initialize_agent(
|
| 35 |
+
tools,
|
| 36 |
+
self.chat_model,
|
| 37 |
+
agent=agent_type,
|
| 38 |
+
verbose=False
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def set_temperature(self, temperature: float):
|
| 42 |
+
"""
|
| 43 |
+
Set the temperature for the chat model.
|
| 44 |
+
"""
|
| 45 |
+
self.chat_model.temperature = temperature
|
| 46 |
+
|
| 47 |
+
def set_max_tokens(self, max_tokens: int):
|
| 48 |
+
"""
|
| 49 |
+
Set the maximum number of tokens for the chat model.
|
| 50 |
+
"""
|
| 51 |
+
self.chat_model.max_tokens = max_tokens
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
agents/memories.py
ADDED
|
File without changes
|
agents/safe_guardrails.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
from guardrails.validators import Validator, register_validator
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 7 |
+
from agents.llms import LLM
|
| 8 |
+
|
| 9 |
+
def setup_logger(name):
|
| 10 |
+
logger = logging.getLogger(name)
|
| 11 |
+
logger.setLevel(logging.INFO)
|
| 12 |
+
handler = logging.StreamHandler()
|
| 13 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 14 |
+
logger.addHandler(handler)
|
| 15 |
+
return logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# @register_validator(name="medical_topic", data_type="string")
|
| 19 |
+
# class MedicalTopicValidator(Validator):
|
| 20 |
+
# """Validates medical topics using project's LLM"""
|
| 21 |
+
|
| 22 |
+
# def __init__(self, threshold: float = 0.7, on_fail=None):
|
| 23 |
+
# super().__init__(on_fail=on_fail)
|
| 24 |
+
# self.threshold = threshold
|
| 25 |
+
# self.llm = LLM()
|
| 26 |
+
# self.logger = setup_logger('MedicalValidator') # Use project's logging
|
| 27 |
+
|
| 28 |
+
# def validate(self, value: str, metadata: Dict = {}) -> str:
|
| 29 |
+
# prompt = f"""Medical relevance analysis (0-1 score then YES/NO):
|
| 30 |
+
# Question: {value}
|
| 31 |
+
# Medical context: {metadata.get('context', 'general health')}
|
| 32 |
+
# Score then Answer:"""
|
| 33 |
+
|
| 34 |
+
# try:
|
| 35 |
+
# response = self.llm.generate(prompt)
|
| 36 |
+
# last_line = response.strip().split('\n')[-1]
|
| 37 |
+
# parts = last_line.upper().strip().split()
|
| 38 |
+
|
| 39 |
+
# if len(parts) != 2:
|
| 40 |
+
# raise ValueError(f"Malformed LLM response: '{last_line}'")
|
| 41 |
+
|
| 42 |
+
# score_str, answer = parts
|
| 43 |
+
|
| 44 |
+
# # Kiểm tra nếu phần điểm không phải là số
|
| 45 |
+
# if not score_str.replace('.', '', 1).isdigit() and '%' not in score_str:
|
| 46 |
+
# raise ValueError(f"Invalid score format: '{score_str}'")
|
| 47 |
+
|
| 48 |
+
# # Convert score
|
| 49 |
+
# score = float(score_str.strip('%')) / 100 if '%' in score_str else float(score_str)
|
| 50 |
+
|
| 51 |
+
# if score < self.threshold or "YES" not in answer:
|
| 52 |
+
# self.logger.warning(f"Rejected question: {value[:60]}")
|
| 53 |
+
# raise ValueError("Non-medical question detected")
|
| 54 |
+
|
| 55 |
+
# return value
|
| 56 |
+
# except Exception as e:
|
| 57 |
+
# self.logger.error(f"Validation failed: {str(e)}")
|
| 58 |
+
# raise
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
from typing import Callable, Dict, Optional
|
| 62 |
+
from guardrails.validators import (
|
| 63 |
+
FailResult,
|
| 64 |
+
PassResult,
|
| 65 |
+
register_validator,
|
| 66 |
+
ValidationResult,
|
| 67 |
+
Validator,
|
| 68 |
+
)
|
| 69 |
+
from agents.llms import LLM
|
| 70 |
+
|
| 71 |
+
TOXICITY_CHECK_PROMPT_TEMPLATE = """
|
| 72 |
+
You are a helpful assistant.
|
| 73 |
+
Respond with a number between 0 and 100 indicating the level of toxicity for the following text:
|
| 74 |
+
{text}
|
| 75 |
+
0 means no toxic language.
|
| 76 |
+
100 means highly toxic language.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
@register_validator(name="toxic-language", data_type="string")
|
| 80 |
+
class ToxicLanguage(Validator):
|
| 81 |
+
def __init__(self, threshold: int = 70, on_fail: Optional[Callable] = None):
|
| 82 |
+
super().__init__(on_fail=on_fail, threshold=threshold)
|
| 83 |
+
self._threshold = threshold
|
| 84 |
+
self.llm = LLM()
|
| 85 |
+
self.fix_value = "Sorry, I can't assist you with that request."
|
| 86 |
+
|
| 87 |
+
def _validate(self, value: str, metadata: Dict) -> ValidationResult:
|
| 88 |
+
prompt = TOXICITY_CHECK_PROMPT_TEMPLATE.format(text=value)
|
| 89 |
+
score = int(self.llm.generate(prompt).strip())
|
| 90 |
+
if score > self._threshold:
|
| 91 |
+
return FailResult(
|
| 92 |
+
error_message=f"Validation failed. Score {score} exceeds threshold of {self._threshold}.",
|
| 93 |
+
fix_value=self.fix_value,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
return PassResult()
|
| 97 |
+
OFF_TOPIC_CHECK_PROMPT_TEMPLATE = """
|
| 98 |
+
You are a helpful assistant.
|
| 99 |
+
Respond with a number between 0 and 100 indicating how off-topic the following text is. Consider the context provided:
|
| 100 |
+
Topic: '{topic}'
|
| 101 |
+
Additional Context: '{additional_context}'
|
| 102 |
+
Text: {text}
|
| 103 |
+
Do not output prose.
|
| 104 |
+
0 means very relevant to the topic.
|
| 105 |
+
100 means completely off-topic.
|
| 106 |
+
Please note that common greetings should not be considered off-topic.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
@register_validator(name="off-topic", data_type="string")
|
| 110 |
+
class OffTopicValidator(Validator):
|
| 111 |
+
def __init__(self, threshold: int = 70, on_fail: Optional[Callable] = None):
|
| 112 |
+
super().__init__(on_fail=on_fail, threshold=threshold)
|
| 113 |
+
self._threshold = threshold
|
| 114 |
+
self.llm = LLM()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _validate(self, value: str, metadata: Dict) -> ValidationResult:
|
| 119 |
+
topic = metadata.get('topic', 'general')
|
| 120 |
+
additional_context = metadata.get('additional_context', '')
|
| 121 |
+
|
| 122 |
+
if topic == 'general':
|
| 123 |
+
return PassResult()
|
| 124 |
+
|
| 125 |
+
# self.fix_value = f"Sorry, i can only assist you with questions related to the topic '{topic}'."
|
| 126 |
+
self.fix_value = "OFF_TOPIC"
|
| 127 |
+
|
| 128 |
+
prompt = OFF_TOPIC_CHECK_PROMPT_TEMPLATE.format(
|
| 129 |
+
text=value,
|
| 130 |
+
topic=topic,
|
| 131 |
+
additional_context=additional_context
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
score = int(self.llm.generate(prompt).strip())
|
| 135 |
+
|
| 136 |
+
print(f"Off-topic score: {score}")
|
| 137 |
+
if score > self._threshold:
|
| 138 |
+
return FailResult(
|
| 139 |
+
error_message=f"Validation failed. Score {score} exceeds threshold of {self._threshold}.",
|
| 140 |
+
fix_value=self.fix_value,
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
return PassResult()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
# validator = OffTopicValidator()
|
| 150 |
+
|
| 151 |
+
# print("Validating:")
|
| 152 |
+
# result = validator.validate("What is the capital of France?", metadata={"topic": "Medical"})
|
| 153 |
+
# print("Validation result:", result)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
from guardrails import Guard
|
| 157 |
+
guard = Guard().use(
|
| 158 |
+
# ToxicLanguage,
|
| 159 |
+
OffTopicValidator,
|
| 160 |
+
# on_fail=lambda value, fail_result: f"Sorry, I can't assist you with that request.",
|
| 161 |
+
# on_fail="exception"
|
| 162 |
+
|
| 163 |
+
on_fail="fix"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
texts = [
|
| 167 |
+
"What is the capital of France?",
|
| 168 |
+
"I want to kill you.",
|
| 169 |
+
"You are a stupid dog",
|
| 170 |
+
"Triệu chứng của bệnh viêm dạ dày",
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
metadata = {'topic': 'Medical'}
|
| 174 |
+
for text in texts:
|
| 175 |
+
print(f"Validating: {text}")
|
| 176 |
+
try:
|
| 177 |
+
validation_result = guard.validate(text, metadata=metadata)
|
| 178 |
+
|
| 179 |
+
print("Validation passed")
|
| 180 |
+
print("Validation result:", validation_result)
|
| 181 |
+
|
| 182 |
+
# response = guard.to_runnable().invoke(text)
|
| 183 |
+
# print("Response:", response)
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Validation failed: {e}")
|
| 187 |
+
|
| 188 |
+
print('-' * 20)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Example usage
|
| 195 |
+
# python agents/safe_guardrails.py
|
agents/sql_agent/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .agent import SQLAgent
|
| 2 |
+
from .states import SQLAgentState
|
| 3 |
+
from .nodes import get_db_info, generate_sql, execute_sql, optional_plot, format_response, generate_answer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'SQLAgent',
|
| 7 |
+
'SQLAgentState',
|
| 8 |
+
'get_db_info',
|
| 9 |
+
'generate_sql',
|
| 10 |
+
'execute_sql',
|
| 11 |
+
'optional_plot',
|
| 12 |
+
'format_response',
|
| 13 |
+
'generate_answer'
|
| 14 |
+
]
|
agents/sql_agent/agent.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
| 4 |
+
from agents.llms import LLM
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from langchain_community.utilities import SQLDatabase
|
| 7 |
+
from utils.consts import DB_PATH
|
| 8 |
+
from agents.sql_agent.states import SQLAgentState
|
| 9 |
+
|
| 10 |
+
# Load environment vars
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
# def get_sql_agent():
|
| 14 |
+
# """
|
| 15 |
+
# Initializes a LangChain SQLDatabaseChain for SQLite.
|
| 16 |
+
# """
|
| 17 |
+
# # Load SQLite DB
|
| 18 |
+
# db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
|
| 19 |
+
# # Patch run to strip Markdown fences and log
|
| 20 |
+
# orig_run = db.run
|
| 21 |
+
# def clean_run(query: str, **kwargs) -> str:
|
| 22 |
+
# lines = query.splitlines()
|
| 23 |
+
# if lines and lines[0].strip().startswith("```"):
|
| 24 |
+
# lines = lines[1:]
|
| 25 |
+
# if lines and lines[-1].strip().startswith("```"):
|
| 26 |
+
# lines = lines[:-1]
|
| 27 |
+
# cleaned = "\n".join(lines).strip()
|
| 28 |
+
# print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
|
| 29 |
+
|
| 30 |
+
# def get_sql_agent():
|
| 31 |
+
# """
|
| 32 |
+
# Initializes a LangChain SQLDatabaseChain for SQLite.
|
| 33 |
+
# """
|
| 34 |
+
# # Load SQLite DB
|
| 35 |
+
# db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
|
| 36 |
+
# # Patch run to strip Markdown fences and log
|
| 37 |
+
# orig_run = db.run
|
| 38 |
+
# def clean_run(query: str, **kwargs) -> str:
|
| 39 |
+
# lines = query.splitlines()
|
| 40 |
+
# if lines and lines[0].strip().startswith("```"):
|
| 41 |
+
# lines = lines[1:]
|
| 42 |
+
# if lines and lines[-1].strip().startswith("```"):
|
| 43 |
+
# lines = lines[:-1]
|
| 44 |
+
# cleaned = "\n".join(lines).strip()
|
| 45 |
+
# print(f"[SQLDatabaseChain] Running SQL: {cleaned}")
|
| 46 |
+
# return orig_run(cleaned, **kwargs)
|
| 47 |
+
# db.run = clean_run
|
| 48 |
+
# # Initialize LLM
|
| 49 |
+
# llm_wrapper = LLM()
|
| 50 |
+
# # Create SQLDatabaseChain
|
| 51 |
+
# chain = SQLDatabaseChain.from_llm(llm_wrapper.chat_model, db, verbose=True)
|
| 52 |
+
# return chain
|
| 53 |
+
|
| 54 |
+
class SQLAgent:
|
| 55 |
+
def __init__(self):
|
| 56 |
+
self.db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
|
| 57 |
+
self.llm = LLM()
|
| 58 |
+
self.graph = self.build_graph()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_graph(self):
|
| 62 |
+
from agents.sql_agent.graph import build_graph
|
| 63 |
+
return build_graph().compile()
|
| 64 |
+
|
| 65 |
+
def run(self, state: SQLAgentState) -> SQLAgentState:
|
| 66 |
+
"""
|
| 67 |
+
Run the SQL agent with the given query.
|
| 68 |
+
"""
|
| 69 |
+
return self.graph.invoke(state)
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
agent = SQLAgent()
|
| 73 |
+
state = {
|
| 74 |
+
"question": None,
|
| 75 |
+
"db_info": {
|
| 76 |
+
"tables": [],
|
| 77 |
+
"columns": {},
|
| 78 |
+
"schema": None
|
| 79 |
+
},
|
| 80 |
+
"sql_query": None,
|
| 81 |
+
"sql_result": None,
|
| 82 |
+
"error": None
|
| 83 |
+
}
|
| 84 |
+
while True:
|
| 85 |
+
question = input("Enter your query (or 'exit' to quit): ")
|
| 86 |
+
state['question'] = question
|
| 87 |
+
if not question or question.lower() in ('exit', 'quit'):
|
| 88 |
+
print("Goodbye!")
|
| 89 |
+
break
|
| 90 |
+
result = agent.run(state)
|
| 91 |
+
# print(result)
|
| 92 |
+
|
| 93 |
+
# answer = result['answer']
|
| 94 |
+
# print(answer)
|
| 95 |
+
|
| 96 |
+
for step in agent.graph.stream(state, stream_mode="updates"):
|
| 97 |
+
print(step)
|
| 98 |
+
|
agents/sql_agent/graph.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
| 4 |
+
from agents.sql_agent.states import SQLAgentState
|
| 5 |
+
from langgraph.graph import StateGraph, START, END
|
| 6 |
+
from agents.sql_agent.nodes import (
|
| 7 |
+
get_db_info,
|
| 8 |
+
generate_sql,
|
| 9 |
+
execute_sql,
|
| 10 |
+
generate_answer,
|
| 11 |
+
detect_off_topic,
|
| 12 |
+
choose_visualization,
|
| 13 |
+
format_data_for_visualization,
|
| 14 |
+
render_visualization,
|
| 15 |
+
finalize_output
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def build_graph(visualize: bool = True) -> StateGraph:
|
| 19 |
+
graph = StateGraph(SQLAgentState)
|
| 20 |
+
|
| 21 |
+
# Add nodes
|
| 22 |
+
graph.add_node("detect_off_topic", detect_off_topic)
|
| 23 |
+
graph.add_node("generate_sql", generate_sql)
|
| 24 |
+
graph.add_node("get_db_info", get_db_info)
|
| 25 |
+
graph.add_node("execute_sql", execute_sql)
|
| 26 |
+
graph.add_node("generate_answer", generate_answer)
|
| 27 |
+
graph.add_node("choose_visualization", choose_visualization)
|
| 28 |
+
graph.add_node("format_data_for_visualization", format_data_for_visualization)
|
| 29 |
+
graph.add_node("render_visualization", render_visualization)
|
| 30 |
+
graph.add_node("finalize_output", finalize_output)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Add edges
|
| 34 |
+
graph.add_edge(START, "detect_off_topic")
|
| 35 |
+
|
| 36 |
+
graph.add_conditional_edges(
|
| 37 |
+
"detect_off_topic",
|
| 38 |
+
lambda state: state['error'],
|
| 39 |
+
path_map={
|
| 40 |
+
# True: "generate_answer",
|
| 41 |
+
True: "get_db_info",
|
| 42 |
+
False: "get_db_info"
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
graph.add_edge("get_db_info", "generate_sql")
|
| 47 |
+
graph.add_edge("generate_sql", "execute_sql")
|
| 48 |
+
graph.add_edge("execute_sql", "choose_visualization")
|
| 49 |
+
graph.add_edge("choose_visualization", "format_data_for_visualization")
|
| 50 |
+
graph.add_edge("format_data_for_visualization", "render_visualization")
|
| 51 |
+
graph.add_edge("render_visualization", "generate_answer")
|
| 52 |
+
graph.add_edge("generate_answer", "finalize_output")
|
| 53 |
+
graph.add_edge("finalize_output", END)
|
| 54 |
+
# graph.add_edge("execute_sql", "generate_answer")
|
| 55 |
+
# graph.add_edge("generate_answer", "choose_visualization")
|
| 56 |
+
# graph.add_edge("choose_visualization", END)
|
| 57 |
+
|
| 58 |
+
if visualize:
|
| 59 |
+
# TODO: Implement visualization
|
| 60 |
+
pass
|
| 61 |
+
return graph
|
| 62 |
+
|
| 63 |
+
def visualize_graph(graph) -> None:
|
| 64 |
+
graph.visualize()
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
state = {
|
| 68 |
+
"question": "top 3 sản phẩm có giá thấp nhất",
|
| 69 |
+
"db_info": {
|
| 70 |
+
"tables": [],
|
| 71 |
+
"columns": {},
|
| 72 |
+
"schema": ""
|
| 73 |
+
},
|
| 74 |
+
"sql_query": "",
|
| 75 |
+
"sql_result": None,
|
| 76 |
+
"error": None,
|
| 77 |
+
"step": None,
|
| 78 |
+
"answer": None,
|
| 79 |
+
"plot_path": None,
|
| 80 |
+
"response_md": None,
|
| 81 |
+
"visualization": None,
|
| 82 |
+
"visualization_reason": None,
|
| 83 |
+
"formatted_data_for_visualization": None,
|
| 84 |
+
"visualization_output": None,
|
| 85 |
+
"off_topic": None
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
graph = build_graph().compile()
|
| 89 |
+
# visualize_graph(graph)
|
| 90 |
+
|
| 91 |
+
result = graph.invoke(state)
|
| 92 |
+
# print(result)
|
| 93 |
+
|
| 94 |
+
answer = result['answer']
|
| 95 |
+
print(answer)
|
| 96 |
+
|
| 97 |
+
for step in graph.stream(
|
| 98 |
+
state, stream_mode="updates"
|
| 99 |
+
):
|
| 100 |
+
print("-" * 80)
|
| 101 |
+
# print(step['step'])
|
| 102 |
+
print(step)
|
| 103 |
+
|
agents/sql_agent/nodes.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import re
|
| 4 |
+
from agents.llms import LLM
|
| 5 |
+
from agents.tools import PlotSQLTool
|
| 6 |
+
from .states import SQLAgentState
|
| 7 |
+
from utils.consts import DB_PATH
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def choose_visualization(state: SQLAgentState) -> SQLAgentState:
|
| 11 |
+
"""Use LLM to suggest a suitable chart type for the SQL result."""
|
| 12 |
+
question = state['question']
|
| 13 |
+
sql_query = state['sql_query']
|
| 14 |
+
sql_result = state['sql_result']
|
| 15 |
+
# Convert sql_result DataFrame to markdown or string preview (or sample rows)
|
| 16 |
+
if sql_result is not None:
|
| 17 |
+
if hasattr(sql_result, 'head'):
|
| 18 |
+
preview = sql_result.head(5).to_markdown(index=False)
|
| 19 |
+
else:
|
| 20 |
+
preview = str(sql_result)
|
| 21 |
+
else:
|
| 22 |
+
preview = "No results"
|
| 23 |
+
|
| 24 |
+
prompt = f'''
|
| 25 |
+
You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.
|
| 26 |
+
|
| 27 |
+
Available chart types and their use cases:
|
| 28 |
+
- Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2.
|
| 29 |
+
- Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large.
|
| 30 |
+
- Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous.
|
| 31 |
+
- Pie Charts: Ideal for showing proportions or percentages within a whole.
|
| 32 |
+
- Line Graphs: Best for showing trends and distributions over time. Best used when both x axis and y axis are continuous or time-based.
|
| 33 |
+
|
| 34 |
+
Provide your response in the following format:
|
| 35 |
+
Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
|
| 36 |
+
Reason: [Brief explanation for your recommendation]
|
| 37 |
+
|
| 38 |
+
User question: {question}
|
| 39 |
+
SQL query: {sql_query}
|
| 40 |
+
Query results: {preview}
|
| 41 |
+
|
| 42 |
+
Recommend a visualization:
|
| 43 |
+
'''
|
| 44 |
+
llm = LLM()
|
| 45 |
+
response = llm.generate(prompt)
|
| 46 |
+
lines = response.split('\n')
|
| 47 |
+
visualization = 'none'
|
| 48 |
+
reason = ''
|
| 49 |
+
for line in lines:
|
| 50 |
+
if line.lower().startswith('recommended visualization:'):
|
| 51 |
+
visualization = line.split(':', 1)[1].strip()
|
| 52 |
+
elif line.lower().startswith('reason:'):
|
| 53 |
+
reason = line.split(':', 1)[1].strip()
|
| 54 |
+
state['visualization'] = visualization
|
| 55 |
+
state['visualization_reason'] = reason
|
| 56 |
+
state['step'] = 'choose_visualization'
|
| 57 |
+
return state
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def format_data_for_visualization(state: SQLAgentState) -> SQLAgentState:
|
| 61 |
+
"""
|
| 62 |
+
Format the data for the chosen visualization type.
|
| 63 |
+
Hỗ trợ line, bar, scatter, grouped bar, fallback LLM cho các visualization khác.
|
| 64 |
+
"""
|
| 65 |
+
import json
|
| 66 |
+
import pandas as pd
|
| 67 |
+
llm = LLM()
|
| 68 |
+
|
| 69 |
+
visualization = state.get('visualization', 'none')
|
| 70 |
+
sql_result = state.get('sql_result')
|
| 71 |
+
question = state.get('question')
|
| 72 |
+
sql_query = state.get('sql_query')
|
| 73 |
+
|
| 74 |
+
# Convert DataFrame to list of lists for processing
|
| 75 |
+
if sql_result is not None and hasattr(sql_result, 'values'):
|
| 76 |
+
data = sql_result.values.tolist()
|
| 77 |
+
columns = list(sql_result.columns)
|
| 78 |
+
elif isinstance(sql_result, list):
|
| 79 |
+
data = sql_result
|
| 80 |
+
columns = []
|
| 81 |
+
else:
|
| 82 |
+
state['formatted_data_for_visualization'] = None
|
| 83 |
+
return state
|
| 84 |
+
|
| 85 |
+
def _format_line_data(data, question):
|
| 86 |
+
if len(data[0]) == 2:
|
| 87 |
+
x_values = [str(row[0]) for row in data]
|
| 88 |
+
y_values = [float(row[1]) for row in data]
|
| 89 |
+
prompt = f"""
|
| 90 |
+
You are a data labeling expert. Given a question and some data, provide a concise and relevant label for the data series.
|
| 91 |
+
Question: {question}
|
| 92 |
+
Data (first few rows): {data[:2]}
|
| 93 |
+
Provide a concise label for this y axis.
|
| 94 |
+
"""
|
| 95 |
+
label = llm.generate(prompt).strip()
|
| 96 |
+
formatted_data = {
|
| 97 |
+
"xValues": x_values,
|
| 98 |
+
"yValues": [
|
| 99 |
+
{
|
| 100 |
+
"data": y_values,
|
| 101 |
+
"label": label
|
| 102 |
+
}
|
| 103 |
+
]
|
| 104 |
+
}
|
| 105 |
+
return formatted_data
|
| 106 |
+
elif len(data[0]) == 3:
|
| 107 |
+
data_by_label = {}
|
| 108 |
+
x_values = []
|
| 109 |
+
labels = list(set(item2 for item1, item2, item3 in data if isinstance(item2, str) and not item2.replace(".", "").isdigit() and "/" not in item2))
|
| 110 |
+
if not labels:
|
| 111 |
+
labels = list(set(item1 for item1, item2, item3 in data if isinstance(item1, str) and not item1.replace(".", "").isdigit() and "/" not in item1))
|
| 112 |
+
for item1, item2, item3 in data:
|
| 113 |
+
if isinstance(item1, str) and not item1.replace(".", "").isdigit() and "/" not in item1:
|
| 114 |
+
label, x, y = item1, item2, item3
|
| 115 |
+
else:
|
| 116 |
+
x, label, y = item1, item2, item3
|
| 117 |
+
if str(x) not in x_values:
|
| 118 |
+
x_values.append(str(x))
|
| 119 |
+
if label not in data_by_label:
|
| 120 |
+
data_by_label[label] = []
|
| 121 |
+
data_by_label[label].append(float(y))
|
| 122 |
+
for other_label in labels:
|
| 123 |
+
if other_label != label:
|
| 124 |
+
if other_label not in data_by_label:
|
| 125 |
+
data_by_label[other_label] = []
|
| 126 |
+
data_by_label[other_label].append(None)
|
| 127 |
+
y_values = [
|
| 128 |
+
{
|
| 129 |
+
"data": data,
|
| 130 |
+
"label": label
|
| 131 |
+
}
|
| 132 |
+
for label, data in data_by_label.items()
|
| 133 |
+
]
|
| 134 |
+
formatted_data = {
|
| 135 |
+
"xValues": x_values,
|
| 136 |
+
"yValues": y_values,
|
| 137 |
+
"yAxisLabel": ""
|
| 138 |
+
}
|
| 139 |
+
prompt = f"""
|
| 140 |
+
You are a data labeling expert. Given a question and some data, provide a concise and relevant label for the y-axis.
|
| 141 |
+
Question: {question}
|
| 142 |
+
Data (first few rows): {data[:2]}
|
| 143 |
+
Provide a concise label for the y-axis.
|
| 144 |
+
"""
|
| 145 |
+
y_axis_label = llm.generate(prompt).strip()
|
| 146 |
+
formatted_data["yAxisLabel"] = y_axis_label
|
| 147 |
+
return formatted_data
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
def _format_scatter_data(data):
|
| 151 |
+
formatted_data = {"series": []}
|
| 152 |
+
if len(data[0]) == 2:
|
| 153 |
+
formatted_data["series"].append({
|
| 154 |
+
"data": [
|
| 155 |
+
{"x": float(x), "y": float(y), "id": i+1}
|
| 156 |
+
for i, (x, y) in enumerate(data)
|
| 157 |
+
],
|
| 158 |
+
"label": "Data Points"
|
| 159 |
+
})
|
| 160 |
+
elif len(data[0]) == 3:
|
| 161 |
+
entities = {}
|
| 162 |
+
for item1, item2, item3 in data:
|
| 163 |
+
if isinstance(item1, str) and not item1.replace(".", "").isdigit() and "/" not in item1:
|
| 164 |
+
label, x, y = item1, item2, item3
|
| 165 |
+
else:
|
| 166 |
+
x, label, y = item1, item2, item3
|
| 167 |
+
if label not in entities:
|
| 168 |
+
entities[label] = []
|
| 169 |
+
entities[label].append({"x": float(x), "y": float(y), "id": len(entities[label])+1})
|
| 170 |
+
for label, d in entities.items():
|
| 171 |
+
formatted_data["series"].append({
|
| 172 |
+
"data": d,
|
| 173 |
+
"label": label
|
| 174 |
+
})
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError("Unexpected data format in results")
|
| 177 |
+
return formatted_data
|
| 178 |
+
|
| 179 |
+
def _format_bar_data(data, question):
|
| 180 |
+
if len(data[0]) == 2:
|
| 181 |
+
labels = [str(row[0]) for row in data]
|
| 182 |
+
values = [float(row[1]) for row in data]
|
| 183 |
+
prompt = f"""
|
| 184 |
+
You are a data labeling expert. Given a question and some data, provide a concise and relevant label for the data series.
|
| 185 |
+
Question: {question}
|
| 186 |
+
Data (first few rows): {data[:2]}
|
| 187 |
+
Provide a concise label for this y axis.
|
| 188 |
+
"""
|
| 189 |
+
label = llm.generate(prompt).strip()
|
| 190 |
+
y_values = [{"data": values, "label": label}]
|
| 191 |
+
elif len(data[0]) == 3:
|
| 192 |
+
categories = set(row[1] for row in data)
|
| 193 |
+
labels = list(categories)
|
| 194 |
+
entities = set(row[0] for row in data)
|
| 195 |
+
y_values = []
|
| 196 |
+
for entity in entities:
|
| 197 |
+
entity_data = [float(row[2]) for row in data if row[0] == entity]
|
| 198 |
+
y_values.append({"data": entity_data, "label": str(entity)})
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError("Unexpected data format in results")
|
| 201 |
+
formatted_data = {
|
| 202 |
+
"labels": labels,
|
| 203 |
+
"values": y_values
|
| 204 |
+
}
|
| 205 |
+
return formatted_data
|
| 206 |
+
|
| 207 |
+
def _format_other_visualizations(visualization, question, sql_query, data):
|
| 208 |
+
# Fallback: use LLM to format data
|
| 209 |
+
prompt = f"""
|
| 210 |
+
You are a Data expert who formats data according to the required needs. You are given the question asked by the user, its sql query, the result of the query and the format you need to format it in.
|
| 211 |
+
For the given question: {question}\n\nSQL query: {sql_query}\n\nResult: {data}\n\nFormat this data for visualization type: {visualization}. Just give the json string. Do not format it.
|
| 212 |
+
"""
|
| 213 |
+
response = llm.generate(prompt)
|
| 214 |
+
try:
|
| 215 |
+
formatted_data_for_visualization = json.loads(response)
|
| 216 |
+
return formatted_data_for_visualization
|
| 217 |
+
except json.JSONDecodeError:
|
| 218 |
+
return {"error": "Failed to format data for visualization", "raw_response": response}
|
| 219 |
+
|
| 220 |
+
visualization_map = {
|
| 221 |
+
"none": lambda data: None,
|
| 222 |
+
"scatter": lambda data: _format_scatter_data(data),
|
| 223 |
+
"bar": lambda data, question: _format_bar_data(data, question),
|
| 224 |
+
"horizontal_bar": lambda data, question: _format_bar_data(data, question),
|
| 225 |
+
"line": lambda data, question: _format_line_data(data, question)
|
| 226 |
+
}
|
| 227 |
+
try:
|
| 228 |
+
state["formatted_data_for_visualization"] = visualization_map[visualization](data, question)
|
| 229 |
+
except (KeyError, Exception):
|
| 230 |
+
state["formatted_data_for_visualization"] = _format_other_visualizations(visualization, question, sql_query, data)
|
| 231 |
+
state['step'] = 'format_data_for_visualization'
|
| 232 |
+
return state
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def render_visualization(state: SQLAgentState) -> SQLAgentState:
|
| 238 |
+
"""
|
| 239 |
+
Render the visualization from formatted data.
|
| 240 |
+
Output: path to saved image file.
|
| 241 |
+
"""
|
| 242 |
+
import matplotlib.pyplot as plt
|
| 243 |
+
import os
|
| 244 |
+
from io import BytesIO
|
| 245 |
+
import uuid
|
| 246 |
+
|
| 247 |
+
data = state.get("formatted_data_for_visualization")
|
| 248 |
+
visualization = state.get("visualization", "none")
|
| 249 |
+
|
| 250 |
+
if not data:
|
| 251 |
+
state["visualization_output"] = None
|
| 252 |
+
return state
|
| 253 |
+
|
| 254 |
+
output_dir = "output/plots"
|
| 255 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 256 |
+
|
| 257 |
+
def save_fig(fig):
|
| 258 |
+
file_path = os.path.join(output_dir, f"visualization_{uuid.uuid4().hex[:8]}.png")
|
| 259 |
+
fig.savefig(file_path, format="png", bbox_inches="tight")
|
| 260 |
+
plt.close(fig)
|
| 261 |
+
return file_path
|
| 262 |
+
|
| 263 |
+
def render_line(data):
|
| 264 |
+
fig, ax = plt.subplots()
|
| 265 |
+
x = data["xValues"]
|
| 266 |
+
for series in data["yValues"]:
|
| 267 |
+
ax.plot(x, series["data"], label=series["label"])
|
| 268 |
+
ax.set_xlabel("X")
|
| 269 |
+
ax.set_ylabel(data.get("yAxisLabel", "Y"))
|
| 270 |
+
ax.legend()
|
| 271 |
+
return save_fig(fig)
|
| 272 |
+
|
| 273 |
+
def render_bar(data, horizontal=False):
|
| 274 |
+
fig, ax = plt.subplots()
|
| 275 |
+
labels = data["labels"]
|
| 276 |
+
n_series = len(data["values"])
|
| 277 |
+
width = 0.8 / n_series
|
| 278 |
+
x_indexes = list(range(len(labels)))
|
| 279 |
+
for i, series in enumerate(data["values"]):
|
| 280 |
+
offset = (i - n_series / 2) * width + width / 2
|
| 281 |
+
if horizontal:
|
| 282 |
+
ax.barh(
|
| 283 |
+
[x + offset for x in x_indexes],
|
| 284 |
+
series["data"],
|
| 285 |
+
height=width,
|
| 286 |
+
label=series["label"]
|
| 287 |
+
)
|
| 288 |
+
ax.set_yticks(x_indexes)
|
| 289 |
+
ax.set_yticklabels(labels)
|
| 290 |
+
else:
|
| 291 |
+
ax.bar(
|
| 292 |
+
[x + offset for x in x_indexes],
|
| 293 |
+
series["data"],
|
| 294 |
+
width=width,
|
| 295 |
+
label=series["label"]
|
| 296 |
+
)
|
| 297 |
+
ax.set_xticks(x_indexes)
|
| 298 |
+
ax.set_xticklabels(labels, rotation=45, ha='right')
|
| 299 |
+
ax.legend()
|
| 300 |
+
return save_fig(fig)
|
| 301 |
+
|
| 302 |
+
def render_scatter(data):
|
| 303 |
+
fig, ax = plt.subplots()
|
| 304 |
+
for series in data["series"]:
|
| 305 |
+
xs = [point["x"] for point in series["data"]]
|
| 306 |
+
ys = [point["y"] for point in series["data"]]
|
| 307 |
+
ax.scatter(xs, ys, label=series["label"])
|
| 308 |
+
ax.set_xlabel("X")
|
| 309 |
+
ax.set_ylabel("Y")
|
| 310 |
+
ax.legend()
|
| 311 |
+
return save_fig(fig)
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
if visualization == "line":
|
| 315 |
+
image_path = render_line(data)
|
| 316 |
+
elif visualization == "bar":
|
| 317 |
+
image_path = render_bar(data, horizontal=False)
|
| 318 |
+
elif visualization == "horizontal_bar":
|
| 319 |
+
image_path = render_bar(data, horizontal=True)
|
| 320 |
+
elif visualization == "scatter":
|
| 321 |
+
image_path = render_scatter(data)
|
| 322 |
+
else:
|
| 323 |
+
state["visualization_output"] = None
|
| 324 |
+
return state
|
| 325 |
+
|
| 326 |
+
state["visualization_output"] = image_path
|
| 327 |
+
except Exception as e:
|
| 328 |
+
state["visualization_output"] = None
|
| 329 |
+
state["error"] = f"Failed to render visualization: {str(e)}"
|
| 330 |
+
|
| 331 |
+
state["step"] = "render_visualization"
|
| 332 |
+
return state
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def finalize_output(state: SQLAgentState) -> SQLAgentState:
|
| 336 |
+
"""
|
| 337 |
+
Node hợp nhất kết quả cuối cùng (answer, visualization_output, error, ...).
|
| 338 |
+
Hiện tại chỉ trả về state, có thể mở rộng xử lý sau.
|
| 339 |
+
"""
|
| 340 |
+
state['step'] = 'finalize_output'
|
| 341 |
+
return state
|
| 342 |
+
|
| 343 |
+
# def ingest(state: SQLAgentState) -> SQLAgentState:
|
| 344 |
+
# """Populate state.tables with list of tables in the DB."""
|
| 345 |
+
# db_info = state['db_info']
|
| 346 |
+
# conn = sqlite3.connect(DB_PATH)
|
| 347 |
+
# try:
|
| 348 |
+
# db_info['tables'] = [row[0] for row in conn.execute(
|
| 349 |
+
# "SELECT name FROM sqlite_master WHERE type='table';"
|
| 350 |
+
# )]
|
| 351 |
+
# # Populate columns for each table
|
| 352 |
+
# columns = {}
|
| 353 |
+
# for table in db_info['tables']:
|
| 354 |
+
# col_rows = conn.execute(f'PRAGMA table_info("{table}")').fetchall()
|
| 355 |
+
# columns[table] = [r[1] for r in col_rows]
|
| 356 |
+
# db_info['columns'] = columns
|
| 357 |
+
# state.db_info = db_info
|
| 358 |
+
# finally:
|
| 359 |
+
# conn.close()
|
| 360 |
+
# return state
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
from agents.safe_guardrails import OffTopicValidator
|
| 364 |
+
from guardrails import Guard
|
| 365 |
+
|
| 366 |
+
def detect_off_topic(state: SQLAgentState) -> SQLAgentState:
|
| 367 |
+
"""Check if the input question is off-topic."""
|
| 368 |
+
question = state['question']
|
| 369 |
+
validator = Guard().use(
|
| 370 |
+
OffTopicValidator,
|
| 371 |
+
on_fail="fix"
|
| 372 |
+
)
|
| 373 |
+
metadata = {
|
| 374 |
+
"topic": "Database Queries",
|
| 375 |
+
"additional_context": "The database is about ecommerce products with tables: products, laptops, phones, tablets, promotions, category"
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
validation_result = validator.validate(question, metadata=metadata)
|
| 379 |
+
if validation_result.validated_output == "OFF_TOPIC":
|
| 380 |
+
state['error'] = True
|
| 381 |
+
else:
|
| 382 |
+
state['error'] = False
|
| 383 |
+
state['step'] = 'detect_off_topic'
|
| 384 |
+
state['off_topic'] = validation_result.validated_output
|
| 385 |
+
|
| 386 |
+
print(state)
|
| 387 |
+
return state
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def get_db_info(state: SQLAgentState) -> SQLAgentState:
|
| 391 |
+
"""Get database information."""
|
| 392 |
+
db_info = state['db_info']
|
| 393 |
+
conn = sqlite3.connect(DB_PATH)
|
| 394 |
+
try:
|
| 395 |
+
db_info['tables'] = [row[0] for row in conn.execute(
|
| 396 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 397 |
+
)]
|
| 398 |
+
# Populate columns for each table
|
| 399 |
+
columns = {}
|
| 400 |
+
for table in db_info['tables']:
|
| 401 |
+
col_rows = conn.execute(f'PRAGMA table_info("{table}")').fetchall()
|
| 402 |
+
columns[table] = [r[1] for r in col_rows]
|
| 403 |
+
db_info['columns'] = columns
|
| 404 |
+
schema = "; ".join(f"{t}({', '.join(db_info['columns'][t])})" for t in db_info['tables'])
|
| 405 |
+
db_info['schema'] = schema
|
| 406 |
+
finally:
|
| 407 |
+
conn.close()
|
| 408 |
+
state['step'] = 'get_db_info'
|
| 409 |
+
return state
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def generate_sql(state: SQLAgentState) -> SQLAgentState:
|
| 413 |
+
"""Use LLM to translate user_query into SQL."""
|
| 414 |
+
llm = LLM()
|
| 415 |
+
# Include detailed schema with columns
|
| 416 |
+
schema = state['db_info']['schema']
|
| 417 |
+
prompt = (
|
| 418 |
+
f"Given this database schema: {schema}, "
|
| 419 |
+
f"write an SQL query to: {state['question']}. "
|
| 420 |
+
"Respond with only the SQL enclosed in triple backticks."
|
| 421 |
+
)
|
| 422 |
+
raw = llm.generate(prompt)
|
| 423 |
+
# print('raw', raw)
|
| 424 |
+
lines = raw.splitlines()
|
| 425 |
+
if lines and lines[0].strip().startswith("```"):
|
| 426 |
+
lines = lines[1:]
|
| 427 |
+
if lines and lines[-1].strip().startswith("```"):
|
| 428 |
+
lines = lines[:-1]
|
| 429 |
+
state['sql_query'] = "\n".join(lines).strip()
|
| 430 |
+
state['step'] = 'generate_sql'
|
| 431 |
+
return state
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def execute_sql(state: SQLAgentState) -> SQLAgentState:
|
| 436 |
+
"""Run the SQL in state.sql and store result DataFrame."""
|
| 437 |
+
sql_query = state['sql_query']
|
| 438 |
+
conn = sqlite3.connect(DB_PATH)
|
| 439 |
+
try:
|
| 440 |
+
state['sql_result'] = pd.read_sql_query(sql_query, conn)
|
| 441 |
+
except Exception as e:
|
| 442 |
+
state['error'] = str(e)
|
| 443 |
+
finally:
|
| 444 |
+
conn.close()
|
| 445 |
+
state['step'] = 'execute_sql'
|
| 446 |
+
return state
|
| 447 |
+
|
| 448 |
+
def generate_answer(state: SQLAgentState) -> SQLAgentState:
|
| 449 |
+
"""Generate answer using LLM based on SQL result."""
|
| 450 |
+
llm = LLM()
|
| 451 |
+
if state['sql_result'] is not None and not state['sql_result'].empty:
|
| 452 |
+
result_str = state['sql_result'].to_string(index=False)
|
| 453 |
+
prompt = (
|
| 454 |
+
f"Given the question: {state['question']},\n"
|
| 455 |
+
f"SQL Query: {state['sql_query']},\n"
|
| 456 |
+
f"and the following SQL query result: {result_str},\n"
|
| 457 |
+
"provide a concise answer:"
|
| 458 |
+
)
|
| 459 |
+
state['answer'] = llm.generate(prompt)
|
| 460 |
+
else:
|
| 461 |
+
state['error'] = state['error'] or "No results found."
|
| 462 |
+
if state["off_topic"] == "OFF_TOPIC":
|
| 463 |
+
state['error'] = "The question is off-topic."
|
| 464 |
+
state["answer"] = "Sorry, I can't assist you with that request."
|
| 465 |
+
state['step'] = 'generate_answer'
|
| 466 |
+
return state
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def optional_plot(state: SQLAgentState) -> SQLAgentState:
|
| 470 |
+
"""If user_query requests plotting, generate plot and set state.plot_path."""
|
| 471 |
+
if any(k in state['question'].lower() for k in ['plot', 'vẽ', 'biểu đồ']):
|
| 472 |
+
tool = PlotSQLTool()
|
| 473 |
+
md = tool._run(state['sql_query'])
|
| 474 |
+
m = re.search(r'!\[.*\]\((.*?)\)', md)
|
| 475 |
+
if m:
|
| 476 |
+
state['plot_path'] = m.group(1)
|
| 477 |
+
else:
|
| 478 |
+
state['error'] = state['error'] or 'Plot generation failed'
|
| 479 |
+
return state
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def format_response(state: SQLAgentState) -> SQLAgentState:
|
| 483 |
+
"""Build markdown response including SQL, table preview, and plot."""
|
| 484 |
+
parts = []
|
| 485 |
+
if state['sql_query']:
|
| 486 |
+
parts.append(f"```sql\n{state['sql_query']}\n```")
|
| 487 |
+
if state['sql_result'] is not None:
|
| 488 |
+
parts.append(state['sql_result'].to_markdown(index=False))
|
| 489 |
+
if state['plot_path']:
|
| 490 |
+
parts.append(f"")
|
| 491 |
+
if state['error']:
|
| 492 |
+
parts.append(f"**Error**: {state['error']}")
|
| 493 |
+
state['response_md'] = "\n\n".join(parts)
|
| 494 |
+
return state
|
| 495 |
+
|
| 496 |
+
|
agents/sql_agent/prompts.py
ADDED
|
File without changes
|
agents/sql_agent/states.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Dict
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from typing import TypedDict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SQLAgentState(TypedDict):
|
| 7 |
+
"""
|
| 8 |
+
Carries context through the text→SQL and plotting pipeline:
|
| 9 |
+
- question: original NL input
|
| 10 |
+
- sql_query: generated SQL
|
| 11 |
+
- sql_result: raw query result DataFrame
|
| 12 |
+
- answer: final answer
|
| 13 |
+
- error: any error messages
|
| 14 |
+
"""
|
| 15 |
+
question: str
|
| 16 |
+
db_info: dict
|
| 17 |
+
sql_query: str
|
| 18 |
+
sql_result: Optional[pd.DataFrame] = None
|
| 19 |
+
answer: str = ""
|
| 20 |
+
error: Optional[str] = None
|
| 21 |
+
plot_path: Optional[str] = None
|
| 22 |
+
response_md: str = ""
|
| 23 |
+
step: Optional[str] = None
|
| 24 |
+
visualization: Optional[str] = None
|
| 25 |
+
visualization_reason: Optional[str] = None
|
| 26 |
+
formatted_data_for_visualization: Optional[dict] = None
|
| 27 |
+
visualization_output: Optional[str] = None
|
| 28 |
+
off_topic: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
|
agents/tools.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from langchain.tools import BaseTool
|
| 4 |
+
import os
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from utils.consts import DB_PATH, PLOTS_DIR
|
| 7 |
+
|
| 8 |
+
# Fetch table list
|
| 9 |
+
_conn = sqlite3.connect(DB_PATH)
|
| 10 |
+
_TABLES = [row[0] for row in _conn.execute("SELECT name FROM sqlite_master WHERE type='table';")]
|
| 11 |
+
_conn.close()
|
| 12 |
+
_TABLES_LIST = ", ".join(_TABLES)
|
| 13 |
+
|
| 14 |
+
class SQLiteQueryTool(BaseTool):
|
| 15 |
+
name: str = "sqlite_query"
|
| 16 |
+
description: str = f"Executes a SQL query against the ecommerce SQLite database and returns results as CSV. Available tables: {_TABLES_LIST}."
|
| 17 |
+
|
| 18 |
+
def _run(self, query: str) -> str:
|
| 19 |
+
print(f"[SQLiteQueryTool] Executing query: {query}")
|
| 20 |
+
conn = sqlite3.connect(DB_PATH)
|
| 21 |
+
try:
|
| 22 |
+
df = pd.read_sql_query(query, conn)
|
| 23 |
+
return df.to_csv(index=False)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
return f"SQL Error: {e}"
|
| 26 |
+
finally:
|
| 27 |
+
conn.close()
|
| 28 |
+
|
| 29 |
+
async def _arun(self, query: str) -> str:
|
| 30 |
+
raise NotImplementedError("Async not supported for SQLiteQueryTool")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class PlotSQLTool(BaseTool):
|
| 34 |
+
name: str = "plot_sql"
|
| 35 |
+
description: str = f"Executes a SQL query and generates a plot saved as a PNG; returns markdown image link. Available tables: {_TABLES_LIST}."
|
| 36 |
+
|
| 37 |
+
def _run(self, query: str) -> str:
|
| 38 |
+
print(f"[PlotSQLTool] Executing query: {query}")
|
| 39 |
+
conn = sqlite3.connect(DB_PATH)
|
| 40 |
+
try:
|
| 41 |
+
df = pd.read_sql_query(query, conn)
|
| 42 |
+
plt.figure()
|
| 43 |
+
df.plot(kind='bar' if df.shape[1] > 1 else 'line', legend=False)
|
| 44 |
+
timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
|
| 45 |
+
filename = f"plot_{timestamp}.png"
|
| 46 |
+
# Save plot to configured output directory
|
| 47 |
+
output_dir = PLOTS_DIR
|
| 48 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 49 |
+
filepath = os.path.join(output_dir, filename)
|
| 50 |
+
plt.tight_layout()
|
| 51 |
+
plt.savefig(filepath)
|
| 52 |
+
plt.close()
|
| 53 |
+
return f""
|
| 54 |
+
except Exception as e:
|
| 55 |
+
return f"Plot Error: {e}"
|
| 56 |
+
finally:
|
| 57 |
+
conn.close()
|
| 58 |
+
|
| 59 |
+
async def _arun(self, query: str) -> str:
|
| 60 |
+
raise NotImplementedError("Async not supported for PlotSQLTool")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
app.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from utils.consts import DB_PATH
|
| 4 |
+
import sqlite3
|
| 5 |
+
import re
|
| 6 |
+
import os
|
| 7 |
+
from agents.sql_agent.agent import SQLAgent
|
| 8 |
+
import time
|
| 9 |
+
from agents.tools import PlotSQLTool
|
| 10 |
+
from agents.dataframe_agent import get_dataframe_agent
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
db_name = os.path.basename(DB_PATH)
|
| 14 |
+
|
| 15 |
+
st.set_page_config(page_title="🔍 TalkToData", layout="wide", initial_sidebar_state="collapsed")
|
| 16 |
+
|
| 17 |
+
# Loại bỏ title markdown để tránh hiển thị lặp lại
|
| 18 |
+
# Sidebar for settings
|
| 19 |
+
with st.sidebar:
|
| 20 |
+
st.header("ℹ️ About", anchor=None)
|
| 21 |
+
st.markdown("""
|
| 22 |
+
**TalkToData** v0.1.0
|
| 23 |
+
Your personal AI Data Analyst.
|
| 24 |
+
""", unsafe_allow_html=True)
|
| 25 |
+
|
| 26 |
+
# Initialize chat history
|
| 27 |
+
if 'chat_history' not in st.session_state:
|
| 28 |
+
st.session_state.chat_history = []
|
| 29 |
+
|
| 30 |
+
# Initialize SQL agent
|
| 31 |
+
# agent = get_sql_agent()
|
| 32 |
+
|
| 33 |
+
agent = SQLAgent()
|
| 34 |
+
state = {
|
| 35 |
+
"question": None,
|
| 36 |
+
"db_info": {
|
| 37 |
+
"tables": [],
|
| 38 |
+
"columns": {},
|
| 39 |
+
"schema": None
|
| 40 |
+
},
|
| 41 |
+
"sql_query": None,
|
| 42 |
+
"sql_result": None,
|
| 43 |
+
"error": None,
|
| 44 |
+
"step": None,
|
| 45 |
+
"answer": None
|
| 46 |
+
}
|
| 47 |
+
# --- Upload Screen State ---
|
| 48 |
+
if 'files_uploaded' not in st.session_state:
|
| 49 |
+
st.session_state['files_uploaded'] = False
|
| 50 |
+
|
| 51 |
+
# TEMP: Bypass landing page
|
| 52 |
+
st.session_state['files_uploaded'] = True
|
| 53 |
+
|
| 54 |
+
if not st.session_state['files_uploaded']:
|
| 55 |
+
# CSS to center and enlarge only the welcome start button
|
| 56 |
+
st.markdown("""
|
| 57 |
+
<style>
|
| 58 |
+
.welcome .stButton { display: flex; justify-content: center; }
|
| 59 |
+
.welcome .stButton button { font-size:2.5rem !important; padding:1.25rem 2rem !important; }
|
| 60 |
+
</style>
|
| 61 |
+
""", unsafe_allow_html=True)
|
| 62 |
+
# Wrap welcome content to scope styling
|
| 63 |
+
st.markdown("<div class='welcome' style='max-width:600px;margin:auto;text-align:center;'>", unsafe_allow_html=True)
|
| 64 |
+
# Title and subtitle
|
| 65 |
+
st.markdown("""
|
| 66 |
+
<h1 style='text-align:center; margin-bottom:0;'>🔍 TalkToData</h1>
|
| 67 |
+
<h3 style='text-align:center; color:gray;'>Your Personal AI Data Analyst that instantly answers your data questions with clear insights and elegant visualizations.</h3>
|
| 68 |
+
""", unsafe_allow_html=True)
|
| 69 |
+
# Standalone welcome start button
|
| 70 |
+
if st.button("🚀 Explore now", key="start"):
|
| 71 |
+
st.session_state['files_uploaded'] = True
|
| 72 |
+
st.experimental_rerun()
|
| 73 |
+
# Close welcome wrapper
|
| 74 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 75 |
+
st.divider()
|
| 76 |
+
# SaaS-style Features section
|
| 77 |
+
st.markdown("## Features")
|
| 78 |
+
feat_cols = st.columns(3)
|
| 79 |
+
feat_cols[0].markdown("### 🗣 Natural-Language Queries\nAsk your data without SQL knowledge.")
|
| 80 |
+
feat_cols[1].markdown("### 📊 Instant Visualizations\nGet charts from one command.")
|
| 81 |
+
feat_cols[2].markdown("### 🔒 Secure & Local\nYour data stays on your machine.")
|
| 82 |
+
st.divider()
|
| 83 |
+
# How It Works section
|
| 84 |
+
st.markdown("## How It Works")
|
| 85 |
+
step_cols = st.columns(3)
|
| 86 |
+
step_cols[0].markdown("#### 1️⃣ Upload\nUpload .db or CSV files.")
|
| 87 |
+
step_cols[1].markdown("#### 2️⃣ Chat\nInteract in natural language.")
|
| 88 |
+
step_cols[2].markdown("#### 3️⃣ Visualize\nSee results as tables or charts.")
|
| 89 |
+
st.divider()
|
| 90 |
+
# Use Cases
|
| 91 |
+
st.markdown("## Use Cases")
|
| 92 |
+
st.markdown("- \"Show me top 5 products by sales\" → Chart")
|
| 93 |
+
st.markdown("- \"List customers from 2020\" → Table")
|
| 94 |
+
st.divider()
|
| 95 |
+
# Testimonials
|
| 96 |
+
st.markdown("## Testimonials")
|
| 97 |
+
testi_cols = st.columns(2)
|
| 98 |
+
testi_cols[0].markdown("> \"TalkToData transformed our data workflow!\" \n— Jane Doe, Data Analyst")
|
| 99 |
+
testi_cols[1].markdown("> \"The AI assistant is incredibly smart and fast.\" \n— John Smith, Product Manager")
|
| 100 |
+
st.divider()
|
| 101 |
+
# Footer
|
| 102 |
+
st.markdown("2025 TalkToData. All rights reserved.")
|
| 103 |
+
|
| 104 |
+
st.markdown("<p style='text-align: center; color: gray;'>TalkToData v0.1.0 - Copyright 2025 by <a href='https://github.com/phamdinhkhanh'>Khanh Pham</a></p>", unsafe_allow_html=True)
|
| 105 |
+
st.html(
|
| 106 |
+
"<p><span style='text-decoration: line-through double red;'>Oops</span>!</p>"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
st.divider()
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
# App title and return button
|
| 113 |
+
# st.title("🔍 TalkToData")
|
| 114 |
+
st.markdown("### TalkToData")
|
| 115 |
+
# TEMP: Commented out back-to-home
|
| 116 |
+
# if st.button('⬅️ Back to Home', key='back_to_upload'):
|
| 117 |
+
# st.session_state['files_uploaded'] = False
|
| 118 |
+
# # Xóa dữ liệu cũ
|
| 119 |
+
# if 'uploaded_csvs' in st.session_state:
|
| 120 |
+
# st.session_state['uploaded_csvs'] = []
|
| 121 |
+
# st.experimental_rerun()
|
| 122 |
+
# Layout: Data source selector, main content, and chat
|
| 123 |
+
data_col, left_col, right_col = st.columns([1.5, 3, 2])
|
| 124 |
+
# Data source selection
|
| 125 |
+
with data_col:
|
| 126 |
+
# st.subheader("Data Sources")
|
| 127 |
+
# Upload data
|
| 128 |
+
with st.expander("**Upload Data**", expanded=True):
|
| 129 |
+
st.file_uploader('Select SQLite (.db), CSV or Excel (.xlsx) files',
|
| 130 |
+
type=['db', 'csv', 'xlsx'],
|
| 131 |
+
accept_multiple_files=True,
|
| 132 |
+
key='upload_any_col',
|
| 133 |
+
label_visibility="collapsed")
|
| 134 |
+
gsheet_url = st.text_input('Enter Google Sheets URL (optional)', '', key='gsheet_url')
|
| 135 |
+
upload_status = []
|
| 136 |
+
has_db = False
|
| 137 |
+
has_csv = False
|
| 138 |
+
|
| 139 |
+
# Retrieve uploaded files list safely
|
| 140 |
+
uploaded_files = st.session_state.get('upload_any_col', [])
|
| 141 |
+
# Process Google Sheets if URL provided
|
| 142 |
+
url = st.session_state.get('gsheet_url', '').strip()
|
| 143 |
+
if url:
|
| 144 |
+
try:
|
| 145 |
+
csv_url = url.replace('/edit#gid=', '/export?format=csv&gid=')
|
| 146 |
+
df_gs = pd.read_csv(csv_url)
|
| 147 |
+
if 'uploaded_csvs' not in st.session_state:
|
| 148 |
+
st.session_state['uploaded_csvs'] = []
|
| 149 |
+
st.session_state['uploaded_csvs'].append({'name': 'GoogleSheets', 'df': df_gs})
|
| 150 |
+
upload_status.append('✅ Google Sheets loaded')
|
| 151 |
+
has_csv = True
|
| 152 |
+
except Exception as e:
|
| 153 |
+
upload_status.append(f'❌ Google Sheets error: {e}')
|
| 154 |
+
|
| 155 |
+
# Process files
|
| 156 |
+
for f in uploaded_files:
|
| 157 |
+
if f.name.lower().endswith('.db'):
|
| 158 |
+
try:
|
| 159 |
+
with open(DB_PATH, "wb") as dbf:
|
| 160 |
+
dbf.write(f.read())
|
| 161 |
+
upload_status.append(f"✅ Database: {f.name}")
|
| 162 |
+
has_db = True
|
| 163 |
+
except Exception as e:
|
| 164 |
+
upload_status.append(f"❌ Database error: {e}")
|
| 165 |
+
|
| 166 |
+
# Process CSV and Excel
|
| 167 |
+
name = f.name.lower()
|
| 168 |
+
if name.endswith('.csv') or name.endswith('.xlsx'):
|
| 169 |
+
try:
|
| 170 |
+
if name.endswith('.xlsx'):
|
| 171 |
+
# Process each sheet in Excel
|
| 172 |
+
f.seek(0)
|
| 173 |
+
xls = pd.ExcelFile(f)
|
| 174 |
+
sheets = st.multiselect(f"Select sheet(s) from {f.name}", xls.sheet_names, default=xls.sheet_names)
|
| 175 |
+
for sheet in sheets:
|
| 176 |
+
# Read raw to detect header rows
|
| 177 |
+
raw = xls.parse(sheet, header=None)
|
| 178 |
+
nn = raw.notnull().sum(axis=1)
|
| 179 |
+
hdr = [i for i, cnt in enumerate(nn) if cnt > 1]
|
| 180 |
+
if len(hdr) >= 2:
|
| 181 |
+
header = hdr[:2]
|
| 182 |
+
elif len(hdr) == 1:
|
| 183 |
+
header = [hdr[0]]
|
| 184 |
+
else:
|
| 185 |
+
header = [0]
|
| 186 |
+
df_sheet = xls.parse(sheet, header=header)
|
| 187 |
+
# Flatten MultiIndex if needed
|
| 188 |
+
if isinstance(df_sheet.columns, pd.MultiIndex):
|
| 189 |
+
df_sheet.columns = [" ".join([str(x) for x in col if pd.notna(x)]).strip() for col in df_sheet.columns]
|
| 190 |
+
# Store with sheet label
|
| 191 |
+
sheet_key = f"{f.name}:{sheet}"
|
| 192 |
+
if 'uploaded_csvs' not in st.session_state:
|
| 193 |
+
st.session_state['uploaded_csvs'] = []
|
| 194 |
+
st.session_state['uploaded_csvs'].append({'name': sheet_key, 'df': df_sheet})
|
| 195 |
+
upload_status.append(f"✅ Excel: {sheet_key}")
|
| 196 |
+
else:
|
| 197 |
+
temp_df = pd.read_csv(f)
|
| 198 |
+
|
| 199 |
+
if 'uploaded_csvs' not in st.session_state:
|
| 200 |
+
st.session_state['uploaded_csvs'] = []
|
| 201 |
+
|
| 202 |
+
# Check existing and update
|
| 203 |
+
csv_exists = False
|
| 204 |
+
for i, csv in enumerate(st.session_state['uploaded_csvs']):
|
| 205 |
+
if csv['name'] == f.name:
|
| 206 |
+
st.session_state['uploaded_csvs'][i]['df'] = temp_df
|
| 207 |
+
csv_exists = True
|
| 208 |
+
break
|
| 209 |
+
if not csv_exists:
|
| 210 |
+
st.session_state['uploaded_csvs'].append({'name': f.name, 'df': temp_df})
|
| 211 |
+
upload_status.append(f"✅ CSV/Excel: {f.name}")
|
| 212 |
+
has_csv = True
|
| 213 |
+
except Exception as e:
|
| 214 |
+
upload_status.append(f"❌ CSV/Excel error: {e}")
|
| 215 |
+
|
| 216 |
+
# Hiển thị trạng thái upload
|
| 217 |
+
if upload_status:
|
| 218 |
+
for status in upload_status:
|
| 219 |
+
st.write(status)
|
| 220 |
+
# After upload, select data sources
|
| 221 |
+
ds = []
|
| 222 |
+
if os.path.exists(DB_PATH) and os.path.getsize(DB_PATH) > 0:
|
| 223 |
+
ds.append(db_name)
|
| 224 |
+
if 'uploaded_csvs' in st.session_state:
|
| 225 |
+
ds += [csv['name'] for csv in st.session_state['uploaded_csvs']]
|
| 226 |
+
if ds:
|
| 227 |
+
# Initialize selected_sources session state to default to db_name
|
| 228 |
+
if 'selected_sources' not in st.session_state:
|
| 229 |
+
st.session_state['selected_sources'] = [db_name] if db_name in ds else []
|
| 230 |
+
selected_sources = st.multiselect(
|
| 231 |
+
"**Select sources**", options=ds,
|
| 232 |
+
key='selected_sources'
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
st.info("Upload a database or CSV/Excel file to select a data source.")
|
| 236 |
+
|
| 237 |
+
with left_col:
|
| 238 |
+
# Data Preview: filter sources by user selection
|
| 239 |
+
selected = st.session_state.get('selected_sources', [])
|
| 240 |
+
preview_db = os.path.exists(DB_PATH) and db_name in selected
|
| 241 |
+
# Filter CSV/Excel previews
|
| 242 |
+
preview_csvs = [csv for csv in st.session_state.get('uploaded_csvs', []) if csv['name'] in selected]
|
| 243 |
+
if preview_db or preview_csvs:
|
| 244 |
+
# Display previews
|
| 245 |
+
with st.container(height=415):
|
| 246 |
+
st.markdown("**Data Preview**")
|
| 247 |
+
# Build tab labels
|
| 248 |
+
tab_labels = []
|
| 249 |
+
if preview_db:
|
| 250 |
+
tab_labels.append(db_name)
|
| 251 |
+
for c in preview_csvs:
|
| 252 |
+
tab_labels.append(c['name'])
|
| 253 |
+
tabs = st.tabs(tab_labels)
|
| 254 |
+
idx = 0
|
| 255 |
+
# Database preview
|
| 256 |
+
if preview_db:
|
| 257 |
+
with tabs[idx]:
|
| 258 |
+
conn = sqlite3.connect(DB_PATH)
|
| 259 |
+
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
| 260 |
+
if tables:
|
| 261 |
+
t_tabs = st.tabs([t[0] for t in tables])
|
| 262 |
+
for t, tab in zip(tables, t_tabs):
|
| 263 |
+
with tab:
|
| 264 |
+
st.table(pd.read_sql_query(f"SELECT * FROM {t[0]}", conn))
|
| 265 |
+
else:
|
| 266 |
+
st.info("No tables found.")
|
| 267 |
+
conn.close()
|
| 268 |
+
idx += 1
|
| 269 |
+
# CSV/Excel previews
|
| 270 |
+
for c in preview_csvs:
|
| 271 |
+
with tabs[idx]:
|
| 272 |
+
st.table(c['df'])
|
| 273 |
+
idx += 1
|
| 274 |
+
|
| 275 |
+
# --- Data Exploration Section (Always Visible) ---
|
| 276 |
+
with st.container(height=225):
|
| 277 |
+
# Data Exploration: only support Database source
|
| 278 |
+
selected = st.session_state.get('selected_sources', [])
|
| 279 |
+
if db_name not in selected:
|
| 280 |
+
st.warning(f"⚠️ Data Exploration only supports SQL queries on database .db files. Please select at least a database to continue.")
|
| 281 |
+
else:
|
| 282 |
+
# st.subheader("Data Exploration")
|
| 283 |
+
sql_explore = st.text_area(
|
| 284 |
+
"Enter SQL query to explore:",
|
| 285 |
+
value=st.session_state.get('explore_sql', ''),
|
| 286 |
+
height=100,
|
| 287 |
+
key='explore_sql'
|
| 288 |
+
)
|
| 289 |
+
if st.button("Run Query", key="explore_run"):
|
| 290 |
+
try:
|
| 291 |
+
df_explore = pd.read_sql_query(sql_explore, sqlite3.connect(DB_PATH))
|
| 292 |
+
st.session_state['explore_result'] = df_explore
|
| 293 |
+
# Record exploration history
|
| 294 |
+
if 'explore_history' not in st.session_state:
|
| 295 |
+
st.session_state['explore_history'] = []
|
| 296 |
+
# User query
|
| 297 |
+
st.session_state['explore_history'].append({
|
| 298 |
+
'source': 'explore', 'role': 'user', 'content': sql_explore, 'timestamp': datetime.now()
|
| 299 |
+
})
|
| 300 |
+
# Assistant result as CSV
|
| 301 |
+
res_str = df_explore.to_csv(index=False)
|
| 302 |
+
st.session_state['explore_history'].append({
|
| 303 |
+
'source': 'explore', 'role': 'assistant', 'content': res_str, 'timestamp': datetime.now()
|
| 304 |
+
})
|
| 305 |
+
except Exception as e:
|
| 306 |
+
st.error(f"Error: {e}")
|
| 307 |
+
# Wrap tabs in scrollable container
|
| 308 |
+
with st.container(height=300):
|
| 309 |
+
# st.markdown("<div style='height:300px; overflow:auto'>", unsafe_allow_html=True)
|
| 310 |
+
tabs = st.tabs(["Results", "History"])
|
| 311 |
+
# Results tab: show explore_result only
|
| 312 |
+
with tabs[0]:
|
| 313 |
+
if 'explore_result' in st.session_state:
|
| 314 |
+
# st.subheader("Results")
|
| 315 |
+
st.table(st.session_state['explore_result'])
|
| 316 |
+
else:
|
| 317 |
+
st.write("No results yet.")
|
| 318 |
+
# History tab: Query history
|
| 319 |
+
with tabs[1]:
|
| 320 |
+
# st.subheader("History")
|
| 321 |
+
# Build paired history entries
|
| 322 |
+
combined = []
|
| 323 |
+
# Exploration history pairs
|
| 324 |
+
explore_hist = st.session_state.get('explore_history', [])
|
| 325 |
+
for i in range(0, len(explore_hist), 2):
|
| 326 |
+
u = explore_hist[i] if i < len(explore_hist) else {}
|
| 327 |
+
a = explore_hist[i+1] if i+1 < len(explore_hist) else {}
|
| 328 |
+
combined.append({
|
| 329 |
+
'source': db_name,
|
| 330 |
+
'query_type': 'sql',
|
| 331 |
+
'query': u.get('content'),
|
| 332 |
+
'result': a.get('content'),
|
| 333 |
+
'timestamp': u.get('timestamp')
|
| 334 |
+
})
|
| 335 |
+
# Chat history pairs for all sources
|
| 336 |
+
for source, chat_hist in st.session_state.get('chat_histories', {}).items():
|
| 337 |
+
for idx in range(len(chat_hist)):
|
| 338 |
+
if chat_hist[idx].get('role') == 'user':
|
| 339 |
+
q = chat_hist[idx].get('content')
|
| 340 |
+
r = chat_hist[idx+1].get('content') if idx+1 < len(chat_hist) else None
|
| 341 |
+
combined.append({
|
| 342 |
+
'source': source,
|
| 343 |
+
'query_type': 'chat',
|
| 344 |
+
'query': q,
|
| 345 |
+
'result': r,
|
| 346 |
+
'timestamp': chat_hist[idx].get('timestamp')
|
| 347 |
+
})
|
| 348 |
+
if combined:
|
| 349 |
+
df_history = pd.DataFrame(combined)
|
| 350 |
+
# ensure timestamp column is datetime
|
| 351 |
+
if not pd.api.types.is_datetime64_any_dtype(df_history['timestamp']):
|
| 352 |
+
df_history['timestamp'] = pd.to_datetime(df_history['timestamp'])
|
| 353 |
+
# sort latest first
|
| 354 |
+
df_history = df_history.sort_values('timestamp', ascending=False)
|
| 355 |
+
st.table(df_history)
|
| 356 |
+
else:
|
| 357 |
+
st.write("No history yet.")
|
| 358 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
| 359 |
+
|
| 360 |
+
with right_col:
|
| 361 |
+
|
| 362 |
+
# Use selected_sources from left data selector
|
| 363 |
+
data_sources = st.session_state.get('selected_sources', [])
|
| 364 |
+
csv_files = st.session_state.get('uploaded_csvs', [])
|
| 365 |
+
selected_source = data_sources[0] if data_sources else None
|
| 366 |
+
|
| 367 |
+
# Chat history per source (only if a source is selected)
|
| 368 |
+
if 'chat_histories' not in st.session_state:
|
| 369 |
+
st.session_state['chat_histories'] = {}
|
| 370 |
+
# Initialize past conversations container
|
| 371 |
+
if 'all_conversations' not in st.session_state:
|
| 372 |
+
st.session_state['all_conversations'] = {}
|
| 373 |
+
|
| 374 |
+
# Only proceed with chat if a data source is selected
|
| 375 |
+
if selected_source is not None:
|
| 376 |
+
if selected_source not in st.session_state['chat_histories']:
|
| 377 |
+
st.session_state['chat_histories'][selected_source] = []
|
| 378 |
+
if selected_source not in st.session_state['all_conversations']:
|
| 379 |
+
st.session_state['all_conversations'][selected_source] = []
|
| 380 |
+
chat_history = st.session_state['chat_histories'][selected_source]
|
| 381 |
+
|
| 382 |
+
# Only show chat interface if a data source is selected
|
| 383 |
+
if selected_source is not None:
|
| 384 |
+
container = st.container(height=700, border=True)
|
| 385 |
+
# Align New Conversation button top-right
|
| 386 |
+
with container:
|
| 387 |
+
cols = st.columns([2, 1])
|
| 388 |
+
with cols[0]:
|
| 389 |
+
st.markdown("**Ask TalkToData**")
|
| 390 |
+
if cols[1].button("New Chat", key=f"new_conv_{selected_source}"):
|
| 391 |
+
if chat_history:
|
| 392 |
+
conv = chat_history.copy()
|
| 393 |
+
ts = conv[0].get('timestamp', datetime.now())
|
| 394 |
+
st.session_state['all_conversations'][selected_source].append({'messages':conv, 'timestamp':ts})
|
| 395 |
+
st.session_state['chat_histories'][selected_source] = []
|
| 396 |
+
st.experimental_rerun()
|
| 397 |
+
|
| 398 |
+
# Display chat messages
|
| 399 |
+
chat_history = st.session_state['chat_histories'][selected_source]
|
| 400 |
+
# Welcome message for new chat
|
| 401 |
+
if not chat_history:
|
| 402 |
+
container.chat_message("assistant").write("👋 Hello! Welcome to TalkToData. Ask any question about your data to get started.")
|
| 403 |
+
for turn in chat_history:
|
| 404 |
+
role = turn.get('role', '')
|
| 405 |
+
content = turn.get('content', '')
|
| 406 |
+
if role == 'user':
|
| 407 |
+
container.chat_message("user").write(content)
|
| 408 |
+
else:
|
| 409 |
+
container.chat_message("assistant").write(content)
|
| 410 |
+
|
| 411 |
+
# Chat input
|
| 412 |
+
user_input = st.chat_input(f"Ask a question about {selected_source}...")
|
| 413 |
+
else:
|
| 414 |
+
# Placeholder to maintain layout
|
| 415 |
+
st.container(height=700, border=True)
|
| 416 |
+
user_input = None
|
| 417 |
+
if user_input:
|
| 418 |
+
chat_history.append({"role": "user", "content": user_input, "timestamp": datetime.now()})
|
| 419 |
+
with container.chat_message("user"):
|
| 420 |
+
st.write(user_input)
|
| 421 |
+
# Answer logic
|
| 422 |
+
with container.chat_message("assistant"):
|
| 423 |
+
with st.spinner("Thinking..."):
|
| 424 |
+
if selected_source == db_name:
|
| 425 |
+
# Handle /sql and /plot commands
|
| 426 |
+
if user_input.strip().lower().startswith('/sql'):
|
| 427 |
+
sql = user_input[len('/sql'):].strip()
|
| 428 |
+
try:
|
| 429 |
+
df = pd.read_sql_query(sql, sqlite3.connect(DB_PATH))
|
| 430 |
+
st.write(f"```sql\n{sql}\n```")
|
| 431 |
+
st.table(df)
|
| 432 |
+
chat_history.append({"role": "assistant", "content": f"```sql\n{sql}\n```", "timestamp": datetime.now()})
|
| 433 |
+
except Exception as e:
|
| 434 |
+
err = f"SQL Error: {e}"
|
| 435 |
+
st.error(err)
|
| 436 |
+
chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
|
| 437 |
+
elif user_input.strip().lower().startswith('/plot'):
|
| 438 |
+
sql = user_input[len('/plot'):].strip()
|
| 439 |
+
try:
|
| 440 |
+
tool = PlotSQLTool()
|
| 441 |
+
md = tool._run(sql)
|
| 442 |
+
st.markdown(md)
|
| 443 |
+
m = re.search(r'!\[.*\]\((.*?)\)', md)
|
| 444 |
+
if m:
|
| 445 |
+
st.image(m.group(1))
|
| 446 |
+
chat_history.append({"role": "assistant", "content": md, "timestamp": datetime.now()})
|
| 447 |
+
except Exception as e:
|
| 448 |
+
err = f"Plot Error: {e}"
|
| 449 |
+
st.error(err)
|
| 450 |
+
chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
|
| 451 |
+
else:
|
| 452 |
+
# Use SQL agent as before
|
| 453 |
+
state['question'] = user_input
|
| 454 |
+
try:
|
| 455 |
+
for step in agent.graph.stream(state, stream_mode="updates"):
|
| 456 |
+
step_name, step_details = next(iter(step.items()))
|
| 457 |
+
if step_name == 'generate_sql':
|
| 458 |
+
with st.expander("SQL Generated", expanded=False):
|
| 459 |
+
st.markdown(f"```sql\n{step_details.get('sql_query', '')}\n```")
|
| 460 |
+
elif step_name == 'execute_sql':
|
| 461 |
+
with st.expander("SQL Result", expanded=False):
|
| 462 |
+
st.table(step_details.get('sql_result', pd.DataFrame()))
|
| 463 |
+
elif step_name == 'generate_answer':
|
| 464 |
+
st.write(step_details.get('answer', ''))
|
| 465 |
+
chat_history.append({"role": "assistant", "content": step_details.get('answer', ''), "timestamp": datetime.now()})
|
| 466 |
+
elif step_name == 'render_visualization':
|
| 467 |
+
# with st.expander("Chart", expanded=False):
|
| 468 |
+
st.image(step_details.get('visualization_output', ''))
|
| 469 |
+
except Exception as e:
|
| 470 |
+
err = f"SQL Agent Error: {e}"
|
| 471 |
+
st.error(err)
|
| 472 |
+
chat_history.append({"role": "assistant", "content": err, "timestamp": datetime.now()})
|
| 473 |
+
else:
|
| 474 |
+
# Use DataFrame agent for selected CSV
|
| 475 |
+
csv_file = next((csv for csv in csv_files if csv['name'] == selected_source), None)
|
| 476 |
+
if csv_file:
|
| 477 |
+
if 'csv_agents' not in st.session_state:
|
| 478 |
+
st.session_state['csv_agents'] = {}
|
| 479 |
+
if selected_source not in st.session_state['csv_agents']:
|
| 480 |
+
st.session_state['csv_agents'][selected_source] = get_dataframe_agent(csv_file['df'])
|
| 481 |
+
agent = st.session_state['csv_agents'][selected_source]
|
| 482 |
+
try:
|
| 483 |
+
response = agent.invoke(user_input)
|
| 484 |
+
answer = response["output"] if isinstance(response, dict) and "output" in response else str(response)
|
| 485 |
+
except Exception as e:
|
| 486 |
+
answer = f"CSV Agent Error: {e}"
|
| 487 |
+
st.write(answer)
|
| 488 |
+
chat_history.append({"role": "assistant", "content": answer, "timestamp": datetime.now()})
|
| 489 |
+
# Refresh to update History immediately
|
| 490 |
+
# st.experimental_rerun()
|
| 491 |
+
|
| 492 |
+
# Past Conversations Panel
|
| 493 |
+
with st.container(height=200):
|
| 494 |
+
st.markdown("**Recent Conversations**")
|
| 495 |
+
# Flatten and sort conversations by most recent first
|
| 496 |
+
entries = []
|
| 497 |
+
for source, convs in st.session_state.get('all_conversations', {}).items():
|
| 498 |
+
for conv in convs:
|
| 499 |
+
entries.append((source, conv))
|
| 500 |
+
entries = sorted(entries, key=lambda x: x[1]['timestamp'], reverse=True)
|
| 501 |
+
for source, conv in entries:
|
| 502 |
+
label = conv['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
|
| 503 |
+
with st.expander(f"{source} - {label}", expanded=False):
|
| 504 |
+
for msg in conv['messages']:
|
| 505 |
+
if msg.get('role') == 'user':
|
| 506 |
+
st.chat_message('user').write(msg.get('content'))
|
| 507 |
+
else:
|
| 508 |
+
st.chat_message('assistant').write(msg.get('content'))
|
db/csv/category.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,name
|
| 2 |
+
1,Laptop
|
| 3 |
+
2,Tablet
|
| 4 |
+
3,Smartphone
|
| 5 |
+
4,Accessory
|
| 6 |
+
5,Wearable
|
db/csv/laptop.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,product_id,ram,storage,processor
|
| 2 |
+
1,1,8,256,Intel i5
|
| 3 |
+
2,2,16,512,Intel i7
|
| 4 |
+
3,7,32,1024,Intel i9
|
| 5 |
+
4,1,4,128,Intel i3
|
| 6 |
+
5,7,64,2048,AMD Ryzen 9
|
db/csv/product.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,name,price,promotion_id,category_id
|
| 2 |
+
1,Basic Laptop,799.99,1,1
|
| 3 |
+
2,High-end Laptop,1999.99,2,1
|
| 4 |
+
3,Standard Tablet,499.99,1,2
|
| 5 |
+
4,Pro Tablet,999.99,3,2
|
| 6 |
+
5,Smartphone Model A,699.99,2,3
|
| 7 |
+
6,Smartphone Model B,899.99,3,3
|
| 8 |
+
7,Ultra Laptop,2499.99,4,1
|
| 9 |
+
8,Mini Tablet,299.99,4,2
|
| 10 |
+
9,Smartphone Model C,799.99,5,3
|
| 11 |
+
10,Smartphone Model D,999.99,6,3
|
db/csv/promotion.csv
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,description,discount
|
| 2 |
+
1,Spring Sale,0.1
|
| 3 |
+
2,Black Friday,0.25
|
| 4 |
+
3,Clearance,0.5
|
| 5 |
+
4,Summer Sale,0.15
|
| 6 |
+
5,Cyber Monday,0.20
|
| 7 |
+
6,Holiday Sale,0.30
|
db/csv/smartphone.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,product_id,camera_megapixels,os,battery
|
| 2 |
+
1,5,12.0,Android,3000
|
| 3 |
+
2,6,48.0,iOS,3500
|
| 4 |
+
3,5,16.0,Android,3200
|
| 5 |
+
4,6,20.0,iOS,3100
|
| 6 |
+
5,5,64.0,Android,3400
|
| 7 |
+
6,6,108.0,iOS,3800
|
| 8 |
+
7,5,12.0,Android,2900
|
| 9 |
+
8,6,48.0,Android,3300
|
| 10 |
+
9,5,32.0,Android,3500
|
| 11 |
+
10,6,24.0,iOS,3000
|
db/csv/tablet.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,product_id,screen_size,battery,support_sim
|
| 2 |
+
1,3,10.1,6000,0
|
| 3 |
+
2,4,12.9,8000,1
|
| 4 |
+
3,8,8.0,4200,1
|
| 5 |
+
4,4,13.3,10000,0
|
db/sample_ecommerce.db
ADDED
|
Binary file (32.8 kB). View file
|
|
|
pytest.ini
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
# Only run tests in the pytest_tests directory
|
| 3 |
+
testpaths = tests/pytest_tests
|
| 4 |
+
python_files = test_*.py
|
| 5 |
+
log_cli = true
|
| 6 |
+
log_cli_level = INFO
|
| 7 |
+
log_cli_format = %(asctime)s [%(levelname)s] %(message)s
|
| 8 |
+
log_cli_date_format = %H:%M:%S
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit== 1.36.0
|
| 2 |
+
pandas==2.2.0
|
| 3 |
+
numpy==1.26.4
|
| 4 |
+
langchain==0.3.25
|
| 5 |
+
langchain_core==0.3.58
|
| 6 |
+
langchain-google-genai==2.0.4
|
| 7 |
+
langgraph==0.3.31
|
| 8 |
+
python-dotenv==1.0.1
|
| 9 |
+
sqlalchemy==2.0.2
|
| 10 |
+
guardrails-ai==0.6.6
|
| 11 |
+
openpyxl==3.1.5
|
| 12 |
+
pydantic==2.9.2
|
utils/consts.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from os.path import abspath, dirname, join
|
| 3 |
+
|
| 4 |
+
# Root of the project
|
| 5 |
+
ROOT_DIR = abspath(join(dirname(__file__), '..'))
|
| 6 |
+
# Database directory and path
|
| 7 |
+
DB_DIR = join(ROOT_DIR, 'db')
|
| 8 |
+
os.makedirs(DB_DIR, exist_ok=True)
|
| 9 |
+
DB_PATH = join(DB_DIR, 'sample_ecommerce.db')
|
| 10 |
+
# Output directory for plots
|
| 11 |
+
PLOTS_DIR = join(ROOT_DIR, 'output', 'plots')
|
| 12 |
+
os.makedirs(PLOTS_DIR, exist_ok=True)
|