GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

This repository contains the GRIFFIN model presented in the paper GRIFFIN: Effective Token Alignment for Faster Speculative Decoding.

GRIFFIN is a novel framework designed to address token misalignment in speculative decoding, achieving significant speedups over vanilla decoding and other state-of-the-art speculative methods. It incorporates a token-alignable training strategy, which uses a loss masking mechanism to exclude highly misaligned tokens, and a token-alignable draft model that introduces input tokens to correct inconsistencies. Experiments demonstrate that GRIFFIN achieves an average acceptance length improvement of over 8% and a speedup ratio exceeding 7%.

For more details and the official implementation, see the GitHub repository.

Overview

GRIFFIN significantly accelerates inference in large language models (LLMs). Below are some performance benchmarks:

benchmark

Speed up ratios of GRIFFIN when temperature = 0.

benchmark

Speed up ratios of GRIFFIN when temperature = 1.

Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU

demogif

Inference

You can use the provided eagenerate function for accelerated generation, similar to the generate method in Hugging Face Transformers.

import torch
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template

# Replace with your actual model paths
base_model_path = "meta-llama/Llama-3-8B-Instruct" # Example base model
EAGLE_model_path = "husj576/GRIFFIN-llama3-instruct-8B" # Example GRIFFIN draft model

model = EaModel.from_pretrained(
    base_model_path=base_model_path,
    ea_model_path=EAGLE_model_path,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    total_token=-1
)
model.eval()

your_message="Hello"
conv = get_conversation_template("llama3") # Use the correct conversation template for your base model
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])

print(output)

Note: For chat models like Vicuna, LLaMA2-Chat, and LLaMA3-Instruct, you must use the correct chat template to ensure proper model output and optimal GRIFFIN performance.

Downloads last month
6
Safetensors
Model size
0.9B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for husj576/GRIFFIN-llama3-instruct-8B