Skip to content

量化 KV 缓存

FP8 KV 缓存概览

高效利用内存对于处理大型语言模型至关重要。将 KV(键值)缓存量化为 FP8 格式可以显著减少其内存占用。这种优化使您能够在内存中存储更多 token,从而提高吞吐量并支持更长的上下文窗口。

注意: 当使用 Flash Attention 3 后端配合 FP8 KV 缓存时,注意力计算也会在量化(FP8)域中执行。在此配置下,除了键和值之外,查询也会被量化为 FP8。

支持的 FP8 KV 缓存量化方案

vLLM 支持两种主要的 FP8 KV 缓存量化策略:

  • 按张量量化:
    为每个 Q、K 和 V 张量单独应用一个缩放因子。(q/k/v_scale = [1]
  • 按注意力头量化:
    每个缩放因子对应一个注意力头:q_scale = [num_heads]k/v_scale = [num_kv_heads]

注意:
按注意力头量化目前仅适用于 Flash Attention 后端,且需要使用 llm-compressor 提供的校准流程。

缩放因子校准方法

您可以通过以下三种不同方法配置 vLLM 中量化缩放因子的计算方式:

  1. 无校准(默认缩放因子):
    所有量化缩放因子均设为 1.0
    配置方式:

    kv_cache_dtype="fp8"
    calculate_kv_scales=False
    

  2. 随机 token 校准(实时校准):
    缩放因子在预热期间根据单个批次的随机 token 自动估算,之后固定不变。
    配置方式:

    kv_cache_dtype="fp8"
    calculate_kv_scales=True
    

  3. [推荐] 使用数据集校准(通过 llm-compressor):
    使用精心挑选的校准数据集估算缩放因子,以获得最高精度。
    这需要安装 llm-compressor 库。
    请参见下方示例!

其他 kv_cache_dtype 选项

  • kv_cache_dtype="auto":使用模型的默认数据类型
  • kv_cache_dtype="fp8_e4m3":支持 CUDA 11.8+ 和 ROCm(AMD GPU)
  • kv_cache_dtype="fp8_e5m2":支持 CUDA 11.8+

示例

1. 无校准(kv_cache_dtype="fp8"calculate_kv_scales=False

所有量化缩放因子均设为 1.0。

from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
    model="meta-llama/Llama-2-7b-chat-hf",
    kv_cache_dtype="fp8",
    calculate_kv_scales=False,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)

2. 随机 token 校准(kv_cache_dtype="fp8"calculate_kv_scales=True

缩放因子在预热期间根据单个批次的 token 自动估算。

from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
    model="meta-llama/Llama-2-7b-chat-hf",
    kv_cache_dtype="fp8",
    calculate_kv_scales=True,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)

3. [推荐] 使用数据集校准(配合 llm-compressor

为了获得最高质量的量化效果,我们推荐使用 llm-compressor 对数据集进行校准。这可以实现诸如按注意力头量化等高级策略。

安装所需包

pip install llmcompressor

示例:将 Llama 注意力机制与 KV 缓存量化至 FP8

"""
使用 llm-compressor 一次性校准将 Llama 注意力机制 + KV 缓存量化至 FP8
(可选择 'tensor' 或 'attn_head' 策略)。
"""

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs

# -----------------------------
# 配置
# -----------------------------
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
STRATEGY = "tensor"       # 或 "attn_head"
NUM_CALIB_SAMPLES = 512   # 良好的起始值
MAX_SEQ_LEN = 2048

# -----------------------------
# 辅助函数
# -----------------------------
def process_and_tokenize(example, tokenizer: AutoTokenizer):
    """将聊天消息转换为 token。"""
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return tokenizer(
        text,
        padding=False,
        max_length=MAX_SEQ_LEN,
        truncation=True,
        add_special_tokens=False,
    )

def build_recipe(strategy: str) -> QuantizationModifier:
    fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy)
    return QuantizationModifier(
        config_groups={
            "attention": QuantizationScheme(
                targets=["LlamaAttention"],  # 量化查询:q_scale
                input_activations=fp8_args,
            )
        },
        kv_cache_scheme=fp8_args,           # 量化 KV 缓存:k/v_scale
    )

# -----------------------------
# 主函数
# -----------------------------
def main():
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]")
    ds = ds.shuffle(seed=42)
    ds = ds.map(
        lambda ex: process_and_tokenize(ex, tokenizer),
        remove_columns=ds.column_names,
    )

    recipe = build_recipe(STRATEGY)
    oneshot(
        model=model,
        dataset=ds,
        recipe=recipe,
        max_seq_length=MAX_SEQ_LEN,
        num_calibration_samples=NUM_CALIB_SAMPLES,
    )

    save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}"
    model.save_pretrained(save_dir, save_compressed=True)
    tokenizer.save_pretrained(save_dir)

if __name__ == "__main__":
    main()

如需查看更详细和最新的示例,请参阅 llm-compressor 官方示例