CustomOp¶
CustomOp 是一个抽象类,用于将各种操作的前向方法调度到合适的后端。它还提供了 vLLM 和 OOT(Out-Of-Tree)插件注册自定义操作的机制。
本文档将介绍 CustomOp 在 vLLM 中的工作原理以及如何实现一个新的 CustomOp。
CustomOp 在 vLLM 中的工作原理¶
CustomOp 在其类中管理两个字典,分别存储 vLLM 和 OOT 插件的所有自定义操作(即操作类,以注册名称作为索引)。
我们可以使用 @CustomOp.register("op_name") 将一个操作类注册到 CustomOp 系统中。注册后,op_name 及其对应的类会被添加到 op_registry 字典中。此外,我们也可以通过 @CustomOp.register_oot("op_name") 注册一个 OOT 操作。稍后我们会详细介绍这一机制。
当调用一个 CustomOp 时(即调用其 forward() 方法),如果该操作已启用(即通过 --compilation_config.custom_ops '["+op_name"]' 启用),它会根据 current_platform 自动将前向方法调度到合适的后端。否则(即未启用),它只会调用 forward_native() 方法,使用 PyTorch 原生的实现。
- CPU 平台: 调度到
forward_cpu()。 - CUDA 平台: 调度到
forward_cuda()。 - ROCm 平台: 调度到
forward_hip()。如果未实现forward_hip(),则回退到使用forward_cuda()。 - XPU 平台: 调度到
forward_xpu()。 - TPU 平台: 调度到
forward_tpu()。 - OOT 平台: 调度到
forward_oot()。此方法仅在 OOT 平台上调用。 - 默认: 作为所有平台的最终回退方案,调度到
forward_native()。
Note
注意,由于类继承的存在,调度逻辑可能不是绝对的。派生类可能会覆盖该行为。
此外,vLLM 根据 compilation_config.custom_ops 决定是否启用或禁用某个 CustomOp。具体来说,如果某个 CustomOp 未在 compilation_config.custom_ops 中注册(即使用默认配置),那么当 compilation_config.custom_ops 包含 all 时,该操作会被启用;如果包含 none,则会被禁用。
Note
注意,all 和 none 不能同时存在于 compilation_config.custom_ops 中。
默认情况下,如果 compilation_config.backend == "inductor" 且 compilation_config.mode != CompilationMode.NONE,则会在 compilation_config.custom_ops 中追加一个 none,否则会追加一个 all。换句话说,这意味着在某些平台(即那些使用 inductor 作为 torch.compile 默认后端的平台)上,当以 torch compile 模式运行时,CustomOp 会被禁用。在这种情况下,Inductor 会为这些被禁用的自定义操作生成(融合的)Triton 内核。
Note
对于多模态模型,vLLM 强制启用了一些自定义操作,以便在 ViT 部分使用针对设备深度优化的内核,从而获得更好的性能,例如 MMEncoderAttention 和 ApplyRotaryEmb。我们也可以向 CustomOp 的 __init__() 方法传递 enforce_enable=True 参数,以在对象级别强制启用该操作。
注意,在我们为多模态部分添加独立的 compilation_config 后,此 enforce_enable 机制将被移除。
如何为 CustomOp 自定义配置¶
vLLM 还为用户提供了细粒度的控制,允许用户手动指定启用或禁用哪些自定义操作,方法是在启动服务器时传递 --compilation_config.custom_ops '["..."]'。
例如:
- 使用
--compilation_config.custom_ops '["all"]'启用所有自定义操作。 - 使用
--compilation_config.custom_ops '["none"]'禁用所有自定义操作。 - 使用
--compilation_config.custom_ops '["all,-op1"]'启用除 op1 以外的所有自定义操作(即以-为前缀表示“禁用”)。 - 使用
--compilation_config.custom_ops '["none,+op1,+op2"]'仅启用 op1 和 op2(即以+为前缀表示“启用”)。
vLLM 中支持的 CustomOp 类型¶
1. Attention:
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
@PluggableLayer.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirly, although we use PluggableLayer to register
this layer now.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
2. Activation:
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
"""An activation function for GeluAndMulSparse.
This activation function is used in Gemma3n. It computes:
up_proj = self.up_proj(x)
gate_proj = self.gate_proj(x)
gate_proj = self._gaussian_topk(gate_proj) # sparsity
activations = self.act_fn(gate_proj) # gelu
down_proj = self.down_proj(activations * up_proj)
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
3. MM-Conv:
@CustomOp.register("conv2d")
class Conv2dLayer(ConvLayerBase):
"""Conv layer with Conv2d."""
@CustomOp.register("conv3d")
class Conv3dLayer(ConvLayerBase):
"""Conv layer with Conv3d."""
4. Embedding:
@CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
@CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
5. Linear:
@PluggableLayer.register("row_parallel_linear")
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
reduce_results: If true, call all-reduce on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y = X_iA_i
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
@PluggableLayer.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
@PluggableLayer.register("replicated_linear")
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
6. Logits Processor:
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
7. Mamba:
@CustomOp.register("mamba_mixer")
class MambaMixer(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
@CustomOp.register("mamba_mixer2")
class MambaMixer2(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
@CustomOp.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp):
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
8. MoE:
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all_reduce on the output of the layer
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers.
"""
@CustomOp.register("modular_fused_moe")
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend(
moe_config=self.moe,
use_ep=self.moe.moe_parallel_config.use_ep,
use_dp=self.moe.moe_parallel_config.dp_size > 1,
)
# AITER only supports gated activations (silu/gelu), so disable it
# for non-gated MoE (is_act_and_mul=False)
self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
self.kernel: mk.FusedMoEModularKernel | None = None
self._is_monolithic = (
current_platform.is_cpu()
or current_platform.is_xpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
)
if self.is_monolithic:
self.apply_monolithic: Callable = self._select_monolithic()
def _select_monolithic(self) -> Callable:
"""Select the monolithic implementation based on platform."""
if current_platform.is_cpu():
return self.forward_monolithic_cpu
elif current_platform.is_xpu():
return self.forward_monolithic_xpu
else:
return self.forward_monolithic_cuda
def forward_native(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids)
@property
def is_monolithic(self) -> bool:
return self._is_monolithic
@property
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_up_dim,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (
envs.VLLM_ROCM_MOE_PADDING
and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def _setup_kernel(
self,
layer: Module,
w13: torch.Tensor,
w2: torch.Tensor,
) -> None:
# Shuffle weights to runtime format.
w13, w2 = convert_to_unquantized_kernel_format(
self.unquantized_backend,
layer=layer,
w13_weight=w13,
w2_weight=w2,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
# Setup Modular Kernel for TP Case
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
backend=self.unquantized_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
w13_weights_shuffled, w2_weights_shuffled = (
convert_moe_weights_to_flashinfer_trtllm_block_layout(
_cache_permute_indices,
layer.w13_weight.data,
layer.w2_weight.data,
)
)
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
self.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
experts_start_id=ep_rank_start,
)
elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
dtype_w13 = layer.w13_weight.dtype
_, n_w13, k_w13 = layer.w13_weight.size()
dtype_w2 = layer.w2_weight.dtype
_, n_w2, k_w2 = layer.w2_weight.size()
if (
envs.VLLM_CPU_SGL_KERNEL
and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
):
packed_w13_weight = torch.ops._C.convert_weight_packed(
layer.w13_weight
)
assert packed_w13_weight.size() == layer.w13_weight.size()
layer.w13_weight.copy_(packed_w13_weight)
del packed_w13_weight
packed_w2_weight = torch.ops._C.convert_weight_packed(
layer.w2_weight
)
assert packed_w2_weight.size() == layer.w2_weight.size()
layer.w2_weight.copy_(packed_w2_weight)
self.cpu_fused_moe: Callable = cpu_fused_moe.SGLFusedMOE(layer)
else:
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike():
self._setup_kernel(
layer=layer,
w13=layer.w13_weight,
w2=layer.w2_weight,
)
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
return self.kernel(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
)
def forward_monolithic_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401
assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
return torch.ops.vllm.flashinfer_fused_moe_bf16(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
hidden_states=x,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routing_method_type=layer.routing_method_type,
)
def forward_monolithic_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.cpu_fused_moe(
layer,
x,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
layer.global_num_experts,
layer.expert_map,
layer.custom_routing_function,
layer.scoring_func,
layer.routed_scaling_factor,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.activation,
)
def forward_monolithic_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.ipex_fusion(
x,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
custom_routing_function=layer.custom_routing_function,
)
@CustomOp.register("transformers_fused_moe")
class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers modeling backend."""
9. Norm:
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
"""RMS Normalization with optional gating.
This is a native PyTorch implementation that supports:
- Standard RMS normalization
- Group RMS normalization
- Optional gating with SiLU activation
"""
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""
10. Quantization:
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
This CustomOp supports both static and dynamic quantization.
"""
11. Rope:
@CustomOp.register("rotary_embedding")
class RotaryEmbeddingBase(CustomOp):
"""Original rotary positional embedding."""
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
实现新 CustomOp 的指南¶
在 vLLM 中实现新的 CustomOp¶
本部分是一个教程,介绍如何在 vLLM 中实现一个新的 CustomOp。
步骤:
- 实现一个新的操作类,继承自
CustomOp基类。 - 在该操作类上添加
@CustomOp.register("op_name")装饰器,将其注册到CustomOp系统中。 - 根据需要实现不同的
forward_xxx()方法。
以 MMEncoderAttention 为例:
Code
```python @CustomOp.register("mm_encoder_attn") class MMEncoderAttention(CustomOp):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()
# Init...
def forward_native( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: # Call TORCH_SDPA implementation...
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# Call FA or TORCH_SDPA implementation...
def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# Call TORCH_SDPA implementation...
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# Call FA implementation...
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# Call PALLAS implementation...
```
在 OOT 设备插件中注册新的 CustomOp¶
目前,得益于 vLLM 的硬件插件机制,各种 OOT 设备插件不断涌现,使得 vLLM 能够无缝运行在不同的硬件上。您还可以在 Introducing vLLM Hardware Plugin, Best Practice from Ascend NPU 中找到有关此机制的更多详细信息。
- 官方设备插件: vllm-ascend(适用于华为昇腾 NPU)、vllm-spyre(适用于 Spyre)、vllm-gaudi(适用于英特尔 Gaudi)、vllm-neuron(适用于 AWS Neuron)、vllm-meta(适用于 Apple Silicon)等。
- 非官方设备插件: vllm-metax(适用于 MetaX GPU)、vllm-kunlun(适用于百度昆仑 XPU)等。
在这种情况下,CustomOp 可以让这些硬件制造商在运行时通过注册一个 OOT CustomOp 并实现 forward_oot() 方法,无缝地用其针对特定设备深度优化的内核替换 vLLM 的操作。
现在,本部分将向您展示如何为设备插件注册一个 OOT CustomOp。
以 MMEncoderAttention 为例:
- 实现一个
CustomMMEncoderAttention类,该类继承自MMEncoderAttention并实现其forward_oot()方法。 - 将您的
CustomMMEncoderAttention注册到 vLLM 中以替换MMEncoderAttention。
Code
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.custom_op import CustomOp
@CustomOp.register_oot("MMEncoderAttention")
class CustomMMEncoderAttention(MMEncoderAttention):
def __init__(...):
super().__init__(...)
def forward_oot(...):
# Call optimized device-specific kernels.
...
在这种情况下,一个新的条目 {"MMEncoderAttention": CustomMMEncoderAttention} 将被添加到 op_registry_oot 中。当初始化一个 MMEncoderAttention 操作对象时,如果类名(即 MMEncoderAttention)包含在 op_registry_oot 的键中,vLLM 将用我们注册的类(即 CustomMMEncoderAttention)替换它并实例化它。
之后,当调用此 MMEncoderAttention 操作时,如果启用了您的 forward_oot(),则会调用它。因此,您无需直接修改 vLLM 即可在硬件上获得预期的性能。
此外,您还可以将所有 CustomOp 注册在一个地方以便更好地管理。