GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

This repository contains the draft model for GRIFFIN, a novel framework designed to accelerate inference in large language models (LLMs) by addressing token misalignment in speculative decoding. GRIFFIN incorporates a token-alignable training strategy and a token-alignable draft model to mitigate this issue, demonstrating significant speedup ratios over existing state-of-the-art methods.

For more details, refer to the paper: GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

The official code and further details can be found on the project's GitHub repository: https://github.com/hsj576/GRIFFIN

Overview

GRIFFIN is a novel framework designed to address token misalignment in speculative decoding. This repository provides the implementation of GRIFFIN, including its token-alignable training strategy and token-alignable draft model.

  • GRIFFIN is:
    • 4.2x faster than vanilla decoding.
    • 1.3x faster than EAGLE-2.

benchmark

Speed up ratios of GRIFFIN when temperature = 0.

benchmark

Speed up ratios of GRIFFIN when temperature = 1.

Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU

demogif

Sample Usage

You can use the provided eagenerate function for accelerated generation, similar to using the generate method from Hugging Face. Here is an example:

import torch
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template

# Replace with the actual path to your base model and GRIFFIN weight
base_model_path = "meta-llama/Llama-2-13b-chat-hf" # Example base model
GRIFFIN_model_path = "husj576/GRIFFIN-llama2-chat-13B" # Example GRIFFIN model

model = EaModel.from_pretrained(
    base_model_path=base_model_path,
    ea_model_path=GRIFFIN_model_path,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    total_token=-1
)
model.eval()

your_message="Hello"
conv = get_conversation_template("llama2") # Use appropriate conversation template for your base model
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)

Note: Vicuna, LLaMA2-Chat, and LLaMA3-Instruct are both chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of GRIFFIN.

Citation

If you find our work helpful or inspiring, please feel free to cite it.

@misc{hu2025griffineffectivetokenalignment,
      title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding},
      author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
      year={2025},
      eprint={2502.11018},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2502.11018},
}
Downloads last month
18
Safetensors
Model size
0.8B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support