diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..09630cc35d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -961,11 +961,10 @@ def cp_p2p_fwd_flash_attn( **fa_forward_kwargs, ) rng_states = None - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_per_step = fa_outputs[4] softmax_lse_per_step = fa_outputs[5] - if not use_flash_attn_3: - rng_states = fa_outputs[7] + rng_states = fa_outputs[7] else: out_per_step = fa_outputs[0] softmax_lse_per_step = fa_outputs[1] @@ -3006,11 +3005,10 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_per_step[i] = fa_outputs[4] softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] + rng_states[i] = fa_outputs[7] else: out_per_step[i] = fa_outputs[0] softmax_lse_per_step[i] = fa_outputs[1] @@ -3544,9 +3542,9 @@ def forward( causal=causal, **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: + if not use_flash_attn_3 and not fa_utils.v2_7_0_plus: out_, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not use_flash_attn_3 else None + rng_state = fa_outputs[7] else: out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None