diff --git a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv index e5d9f3e300..4058a37ad9 100644 --- a/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv +++ b/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv @@ -18,6 +18,7 @@ cu_num,M,N,K,q_dtype_w,libtype,kernelId,splitK,us,kernelName,tflops,bw,errRatio 80,128,1536,5120,torch.int8,asm,3,3,14.8921,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_64x128E,135.19,598.5,0.1017 80,150,1536,5120,torch.int8,asm,4,3,15.3652,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_80x128E,153.55,591.8,0.1017 80,192,1536,1024,torch.int8,asm,1,1,6.9538,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_32x128E,86.86,339.28,0.0 +80,192,1536,5120,torch.float8_e4m3fnuz,cktile,9,0,18.9229,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,159.59,498.72,0.0 80,192,1536,5120,torch.int8,asm,3,2,17.2532,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_64x128E,175.03,546.98,0.0645 80,220,1536,5120,torch.int8,asm,4,2,18.9054,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_80x128E,183.03,511.31,0.0653 80,256,1536,5120,torch.int8,asm,5,2,20.7976,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_96x128E,193.61,478.97,0.0645 @@ -35,7 +36,6 @@ cu_num,M,N,K,q_dtype_w,libtype,kernelId,splitK,us,kernelName,tflops,bw,errRatio 80,4096,8192,1024,torch.int8,asm,6,1,232.1688,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,295.99,343.25,0.0 80,8192,8192,1024,torch.int8,asm,6,1,459.697,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,298.98,328.47,0.0 80,16384,8192,1024,torch.int8,asm,6,1,900.9508,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,305.1,325.88,0.0 -80,192,1536,5120,torch.float8_e4m3fnuz,cktile,9,0,18.9229,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x64_default,159.59,498.72,0.0 256,64,192,1024,torch.float8_e4m3fn,flydsl,989,0,3.199,flydsl_bpreshuflle_16x64x512_F8_F8_B16_1x0x0x3_default,7.87,89.63,0.0 256,32,384,7168,torch.float8_e4m3fn,ck,10,0,10.6799,a8w8_bpreshuffle_128x16x32x512_16x16_16x16_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v1,16.49,281.51,0.0 256,64,384,7168,torch.float8_e4m3fn,ck,10,0,10.0171,a8w8_bpreshuffle_128x16x32x512_16x16_16x16_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v1,35.17,325.49,0.0 @@ -74,10 +74,10 @@ cu_num,M,N,K,q_dtype_w,libtype,kernelId,splitK,us,kernelName,tflops,bw,errRatio 256,32,1280,8192,torch.float8_e4m3fn,ck,10,0,11.5596,a8w8_bpreshuffle_128x16x32x512_16x16_16x16_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v1,58.05,936.87,0.0 256,64,1280,8192,torch.float8_e4m3fn,ck,8,0,11.6259,a8w8_bpreshuffle_128x32x16x512_16x16_16x16_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v1,115.45,961.12,0.0 256,128,1280,8192,torch.float8_e4m3fn,ck,11,0,12.3468,a8w8_bpreshuffle_256x16x64x512_16x16_16x16_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v1,217.41,960.74,0.0 -256,256,1280,8192,torch.float8_e4m3fn,cktile,9,0,14.9227,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_32x64x512_1x4x1_16x16x128_default,359.77,887.12,0.0 -256,512,1280,8192,torch.float8_e4m3fn,ck,114,0,20.8376,a8w8_bpreshuffle_256x64x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3,515.29,767.4,0.0 -256,1024,1280,8192,torch.float8_e4m3fn,cktile,216,0,27.0331,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x3_48x64x256_1x4x1_16x16x128_default,794.39,795.17,0.0 -256,2048,1280,8192,torch.float8_e4m3fn,cktile,99,0,36.1565,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_96x128x256_1x4x1_16x16x128_default,1187.88,899.03,0.0 +256,256,1280,8192,torch.float8_e4m3fn,flydsl,857,0,16.4308,flydsl_bpreshuflle_32x64x512_F8_F8_B16_2x1x0x2_default,326.75,805.7,0.0 +256,512,1280,8192,torch.float8_e4m3fn,ck,114,0,20.7503,a8w8_bpreshuffle_256x64x64x256_16x16_16x16_16x16x1_16x16x1_1x32x1x8_8x8x1_2x1_intrawave_v3,517.46,770.63,0.0 +256,1024,1280,8192,torch.float8_e4m3fn,flydsl,460,0,27.4148,flydsl_bpreshuflle_64x128x256_F8_F8_B16_2x0x1x1_default,783.33,784.1,0.0 +256,2048,1280,8192,torch.float8_e4m3fn,ck,139,0,39.3436,a8w8_bpreshuffle_256x128x128x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3,1091.66,826.2,0.0 256,4096,1280,8192,torch.float8_e4m3fn,ck,138,0,57.904,a8w8_bpreshuffle_256x112x256x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3,1483.48,941.66,0.0 256,8192,1280,8192,torch.float8_e4m3fn,ck,51,0,94.2147,a8w8_bpreshuffle_256x192x128x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,1823.48,1046.19,0.0 256,16384,1280,8192,torch.float8_e4m3fn,ck,154,0,170.1443,a8w8_bpreshuffle_256x128x256x256_16x16_16x16_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3,2019.45,1096.99,0.0 @@ -477,38 +477,6 @@ cu_num,M,N,K,q_dtype_w,libtype,kernelId,splitK,us,kernelName,tflops,bw,errRatio 256,4096,26624,16384,torch.float8_e4m3fn,cktile,10,0,1328.8625,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_256x256x128_1x4x1_16x16x128_default,2689.08,542.89,0.0 256,8192,26624,16384,torch.float8_e4m3fn,cktile,155,0,2500.8282,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x4_256x256x128_1x4x1_16x16x128_default,2857.78,402.52,0.0 256,16384,26624,16384,torch.float8_e4m3fn,cktile,10,0,4993.8298,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_256x256x128_1x4x1_16x16x128_default,2862.26,315.8,0.0 -304,64,1536,5120,torch.int8,asm,0,1,16.7759,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,60.0,500.04,0.0 -304,64,5120,1280,torch.int8,asm,0,1,7.7424,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,108.35,941.68,0.0 -304,128,1536,5120,torch.int8,asm,0,1,17.9602,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,112.1,496.26,0.0 -304,128,5120,1280,torch.int8,asm,1,1,8.5797,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_32x128E,195.55,935.72,0.0 -304,256,1536,5120,torch.int8,asm,0,1,17.6791,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,227.76,563.46,0.0 -304,256,5120,1280,torch.int8,asm,2,1,9.2104,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_48x128E,364.31,1031.74,0.0 -304,512,1536,5120,torch.int8,asm,1,1,18.7067,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_32x128E,430.49,644.62,0.0 -304,512,5120,1280,torch.int8,asm,4,1,11.2508,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_80x128E,596.48,1106.75,0.0 -304,1024,1536,5120,torch.int8,asm,2,1,21.0125,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_48x128E,766.5,773.49,0.0 -304,1024,5120,1280,torch.int8,asm,6,1,18.2637,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,734.89,1004.73,0.0 -304,1664,1536,5120,torch.int8,asm,4,1,27.6261,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_80x128E,947.38,778.1,0.0 -304,1664,5120,1280,torch.int8,asm,6,1,21.5241,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1013.3,1195.07,0.0 -304,4096,1536,5120,torch.int8,asm,6,1,63.8361,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1009.22,648.83,0.0 -304,4096,5120,1280,torch.int8,asm,6,1,58.5804,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,916.47,917.36,0.0 -304,8192,1536,5120,torch.int8,asm,6,1,128.8567,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,999.94,581.83,0.0 -304,8192,5120,1280,torch.int8,asm,6,1,105.657,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1016.25,955.22,0.0 -304,10240,1536,5120,torch.int8,asm,6,1,142.1441,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1133.08,645.47,0.0 -304,10240,5120,1280,torch.int8,asm,6,1,133.9571,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1001.95,929.54,0.0 -304,12288,1536,5120,torch.int8,asm,7,1,185.0056,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_128x128E,1044.69,586.62,0.0 -304,12288,5120,1280,torch.int8,asm,6,1,153.8919,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1046.59,962.44,0.0 -304,16384,1536,5120,torch.int8,asm,6,1,211.7054,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1217.25,671.13,0.0 -304,16384,5120,1280,torch.int8,asm,6,1,195.4489,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1098.74,999.22,0.0 -304,20480,1536,5120,torch.int8,asm,6,1,271.7016,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1185.57,646.43,0.0 -304,20480,5120,1280,torch.int8,asm,6,1,245.7639,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1092.25,986.65,0.0 -304,24576,1536,5120,torch.int8,asm,6,1,328.3654,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1177.19,637.07,0.0 -304,24576,5120,1280,torch.int8,asm,6,1,290.5831,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1108.54,996.85,0.0 -304,30720,1536,5120,torch.int8,asm,6,1,402.8387,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1199.45,644.23,0.0 -304,30720,5120,1280,torch.int8,asm,6,1,361.4235,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1114.08,997.3,0.0 -304,32768,1536,5120,torch.int8,asm,6,1,406.391,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1268.23,679.89,0.0 -304,32768,5120,1280,torch.int8,asm,6,1,379.1231,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1132.87,1012.97,0.0 -304,40960,1536,5120,torch.int8,asm,6,1,525.9013,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1225.03,652.99,0.0 -304,40960,5120,1280,torch.int8,asm,6,1,475.7676,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1128.43,1005.56,0.0 256,32768,26624,16384,torch.float8_e4m3fn,cktile,10,0,10041.205,a8w8_bpreshuffle_cktile_0x0x8x4x1x0x0x0x0x1_256x256x128_1x4x1_16x16x128_default,2847.0,270.68,0.0 256,1,51200,5120,torch.float8_e4m3fn,flydsl,26,0,44.5486,flydsl_bpreshuflle_16x256x512_F8_F8_B16_2x0x0x0_default,11.77,5886.86,0.0 256,16,51200,5120,torch.float8_e4m3fn,flydsl,26,0,45.3929,flydsl_bpreshuflle_16x256x512_F8_F8_B16_2x0x0x0_default,184.8,5812.9,0.0 @@ -549,3 +517,35 @@ cu_num,M,N,K,q_dtype_w,libtype,kernelId,splitK,us,kernelName,tflops,bw,errRatio 256,8192,57344,8192,torch.float8_e4m3fn,flydsl,825,0,2916.8532,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2_default,2638.66,506.16,0.0 256,16384,57344,8192,torch.float8_e4m3fn,flydsl,825,0,5899.997,flydsl_bpreshuflle_256x128x128_F8_F8_B16_2x0x1x2_default,2609.01,420.85,0.0 256,32768,57344,8192,torch.float8_e4m3fn,ck,33,0,12218.2137,a8w8_bpreshuffle_256x256x256x128_16x16_16x16_8x32x1_8x32x1_1x32x1x8_8x8x1_2x1_intrawave_v3,2519.71,368.0,0.0 +304,64,1536,5120,torch.int8,asm,0,1,16.7759,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,60.0,500.04,0.0 +304,128,1536,5120,torch.int8,asm,0,1,17.9602,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,112.1,496.26,0.0 +304,256,1536,5120,torch.int8,asm,0,1,17.6791,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,227.76,563.46,0.0 +304,512,1536,5120,torch.int8,asm,1,1,18.7067,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_32x128E,430.49,644.62,0.0 +304,1024,1536,5120,torch.int8,asm,2,1,21.0125,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_48x128E,766.5,773.49,0.0 +304,1664,1536,5120,torch.int8,asm,4,1,27.6261,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_80x128E,947.38,778.1,0.0 +304,4096,1536,5120,torch.int8,asm,6,1,63.8361,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1009.22,648.83,0.0 +304,8192,1536,5120,torch.int8,asm,6,1,128.8567,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,999.94,581.83,0.0 +304,10240,1536,5120,torch.int8,asm,6,1,142.1441,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1133.08,645.47,0.0 +304,12288,1536,5120,torch.int8,asm,7,1,185.0056,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_128x128E,1044.69,586.62,0.0 +304,16384,1536,5120,torch.int8,asm,6,1,211.7054,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1217.25,671.13,0.0 +304,20480,1536,5120,torch.int8,asm,6,1,271.7016,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1185.57,646.43,0.0 +304,24576,1536,5120,torch.int8,asm,6,1,328.3654,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1177.19,637.07,0.0 +304,30720,1536,5120,torch.int8,asm,6,1,402.8387,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1199.45,644.23,0.0 +304,32768,1536,5120,torch.int8,asm,6,1,406.391,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1268.23,679.89,0.0 +304,40960,1536,5120,torch.int8,asm,6,1,525.9013,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1225.03,652.99,0.0 +304,64,5120,1280,torch.int8,asm,0,1,7.7424,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_16x128E,108.35,941.68,0.0 +304,128,5120,1280,torch.int8,asm,1,1,8.5797,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_32x128E,195.55,935.72,0.0 +304,256,5120,1280,torch.int8,asm,2,1,9.2104,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_48x128E,364.31,1031.74,0.0 +304,512,5120,1280,torch.int8,asm,4,1,11.2508,_ZN5aiter41I8gemm_bf16_perTokenI8_BpreShuffle_80x128E,596.48,1106.75,0.0 +304,1024,5120,1280,torch.int8,asm,6,1,18.2637,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,734.89,1004.73,0.0 +304,1664,5120,1280,torch.int8,asm,6,1,21.5241,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1013.3,1195.07,0.0 +304,4096,5120,1280,torch.int8,asm,6,1,58.5804,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,916.47,917.36,0.0 +304,8192,5120,1280,torch.int8,asm,6,1,105.657,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1016.25,955.22,0.0 +304,10240,5120,1280,torch.int8,asm,6,1,133.9571,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1001.95,929.54,0.0 +304,12288,5120,1280,torch.int8,asm,6,1,153.8919,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1046.59,962.44,0.0 +304,16384,5120,1280,torch.int8,asm,6,1,195.4489,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1098.74,999.22,0.0 +304,20480,5120,1280,torch.int8,asm,6,1,245.7639,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1092.25,986.65,0.0 +304,24576,5120,1280,torch.int8,asm,6,1,290.5831,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1108.54,996.85,0.0 +304,30720,5120,1280,torch.int8,asm,6,1,361.4235,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1114.08,997.3,0.0 +304,32768,5120,1280,torch.int8,asm,6,1,379.1231,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1132.87,1012.97,0.0 +304,40960,5120,1280,torch.int8,asm,6,1,475.7676,_ZN5aiter42I8gemm_bf16_perTokenI8_BpreShuffle_112x256E,1128.43,1005.56,0.0 diff --git a/op_tests/tuning_tests/README.md b/op_tests/tuning_tests/README.md new file mode 100644 index 0000000000..18e1946715 --- /dev/null +++ b/op_tests/tuning_tests/README.md @@ -0,0 +1,77 @@ +# Tuning Tests + +Minimal test suite for validating the aiter tuning infrastructure. + +## Structure + +| File | Level | GPU | What it tests | +|------|-------|-----|---------------| +| `test_csv_validation.py` | 0 | No | Tuned CSV integrity: duplicates, invalid times, errRatio, git conflicts | +| `test_tuner_infra.py` | 1 | No | `base_tuner` utilities: CSV I/O, merge, dedup, calculate, post_process topk | +| `test_mp_tuner_logic.py` | 1 | No | `mp_tuner` polling: timeout, AcceleratorError, KeyError, pool restart | +| `test_tune_pipeline.py` | 2 | Yes | End-to-end: run each tuner on small shapes, verify output CSV | +| `test_run_config.py` | 2 | Yes | Run --run_config on ALL existing tuned CSVs (configs + model_configs) | + +## Tuner family coverage + +| Family | Tuner script | Tuned CSVs validated | run_config | pipeline | +|--------|-------------|---------------------|------------|----------| +| `a8w8` | `csrc/ck_gemm_a8w8/gemm_a8w8_tune.py` | `a8w8_tuned_gemm.csv` | ✓ | ✓ (int8+fp8) | +| `a8w8_bpreshuffle` | `csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py` | `a8w8_bpreshuffle_tuned_gemm*.csv` | ✓ | ✓ (int8+fp8) | +| `a8w8_blockscale` | `csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py` | `a8w8_blockscale_tuned_gemm*.csv` | ✓ | ✓ + shape_grouped | +| `a8w8_blockscale_bpreshuffle` | same + `--preshuffle` | `a8w8_blockscale_bpreshuffle_tuned_gemm*.csv` | ✓ | — | +| `a4w4_blockscale` | `csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py` | `a4w4_blockscale_tuned_gemm*.csv` | ✓ | — | +| `batched_a8w8` | `csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py` | `a8w8_tuned_batched_gemm.csv` | ✓ | ✓ | +| `batched_bf16` | `csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py` | `bf16_tuned_batched_gemm.csv` | ✓ | ✓ + shape_grouped | +| `fmoe` | `csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py` | `tuned_fmoe.csv` + model_configs | ✓ | ✓ (bf16/fp8/int8/gelu) | +| `gradlib_bf16` | `gradlib/gradlib/gemm_tuner.py` | `bf16_tuned_gemm.csv` | — | ✓ (hipBLASLt/ASM/FlyDSL) | + +`test_run_config` auto-discovers tuned CSVs from both `aiter/configs/` and `aiter/configs/model_configs/`, merges them per family. + +## Running + +```bash +# Level 0+1 only (no GPU, <10s) +python3 -m unittest op_tests.tuning_tests.test_csv_validation \ + op_tests.tuning_tests.test_tuner_infra \ + op_tests.tuning_tests.test_mp_tuner_logic -v + +# Level 2: pipeline smoke (~10min) +python3 -m unittest op_tests.tuning_tests.test_tune_pipeline -v + +# Level 2: run_config validation (~20min, all tuned CSVs) +python3 -m unittest op_tests.tuning_tests.test_run_config -v + +# Everything +python3 -m unittest discover -s op_tests/tuning_tests -v +``` + +## Reproducing with custom config + +Use `TUNE_TEST_FAMILY` and `TUNE_TEST_CONFIG` to run `--run_config` on a specific tuned CSV: + +```bash +# Single config (relative path from aiter root) +TUNE_TEST_FAMILY=a8w8_blockscale \ +TUNE_TEST_CONFIG="aiter/configs/a8w8_blockscale_tuned_gemm.csv" \ +python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v + +# Merge multiple configs (pathsep separated, same as AITER_CONFIG_* env) +TUNE_TEST_FAMILY=a8w8_blockscale \ +TUNE_TEST_CONFIG="aiter/configs/a8w8_blockscale_tuned_gemm.csv:aiter/configs/model_configs/a8w8_blockscale_tuned_gemm_ds_v3.csv" \ +python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v + +# Reproduce fmoe issues +TUNE_TEST_FAMILY=fmoe \ +TUNE_TEST_CONFIG="aiter/configs/tuned_fmoe.csv" \ +python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v + +# blockscale with preshuffle +TUNE_TEST_FAMILY=a8w8_blockscale_bpreshuffle \ +TUNE_TEST_CONFIG="aiter/configs/a8w8_blockscale_bpreshuffle_tuned_gemm.csv" \ +python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v +``` + +Available families: `a8w8`, `a8w8_bpreshuffle`, `a8w8_blockscale`, `a8w8_blockscale_bpreshuffle`, `a4w4_blockscale`, `batched_a8w8`, `batched_bf16`, `fmoe` + +The test checks both **exit code** and **per-shape status** — shapes with `ERROR` (kernel crash) or `MISMATCH` (accuracy exceeded errRatio) will fail the test. diff --git a/op_tests/tuning_tests/__init__.py b/op_tests/tuning_tests/__init__.py new file mode 100644 index 0000000000..3c78ed50d9 --- /dev/null +++ b/op_tests/tuning_tests/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. diff --git a/op_tests/tuning_tests/test_csv_validation.py b/op_tests/tuning_tests/test_csv_validation.py new file mode 100644 index 0000000000..6d7ab7cc9d --- /dev/null +++ b/op_tests/tuning_tests/test_csv_validation.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +""" +Level 0: Static validation of tuned/untuned CSV files (no GPU, fast). + +Catches: duplicates, invalid times, high errRatio, git merge conflicts, +missing untuned files. +""" + +import os +import unittest +import pandas as pd + +AITER_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +CONFIGS_DIR = os.path.join(AITER_ROOT, "aiter", "configs") + + +class TestCSVValidation(unittest.TestCase): + + TUNED_CSVS = { + "a8w8": "a8w8_tuned_gemm.csv", + "a8w8_bpreshuffle": "a8w8_bpreshuffle_tuned_gemm.csv", + "a8w8_blockscale": "a8w8_blockscale_tuned_gemm.csv", + "a8w8_blockscale_bpreshuffle": "a8w8_blockscale_bpreshuffle_tuned_gemm.csv", + "a4w4_blockscale": "a4w4_blockscale_tuned_gemm.csv", + "a8w8_batched": "a8w8_tuned_batched_gemm.csv", + "bf16": "bf16_tuned_gemm.csv", + "bf16_batched": "bf16_tuned_batched_gemm.csv", + "fmoe": "tuned_fmoe.csv", + } + + def _load_csv(self, name): + path = os.path.join(CONFIGS_DIR, self.TUNED_CSVS[name]) + if not os.path.exists(path): + self.skipTest(f"{self.TUNED_CSVS[name]} not found") + df = pd.read_csv(path) + df.columns = df.columns.str.strip() + return df + + def _get_key_cols(self, df): + candidates = [ + "cu_num", + "M", + "N", + "K", + "B", + "token", + "model_dim", + "inter_dim", + "expert", + "topk", + ] + return [c for c in candidates if c in df.columns] + + def _check_no_duplicates(self, name, extra_keys=None): + df = self._load_csv(name) + keys = self._get_key_cols(df) + if extra_keys: + keys.extend([k for k in extra_keys if k in df.columns]) + dupes = df[df.duplicated(subset=keys, keep=False)] + self.assertEqual( + len(dupes), + 0, + f"{name}: {len(dupes)} duplicate rows (first 10):\n{dupes.head(10)}", + ) + + def test_a8w8_no_duplicates(self): + self._check_no_duplicates("a8w8", extra_keys=["q_dtype_w"]) + + def test_a8w8_blockscale_no_duplicates(self): + self._check_no_duplicates("a8w8_blockscale") + + def test_fmoe_no_duplicates(self): + self._check_no_duplicates( + "fmoe", + extra_keys=[ + "act_type", + "dtype", + "q_dtype_a", + "q_dtype_w", + "q_type", + "use_g1u1", + "doweight_stage1", + "_tag", + ], + ) + + def test_no_git_conflict_markers(self): + for name, fname in self.TUNED_CSVS.items(): + with self.subTest(csv=name): + path = os.path.join(CONFIGS_DIR, fname) + if not os.path.exists(path): + continue + with open(path, "r") as f: + content = f.read() + for marker in ["<<<<<<<", "=======", ">>>>>>>"]: + self.assertNotIn( + marker, content, f"{name}: git conflict marker '{marker}' found" + ) + + def test_no_invalid_times(self): + for name in self.TUNED_CSVS: + with self.subTest(csv=name): + df = self._load_csv(name) + if "us" not in df.columns: + continue + us = pd.to_numeric(df["us"], errors="coerce") + bad = df[us <= 0] + self.assertEqual( + len(bad), 0, f"{name}: {len(bad)} rows with us <= 0:\n{bad.head(5)}" + ) + + def test_error_ratios_within_bounds(self): + for name in self.TUNED_CSVS: + with self.subTest(csv=name): + df = self._load_csv(name) + if "errRatio" not in df.columns: + continue + err_col = df["errRatio"] + if err_col.dtype == object: + err_col = err_col.str.rstrip("%").astype(float) / 100.0 + else: + err_col = pd.to_numeric(err_col, errors="coerce") + high = df[err_col > 0.2] + self.assertEqual( + len(high), + 0, + f"{name}: {len(high)} rows with errRatio > 0.2:\n{high.head(5)}", + ) + + def test_untuned_csvs_exist(self): + untuned_files = [ + "a8w8_untuned_gemm.csv", + "a8w8_bpreshuffle_untuned_gemm.csv", + "a8w8_blockscale_untuned_gemm.csv", + "a8w8_untuned_batched_gemm.csv", + "bf16_untuned_batched_gemm.csv", + "untuned_fmoe.csv", + ] + for f in untuned_files: + with self.subTest(file=f): + path = os.path.join(CONFIGS_DIR, f) + self.assertTrue(os.path.exists(path), f"Missing: {f}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/op_tests/tuning_tests/test_mp_tuner_logic.py b/op_tests/tuning_tests/test_mp_tuner_logic.py new file mode 100644 index 0000000000..3cf0d4b84e --- /dev/null +++ b/op_tests/tuning_tests/test_mp_tuner_logic.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +""" +Unit tests for mp_tuner polling loop logic. + +Simulates async_result behavior without GPU/multiprocessing to verify: +1. consecutive_timeouts tracks correctly and resets on success +2. half-GPU threshold triggers break at the right time +3. KeyError tasks stay in remaining_tasks and get retried after root-cause restart + +Run: python3 -m unittest op_tests.test_mp_tuner_logic -v +""" + +import time +import unittest +from multiprocessing import TimeoutError as MPTimeoutError + + +class FakeAsyncResult: + """Simulates multiprocessing.AsyncResult for testing polling logic.""" + + def __init__(self, behavior, value=None): + """ + behavior: "ok", "timeout_pending", "timeout_expired", "keyerror", "accelerator" + value: return value for "ok" + """ + self.behavior = behavior + self.value = value + + def get(self, timeout=10): + if self.behavior == "ok": + return self.value + elif self.behavior in ("timeout_pending", "timeout_expired"): + raise MPTimeoutError("timeout") + elif self.behavior == "keyerror": + raise KeyError("12345") + elif self.behavior == "accelerator": + raise type("AcceleratorError", (Exception,), {})("GPU fault") + + +def simulate_poll_round(remaining_tasks, task_start_times, mp_num, timeout): + """ + Simulate one round of the mp_tuner polling loop. + Returns (completed, dummy_failed, pool_restart_needed, broke_early) + """ + completed_this_round = [] + dummy_failed_tasks = [] + consecutive_timeouts = 0 + half_gpu = max(1, (mp_num + 1) // 2) + pool_restart_needed = False + broke_early = False + + for k, async_result in remaining_tasks: + try: + if timeout is not None: + elapsed = time.time() - task_start_times[k] + remaining_time = timeout - elapsed + actual_timeout = max(1, min(10, remaining_time)) + else: + actual_timeout = 10 + + async_result.get(timeout=actual_timeout) + completed_this_round.append((k, async_result)) + consecutive_timeouts = 0 + + except MPTimeoutError: + if timeout is not None: + elapsed = time.time() - task_start_times[k] + if elapsed > timeout: + consecutive_timeouts += 1 + completed_this_round.append((k, async_result)) + pool_restart_needed = True + + if consecutive_timeouts >= half_gpu: + broke_early = True + break + else: + consecutive_timeouts = 0 + + except Exception as e: + error_type = type(e).__name__ + is_mapping_error = error_type == "KeyError" + + if is_mapping_error: + dummy_failed_tasks.append((k, "mapping error")) + elif error_type == "AcceleratorError": + completed_this_round.append((k, async_result)) + pool_restart_needed = True + broke_early = True + break + else: + completed_this_round.append((k, async_result)) + + return completed_this_round, dummy_failed_tasks, pool_restart_needed, broke_early + + +class TestConsecutiveTimeouts(unittest.TestCase): + + def test_single_timeout_no_break_8gpu(self): + """1 stuck GPU out of 8: should NOT break early.""" + mp_num = 8 + timeout = 0.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("timeout_expired")), + (1, FakeAsyncResult("ok", [("info", 1.0, 0.0)])), + (2, FakeAsyncResult("timeout_expired")), + (3, FakeAsyncResult("ok", [("info", 2.0, 0.0)])), + ] + start_times = {k: now - 10 for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertFalse(broke, "Should NOT break early with interleaved success") + self.assertTrue(restart, "Should still need restart (at least 1 timeout)") + self.assertEqual(len(completed), 4, "All tasks should be processed") + + def test_half_gpu_consecutive_triggers_break(self): + """4 consecutive timeouts with 8 GPUs (half=4): should break.""" + mp_num = 8 + timeout = 0.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("timeout_expired")), + (1, FakeAsyncResult("timeout_expired")), + (2, FakeAsyncResult("timeout_expired")), + (3, FakeAsyncResult("timeout_expired")), + (4, FakeAsyncResult("ok", [("info", 1.0, 0.0)])), + ] + start_times = {k: now - 10 for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertTrue(broke, "Should break after 4 consecutive timeouts (half of 8)") + self.assertTrue(restart) + self.assertEqual(len(completed), 4, "Task 4 not polled due to break") + + def test_success_resets_consecutive(self): + """Success in between resets counter: 3 timeouts, 1 ok, 3 timeouts != break for 8 GPU.""" + mp_num = 8 + timeout = 0.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("timeout_expired")), + (1, FakeAsyncResult("timeout_expired")), + (2, FakeAsyncResult("timeout_expired")), + (3, FakeAsyncResult("ok", [("info", 1.0, 0.0)])), + (4, FakeAsyncResult("timeout_expired")), + (5, FakeAsyncResult("timeout_expired")), + (6, FakeAsyncResult("timeout_expired")), + ] + start_times = {k: now - 10 for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertFalse(broke, "Should NOT break: success at task 3 resets counter") + self.assertTrue(restart, "Still need restart from timeouts") + self.assertEqual(len(completed), 7) + + def test_2gpu_half_is_1(self): + """2 GPUs: half=1, single consecutive timeout triggers break.""" + mp_num = 2 + timeout = 0.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("timeout_expired")), + (1, FakeAsyncResult("ok", [("info", 1.0, 0.0)])), + ] + start_times = {k: now - 10 for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertTrue(broke, "2 GPUs: half=1, first timeout should break") + self.assertEqual(len(completed), 1) + + def test_pending_timeout_resets_consecutive(self): + """Task not yet expired (still running) resets consecutive count.""" + mp_num = 4 + timeout = 100.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("timeout_expired")), + (1, FakeAsyncResult("timeout_pending")), + (2, FakeAsyncResult("timeout_expired")), + ] + start_times = { + 0: now - 200, + 1: now, + 2: now - 200, + } + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertFalse(broke, "Pending task resets consecutive, so no break") + self.assertEqual(len(completed), 2) + + +class TestKeyErrorHandling(unittest.TestCase): + + def test_keyerror_stays_in_remaining(self): + """KeyError tasks should NOT be in completed_this_round.""" + mp_num = 4 + timeout = 0.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("keyerror")), + (1, FakeAsyncResult("ok", [("info", 1.0, 0.0)])), + ] + start_times = {k: now - 10 for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + completed_ids = {k for k, _ in completed} + self.assertNotIn(0, completed_ids, "KeyError task should NOT be completed") + self.assertIn(1, completed_ids, "OK task should be completed") + self.assertEqual(len(dummy), 1, "KeyError task should be in dummy_failed") + self.assertFalse(restart, "KeyError alone should NOT trigger restart") + + def test_keyerror_with_timeout_gets_resubmitted(self): + """KeyError tasks wait for root-cause timeout to trigger restart.""" + mp_num = 2 + timeout = 0.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("keyerror")), + (1, FakeAsyncResult("timeout_expired")), + ] + start_times = {k: now - 10 for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + completed_ids = {k for k, _ in completed} + self.assertNotIn(0, completed_ids, "KeyError task stays for resubmit") + self.assertIn(1, completed_ids, "Root-cause timeout is completed") + self.assertTrue(restart, "Timeout should trigger restart") + + new_remaining = [(k, ar) for k, ar in remaining if k not in completed_ids] + self.assertEqual(len(new_remaining), 1) + self.assertEqual( + new_remaining[0][0], 0, "Only KeyError task remains for resubmit" + ) + + def test_keyerror_no_restart_without_root_cause(self): + """If only KeyError tasks remain, no restart, they keep polling.""" + mp_num = 4 + timeout = 100.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("keyerror")), + (1, FakeAsyncResult("keyerror")), + ] + start_times = {k: now for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertFalse(restart, "No restart without root cause") + self.assertEqual(len(completed), 0, "Nothing completed") + self.assertEqual(len(dummy), 2, "Both are mapping errors") + + +class TestAcceleratorError(unittest.TestCase): + + def test_accelerator_breaks_immediately(self): + """AcceleratorError should break immediately and trigger restart.""" + mp_num = 4 + timeout = 100.0 + now = time.time() + remaining = [ + (0, FakeAsyncResult("ok", [("info", 1.0, 0.0)])), + (1, FakeAsyncResult("accelerator")), + (2, FakeAsyncResult("ok", [("info", 2.0, 0.0)])), + ] + start_times = {k: now for k, _ in remaining} + + completed, dummy, restart, broke = simulate_poll_round( + remaining, start_times, mp_num, timeout + ) + self.assertTrue(broke, "AcceleratorError should break") + self.assertTrue(restart, "AcceleratorError should trigger restart") + completed_ids = {k for k, _ in completed} + self.assertIn(0, completed_ids) + self.assertIn(1, completed_ids) + self.assertNotIn(2, completed_ids, "Task 2 not reached due to break") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/op_tests/tuning_tests/test_run_config.py b/op_tests/tuning_tests/test_run_config.py new file mode 100644 index 0000000000..cdbfe2c572 --- /dev/null +++ b/op_tests/tuning_tests/test_run_config.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +""" +Level 2: Run all existing tuned configs through --run_config to verify +the production operator works with every shape in the config CSVs. + +For each tuner family, discovers all tuned CSVs (default + model_configs), +merges them via pathsep, and runs --run_config to benchmark every shape. +Any shape that errors (us=-1 or exception) is reported as a test failure. + +Run: + python3 -m unittest op_tests.tuning_tests.test_run_config -v +""" + +import os +import sys +import subprocess +import unittest + +AITER_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +CONFIGS_DIR = os.path.join(AITER_ROOT, "aiter", "configs") +MODEL_CONFIGS_DIR = os.path.join(CONFIGS_DIR, "model_configs") + +# Override: specify a tuner family and config CSV to test directly. +# TUNE_TEST_FAMILY=a8w8_blockscale TUNE_TEST_CONFIG=/path/to/tuned.csv \ +# python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v +# +# TUNE_TEST_CONFIG supports pathsep (:) for merging multiple CSVs, e.g.: +# TUNE_TEST_CONFIG="configs/a8w8_blockscale_tuned_gemm.csv:model_configs/xxx.csv" +TUNE_TEST_FAMILY = os.environ.get("TUNE_TEST_FAMILY") +TUNE_TEST_CONFIG = os.environ.get("TUNE_TEST_CONFIG") + + +def _gpu_available(): + try: + import torch + + return torch.cuda.is_available() and torch.cuda.device_count() > 0 + except ImportError: + return False + + +def _find_tuned_csvs(pattern): + """Find all tuned CSVs matching pattern in configs/ and model_configs/.""" + found = [] + for d in (CONFIGS_DIR, MODEL_CONFIGS_DIR): + if not os.path.isdir(d): + continue + for f in sorted(os.listdir(d)): + if ( + pattern in f + and "tuned" in f + and "untuned" not in f + and f.endswith(".csv") + ): + found.append(os.path.join(d, f)) + return found + + +def _merge_config_paths(csv_list): + """Merge multiple CSV paths with os.pathsep (like AITER_CONFIG_* env).""" + return os.pathsep.join(csv_list) + + +def _run_config(script, config_csv, timeout=600, extra_args=None): + """Run tuner with --run_config and return result.""" + cmd = [ + sys.executable, + os.path.join(AITER_ROOT, script), + "--run_config", + config_csv, + "--warmup", + "2", + "--iters", + "5", + ] + if extra_args: + cmd.extend(extra_args) + env = os.environ.copy() + script_dir = os.path.dirname(os.path.join(AITER_ROOT, script)) + env["PYTHONPATH"] = script_dir + ":" + env.get("PYTHONPATH", "") + try: + return subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=AITER_ROOT, + env=env, + ) + except subprocess.TimeoutExpired as e: + raise AssertionError( + f"run_config timed out after {timeout}s\n" + f" cmd: {' '.join(cmd)}\n" + f" stdout (last 500): {(e.stdout or b'')[-500:]}\n" + f" stderr (last 500): {(e.stderr or b'')[-500:]}" + ) from None + + +TUNER_FAMILIES = { + "a8w8": { + "script": "csrc/ck_gemm_a8w8/gemm_a8w8_tune.py", + "csv_pattern": "a8w8_tuned_gemm", + "exclude_patterns": ["bpreshuffle", "blockscale", "batched"], + }, + "a8w8_bpreshuffle": { + "script": "csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py", + "csv_pattern": "a8w8_bpreshuffle_tuned_gemm", + "exclude_patterns": ["blockscale"], + }, + "a8w8_blockscale": { + "script": "csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py", + "csv_pattern": "a8w8_blockscale_tuned_gemm", + "exclude_patterns": ["bpreshuffle", "fmoe"], + }, + "a8w8_blockscale_bpreshuffle": { + "script": "csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py", + "csv_pattern": "a8w8_blockscale_bpreshuffle_tuned_gemm", + "exclude_patterns": ["fmoe"], + "extra_args": ["--preshuffle"], + }, + "a4w4_blockscale": { + "script": "csrc/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.py", + "csv_pattern": "a4w4_blockscale_tuned_gemm", + "exclude_patterns": [], + }, + "batched_a8w8": { + "script": "csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py", + "csv_pattern": "a8w8_tuned_batched_gemm", + "exclude_patterns": [], + }, + "batched_bf16": { + "script": "csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py", + "csv_pattern": "bf16_tuned_batched_gemm", + "exclude_patterns": [], + }, + "fmoe": { + "script": "csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py", + "csv_pattern": "tuned_fmoe", + "exclude_patterns": ["untuned", "profile"], + "timeout": 1200, + }, +} + + +@unittest.skipUnless(_gpu_available(), "No GPU available") +class TestRunConfig(unittest.TestCase): + """Run --run_config on all existing tuned CSVs to verify production ops.""" + + def _test_family(self, name): + cfg = TUNER_FAMILIES[name] + pattern = cfg["csv_pattern"] + excludes = cfg.get("exclude_patterns", []) + timeout = cfg.get("timeout", 600) + + csvs = _find_tuned_csvs(pattern) + csvs = [ + c for c in csvs if not any(ex in os.path.basename(c) for ex in excludes) + ] + + if not csvs: + self.skipTest(f"No tuned CSVs found for {name} (pattern={pattern})") + + merged = _merge_config_paths(csvs) + csv_names = [os.path.basename(c) for c in csvs] + extra_args = cfg.get("extra_args", None) + + result = _run_config( + cfg["script"], merged, timeout=timeout, extra_args=extra_args + ) + + output = result.stdout + result.stderr + if result.returncode != 0: + print(f"\n=== {name} run_config FAILED ===") + print(f"CSVs: {csv_names}") + print(f"STDOUT (last 2000):\n{result.stdout[-2000:]}") + print(f"STDERR (last 2000):\n{result.stderr[-2000:]}") + + self.assertEqual( + result.returncode, 0, f"{name} run_config failed (csvs={csv_names})" + ) + + # Parse benchmark result lines from the table output. + # Status column shows: OK, ERROR, MISMATCH + lines = output.split("\n") + error_shapes = [] + mismatch_shapes = [] + for line in lines: + stripped = line.strip() + if "| " not in stripped: + continue + if stripped.endswith("ERROR"): + error_shapes.append(stripped) + elif stripped.endswith("MISMATCH"): + mismatch_shapes.append(stripped) + + failures = [] + if error_shapes: + failures.append( + f"Errors ({len(error_shapes)} shapes):\n" + "\n".join(error_shapes[:20]) + ) + if mismatch_shapes: + failures.append( + f"Accuracy mismatches ({len(mismatch_shapes)} shapes):\n" + + "\n".join(mismatch_shapes[:20]) + ) + + self.assertEqual( + len(failures), + 0, + f"{name} run_config issues (csvs={csv_names}):\n" + "\n".join(failures), + ) + + def test_a8w8(self): + self._test_family("a8w8") + + def test_a8w8_bpreshuffle(self): + self._test_family("a8w8_bpreshuffle") + + def test_a8w8_blockscale(self): + self._test_family("a8w8_blockscale") + + def test_a8w8_blockscale_bpreshuffle(self): + self._test_family("a8w8_blockscale_bpreshuffle") + + def test_a4w4_blockscale(self): + self._test_family("a4w4_blockscale") + + def test_batched_a8w8(self): + self._test_family("batched_a8w8") + + def test_batched_bf16(self): + self._test_family("batched_bf16") + + def test_fmoe(self): + self._test_family("fmoe") + + +@unittest.skipUnless(_gpu_available(), "No GPU available") +@unittest.skipUnless( + TUNE_TEST_FAMILY and TUNE_TEST_CONFIG, + "Set TUNE_TEST_FAMILY and TUNE_TEST_CONFIG to run", +) +class TestRunConfigCustom(unittest.TestCase): + """Run --run_config with user-specified family and config CSV. + + Usage: + TUNE_TEST_FAMILY=a8w8_blockscale \ + TUNE_TEST_CONFIG="aiter/configs/a8w8_blockscale_tuned_gemm.csv" \ + python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v + + # Multiple configs (merged): + TUNE_TEST_FAMILY=a8w8_blockscale \ + TUNE_TEST_CONFIG="aiter/configs/a8w8_blockscale_tuned_gemm.csv:aiter/configs/model_configs/a8w8_blockscale_tuned_gemm_ds_v3.csv" \ + python3 -m unittest op_tests.tuning_tests.test_run_config.TestRunConfigCustom -v + """ + + def test_custom(self): + family = TUNE_TEST_FAMILY + config = TUNE_TEST_CONFIG + self.assertIn( + family, + TUNER_FAMILIES, + f"Unknown family '{family}'. Available: {list(TUNER_FAMILIES.keys())}", + ) + cfg = TUNER_FAMILIES[family] + timeout = cfg.get("timeout", 600) + + # Resolve relative paths against AITER_ROOT + resolved = [] + for p in config.split(os.pathsep): + p = p.strip() + if not p: + continue + if not os.path.isabs(p): + p = os.path.join(AITER_ROOT, p) + self.assertTrue(os.path.exists(p), f"Config not found: {p}") + resolved.append(p) + merged = os.pathsep.join(resolved) + + print( + f"\nRunning {family} --run_config with: {[os.path.basename(p) for p in resolved]}" + ) + result = _run_config(cfg["script"], merged, timeout=timeout) + + output = result.stdout + result.stderr + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout[-3000:]}") + print(f"STDERR:\n{result.stderr[-3000:]}") + self.assertEqual(result.returncode, 0, f"{family} run_config failed") + + lines = output.split("\n") + error_shapes = [] + mismatch_shapes = [] + for line in lines: + stripped = line.strip() + if "| " not in stripped: + continue + if stripped.endswith("ERROR"): + error_shapes.append(stripped) + elif stripped.endswith("MISMATCH"): + mismatch_shapes.append(stripped) + + failures = [] + if error_shapes: + failures.append( + f"Errors ({len(error_shapes)} shapes):\n" + "\n".join(error_shapes[:20]) + ) + if mismatch_shapes: + failures.append( + f"Accuracy mismatches ({len(mismatch_shapes)} shapes):\n" + + "\n".join(mismatch_shapes[:20]) + ) + + self.assertEqual( + len(failures), 0, f"{family} run_config issues:\n" + "\n".join(failures) + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/op_tests/tuning_tests/test_tune_pipeline.py b/op_tests/tuning_tests/test_tune_pipeline.py new file mode 100644 index 0000000000..0b8117a54c --- /dev/null +++ b/op_tests/tuning_tests/test_tune_pipeline.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +""" +Level 2: End-to-end tuning pipeline smoke tests (GPU required). + +Runs each tuner on small shapes, verifies CSV output, and tests +--shape_grouped with profile row count comparison. +""" + +import os +import sys +import csv +import tempfile +import subprocess +import unittest +import pandas as pd + +AITER_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) + + +def _gpu_available(): + try: + import torch + + return torch.cuda.is_available() and torch.cuda.device_count() > 0 + except ImportError: + return False + + +def _get_platform_dtypes(): + """Return (fp8_str, quant_type_str) based on GPU arch.""" + try: + from aiter.jit.utils.chip_info import get_gfx + + gfx = get_gfx() + except Exception: + gfx = "gfx942" + if gfx in ("gfx950", "gfx1250"): + return "torch.float8_e4m3fn", "QuantType.per_1x128" + else: + return "torch.float8_e4m3fnuz", "QuantType.per_Token" + + +def _write_csv(path, header, rows): + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(header) + for row in rows: + writer.writerow(row) + + +def _run_tuner(script, untuned, tuned, extra_args=None, timeout=300): + cmd = [ + sys.executable, + os.path.join(AITER_ROOT, script), + "-i", + untuned, + "-o", + tuned, + "--mp", + "1", + "--warmup", + "2", + "--iters", + "5", + ] + if extra_args: + cmd.extend(extra_args) + env = os.environ.copy() + script_dir = os.path.dirname(os.path.join(AITER_ROOT, script)) + env["PYTHONPATH"] = script_dir + ":" + env.get("PYTHONPATH", "") + try: + return subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=AITER_ROOT, + env=env, + ) + except subprocess.TimeoutExpired as e: + raise AssertionError( + f"Tuner timed out after {timeout}s (likely GPU hang or infinite loop)\n" + f" cmd: {' '.join(cmd)}\n" + f" stdout (last 500): {(e.stdout or b'')[-500:]}\n" + f" stderr (last 500): {(e.stderr or b'')[-500:]}" + ) from None + + +@unittest.skipUnless(_gpu_available(), "No GPU available") +class TestTunePipeline(unittest.TestCase): + """Smoke test: run each tuner on 1 small shape, verify CSV output.""" + + @classmethod + def setUpClass(cls): + fp8, qtype = _get_platform_dtypes() + cls.TUNERS = { + "a8w8": { + "script": "csrc/ck_gemm_a8w8/gemm_a8w8_tune.py", + "header": ["M", "N", "K", "q_dtype_w"], + "shapes": [ + (1, 1024, 512, "torch.int8"), + (1, 1024, 512, fp8), + ], + "keys": ["cu_num", "M", "N", "K", "q_dtype_w"], + }, + "a8w8_blockscale": { + "script": "csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py", + "header": ["M", "N", "K"], + "shapes": [(1, 1024, 512)], + "keys": ["cu_num", "M", "N", "K"], + }, + "a8w8_bpreshuffle": { + "script": "csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py", + "header": ["M", "N", "K", "q_dtype_w"], + "shapes": [ + (1, 1024, 512, "torch.int8"), + (1, 1024, 512, fp8), + ], + "keys": ["cu_num", "M", "N", "K", "q_dtype_w"], + }, + "batched_a8w8": { + "script": "csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py", + "header": ["B", "M", "N", "K"], + "shapes": [(2, 1, 512, 256)], + "keys": ["cu_num", "B", "M", "N", "K"], + }, + "batched_bf16": { + "script": "csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py", + "header": ["B", "M", "N", "K"], + "shapes": [(2, 1, 512, 256)], + "keys": ["cu_num", "B", "M", "N", "K"], + }, + "fmoe": { + "script": "csrc/ck_gemm_moe_2stages_codegen/gemm_moe_tune.py", + "header": [ + "token", + "model_dim", + "inter_dim", + "expert", + "topk", + "act_type", + "dtype", + "q_dtype_a", + "q_dtype_w", + "q_type", + "use_g1u1", + "doweight_stage1", + ], + "shapes": [ + # bf16 (no quant) + ( + 512, + 6144, + 4096, + 8, + 2, + "ActivationType.Silu", + "torch.bfloat16", + "torch.bfloat16", + "torch.bfloat16", + "QuantType.No", + 1, + 0, + ), + # fp8 per-token (platform-aware) + ( + 16, + 7168, + 256, + 256, + 8, + "ActivationType.Silu", + "torch.bfloat16", + fp8, + fp8, + qtype, + 1, + 0, + ), + # int8 per-tensor + ( + 512, + 6144, + 4096, + 8, + 2, + "ActivationType.Silu", + "torch.bfloat16", + "torch.int8", + "torch.int8", + "QuantType.per_Tensor", + 1, + 0, + ), + # Gelu activation + doweight_stage1 + ( + 4, + 2304, + 1536, + 8, + 2, + "ActivationType.Gelu", + "torch.bfloat16", + fp8, + fp8, + qtype, + 1, + 1, + ), + ], + "keys": [ + "cu_num", + "token", + "model_dim", + "inter_dim", + "expert", + "topk", + "act_type", + "dtype", + "q_dtype_a", + "q_dtype_w", + "q_type", + "use_g1u1", + "doweight_stage1", + ], + "timeout": 600, + }, + "gradlib_bf16": { + "script": "gradlib/gradlib/gemm_tuner.py", + "header": [ + "M", + "N", + "K", + "bias", + "dtype", + "outdtype", + "scaleAB", + "bpreshuffle", + ], + "shapes": [ + # decode (M=1): hipBLASLt/ASM typically wins + ( + 1, + 1024, + 512, + "False", + "torch.bfloat16", + "torch.float32", + "False", + "False", + ), + # prefill (large M): FlyDSL has a chance to win + ( + 512, + 5120, + 1280, + "False", + "torch.bfloat16", + "torch.bfloat16", + "False", + "False", + ), + ], + "keys": ["M", "N", "K"], + "timeout": 600, + }, + } + + def _run_one(self, name): + cfg = self.TUNERS[name] + timeout = cfg.get("timeout", 300) + with tempfile.TemporaryDirectory() as tmp: + untuned = os.path.join(tmp, "untuned.csv") + tuned = os.path.join(tmp, "tuned.csv") + _write_csv(untuned, cfg["header"], cfg["shapes"]) + + result = _run_tuner(cfg["script"], untuned, tuned, timeout=timeout) + if result.returncode != 0: + print(f"\n=== {name} STDOUT ===\n{result.stdout[-2000:]}") + print(f"\n=== {name} STDERR ===\n{result.stderr[-2000:]}") + self.assertEqual( + result.returncode, + 0, + f"{name} tuner exited with code {result.returncode}", + ) + self.assertTrue(os.path.exists(tuned), f"{name}: tuned CSV not created") + + df = pd.read_csv(tuned) + df.columns = df.columns.str.strip() + self.assertGreaterEqual( + len(df), + len(cfg["shapes"]), + f"{name}: expected >= {len(cfg['shapes'])} rows", + ) + for key in cfg["keys"]: + self.assertIn(key, df.columns, f"{name}: missing column {key}") + for _, row in df.iterrows(): + us = float(row.get("us", -1)) + self.assertNotEqual(us, 0, f"{name}: us == 0 for {dict(row)}") + + def test_a8w8(self): + self._run_one("a8w8") + + def test_a8w8_blockscale(self): + self._run_one("a8w8_blockscale") + + def test_a8w8_bpreshuffle(self): + self._run_one("a8w8_bpreshuffle") + + def test_batched_a8w8(self): + self._run_one("batched_a8w8") + + def test_batched_bf16(self): + self._run_one("batched_bf16") + + def test_fmoe(self): + self._run_one("fmoe") + + def test_gradlib_bf16(self): + """gradlib spawns an internal subprocess; use /tmp paths that persist.""" + cfg = self.TUNERS["gradlib_bf16"] + timeout = cfg.get("timeout", 300) + untuned = "/tmp/_test_gradlib_untuned.csv" + tuned = "/tmp/_test_gradlib_tuned.csv" + try: + _write_csv(untuned, cfg["header"], cfg["shapes"]) + if os.path.exists(tuned): + os.remove(tuned) + result = _run_tuner(cfg["script"], untuned, tuned, timeout=timeout) + if result.returncode != 0: + print(f"\n=== gradlib STDOUT ===\n{result.stdout[-2000:]}") + print(f"\n=== gradlib STDERR ===\n{result.stderr[-2000:]}") + self.assertEqual(result.returncode, 0, "gradlib tuner failed") + self.assertTrue(os.path.exists(tuned), "gradlib: tuned CSV not created") + finally: + for f in (untuned, tuned): + if os.path.exists(f): + os.remove(f) + + +@unittest.skipUnless(_gpu_available(), "No GPU available") +class TestShapeGrouped(unittest.TestCase): + """Test --shape_grouped: same profile count, correct tuned row count.""" + + CONFIGS = { + "a8w8_blockscale": { + "script": "csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py", + "header": ["M", "N", "K"], + "shapes": [(16, 1536, 7168), (16, 576, 7168), (16, 7168, 256)], + "keys": ["cu_num", "M", "N", "K"], + }, + "batched_bf16": { + "script": "csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py", + "header": ["B", "M", "N", "K"], + "shapes": [(2, 1, 512, 256), (4, 16, 1024, 512)], + "keys": ["cu_num", "B", "M", "N", "K"], + }, + } + + def _run_grouped_vs_ref(self, name): + cfg = self.CONFIGS[name] + num_shapes = len(cfg["shapes"]) + with tempfile.TemporaryDirectory() as tmp: + untuned = os.path.join(tmp, "untuned.csv") + tuned_ref = os.path.join(tmp, "tuned_ref.csv") + profile_ref = os.path.join(tmp, "profile_ref.csv") + tuned = os.path.join(tmp, "tuned.csv") + profile = os.path.join(tmp, "profile.csv") + _write_csv(untuned, cfg["header"], cfg["shapes"]) + + r_ref = _run_tuner( + cfg["script"], untuned, tuned_ref, extra_args=["-o2", profile_ref] + ) + self.assertEqual( + r_ref.returncode, 0, f"{name} ref tuner failed:\n{r_ref.stderr[-1000:]}" + ) + + r = _run_tuner( + cfg["script"], + untuned, + tuned, + extra_args=["--shape_grouped", "-o2", profile], + ) + if r.returncode != 0: + print(f"\n=== {name} grouped STDERR ===\n{r.stderr[-2000:]}") + self.assertEqual(r.returncode, 0, f"{name} grouped tuner failed") + + df = pd.read_csv(tuned) + df.columns = df.columns.str.strip() + self.assertEqual( + len(df), + num_shapes, + f"{name}: expected {num_shapes} tuned rows, got {len(df)}", + ) + + if os.path.exists(profile) and os.path.exists(profile_ref): + prof = pd.read_csv(profile) + prof_ref = pd.read_csv(profile_ref) + self.assertEqual( + len(prof), + len(prof_ref), + f"{name}: profile rows grouped={len(prof)} vs ref={len(prof_ref)}", + ) + + def test_a8w8_blockscale(self): + self._run_grouped_vs_ref("a8w8_blockscale") + + def test_batched_bf16(self): + self._run_grouped_vs_ref("batched_bf16") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/op_tests/tuning_tests/test_tuner_infra.py b/op_tests/tuning_tests/test_tuner_infra.py new file mode 100644 index 0000000000..f9d2ce15d3 --- /dev/null +++ b/op_tests/tuning_tests/test_tuner_infra.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +""" +Level 1: Unit tests for base_tuner infrastructure (no GPU required). + +Covers: CSV I/O, merge/dedup, calculate, post_process (topk selection). +""" + +import os +import tempfile +import unittest +import argparse +import pandas as pd + + +class _StubTuner: + """Lazy-init helper — avoids importing aiter at module level.""" + + _cls = None + + @classmethod + def get(cls): + if cls._cls is None: + from aiter.utility.base_tuner import GemmCommonTuner + + class Stub(GemmCommonTuner): + def _setup_specific_arguments(self): + pass + + def tune(self, *a, **kw): + pass + + def getKernelName(self, kid): + return f"k{kid}" + + cls._cls = Stub + return cls._cls("test") + + +class TestReadCSV(unittest.TestCase): + + def test_strips_whitespace(self): + from aiter.utility.base_tuner import _read_csv + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write(" M , N , K \n 1 , 2 , 3 \n 4 , 5 , 6 \n") + path = f.name + try: + df = _read_csv(path) + self.assertEqual(list(df.columns), ["M", "N", "K"]) + self.assertEqual(len(df), 2) + finally: + os.unlink(path) + + def test_drops_unnamed_columns(self): + from aiter.utility.base_tuner import _read_csv + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write("M,N,K,Unnamed: 0\n1,2,3,\n") + path = f.name + try: + df = _read_csv(path) + self.assertNotIn("Unnamed: 0", df.columns) + finally: + os.unlink(path) + + def test_drops_all_na_rows(self): + from aiter.utility.base_tuner import _read_csv + + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write("M,N,K\n1,2,3\n,,\n4,5,6\n") + path = f.name + try: + df = _read_csv(path) + self.assertEqual(len(df), 2) + finally: + os.unlink(path) + + +class TestUpdateTunedf(unittest.TestCase): + + def test_merges_existing_key(self): + tuner = _StubTuner.get() + old = pd.DataFrame( + { + "cu_num": [304], + "M": [1], + "N": [1024], + "K": [512], + "kernelId": [0], + "splitK": [0], + "us": [100.0], + "kernelName": ["old"], + "tflops": [1.0], + "bw": [1.0], + "errRatio": [0.01], + } + ) + new = pd.DataFrame( + { + "cu_num": [304], + "M": [1], + "N": [1024], + "K": [512], + "kernelId": [1], + "splitK": [0], + "us": [50.0], + "kernelName": ["new"], + "tflops": [2.0], + "bw": [2.0], + "errRatio": [0.005], + } + ) + merged = tuner.update_tunedf(old, new) + self.assertEqual(len(merged), 1) + self.assertEqual(float(merged.iloc[0]["us"]), 50.0) + + def test_appends_new_key(self): + tuner = _StubTuner.get() + old = pd.DataFrame( + { + "cu_num": [304], + "M": [1], + "N": [1024], + "K": [512], + "kernelId": [0], + "splitK": [0], + "us": [100.0], + "kernelName": ["k0"], + "tflops": [1.0], + "bw": [1.0], + "errRatio": [0.01], + } + ) + new = pd.DataFrame( + { + "cu_num": [304], + "M": [32], + "N": [2048], + "K": [1024], + "kernelId": [2], + "splitK": [0], + "us": [200.0], + "kernelName": ["k2"], + "tflops": [3.0], + "bw": [3.0], + "errRatio": [0.02], + } + ) + merged = tuner.update_tunedf(old, new) + self.assertEqual(len(merged), 2) + + +class TestSortResults(unittest.TestCase): + + def test_deduplicates(self): + tuner = _StubTuner.get() + df = pd.DataFrame( + { + "cu_num": [304, 304], + "M": [1, 1], + "N": [1024, 1024], + "K": [512, 512], + "kernelId": [0, 1], + "splitK": [0, 0], + "us": [100.0, 50.0], + "kernelName": ["k0", "k1"], + "tflops": [1.0, 2.0], + "bw": [1.0, 2.0], + "errRatio": [0.01, 0.005], + } + ) + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + path = f.name + try: + tuner.sortResults(path, True, tuner.sort_keys) + result = pd.read_csv(path) + self.assertEqual(len(result), 1) + finally: + os.unlink(path) + + +class TestCalculate(unittest.TestCase): + + def test_tflops_bw_positive(self): + tuner = _StubTuner.get() + keys = (304, 128, 4096, 1024) + result = ((keys,), 100.0, 0.01) + tflops, bw = tuner.calculate(result) + self.assertGreater(tflops, 0) + self.assertGreater(bw, 0) + + def test_returns_zero_on_invalid_time(self): + tuner = _StubTuner.get() + keys = (304, 128, 4096, 1024) + result = ((keys,), -1, 1.0) + tflops, bw = tuner.calculate(result) + self.assertEqual(tflops, 0) + self.assertEqual(bw, 0) + + +class TestPostProcess(unittest.TestCase): + """Tests for post_process — especially the topk selection logic.""" + + def _make_args(self, err_ratio=0.05, profile_file=""): + args = argparse.Namespace() + args.errRatio = err_ratio + args.profile_file = profile_file + args.verbose = False + return args + + def _make_result(self, shape_key, kernel_id, split_k, us, err=0.0): + info = (shape_key, kernel_id, split_k, f"kernel_{kernel_id}") + return (info, us, err) + + def test_picks_fastest_per_shape(self): + """Basic: 2 shapes, 3 kernels each, picks fastest.""" + tuner = _StubTuner.get() + args = self._make_args() + rets = [ + self._make_result((304, 1, 1024, 512), 0, 0, 10.0), + self._make_result((304, 1, 1024, 512), 1, 0, 5.0), + self._make_result((304, 1, 1024, 512), 2, 0, 8.0), + self._make_result((304, 32, 2048, 1024), 0, 0, 20.0), + self._make_result((304, 32, 2048, 1024), 1, 0, 12.0), + self._make_result((304, 32, 2048, 1024), 2, 0, 15.0), + ] + resultdf = tuner.post_process(rets, args, topk=1) + self.assertEqual(len(resultdf), 2) + times = sorted(resultdf["us"].tolist()) + self.assertEqual(times, [5.0, 12.0]) + + def test_filters_by_err_ratio(self): + """Kernels exceeding errRatio should be skipped.""" + tuner = _StubTuner.get() + args = self._make_args(err_ratio=0.05) + rets = [ + self._make_result((304, 1, 1024, 512), 0, 0, 5.0, err=0.1), + self._make_result((304, 1, 1024, 512), 1, 0, 10.0, err=0.01), + self._make_result((304, 1, 1024, 512), 2, 0, 8.0, err=0.02), + ] + resultdf = tuner.post_process(rets, args, topk=1) + self.assertEqual(len(resultdf), 1) + self.assertEqual(float(resultdf.iloc[0]["us"]), 8.0) + + def test_filters_invalid_and_inf_times(self): + """us=-1 (error) and us=inf (timeout) should be excluded.""" + tuner = _StubTuner.get() + args = self._make_args() + rets = [ + self._make_result((304, 1, 1024, 512), 0, 0, -1, err=0.0), + self._make_result((304, 1, 1024, 512), 1, 0, float("inf"), err=0.0), + self._make_result((304, 1, 1024, 512), 2, 0, 7.0, err=0.01), + ] + resultdf = tuner.post_process(rets, args, topk=1) + self.assertEqual(len(resultdf), 1) + self.assertEqual(float(resultdf.iloc[0]["us"]), 7.0) + + def test_topk_not_leak_across_shapes(self): + """BUG REGRESSION: topk must not leak between shapes. + + If shape A has 0 valid candidates (all fail errRatio), topk should NOT + be permanently set to 0, causing shape B to also return 0 results. + """ + tuner = _StubTuner.get() + args = self._make_args(err_ratio=0.05) + rets = [ + # Shape A: all kernels fail errRatio → 0 valid candidates + self._make_result((304, 1, 1024, 512), 0, 0, 5.0, err=0.9), + self._make_result((304, 1, 1024, 512), 1, 0, 3.0, err=0.8), + # Shape B: has valid candidates → should still get results + self._make_result((304, 32, 2048, 1024), 0, 0, 10.0, err=0.01), + self._make_result((304, 32, 2048, 1024), 1, 0, 8.0, err=0.02), + ] + resultdf = tuner.post_process(rets, args, topk=1) + shape_b_rows = resultdf[resultdf["M"] == 32] + self.assertGreaterEqual( + len(shape_b_rows), 1, "Shape B should have results even if Shape A has none" + ) + self.assertEqual(float(shape_b_rows.iloc[0]["us"]), 8.0) + + def test_topk_not_shrink_across_shapes(self): + """topk should not shrink when one shape has fewer valid candidates than topk.""" + tuner = _StubTuner.get() + args = self._make_args(err_ratio=0.05) + rets = [ + # Shape A: only 1 valid candidate (topk=2 requested but only 1 available) + self._make_result((304, 1, 1024, 512), 0, 0, 5.0, err=0.01), + self._make_result((304, 1, 1024, 512), 1, 0, 3.0, err=0.9), + # Shape B: 3 valid candidates → should get topk=2 + self._make_result((304, 32, 2048, 1024), 0, 0, 10.0, err=0.01), + self._make_result((304, 32, 2048, 1024), 1, 0, 8.0, err=0.02), + self._make_result((304, 32, 2048, 1024), 2, 0, 12.0, err=0.01), + ] + resultdf = tuner.post_process(rets, args, topk=2) + shape_b_rows = resultdf[resultdf["M"] == 32] + self.assertGreaterEqual( + len(shape_b_rows), + 2, + "Shape B should get topk=2 results even though Shape A only had 1", + ) + + def test_all_shapes_fail(self): + """When all kernels for all shapes fail, should still produce fallback entries.""" + tuner = _StubTuner.get() + args = self._make_args(err_ratio=0.05) + rets = [ + self._make_result((304, 1, 1024, 512), 0, 0, 5.0, err=0.9), + self._make_result((304, 32, 2048, 1024), 0, 0, 10.0, err=0.8), + ] + resultdf = tuner.post_process(rets, args, topk=1) + self.assertEqual(len(resultdf), 2, "Should have fallback entry for each shape") + + def test_single_shape_single_kernel(self): + """Minimal case: 1 shape, 1 kernel → should work.""" + tuner = _StubTuner.get() + args = self._make_args() + rets = [ + self._make_result((304, 1, 1024, 512), 0, 0, 5.0, err=0.01), + ] + resultdf = tuner.post_process(rets, args, topk=1) + self.assertEqual(len(resultdf), 1) + self.assertEqual(float(resultdf.iloc[0]["us"]), 5.0) + + +if __name__ == "__main__": + unittest.main(verbosity=2)