diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 0a5c630862..722c541b26 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -396,6 +396,13 @@ def asm_gemm_all_solutions(self): ) if self.k / splitK < subK: break + # splitK kernels use a semaphore array of size gdx*gdy; skip + # candidates where the grid exceeds the 1024-entry limit. + if splitK > 1: + gdx = (self.n + tile_n - 1) // tile_n + gdy = (self.m + tile_m - 1) // tile_m + if gdx * gdy > 1024: + continue task_asm.append( ( info,