Skip to content

Mihir Test Hybrid Architecture

Source https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/mihir_test_hybrid_architecture.py.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Comprehensive test for the Hybrid Attention Architecture (Attention + SSM).

This script performs two levels of testing:
1.  **Synthetic Verification**: Validates the "plumbing" of the Hybrid backend and
    SSM adapter using the deterministic 'prefix-sum' mode described in the
    research paper. This runs without a real model checkpoint.
2.  **Integration Test**: Runs the full vLLM engine with a Hybrid model (if provided)
    to ensure end-to-end stability.

Usage:
    python examples/offline_inference/test_hybrid_architecture.py

    # To run the integration test with a specific model:
    export HYBRID_MODEL="your-org/step3-text-hybrid"
    python examples/offline_inference/test_hybrid_architecture.py
"""

import os
import sys
import torch
import torch.nn as nn
from typing import Optional

# Ensure we can import from vllm source
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))

from vllm import LLM, SamplingParams
from vllm.v1.attention.backends.hybrid_attn import HybridAttentionImpl
from vllm.model_executor.layers.hybrid_ssm_adapter import HybridSSMAdapter
from vllm.config import ModelConfig, CacheConfig, VllmConfig

# Mocking utilities to avoid needing a full model config for the synthetic test
from unittest.mock import MagicMock

def create_mock_configs():
    """Create minimal mock configs to satisfy HybridSSMAdapter initialization."""
    model_config = MagicMock(spec=ModelConfig)
    model_config.dtype = torch.float32
    model_config.get_num_attention_heads.return_value = 4
    model_config.get_head_size.return_value = 16

    cache_config = MagicMock(spec=CacheConfig)
    cache_config.mamba_cache_dtype = "auto"
    cache_config.mamba_ssm_cache_dtype = "auto"
    cache_config.sliding_window = None

    return model_config, cache_config

def run_synthetic_verification():
    """
    Executes the 'Prefix-Sum Verification' described in the research paper (Section 4.1).

    This test:
    1.  Instantiates the real `HybridSSMAdapter` and `HybridAttentionImpl`.
    2.  Sets `VLLM_HYBRID_SSM_MODE='prefix_sum'`.
    3.  Feeds a sequence of tokens [x1, x2, x3...].
    4.  Mocks the attention branch to return 0.
    5.  Verifies that the output is [x1, x1+x2, x1+x2+x3...].

    This proves that the SSM history branch is correctly wired and can carry 
    state across tokens, independent of the standard attention mechanism.
    """
    print("\n" + "=" * 80)
    print("Running Synthetic Prefix-Sum Verification (Paper Section 4.1)")
    print("=" * 80)

    # Force the environment variable for the adapter
    os.environ["VLLM_HYBRID_SSM_MODE"] = "prefix_sum"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on device: {device}")

    # Setup parameters
    num_heads = 2
    head_size = 8
    num_kv_heads = 2
    intermediate_size = 32
    ssm_state_size = 16
    num_tokens = 5

    # Create mock configs
    model_config, cache_config = create_mock_configs()

    # Instantiate the REAL adapter
    print("Initializing HybridSSMAdapter...")
    adapter = HybridSSMAdapter(
        hidden_size=num_heads * head_size,
        ssm_state_size=ssm_state_size,
        conv_kernel_size=3,
        intermediate_size=intermediate_size,
        model_config=model_config,
        cache_config=cache_config,
        prefix="test_layer.ssm"
    ).to(device)

    # Instantiate the backend
    print("Initializing HybridAttentionImpl...")
    impl = HybridAttentionImpl(
        num_heads=num_heads,
        head_size=head_size,
        scale=1.0,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype="auto",
    )

    # Mock the Triton attention implementation to return zeros
    # This isolates the contribution of the SSM branch.
    class _ZeroTritonImpl:
        def forward(self, *args, output, **kwargs):
            output.zero_()

    impl._triton_impl = _ZeroTritonImpl()

    # Create inputs: Simple 1, 2, 3, 4, 5 sequence broadcast across heads
    base = torch.arange(1, num_tokens + 1, dtype=torch.float32, device=device)
    # Shape: (num_tokens, num_heads, head_size)
    query = base.view(num_tokens, 1, 1).expand(num_tokens, num_heads, head_size)
    key = torch.zeros_like(query)
    value = torch.zeros_like(query)
    kv_cache = torch.empty(0, device=device)
    output = torch.empty_like(query)

    # Create minimal metadata
    class _Meta:
        def __init__(self, n):
            self.num_prefill_tokens = n
            self.num_actual_tokens = n
            # For the adapter's logic to trigger, it checks attn_metadata.mamba_metadata
            # or direct attributes if not wrapped. The adapter unwraps HybridAttentionMetadata.
            # Here we pass a simple object that has the attributes the adapter looks for
            # in the 'prefix_sum' path (mainly num_prefill_tokens).
            self.mamba_metadata = self
            self.triton_metadata = None

    attn_metadata = _Meta(num_tokens)

    # Create a dummy layer object as expected by HybridAttentionImpl
    class _ToyLayer:
        def __init__(self, ssm_adapter):
            self.ssm_adapter = ssm_adapter

    layer = _ToyLayer(adapter)

    print("Executing forward pass...")
    # The backend calls: triton_impl.forward (zeros) + ssm_adapter.forward (prefix sum)
    out = impl.forward(
        layer,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
    )

    # Verify results
    expected_prefix = torch.cumsum(base, dim=0).view(num_tokens, 1, 1).expand_as(out)

    try:
        assert torch.allclose(out, expected_prefix)
        print("\nSUCCESS: Output matches expected prefix sums!")
        print(f"Input (first head, first dim): {base.tolist()}")
        print(f"Output (first head, first dim): {out[:, 0, 0].tolist()}")
        print("State persistence verification passed.")
    except AssertionError:
        print("\nFAILURE: Output does not match expected prefix sums.")
        print(f"Expected: {expected_prefix[:, 0, 0].tolist()}")
        print(f"Got:      {out[:, 0, 0].tolist()}")
        raise

def run_integration_test(model_name: str):
    """
    Runs the full vLLM engine with the specified hybrid model.
    """
    print("\n" + "=" * 80)
    print(f"Running Integration Test with Model: {model_name}")
    print("=" * 80)

    prompts = [
        "The future of AI is",
        "Hybrid architectures allow",
    ]
    sampling_params = SamplingParams(temperature=0.0)

    try:
        llm = LLM(model=model_name)
        outputs = llm.generate(prompts, sampling_params)

        print("\nGenerated Outputs:\n" + "-" * 60)
        for output in outputs:
            prompt = output.prompt
            generated_text = output.outputs[0].text
            print(f"Prompt:    {prompt!r}")
            print(f"Output:    {generated_text!r}")
            print("-" * 60)
        print("\nIntegration test completed successfully (no crashes).")

    except Exception as e:
        print(f"\nIntegration test failed with error: {e}")
        print("Note: Ensure the model config has 'model_type': 'step3_text' and 'use_hybrid_step3_attn': true")

if __name__ == "__main__":
    # 1. Run the synthetic verification (no model required)
    import traceback
    try:
        run_synthetic_verification()
    except Exception as e:
        traceback.print_exc()
        sys.exit(1)

    # 2. Run integration test if a model is provided
    model_name = os.getenv("HYBRID_MODEL")
    if model_name:
        run_integration_test(model_name)
    else:
        print("\n[INFO] Skipping integration test. Set HYBRID_MODEL env var to run end-to-end.")
        print("Example: export HYBRID_MODEL='your-org/step3-text-hybrid'")