diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..6d5dc887bd 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4072,6 +4072,7 @@ def attn_forward_func_with_cp( enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", + "a2a", "a2a+p2p", ], f"Context parallelism does not support MLA with {cp_comm_type=}!"