Skip to content

vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router

FusedTopKBiasRouter

Bases: BaseRouter

Router using fused top-k with e_score_correction_bias.

Source code in vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
class FusedTopKBiasRouter(BaseRouter):
    """Router using fused top-k with e_score_correction_bias."""

    def __init__(
        self,
        top_k: int,
        global_num_experts: int,
        eplb_state: EplbLayerState,
        e_score_correction_bias: torch.Tensor,
        scoring_func: str,
        renormalize: bool = True,
        routed_scaling_factor: float = 1.0,
        enable_eplb: bool = False,
        indices_type_getter: Callable[[], torch.dtype | None] | None = None,
    ):
        super().__init__(
            top_k=top_k,
            global_num_experts=global_num_experts,
            eplb_state=eplb_state,
            enable_eplb=enable_eplb,
            indices_type_getter=indices_type_getter,
        )
        self.e_score_correction_bias = e_score_correction_bias
        self.renormalize = renormalize
        self.scoring_func = scoring_func
        self.routed_scaling_factor = routed_scaling_factor

    @property
    def routing_method_type(self) -> RoutingMethodType:
        return (
            RoutingMethodType.Renormalize
            if not self.renormalize
            else RoutingMethodType.RenormalizeNaive
        )

    def _compute_routing(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        indices_type: torch.dtype | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute routing using fused top-k with bias."""
        topk_weights, topk_ids = fused_topk_bias(
            hidden_states=hidden_states,
            gating_output=router_logits,
            e_score_correction_bias=self.e_score_correction_bias.data,
            topk=self.top_k,
            renormalize=self.renormalize,
            scoring_func=self.scoring_func,
            indices_type=indices_type,
        )

        if self.routed_scaling_factor != 1.0:
            topk_weights *= self.routed_scaling_factor

        return topk_weights, topk_ids

e_score_correction_bias instance-attribute

e_score_correction_bias = e_score_correction_bias

renormalize instance-attribute

renormalize = renormalize

routed_scaling_factor instance-attribute

routed_scaling_factor = routed_scaling_factor

routing_method_type property

routing_method_type: RoutingMethodType

scoring_func instance-attribute

scoring_func = scoring_func

__init__

__init__(
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    e_score_correction_bias: Tensor,
    scoring_func: str,
    renormalize: bool = True,
    routed_scaling_factor: float = 1.0,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], dtype | None]
    | None = None,
)
Source code in vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
def __init__(
    self,
    top_k: int,
    global_num_experts: int,
    eplb_state: EplbLayerState,
    e_score_correction_bias: torch.Tensor,
    scoring_func: str,
    renormalize: bool = True,
    routed_scaling_factor: float = 1.0,
    enable_eplb: bool = False,
    indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
    super().__init__(
        top_k=top_k,
        global_num_experts=global_num_experts,
        eplb_state=eplb_state,
        enable_eplb=enable_eplb,
        indices_type_getter=indices_type_getter,
    )
    self.e_score_correction_bias = e_score_correction_bias
    self.renormalize = renormalize
    self.scoring_func = scoring_func
    self.routed_scaling_factor = routed_scaling_factor

_compute_routing

_compute_routing(
    hidden_states: Tensor,
    router_logits: Tensor,
    indices_type: dtype | None,
) -> tuple[Tensor, Tensor]

Compute routing using fused top-k with bias.

Source code in vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
def _compute_routing(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    indices_type: torch.dtype | None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute routing using fused top-k with bias."""
    topk_weights, topk_ids = fused_topk_bias(
        hidden_states=hidden_states,
        gating_output=router_logits,
        e_score_correction_bias=self.e_score_correction_bias.data,
        topk=self.top_k,
        renormalize=self.renormalize,
        scoring_func=self.scoring_func,
        indices_type=indices_type,
    )

    if self.routed_scaling_factor != 1.0:
        topk_weights *= self.routed_scaling_factor

    return topk_weights, topk_ids

fused_topk_bias

fused_topk_bias(
    hidden_states: Tensor,
    gating_output: Tensor,
    e_score_correction_bias: Tensor,
    topk: int,
    renormalize: bool,
    scoring_func: str = "softmax",
    indices_type: dtype | None = None,
)
Source code in vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
def fused_topk_bias(
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    e_score_correction_bias: torch.Tensor,
    topk: int,
    renormalize: bool,
    scoring_func: str = "softmax",
    indices_type: torch.dtype | None = None,
):
    if not rocm_aiter_ops.is_fused_moe_enabled():
        assert hidden_states.size(0) == gating_output.size(0), (
            "Number of tokens mismatch"
        )

        M, _ = hidden_states.size()

        topk_weights = torch.empty(
            M, topk, dtype=torch.float32, device=hidden_states.device
        )
        topk_ids = torch.empty(
            M,
            topk,
            dtype=torch.int32 if indices_type is None else indices_type,
            device=hidden_states.device,
        )
        token_expert_indices = torch.empty(
            M, topk, dtype=torch.int32, device=hidden_states.device
        )

        if scoring_func == "softmax":
            topk_weights, topk_ids = vllm_topk_softmax(
                topk_weights,
                topk_ids,
                token_expert_indices,
                gating_output,
                renormalize,
                e_score_correction_bias,
            )
            return topk_weights, topk_ids
        elif scoring_func == "sigmoid":
            topk_weights, topk_ids = vllm_topk_sigmoid(
                topk_weights,
                topk_ids,
                token_expert_indices,
                gating_output,
                renormalize,
                e_score_correction_bias,
            )
            return topk_weights, topk_ids
        else:
            raise ValueError(f"Unsupported scoring function: {scoring_func}")

    n_routed_experts = gating_output.shape[-1]
    if scoring_func == "softmax":
        scores = gating_output.softmax(dim=-1)
    elif scoring_func == "sigmoid":
        scores = gating_output.sigmoid()
    else:
        raise ValueError(f"Unsupported scoring function: {scoring_func}")

    scores_for_choice = scores.view(
        -1, n_routed_experts
    ) + e_score_correction_bias.unsqueeze(0)

    # For batch invariance, use sorted=True to ensure deterministic expert selection
    use_sorted = vllm_is_batch_invariant()
    topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
    topk_weights = scores.gather(1, topk_indices)
    if renormalize:
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
    return topk_weights.to(torch.float32), topk_indices.to(
        torch.int32 if indices_type is None else indices_type
    )

vllm_topk_sigmoid

vllm_topk_sigmoid(
    topk_weights: Tensor,
    topk_indices: Tensor,
    token_expert_indices: Tensor,
    gating_output: Tensor,
    renormalize: bool = False,
    e_score_correction_bias: Tensor | None = None,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
def vllm_topk_sigmoid(
    topk_weights: torch.Tensor,
    topk_indices: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: torch.Tensor,
    renormalize: bool = False,
    e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
    ops.topk_sigmoid(
        topk_weights,
        topk_indices,
        token_expert_indices,
        gating_output,
        renormalize,
        e_score_correction_bias,
    )

    return topk_weights, topk_indices

vllm_topk_softmax

vllm_topk_softmax(
    topk_weights: Tensor,
    topk_indices: Tensor,
    token_expert_indices: Tensor,
    gating_output: Tensor,
    renormalize: bool = False,
    e_score_correction_bias: Tensor | None = None,
) -> tuple[Tensor, ...]
Source code in vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
def vllm_topk_softmax(
    topk_weights: torch.Tensor,
    topk_indices: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: torch.Tensor,
    renormalize: bool = False,
    e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
    ops.topk_softmax(
        topk_weights,
        topk_indices,
        token_expert_indices,
        gating_output,
        renormalize,
        e_score_correction_bias,
    )

    return topk_weights, topk_indices