Skip to content

feat: add Gemma4 31B support (ProportionalRotaryEmbedding, rmsnorm dtype)#2705

Open
ClementLinCF wants to merge 3 commits intomainfrom
gemma4-dev
Open

feat: add Gemma4 31B support (ProportionalRotaryEmbedding, rmsnorm dtype)#2705
ClementLinCF wants to merge 3 commits intomainfrom
gemma4-dev

Conversation

@ClementLinCF
Copy link
Copy Markdown
Collaborator

Motivation

Enable Gemma 4 31B-it inference on AMD Instinct MI355X (gfx950) via ATOM. Gemma 4 introduces two aiter-level requirements not covered by existing operators: a proportional RoPE variant where frequency exponents use head_size as denominator with partial rotation, and a weight/input dtype mismatch in RMSNorm that causes silent data corruption on HIP/CK kernels.

Technical Details

New ProportionalRotaryEmbedding class for Gemma 4's rope_type: "proportional":
Frequency exponents use head_size (not rotary_dim) as denominator: inv_freq = 1 / (base ^ (arange(0, 2*rope_angles, 2) / head_size))
Non-rotated dimensions zero-padded so cos=1, sin=0 (identity rotation)
Registered in get_rope() factory under scaling_type == "proportional"
Gemma 4 uses partial_rotary_factor=0.5 for sliding attention (head_dim=128, rotary_dim=64) and partial_rotary_factor=0.25 for global attention (head_dim=256, rotary_dim=64). Standard RotaryEmbedding would compute wrong frequencies because it uses rotary_dim as denominator.

Added dtype guard in rmsnorm2d_fwd and rmsnorm2d_fwd_with_add:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
Gemma 4's RMSNorm weights are float32 while activations are bfloat16. The underlying HIP/CK kernels read weight memory using input's dtype stride, causing silent data corruption (no error, garbage output) on dtype mismatch. This cast fixes Gemma 4 and is safe for all other models (no-op when dtypes already match).

Test Plan

[x] ATOM standalone (Config A): Gemma 4 31B-it, TP=8, 8×MI355X, CUDA graphs enabled
[x] vLLM + ATOM plugin (Config B): same model/hardware
[x] Correctness: Math/Joke/Capital prompts verified
[x] Verify rmsnorm dtype fix: no silent corruption with float32 weights + bfloat16 input

Test Result

Concurrency = 2 (low concurrency)

Metric Config A (Standalone) Config B (vLLM Plugin)
Output tok/s 169.10 136.97
Mean TTFT 58.89 ms 178.64 ms
Median TTFT 53.57 ms 89.83 ms
P99 TTFT 98.28 ms 1537.27 ms
Mean TPOT 11.46 ms 13.30 ms
Median TPOT 11.46 ms 13.02 ms
P99 TPOT 11.61 ms 19.53 ms
Mean E2EL 1513.80 ms 1867.69 ms
Failed 0 0

Concurrency = 40 (high concurrency)

Metric Config A (Standalone) Config B (vLLM Plugin)
Output tok/s 2018.86 1681.18
Mean TTFT 645.75 ms 1066.13 ms
Median TTFT 705.84 ms 304.08 ms
P99 TTFT 727.92 ms 2045.75 ms
Mean TPOT 14.85 ms 15.50 ms
Median TPOT 14.29 ms 15.46 ms
P99 TPOT 17.53 ms 29.84 ms
Mean E2EL 2532.09 ms 3035.50 ms
Failed 0 0

Submission Checklist

ClementLinCF and others added 2 commits April 8, 2026 11:40
Add ProportionalRotaryEmbedding class that implements Gemma 4-style
partial rotary encoding. Frequency exponents use head_size as the
denominator (not rotary_dim), and non-rotated dimensions are zero-padded
so cos=1, sin=0 (identity rotation). Also register "proportional" as a
recognized scaling type in get_rope().
…orruption

The underlying HIP/CK rmsnorm kernels read weight memory using the
input tensor's dtype stride. When weight has a different dtype from
input (e.g. float32 weight with bfloat16 input), the kernel
misinterprets the weight bytes, causing:
- Half of output values become zero
- Remaining values can be astronomically large (up to 1e24)

This commonly occurs with Gemma-style RMSNorm where the normalization
uses x * (1 + weight). The expression `weight.data + 1.0` promotes
bfloat16 weight to float32 in some contexts (e.g. during model init
when weight is still float32 default).

Fix: auto-cast weight to match input dtype before passing to kernels.
Added to both rmsnorm2d_fwd and rmsnorm2d_fwd_with_add.

Reproduction:
```python
import torch
from aiter import rmsnorm2d_fwd
x = torch.randn(4, 5376, dtype=torch.bfloat16, device="cuda")
w = torch.randn(5376, dtype=torch.float32, device="cuda")  # wrong dtype!
out = rmsnorm2d_fwd(x, w, 1e-6)
# Before fix: ~half zeros, max_abs ~1e24
# After fix: correct output matching native PyTorch
```
@ClementLinCF ClementLinCF requested review from a team and Copilot April 12, 2026 12:42
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2705 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Gemma 4 31B operator-level compatibility in AITER by introducing a new RoPE variant and hardening RMSNorm against dtype-mismatch corruption on HIP/CK paths.

Changes:

  • Add ProportionalRotaryEmbedding and register it in get_rope() under rope_type == "proportional".
  • Add a dtype guard to cast RMSNorm weight to input.dtype in rmsnorm2d_fwd and rmsnorm2d_fwd_with_add.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
aiter/rotary_embedding.py Introduces the new proportional RoPE embedding class and wires it into the RoPE factory.
aiter/ops/rmsnorm.py Adds runtime dtype normalization for RMSNorm weights to avoid silent corruption on mismatch.
Comments suppressed due to low confidence (1)

aiter/ops/rmsnorm.py:78

  • The dtype-compat guard is added only for rmsnorm2d_fwd and rmsnorm2d_fwd_with_add, but the underlying kernels for other entry points in this module (e.g., rmsnorm/add_rmsnorm and the _smoothquant/_dynamicquant variants) also reinterpret weight as the input dtype in C++/HIP. That means callers can still hit the same silent corruption through those paths. Consider applying the same dtype normalization (or a hard error) consistently across all public rmsnorm wrappers in this file.
def rmsnorm2d_fwd(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
    use_model_sensitive_rmsnorm: int = 0,
) -> Tensor:
    # Ensure weight dtype matches input dtype. The underlying HIP/CK kernels
    # read weight memory using input's dtype stride, so a dtype mismatch
    # (e.g. float32 weight with bfloat16 input) causes silent data corruption.
    if weight.dtype != input.dtype:
        weight = weight.to(input.dtype)
    if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
        out = rmsnorm2d_fwd_ck(input, weight, epsilon, use_model_sensitive_rmsnorm)
    else:
        out = torch.empty_like(input, dtype=input.dtype, device=input.device)
        rmsnorm(out, input, weight, epsilon)
    return out

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1187 to +1196
self.rope_angles = rotary_dim // 2
self.nope_angles = (head_size // 2) - self.rope_angles
super().__init__(
head_size,
head_size,
max_position_embeddings,
base,
is_neox_style,
dtype,
)
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
Comment on lines +68 to +72
# Ensure weight dtype matches input dtype. The underlying HIP/CK kernels
# read weight memory using input's dtype stride, so a dtype mismatch
# (e.g. float32 weight with bfloat16 input) causes silent data corruption.
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants