|
|
--- |
|
|
base_model: unsloth/gemma-2-2b-it-bnb-4bit |
|
|
tags: |
|
|
- text-generation-inference |
|
|
- transformers |
|
|
- unsloth |
|
|
- gemma2 |
|
|
- text-to-sql |
|
|
- qlora |
|
|
- sql-generation |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
datasets: |
|
|
- gretelai/synthetic_text_to_sql |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
# Gemma-2-2B Text-to-SQL QLoRA Fine-tuned Model |
|
|
|
|
|
- **Developed by:** rajaykumar12959 |
|
|
- **License:** apache-2.0 |
|
|
- **Finetuned from model:** unsloth/gemma-2-2b-it-bnb-4bit |
|
|
- **Dataset:** gretelai/synthetic_text_to_sql |
|
|
- **Task:** Text-to-SQL Generation |
|
|
- **Fine-tuning Method:** QLoRA (Quantized Low-Rank Adaptation) |
|
|
|
|
|
This gemma2 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
|
|
|
|
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model is specifically fine-tuned to generate SQL queries from natural language questions and database schemas. It excels at handling complex multi-table queries requiring JOINs, aggregations, filtering, and advanced SQL operations. |
|
|
|
|
|
### Key Features |
|
|
|
|
|
- ✅ **Multi-table JOINs** (INNER, LEFT, RIGHT) |
|
|
- ✅ **Aggregation functions** (SUM, COUNT, AVG, MIN, MAX) |
|
|
- ✅ **GROUP BY and HAVING clauses** |
|
|
- ✅ **Complex WHERE conditions** |
|
|
- ✅ **Subqueries and CTEs** |
|
|
- ✅ **Date/time operations** |
|
|
- ✅ **String functions and pattern matching** |
|
|
|
|
|
## Training Configuration |
|
|
|
|
|
The model was fine-tuned using QLoRA with the following configuration: |
|
|
|
|
|
```python |
|
|
# LoRA Configuration |
|
|
r = 16 # Rank: 16 is a good balance for 2B models |
|
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
|
|
lora_alpha = 16 |
|
|
lora_dropout = 0 |
|
|
bias = "none" |
|
|
use_gradient_checkpointing = "unsloth" |
|
|
|
|
|
# Training Parameters |
|
|
max_seq_length = 2048 |
|
|
per_device_train_batch_size = 2 |
|
|
gradient_accumulation_steps = 4 # Effective batch size = 8 |
|
|
warmup_steps = 5 |
|
|
max_steps = 100 # Demo configuration - increase to 300+ for production |
|
|
learning_rate = 2e-4 |
|
|
optim = "adamw_8bit" # 8-bit optimizer for memory efficiency |
|
|
weight_decay = 0.01 |
|
|
lr_scheduler_type = "linear" |
|
|
``` |
|
|
|
|
|
## Installation |
|
|
|
|
|
```bash |
|
|
pip install unsloth transformers torch trl datasets |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
from unsloth import FastLanguageModel |
|
|
import torch |
|
|
|
|
|
max_seq_length = 2048 |
|
|
dtype = None |
|
|
load_in_4bit = True |
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name = "rajaykumar12959/gemma-2-2b-text-to-sql-qlora", |
|
|
max_seq_length = max_seq_length, |
|
|
dtype = dtype, |
|
|
load_in_4bit = load_in_4bit, |
|
|
) |
|
|
|
|
|
FastLanguageModel.for_inference(model) # Enable faster inference |
|
|
``` |
|
|
|
|
|
### Inference Function |
|
|
|
|
|
```python |
|
|
def inference_text_to_sql(model, tokenizer, schema, question, max_new_tokens=300): |
|
|
""" |
|
|
Perform inference to generate SQL from natural language question and database schema. |
|
|
|
|
|
Args: |
|
|
model: Fine-tuned Gemma model |
|
|
tokenizer: Model tokenizer |
|
|
schema: Database schema as string |
|
|
question: Natural language question |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
|
|
|
Returns: |
|
|
Generated SQL query as string |
|
|
""" |
|
|
# Format the input prompt |
|
|
input_prompt = f"""<start_of_turn>user |
|
|
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. |
|
|
|
|
|
### Schema: |
|
|
{schema} |
|
|
|
|
|
### Question: |
|
|
{question}<end_of_turn> |
|
|
<start_of_turn>model |
|
|
""" |
|
|
|
|
|
# Tokenize input |
|
|
inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda") |
|
|
|
|
|
# Generate output |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
use_cache=True, |
|
|
do_sample=True, |
|
|
temperature=0.1, # Low temperature for more deterministic output |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
# Decode and clean the result |
|
|
result = tokenizer.batch_decode(outputs)[0] |
|
|
sql_query = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip() |
|
|
|
|
|
return sql_query |
|
|
``` |
|
|
|
|
|
### Example Usage |
|
|
|
|
|
#### Example 1: Simple Single-Table Query |
|
|
|
|
|
```python |
|
|
# Simple employee database |
|
|
simple_schema = """ |
|
|
CREATE TABLE employees ( |
|
|
employee_id INT PRIMARY KEY, |
|
|
name TEXT, |
|
|
department TEXT, |
|
|
salary DECIMAL, |
|
|
hire_date DATE |
|
|
); |
|
|
""" |
|
|
|
|
|
simple_question = "Find all employees in the 'Engineering' department with salary greater than 75000" |
|
|
|
|
|
sql_result = inference_text_to_sql(model, tokenizer, simple_schema, simple_question) |
|
|
print(f"Generated SQL:\n{sql_result}") |
|
|
``` |
|
|
|
|
|
**Expected Output:** |
|
|
```sql |
|
|
SELECT * FROM employees |
|
|
WHERE department = 'Engineering' |
|
|
AND salary > 75000; |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Dataset |
|
|
- **Source:** gretelai/synthetic_text_to_sql |
|
|
- **Size:** 100,000 synthetic text-to-SQL examples |
|
|
- **Columns used:** |
|
|
- `sql_context`: Database schema |
|
|
- `sql_prompt`: Natural language question |
|
|
- `sql`: Target SQL query |
|
|
|
|
|
### Training Process |
|
|
The model uses a custom formatting function to structure the training data: |
|
|
|
|
|
```python |
|
|
def formatting_prompts_func(examples): |
|
|
schemas = examples["sql_context"] |
|
|
questions = examples["sql_prompt"] |
|
|
outputs = examples["sql"] |
|
|
|
|
|
texts = [] |
|
|
for schema, question, output in zip(schemas, questions, outputs): |
|
|
text = gemma_prompt.format(schema, question, output) + EOS_TOKEN |
|
|
texts.append(text) |
|
|
return { "text" : texts, } |
|
|
``` |
|
|
|
|
|
### Hardware Requirements |
|
|
- **GPU:** Single GPU with 8GB+ VRAM |
|
|
- **Training Time:** ~30 minutes for 100 steps |
|
|
- **Memory Optimization:** 4-bit quantization + 8-bit optimizer |
|
|
|
|
|
## Performance Characteristics |
|
|
|
|
|
### Strengths |
|
|
- Excellent performance on multi-table JOINs |
|
|
- Accurate aggregation and GROUP BY operations |
|
|
- Proper handling of foreign key relationships |
|
|
- Good understanding of filtering logic (WHERE/HAVING) |
|
|
|
|
|
### Model Capabilities Test |
|
|
The model was tested on a complex 4-table JOIN query requiring: |
|
|
1. **Multi-table JOINs** (users → orders → order_items → products) |
|
|
2. **Category filtering** (WHERE p.category = 'Electronics') |
|
|
3. **User grouping** (GROUP BY user fields) |
|
|
4. **Aggregation** (SUM of price × quantity) |
|
|
5. **Aggregate filtering** (HAVING total > 500) |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Training Scale:** Trained with only 100 steps for demonstration. For production use, increase `max_steps` to 300+ |
|
|
- **Context Length:** Limited to 2048 tokens maximum sequence length |
|
|
- **SQL Dialects:** Primarily trained on standard SQL syntax |
|
|
- **Complex Subqueries:** May require additional fine-tuning for highly complex nested queries |
|
|
|
|
|
## Reproduction |
|
|
|
|
|
To reproduce this training: |
|
|
|
|
|
1. **Clone the notebook:** Use the provided `Fine_tune_qlora.ipynb` |
|
|
2. **Install dependencies:** |
|
|
```bash |
|
|
pip install unsloth transformers torch trl datasets |
|
|
``` |
|
|
3. **Configure training:** Adjust `max_steps` in TrainingArguments for longer training |
|
|
4. **Run training:** Execute all cells in the notebook |
|
|
|
|
|
### Production Training Recommendations |
|
|
```python |
|
|
# For production use, update these parameters: |
|
|
max_steps = 300, # Increase from 100 |
|
|
warmup_steps = 10, # Increase warmup |
|
|
per_device_train_batch_size = 4, # If you have more GPU memory |
|
|
``` |
|
|
|
|
|
## Model Card |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|--------| |
|
|
| Base Model | Gemma-2-2B (4-bit quantized) | |
|
|
| Fine-tuning Method | QLoRA | |
|
|
| LoRA Rank | 16 | |
|
|
| Training Steps | 100 (demo) | |
|
|
| Learning Rate | 2e-4 | |
|
|
| Batch Size | 8 (effective) | |
|
|
| Max Sequence Length | 2048 | |
|
|
| Dataset Size | 100k examples | |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{gemma-2-2b-text-to-sql-qlora, |
|
|
author = {rajaykumar12959}, |
|
|
title = {Gemma-2-2B Text-to-SQL QLoRA Fine-tuned Model}, |
|
|
year = {2024}, |
|
|
publisher = {Hugging Face}, |
|
|
howpublished = {\url{https://huggingface.co/rajaykumar12959/gemma-2-2b-text-to-sql-qlora}}, |
|
|
} |
|
|
``` |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- **Base Model:** Google's Gemma-2-2B via Unsloth optimization |
|
|
- **Dataset:** Gretel AI's synthetic text-to-SQL dataset |
|
|
- **Framework:** Unsloth for efficient fine-tuning and TRL for training |
|
|
- **Method:** QLoRA for parameter-efficient training |
|
|
|
|
|
## License |
|
|
|
|
|
This model is licensed under Apache 2.0. See the LICENSE file for details. |
|
|
|
|
|
--- |
|
|
|
|
|
*This model is intended for research and educational purposes. Please ensure compliance with your organization's data and AI usage policies when using in production environments.* |