Multimodal SLM Fine-Tuning: ChartQA with MatCha

This repository contains the code and documentation for fine-tuning a Small Language Model (SLM) on the ChartQA dataset using Parameter-Efficient Fine-Tuning (LoRA).

This project was developed to demonstrate a complete multimodal fine-tuning pipeline capable of running on a single NVIDIA T4 GPU (16GB VRAM).

πŸš€ How to Run Inference

The following standalone code snippet demonstrates how to pull the fine-tuned LoRA adapters from Hugging Face, merge them with the base model, and run inference on a custom chart image.

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
from PIL import Image

# 1. Define Model IDs and Device
base_model_id = "google/matcha-base"
adapter_id = "Sairam22/matcha-chartqa-lora-adapter" # Replace if your repo name is different
device = "cuda" if torch.cuda.is_available() else "cpu"

# 2. Load Base Model and Processor
print("Loading base model and processor...")
processor = AutoProcessor.from_pretrained(adapter_id)
base_model = AutoModelForImageTextToText.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

# 3. Pull Adapter from Hugging Face and Merge
print("Pulling adapter and merging weights...")
model = PeftModel.from_pretrained(base_model, adapter_id)
model = model.merge_and_unload()
model = model.to(device)

# 4. Prepare Image and Prompt
image_path = "path_to_your_chart.png" # Provide the path to a local chart image
image = Image.open(image_path).convert("RGB")
prompt = "Question: What is the highest value in the bar chart?\nAnswer:"

# 5. Process Inputs and Cast Dtypes
inputs = processor(images=image, text=prompt, return_tensors="pt")
# Ensure float32 tensors (like images) are cast to float16 to match the model weights
inputs = {k: v.to(device, dtype=torch.float16) if v.dtype == torch.float32 else v.to(device) for k, v in inputs.items()}

# 6. Run Inference
print("Generating prediction...")
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=32)
    
prediction = processor.decode(outputs[0], skip_special_tokens=True)
print(f"Prediction: {prediction}")

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0002
  • train_batch_size: 2
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 8
  • optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: linear
  • training_steps: 50
  • mixed_precision_training: Native AMP

🧠 Decision Log & T4 Optimizations To ensure this pipeline runs efficiently on a single NVIDIA T4 GPU (16GB VRAM) and within strict time limits, several specific parameter choices were made:

Model Selection (google/matcha-base): Chosen because it is pre-trained specifically for chart visual language tasks (based on Pix2Struct) and is highly lightweight (~256M parameters), fitting easily into T4 memory.

Precision (fp16=True): Casting the base model to float16 cuts memory consumption in half, prevents datatype mismatch errors (c10::Half), and leverages the T4's Tensor Cores to speed up training.

LoRA Configuration (r=8, alpha=16): A rank of 8 introduces less than 5% trainable parameters. This prevents Out-Of-Memory (OOM) errors during training since the optimizer states are kept minimal. MatCha's specific attention layers (query, value) were explicitly targeted.

Batch Sizing (batch_size=2, gradient_accumulation=4): A physical batch size of 2 ensures VRAM limits aren't breached during the forward/backward pass, while gradient accumulation simulates an effective batch size of 8 for stable loss convergence.

Adapter Merging (merge_and_unload()): Fulfills the assignment requirement while also improving inference speed by flattening the adapter weights directly into the base model matrices, completely removing dynamic routing overhead during generation.

Training Subset: Due to hardware compute constraints and time limitations, a rapid subset of 100 samples was used for the final epoch to successfully validate the end-to-end pipeline, adapter uploading, and inference mechanisms.

Framework versions

  • PEFT 0.18.1
  • Transformers 5.3.0
  • Pytorch 2.10.0+cu128
  • Datasets 4.0.0
  • Tokenizers 0.22.2
Downloads last month
54
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Sairam22/matcha-chartqa-lora-adapter

Adapter
(2)
this model

Dataset used to train Sairam22/matcha-chartqa-lora-adapter