Multimodal SLM Fine-Tuning: ChartQA with MatCha
This repository contains the code and documentation for fine-tuning a Small Language Model (SLM) on the ChartQA dataset using Parameter-Efficient Fine-Tuning (LoRA).
This project was developed to demonstrate a complete multimodal fine-tuning pipeline capable of running on a single NVIDIA T4 GPU (16GB VRAM).
π How to Run Inference
The following standalone code snippet demonstrates how to pull the fine-tuned LoRA adapters from Hugging Face, merge them with the base model, and run inference on a custom chart image.
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
from PIL import Image
# 1. Define Model IDs and Device
base_model_id = "google/matcha-base"
adapter_id = "Sairam22/matcha-chartqa-lora-adapter" # Replace if your repo name is different
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2. Load Base Model and Processor
print("Loading base model and processor...")
processor = AutoProcessor.from_pretrained(adapter_id)
base_model = AutoModelForImageTextToText.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# 3. Pull Adapter from Hugging Face and Merge
print("Pulling adapter and merging weights...")
model = PeftModel.from_pretrained(base_model, adapter_id)
model = model.merge_and_unload()
model = model.to(device)
# 4. Prepare Image and Prompt
image_path = "path_to_your_chart.png" # Provide the path to a local chart image
image = Image.open(image_path).convert("RGB")
prompt = "Question: What is the highest value in the bar chart?\nAnswer:"
# 5. Process Inputs and Cast Dtypes
inputs = processor(images=image, text=prompt, return_tensors="pt")
# Ensure float32 tensors (like images) are cast to float16 to match the model weights
inputs = {k: v.to(device, dtype=torch.float16) if v.dtype == torch.float32 else v.to(device) for k, v in inputs.items()}
# 6. Run Inference
print("Generating prediction...")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=32)
prediction = processor.decode(outputs[0], skip_special_tokens=True)
print(f"Prediction: {prediction}")
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0002
- train_batch_size: 2
- eval_batch_size: 8
- seed: 42
- gradient_accumulation_steps: 4
- total_train_batch_size: 8
- optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- training_steps: 50
- mixed_precision_training: Native AMP
π§ Decision Log & T4 Optimizations To ensure this pipeline runs efficiently on a single NVIDIA T4 GPU (16GB VRAM) and within strict time limits, several specific parameter choices were made:
Model Selection (google/matcha-base): Chosen because it is pre-trained specifically for chart visual language tasks (based on Pix2Struct) and is highly lightweight (~256M parameters), fitting easily into T4 memory.
Precision (fp16=True): Casting the base model to float16 cuts memory consumption in half, prevents datatype mismatch errors (c10::Half), and leverages the T4's Tensor Cores to speed up training.
LoRA Configuration (r=8, alpha=16): A rank of 8 introduces less than 5% trainable parameters. This prevents Out-Of-Memory (OOM) errors during training since the optimizer states are kept minimal. MatCha's specific attention layers (query, value) were explicitly targeted.
Batch Sizing (batch_size=2, gradient_accumulation=4): A physical batch size of 2 ensures VRAM limits aren't breached during the forward/backward pass, while gradient accumulation simulates an effective batch size of 8 for stable loss convergence.
Adapter Merging (merge_and_unload()): Fulfills the assignment requirement while also improving inference speed by flattening the adapter weights directly into the base model matrices, completely removing dynamic routing overhead during generation.
Training Subset: Due to hardware compute constraints and time limitations, a rapid subset of 100 samples was used for the final epoch to successfully validate the end-to-end pipeline, adapter uploading, and inference mechanisms.
Framework versions
- PEFT 0.18.1
- Transformers 5.3.0
- Pytorch 2.10.0+cu128
- Datasets 4.0.0
- Tokenizers 0.22.2
- Downloads last month
- 54
Model tree for Sairam22/matcha-chartqa-lora-adapter
Base model
google/matcha-base