GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
This repository contains the draft model for GRIFFIN, a novel framework designed to accelerate inference in large language models (LLMs) by addressing token misalignment in speculative decoding. GRIFFIN incorporates a token-alignable training strategy and a token-alignable draft model to mitigate this issue, demonstrating significant speedup ratios over existing state-of-the-art methods.
For more details, refer to the paper: GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
The official code and further details can be found on the project's GitHub repository: https://github.com/hsj576/GRIFFIN
Overview
GRIFFIN is a novel framework designed to address token misalignment in speculative decoding. This repository provides the implementation of GRIFFIN, including its token-alignable training strategy and token-alignable draft model.
- GRIFFIN is:
- 4.2x faster than vanilla decoding.
- 1.3x faster than EAGLE-2.
Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU
Sample Usage
You can use the provided eagenerate function for accelerated generation, similar to using the generate method from Hugging Face. Here is an example:
import torch
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
# Replace with the actual path to your base model and GRIFFIN weight
base_model_path = "meta-llama/Llama-2-13b-chat-hf" # Example base model
GRIFFIN_model_path = "husj576/GRIFFIN-llama2-chat-13B" # Example GRIFFIN model
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=GRIFFIN_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
total_token=-1
)
model.eval()
your_message="Hello"
conv = get_conversation_template("llama2") # Use appropriate conversation template for your base model
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)
Note: Vicuna, LLaMA2-Chat, and LLaMA3-Instruct are both chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of GRIFFIN.
Citation
If you find our work helpful or inspiring, please feel free to cite it.
@misc{hu2025griffineffectivetokenalignment,
title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding},
author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
year={2025},
eprint={2502.11018},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.11018},
}
- Downloads last month
- 18