diff --git a/xla/backends/gpu/autotuner/triton.cc b/xla/backends/gpu/autotuner/triton.cc index b11bb7889d721..0dfd531de48f9 100644 --- a/xla/backends/gpu/autotuner/triton.cc +++ b/xla/backends/gpu/autotuner/triton.cc @@ -65,6 +65,12 @@ namespace { std::vector GetDefaultTritonConfigs( se::GpuComputeCapability compute_capability) { if (compute_capability.IsRocm()) { + const auto* rocm_cc = compute_capability.rocm_compute_capability(); + if (rocm_cc->gfx9_mi300()) { + return GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI300); + } else if (rocm_cc->gfx9_mi350()) { + return GetTritonConfigsForPlatform(TritonConfigsPlatform::kMI350); + } return GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultRocm); } diff --git a/xla/backends/gpu/autotuner/triton/default_configs/mi300.txtpb b/xla/backends/gpu/autotuner/triton/default_configs/mi300.txtpb new file mode 100644 index 0000000000000..aca70ca28511e --- /dev/null +++ b/xla/backends/gpu/autotuner/triton/default_configs/mi300.txtpb @@ -0,0 +1,47 @@ +# Copyright 2026 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config { block_m: 32 block_n: 32 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 64 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 128 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 8 block_k: 16 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 128 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 32 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 2 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 1 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 256 block_k: 32 split_k: 4 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 32 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 8 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 32 split_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 32 block_k: 32 split_k: 8 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 128 split_k: 2 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 256 block_n: 8 block_k: 32 split_k: 4 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 128 split_k: 2 num_stages: 5 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 128 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 256 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 256 block_n: 8 block_k: 16 split_k: 8 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 8 block_k: 16 split_k: 32 num_stages: 1 num_warps: 2 num_ctas: 1 } diff --git a/xla/backends/gpu/autotuner/triton/default_configs/mi350.txtpb b/xla/backends/gpu/autotuner/triton/default_configs/mi350.txtpb new file mode 100644 index 0000000000000..849069a051ffe --- /dev/null +++ b/xla/backends/gpu/autotuner/triton/default_configs/mi350.txtpb @@ -0,0 +1,54 @@ +# Copyright 2026 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config { block_m: 32 block_n: 32 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 64 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 128 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 8 block_k: 16 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 8 block_k: 32 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 256 block_n: 256 block_k: 16 split_k: 2 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 64 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 64 split_k: 4 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 16 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 128 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 64 split_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 16 split_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 16 split_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 32 split_k: 8 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 16 split_k: 8 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 8 block_k: 256 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 4 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 8 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 8 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 64 block_k: 128 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 64 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 16 split_k: 4 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 64 split_k: 16 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 256 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 256 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 8 block_k: 32 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 64 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 16 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } diff --git a/xla/backends/gpu/autotuner/triton/triton_configs.cc b/xla/backends/gpu/autotuner/triton/triton_configs.cc index 7d56598826e79..e4ab7a74f7b22 100644 --- a/xla/backends/gpu/autotuner/triton/triton_configs.cc +++ b/xla/backends/gpu/autotuner/triton/triton_configs.cc @@ -58,7 +58,10 @@ const std::vector& GetTritonConfigsForPlatform( ParseConfig(configs::get_cuda())}, {TritonConfigsPlatform::kDefaultRocm, ParseConfig(configs::get_rocm())}, - {TritonConfigsPlatform::kHopper, ParseConfig(configs::get_h100())}}); + {TritonConfigsPlatform::kHopper, ParseConfig(configs::get_h100())}, + {TritonConfigsPlatform::kMI300, ParseConfig(configs::get_mi300())}, + {TritonConfigsPlatform::kMI350, + ParseConfig(configs::get_mi350())}}); return kConfigs->at(platform); } diff --git a/xla/backends/gpu/autotuner/triton/triton_configs.h b/xla/backends/gpu/autotuner/triton/triton_configs.h index c6b30359e9965..4d4c9e268adbd 100644 --- a/xla/backends/gpu/autotuner/triton/triton_configs.h +++ b/xla/backends/gpu/autotuner/triton/triton_configs.h @@ -28,6 +28,8 @@ enum class TritonConfigsPlatform { kDefaultCuda, kDefaultRocm, kHopper, + kMI300, + kMI350, }; const std::vector& GetTritonConfigsForPlatform(