Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aiter/ops/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def rmsnorm2d_fwd(
epsilon: float,
use_model_sensitive_rmsnorm: int = 0,
) -> Tensor:
# Ensure weight dtype matches input dtype. The underlying HIP/CK kernels
# read weight memory using input's dtype stride, so a dtype mismatch
# (e.g. float32 weight with bfloat16 input) causes silent data corruption.
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
Comment on lines +68 to +72
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
out = rmsnorm2d_fwd_ck(input, weight, epsilon, use_model_sensitive_rmsnorm)
else:
Expand All @@ -82,6 +87,8 @@ def rmsnorm2d_fwd_with_add(
epsilon: float,
use_model_sensitive_rmsnorm: int = 0,
) -> None:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
rmsnorm2d_fwd_with_add_ck(
out,
Expand Down
49 changes: 49 additions & 0 deletions aiter/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,46 @@ def get_next_input_positions(
]


class ProportionalRotaryEmbedding(RotaryEmbedding):
"""Gemma4-style proportional RoPE.

Frequency exponents use head_size (not rotary_dim) as the denominator.
Non-rotated dims are zero-padded so cos=1, sin=0 (identity rotation).
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
self.rope_angles = rotary_dim // 2
self.nope_angles = (head_size // 2) - self.rope_angles
super().__init__(
head_size,
head_size,
max_position_embeddings,
base,
is_neox_style,
dtype,
)
Comment on lines +1187 to +1196

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
freq_exponents = (
torch.arange(0, 2 * self.rope_angles, 2, dtype=dtypes.fp32)
/ self.head_size
)
inv_freq = 1.0 / (base**freq_exponents)
if self.nope_angles > 0:
inv_freq = torch.cat(
[inv_freq, torch.zeros(self.nope_angles, dtype=dtypes.fp32)]
)
return inv_freq


@dataclass
class AiterFusedSetKVBufferArg:
kv_cache: Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -1933,6 +1973,15 @@ def get_rope(
dtype,
mrope_section=rope_scaling["mrope_section"],
)
elif scaling_type == "proportional":
rotary_emb = ProportionalRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down
Loading