rajaykumar12959's picture
Update README.md
e999a53 verified
---
base_model: unsloth/gemma-2-2b-it-bnb-4bit
tags:
- text-generation-inference
- transformers
- unsloth
- gemma2
- text-to-sql
- qlora
- sql-generation
license: apache-2.0
language:
- en
datasets:
- gretelai/synthetic_text_to_sql
pipeline_tag: text-generation
---
# Gemma-2-2B Text-to-SQL QLoRA Fine-tuned Model
- **Developed by:** rajaykumar12959
- **License:** apache-2.0
- **Finetuned from model:** unsloth/gemma-2-2b-it-bnb-4bit
- **Dataset:** gretelai/synthetic_text_to_sql
- **Task:** Text-to-SQL Generation
- **Fine-tuning Method:** QLoRA (Quantized Low-Rank Adaptation)
This gemma2 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
## Model Description
This model is specifically fine-tuned to generate SQL queries from natural language questions and database schemas. It excels at handling complex multi-table queries requiring JOINs, aggregations, filtering, and advanced SQL operations.
### Key Features
- ✅ **Multi-table JOINs** (INNER, LEFT, RIGHT)
- ✅ **Aggregation functions** (SUM, COUNT, AVG, MIN, MAX)
- ✅ **GROUP BY and HAVING clauses**
- ✅ **Complex WHERE conditions**
- ✅ **Subqueries and CTEs**
- ✅ **Date/time operations**
- ✅ **String functions and pattern matching**
## Training Configuration
The model was fine-tuned using QLoRA with the following configuration:
```python
# LoRA Configuration
r = 16 # Rank: 16 is a good balance for 2B models
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_alpha = 16
lora_dropout = 0
bias = "none"
use_gradient_checkpointing = "unsloth"
# Training Parameters
max_seq_length = 2048
per_device_train_batch_size = 2
gradient_accumulation_steps = 4 # Effective batch size = 8
warmup_steps = 5
max_steps = 100 # Demo configuration - increase to 300+ for production
learning_rate = 2e-4
optim = "adamw_8bit" # 8-bit optimizer for memory efficiency
weight_decay = 0.01
lr_scheduler_type = "linear"
```
## Installation
```bash
pip install unsloth transformers torch trl datasets
```
## Usage
### Loading the Model
```python
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
load_in_4bit = True
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "rajaykumar12959/gemma-2-2b-text-to-sql-qlora",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable faster inference
```
### Inference Function
```python
def inference_text_to_sql(model, tokenizer, schema, question, max_new_tokens=300):
"""
Perform inference to generate SQL from natural language question and database schema.
Args:
model: Fine-tuned Gemma model
tokenizer: Model tokenizer
schema: Database schema as string
question: Natural language question
max_new_tokens: Maximum tokens to generate
Returns:
Generated SQL query as string
"""
# Format the input prompt
input_prompt = f"""<start_of_turn>user
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
### Schema:
{schema}
### Question:
{question}<end_of_turn>
<start_of_turn>model
"""
# Tokenize input
inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda")
# Generate output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=True,
temperature=0.1, # Low temperature for more deterministic output
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode and clean the result
result = tokenizer.batch_decode(outputs)[0]
sql_query = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip()
return sql_query
```
### Example Usage
#### Example 1: Simple Single-Table Query
```python
# Simple employee database
simple_schema = """
CREATE TABLE employees (
employee_id INT PRIMARY KEY,
name TEXT,
department TEXT,
salary DECIMAL,
hire_date DATE
);
"""
simple_question = "Find all employees in the 'Engineering' department with salary greater than 75000"
sql_result = inference_text_to_sql(model, tokenizer, simple_schema, simple_question)
print(f"Generated SQL:\n{sql_result}")
```
**Expected Output:**
```sql
SELECT * FROM employees
WHERE department = 'Engineering'
AND salary > 75000;
```
## Training Details
### Dataset
- **Source:** gretelai/synthetic_text_to_sql
- **Size:** 100,000 synthetic text-to-SQL examples
- **Columns used:**
- `sql_context`: Database schema
- `sql_prompt`: Natural language question
- `sql`: Target SQL query
### Training Process
The model uses a custom formatting function to structure the training data:
```python
def formatting_prompts_func(examples):
schemas = examples["sql_context"]
questions = examples["sql_prompt"]
outputs = examples["sql"]
texts = []
for schema, question, output in zip(schemas, questions, outputs):
text = gemma_prompt.format(schema, question, output) + EOS_TOKEN
texts.append(text)
return { "text" : texts, }
```
### Hardware Requirements
- **GPU:** Single GPU with 8GB+ VRAM
- **Training Time:** ~30 minutes for 100 steps
- **Memory Optimization:** 4-bit quantization + 8-bit optimizer
## Performance Characteristics
### Strengths
- Excellent performance on multi-table JOINs
- Accurate aggregation and GROUP BY operations
- Proper handling of foreign key relationships
- Good understanding of filtering logic (WHERE/HAVING)
### Model Capabilities Test
The model was tested on a complex 4-table JOIN query requiring:
1. **Multi-table JOINs** (users → orders → order_items → products)
2. **Category filtering** (WHERE p.category = 'Electronics')
3. **User grouping** (GROUP BY user fields)
4. **Aggregation** (SUM of price × quantity)
5. **Aggregate filtering** (HAVING total > 500)
## Limitations
- **Training Scale:** Trained with only 100 steps for demonstration. For production use, increase `max_steps` to 300+
- **Context Length:** Limited to 2048 tokens maximum sequence length
- **SQL Dialects:** Primarily trained on standard SQL syntax
- **Complex Subqueries:** May require additional fine-tuning for highly complex nested queries
## Reproduction
To reproduce this training:
1. **Clone the notebook:** Use the provided `Fine_tune_qlora.ipynb`
2. **Install dependencies:**
```bash
pip install unsloth transformers torch trl datasets
```
3. **Configure training:** Adjust `max_steps` in TrainingArguments for longer training
4. **Run training:** Execute all cells in the notebook
### Production Training Recommendations
```python
# For production use, update these parameters:
max_steps = 300, # Increase from 100
warmup_steps = 10, # Increase warmup
per_device_train_batch_size = 4, # If you have more GPU memory
```
## Model Card
| Parameter | Value |
|-----------|--------|
| Base Model | Gemma-2-2B (4-bit quantized) |
| Fine-tuning Method | QLoRA |
| LoRA Rank | 16 |
| Training Steps | 100 (demo) |
| Learning Rate | 2e-4 |
| Batch Size | 8 (effective) |
| Max Sequence Length | 2048 |
| Dataset Size | 100k examples |
## Citation
```bibtex
@misc{gemma-2-2b-text-to-sql-qlora,
author = {rajaykumar12959},
title = {Gemma-2-2B Text-to-SQL QLoRA Fine-tuned Model},
year = {2024},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/rajaykumar12959/gemma-2-2b-text-to-sql-qlora}},
}
```
## Acknowledgments
- **Base Model:** Google's Gemma-2-2B via Unsloth optimization
- **Dataset:** Gretel AI's synthetic text-to-SQL dataset
- **Framework:** Unsloth for efficient fine-tuning and TRL for training
- **Method:** QLoRA for parameter-efficient training
## License
This model is licensed under Apache 2.0. See the LICENSE file for details.
---
*This model is intended for research and educational purposes. Please ensure compliance with your organization's data and AI usage policies when using in production environments.*