GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
This repository contains the GRIFFIN model presented in the paper GRIFFIN: Effective Token Alignment for Faster Speculative Decoding.
GRIFFIN is a novel framework designed to address token misalignment in speculative decoding, achieving significant speedups over vanilla decoding and other state-of-the-art speculative methods. It incorporates a token-alignable training strategy, which uses a loss masking mechanism to exclude highly misaligned tokens, and a token-alignable draft model that introduces input tokens to correct inconsistencies. Experiments demonstrate that GRIFFIN achieves an average acceptance length improvement of over 8% and a speedup ratio exceeding 7%.
For more details and the official implementation, see the GitHub repository.
Overview
GRIFFIN significantly accelerates inference in large language models (LLMs). Below are some performance benchmarks:
Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU
Inference
You can use the provided eagenerate function for accelerated generation, similar to the generate method in Hugging Face Transformers.
import torch
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
# Replace with your actual model paths
base_model_path = "meta-llama/Llama-3-8B-Instruct" # Example base model
EAGLE_model_path = "husj576/GRIFFIN-llama3-instruct-8B" # Example GRIFFIN draft model
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=EAGLE_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
total_token=-1
)
model.eval()
your_message="Hello"
conv = get_conversation_template("llama3") # Use the correct conversation template for your base model
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)
Note: For chat models like Vicuna, LLaMA2-Chat, and LLaMA3-Instruct, you must use the correct chat template to ensure proper model output and optimal GRIFFIN performance.
- Downloads last month
- 6