Skip to content

vllm.model_executor.layers.hybrid_attn_layer

HybridAttentionLayer

Bases: Attention, AttentionLayerBase

Attention layer that fuses sliding-window KV with an SSM history branch.

This layer is a thin wrapper around the standard Attention module that:

  • Forces the use of HybridAttentionBackend for its attention backend.
  • Owns a HybridSSMAdapter instance representing the history branch.
  • Reuses Attention.get_kv_cache_spec so it continues to expose either a SlidingWindowSpec or FullAttentionSpec for its KV cache.
Source code in vllm/model_executor/layers/hybrid_attn_layer.py
class HybridAttentionLayer(Attention, AttentionLayerBase):
    """Attention layer that fuses sliding-window KV with an SSM history branch.

    This layer is a thin wrapper around the standard ``Attention`` module that:

    - Forces the use of ``HybridAttentionBackend`` for its attention backend.
    - Owns a ``HybridSSMAdapter`` instance representing the history branch.
    - Reuses ``Attention.get_kv_cache_spec`` so it continues to expose either
      a ``SlidingWindowSpec`` or ``FullAttentionSpec`` for its KV cache.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        *,
        ssm_state_size: int,
        ssm_conv_kernel_size: int,
        ssm_intermediate_size: int,
        cache_config: CacheConfig | None = None,
        prefix: str = "",
        **extra_impl_args,
    ) -> None:
        # First, initialize the underlying Attention module so that nn.Module
        # internals (such as _modules) are ready before we attach submodules
        # like the SSM adapter.
        #
        # We force the attention backend to be HybridAttentionBackend while
        # reusing all of Attention's internal wiring (KV cache quantization,
        # static forward context registration, etc.).
        super().__init__(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            cache_config=cache_config,
            prefix=prefix,
            attn_backend=HybridAttentionBackend,
            **extra_impl_args,
        )

        # Initialize the history branch adapter so it can participate in the
        # v1 KV cache spec discovery with a distinct layer prefix.
        vllm_config = get_current_vllm_config()
        model_config = vllm_config.model_config
        self.ssm_adapter = HybridSSMAdapter(
            hidden_size=num_heads * head_size,
            ssm_state_size=ssm_state_size,
            conv_kernel_size=ssm_conv_kernel_size,
            intermediate_size=ssm_intermediate_size,
            model_config=model_config,
            cache_config=cache_config or vllm_config.cache_config,
            prefix=f"{prefix}.ssm",
        )

    def get_attn_backend(self) -> type[AttentionBackend]:
        # Satisfy ``AttentionLayerBase`` by returning the concrete backend
        # class for this layer.
        return HybridAttentionBackend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Delegate KV spec computation to the parent Attention implementation.
        return super().get_kv_cache_spec(vllm_config)

ssm_adapter instance-attribute

ssm_adapter = HybridSSMAdapter(
    hidden_size=num_heads * head_size,
    ssm_state_size=ssm_state_size,
    conv_kernel_size=ssm_conv_kernel_size,
    intermediate_size=ssm_intermediate_size,
    model_config=model_config,
    cache_config=cache_config or cache_config,
    prefix=f"{prefix}.ssm",
)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    *,
    ssm_state_size: int,
    ssm_conv_kernel_size: int,
    ssm_intermediate_size: int,
    cache_config: CacheConfig | None = None,
    prefix: str = "",
    **extra_impl_args,
) -> None
Source code in vllm/model_executor/layers/hybrid_attn_layer.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int | None = None,
    *,
    ssm_state_size: int,
    ssm_conv_kernel_size: int,
    ssm_intermediate_size: int,
    cache_config: CacheConfig | None = None,
    prefix: str = "",
    **extra_impl_args,
) -> None:
    # First, initialize the underlying Attention module so that nn.Module
    # internals (such as _modules) are ready before we attach submodules
    # like the SSM adapter.
    #
    # We force the attention backend to be HybridAttentionBackend while
    # reusing all of Attention's internal wiring (KV cache quantization,
    # static forward context registration, etc.).
    super().__init__(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=num_kv_heads,
        cache_config=cache_config,
        prefix=prefix,
        attn_backend=HybridAttentionBackend,
        **extra_impl_args,
    )

    # Initialize the history branch adapter so it can participate in the
    # v1 KV cache spec discovery with a distinct layer prefix.
    vllm_config = get_current_vllm_config()
    model_config = vllm_config.model_config
    self.ssm_adapter = HybridSSMAdapter(
        hidden_size=num_heads * head_size,
        ssm_state_size=ssm_state_size,
        conv_kernel_size=ssm_conv_kernel_size,
        intermediate_size=ssm_intermediate_size,
        model_config=model_config,
        cache_config=cache_config or vllm_config.cache_config,
        prefix=f"{prefix}.ssm",
    )

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/hybrid_attn_layer.py
def get_attn_backend(self) -> type[AttentionBackend]:
    # Satisfy ``AttentionLayerBase`` by returning the concrete backend
    # class for this layer.
    return HybridAttentionBackend

get_kv_cache_spec

get_kv_cache_spec(vllm_config: VllmConfig) -> KVCacheSpec
Source code in vllm/model_executor/layers/hybrid_attn_layer.py
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
    # Delegate KV spec computation to the parent Attention implementation.
    return super().get_kv_cache_spec(vllm_config)