feat: add Gemma4 31B support (ProportionalRotaryEmbedding, rmsnorm dtype)#2705
Open
ClementLinCF wants to merge 3 commits intomainfrom
Open
feat: add Gemma4 31B support (ProportionalRotaryEmbedding, rmsnorm dtype)#2705ClementLinCF wants to merge 3 commits intomainfrom
ClementLinCF wants to merge 3 commits intomainfrom
Conversation
Add ProportionalRotaryEmbedding class that implements Gemma 4-style partial rotary encoding. Frequency exponents use head_size as the denominator (not rotary_dim), and non-rotated dimensions are zero-padded so cos=1, sin=0 (identity rotation). Also register "proportional" as a recognized scaling type in get_rope().
…orruption The underlying HIP/CK rmsnorm kernels read weight memory using the input tensor's dtype stride. When weight has a different dtype from input (e.g. float32 weight with bfloat16 input), the kernel misinterprets the weight bytes, causing: - Half of output values become zero - Remaining values can be astronomically large (up to 1e24) This commonly occurs with Gemma-style RMSNorm where the normalization uses x * (1 + weight). The expression `weight.data + 1.0` promotes bfloat16 weight to float32 in some contexts (e.g. during model init when weight is still float32 default). Fix: auto-cast weight to match input dtype before passing to kernels. Added to both rmsnorm2d_fwd and rmsnorm2d_fwd_with_add. Reproduction: ```python import torch from aiter import rmsnorm2d_fwd x = torch.randn(4, 5376, dtype=torch.bfloat16, device="cuda") w = torch.randn(5376, dtype=torch.float32, device="cuda") # wrong dtype! out = rmsnorm2d_fwd(x, w, 1e-6) # Before fix: ~half zeros, max_abs ~1e24 # After fix: correct output matching native PyTorch ```
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds Gemma 4 31B operator-level compatibility in AITER by introducing a new RoPE variant and hardening RMSNorm against dtype-mismatch corruption on HIP/CK paths.
Changes:
- Add
ProportionalRotaryEmbeddingand register it inget_rope()underrope_type == "proportional". - Add a dtype guard to cast RMSNorm
weighttoinput.dtypeinrmsnorm2d_fwdandrmsnorm2d_fwd_with_add.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
aiter/rotary_embedding.py |
Introduces the new proportional RoPE embedding class and wires it into the RoPE factory. |
aiter/ops/rmsnorm.py |
Adds runtime dtype normalization for RMSNorm weights to avoid silent corruption on mismatch. |
Comments suppressed due to low confidence (1)
aiter/ops/rmsnorm.py:78
- The dtype-compat guard is added only for rmsnorm2d_fwd and rmsnorm2d_fwd_with_add, but the underlying kernels for other entry points in this module (e.g., rmsnorm/add_rmsnorm and the _smoothquant/_dynamicquant variants) also reinterpret weight as the input dtype in C++/HIP. That means callers can still hit the same silent corruption through those paths. Consider applying the same dtype normalization (or a hard error) consistently across all public rmsnorm wrappers in this file.
def rmsnorm2d_fwd(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
use_model_sensitive_rmsnorm: int = 0,
) -> Tensor:
# Ensure weight dtype matches input dtype. The underlying HIP/CK kernels
# read weight memory using input's dtype stride, so a dtype mismatch
# (e.g. float32 weight with bfloat16 input) causes silent data corruption.
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
out = rmsnorm2d_fwd_ck(input, weight, epsilon, use_model_sensitive_rmsnorm)
else:
out = torch.empty_like(input, dtype=input.dtype, device=input.device)
rmsnorm(out, input, weight, epsilon)
return out
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1187
to
+1196
| self.rope_angles = rotary_dim // 2 | ||
| self.nope_angles = (head_size // 2) - self.rope_angles | ||
| super().__init__( | ||
| head_size, | ||
| head_size, | ||
| max_position_embeddings, | ||
| base, | ||
| is_neox_style, | ||
| dtype, | ||
| ) |
| base: float, | ||
| is_neox_style: bool, | ||
| dtype: torch.dtype, | ||
| ) -> None: |
Comment on lines
+68
to
+72
| # Ensure weight dtype matches input dtype. The underlying HIP/CK kernels | ||
| # read weight memory using input's dtype stride, so a dtype mismatch | ||
| # (e.g. float32 weight with bfloat16 input) causes silent data corruption. | ||
| if weight.dtype != input.dtype: | ||
| weight = weight.to(input.dtype) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Enable Gemma 4 31B-it inference on AMD Instinct MI355X (gfx950) via ATOM. Gemma 4 introduces two aiter-level requirements not covered by existing operators: a proportional RoPE variant where frequency exponents use head_size as denominator with partial rotation, and a weight/input dtype mismatch in RMSNorm that causes silent data corruption on HIP/CK kernels.
Technical Details
New ProportionalRotaryEmbedding class for Gemma 4's rope_type: "proportional":
Frequency exponents use head_size (not rotary_dim) as denominator: inv_freq = 1 / (base ^ (arange(0, 2*rope_angles, 2) / head_size))
Non-rotated dimensions zero-padded so cos=1, sin=0 (identity rotation)
Registered in get_rope() factory under scaling_type == "proportional"
Gemma 4 uses partial_rotary_factor=0.5 for sliding attention (head_dim=128, rotary_dim=64) and partial_rotary_factor=0.25 for global attention (head_dim=256, rotary_dim=64). Standard RotaryEmbedding would compute wrong frequencies because it uses rotary_dim as denominator.
Added dtype guard in rmsnorm2d_fwd and rmsnorm2d_fwd_with_add:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
Gemma 4's RMSNorm weights are float32 while activations are bfloat16. The underlying HIP/CK kernels read weight memory using input's dtype stride, causing silent data corruption (no error, garbage output) on dtype mismatch. This cast fixes Gemma 4 and is safe for all other models (no-op when dtypes already match).
Test Plan
[x] ATOM standalone (Config A): Gemma 4 31B-it, TP=8, 8×MI355X, CUDA graphs enabled
[x] vLLM + ATOM plugin (Config B): same model/hardware
[x] Correctness: Math/Joke/Capital prompts verified
[x] Verify rmsnorm dtype fix: no silent corruption with float32 weights + bfloat16 input
Test Result
Concurrency = 2 (low concurrency)
Concurrency = 40 (high concurrency)
Submission Checklist