diff --git a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py index d2abf9210d..fc88cf47f3 100644 --- a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py @@ -25,7 +25,7 @@ from flydsl._mlir import ir from flydsl.expr.typing import T -from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl +from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl, const_expr from flydsl._mlir.dialects import llvm, scf, memref from flydsl._mlir.dialects.arith import CmpIPredicate @@ -448,7 +448,9 @@ def moe_gemm1( arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) ) - if use_async_copy and a_elem_vec_pack > 1: + _eff_lds_stride = 0 + _eff_tile_k_bytes = 0 + if const_expr(use_async_copy and a_elem_vec_pack > 1): _eff_lds_stride = lds_stride // a_elem_vec_pack _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack else: @@ -464,11 +466,10 @@ def moe_gemm1( bx_persist = gpu.block_id("y") # persistent WG index by_n = by * arith.constant(tile_n, index=True) - if _is_splitk: + k_base_idx = arith.index(0) + if const_expr(_is_splitk): bz = gpu.block_id("z") # K-batch id k_base_idx = bz * arith.constant(_k_dim, index=True) - else: - k_base_idx = arith.index(0) k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) layout_tx_wave_lane = fx.make_layout((num_waves, 64), stride=(64, 1)) @@ -533,9 +534,9 @@ def moe_gemm1( numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 ) - if is_f16_a: - sx_rsrc = None - else: + sx_rsrc = 1 + sw_rsrc = 1 + if const_expr(not is_f16_a): # A scale: [sorted_size, model_dim/32] pre-scattered by caller c32 = arith.constant(32, index=True) kblk = k_in / c32 @@ -545,9 +546,7 @@ def moe_gemm1( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: - sw_rsrc = None - else: + if const_expr(not is_f16_b): c32 = arith.constant(32, index=True) kblk_w = k_in / c32 mn_w = arith.constant(experts * (2 * inter_dim), index=True) @@ -579,7 +578,7 @@ def moe_gemm1( # Sorted-scale buffer resource for fused mxfp4 quantization _sorted_scale_cols = inter_dim // 32 _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) - if _need_sort: + if const_expr(_need_sort): sorted_scale_rsrc = buffer_ops.create_buffer_resource( arg_out_scale_sorted, max_size=False ) @@ -713,9 +712,8 @@ def load_x_tile(base_k): # N-tile precompute for gate AND up weights gate_n_intra_list = [] gate_n_blk_list = [] - if not gate_only: - up_n_intra_list = [] - up_n_blk_list = [] + up_n_intra_list = [] + up_n_blk_list = [] c_n0_static = experts * (2 * inter_dim) // 16 layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) inter_idx = arith.constant(inter_dim, index=True) @@ -732,7 +730,7 @@ def load_x_tile(base_k): gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) gate_n_blk_list.append(layout_get(gate_coord, 0)) gate_n_intra_list.append(layout_get(gate_coord, 1)) - if not gate_only: + if const_expr(not gate_only): # Up: rows [expert_off + inter_dim, expert_off + 2*inter_dim) up_row_w = gate_row_w + inter_idx up_coord = idx2crd(up_row_w, layout_n_blk_intra) @@ -790,14 +788,14 @@ def load_b_tile(base_k): ) g_packs0.append(gb0) g_packs1.append(gb1) - if not gate_only: + if const_expr(not gate_only): ub0, ub1 = load_b_packs_k64( base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni] ) u_packs0.append(ub0) u_packs1.append(ub1) gate_b_tile.append((g_packs0, g_packs1)) - if not gate_only: + if const_expr(not gate_only): up_b_tile.append((u_packs0, u_packs1)) return gate_b_tile, up_b_tile @@ -810,8 +808,7 @@ def load_b_tile(base_k): ) _gate_scale_bases = [] - if not gate_only: - _up_scale_bases = [] + _up_scale_bases = [] for _ni in range_constexpr(num_acc_n_packed): _col_base = ( by_n @@ -824,7 +821,7 @@ def load_b_tile(base_k): _gate_scale_bases.append( _gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem ) - if not gate_only: + if const_expr(not gate_only): _up_mni = ( expert_off_idx + inter_idx + _col_base ) // arith.constant(32, index=True) @@ -843,13 +840,21 @@ def load_b_tile(base_k): _c2_idx = arith.constant(2, index=True) _scale_mask_lo = arith.constant(0xFF, type=T.i32) - if pack_M < scale_mn_pack: + _m_half_idx = arith.constant(0, type=T.i32) + _m_half_i32 = arith.constant(0, type=T.i32) + _scale_shift = arith.constant(0, type=T.i32) + _scale_shift_hi = arith.constant(0, type=T.i32) + _n_half_idx = arith.constant(0, type=T.i32) + _n_half_i32 = arith.constant(0, type=T.i32) + _bscale_shift = arith.constant(0, type=T.i32) + _bscale_shift_hi = arith.constant(0, type=T.i32) + if const_expr(pack_M < scale_mn_pack): _m_half_idx = (bx_m // _c16_idx) % _c2_idx _m_half_i32 = arith.index_cast(T.i32, _m_half_idx) _scale_shift = _m_half_i32 * arith.constant(8, type=T.i32) _scale_shift_hi = _scale_shift + arith.constant(16, type=T.i32) - if pack_N < scale_mn_pack: + if const_expr(pack_N < scale_mn_pack): _n_half_idx = (n_tile_base // _c16_idx) % _c2_idx _n_half_i32 = arith.index_cast(T.i32, _n_half_idx) _bscale_shift = _n_half_i32 * arith.constant(8, type=T.i32) @@ -857,7 +862,7 @@ def load_b_tile(base_k): def _rearrange_a_scale(raw_i32): """Rearrange scale bytes for pack_M=1: extract m_half's k0,k1 bytes.""" - if pack_M >= scale_mn_pack: + if const_expr(pack_M >= scale_mn_pack): return raw_i32 b_k0 = arith.andi( arith.shrui(raw_i32, _scale_shift), _scale_mask_lo @@ -871,7 +876,7 @@ def _rearrange_a_scale(raw_i32): def _rearrange_b_scale(raw_i32): """Rearrange scale bytes for pack_N=1: extract n_half's k0,k1 bytes.""" - if pack_N >= scale_mn_pack: + if const_expr(pack_N >= scale_mn_pack): return raw_i32 b_k0 = arith.andi( arith.shrui(raw_i32, _bscale_shift), _scale_mask_lo @@ -913,7 +918,7 @@ def prefetch_ab_scale_tile(base_k): gate_b_scale.append( vector.from_elements(T.vec(1, T.i32), [gs]) ) - if not gate_only: + if const_expr(not gate_only): us = buffer_ops.buffer_load( sw_rsrc, _up_scale_bases[ni] + k_off, @@ -933,7 +938,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, @@ -949,7 +954,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): elem_bytes=elem_bytes, ) - if use_async_copy: + if const_expr(use_async_copy): _dma_bytes = 16 _wave_size = 64 _eff_bytes_per_buffer = ( @@ -977,7 +982,7 @@ def dma_x_tile_to_lds(base_k, lds_buffer): global_byte_idx = row_k_dw * c4_idx + col_local_sw global_offset = arith.index_cast(T.i32, global_byte_idx) - if i == 0: + if const_expr(i == 0): lds_addr = memref.extract_aligned_pointer_as_index( lds_buffer ) + wave_id * arith.constant( @@ -1036,7 +1041,7 @@ def prefetch_full_a_from_lds(lds_buffer): mi_val = arith.constant(mi_idx * 16, index=True) curr_row = row_a_lds + mi_val a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) - if is_f8_a: + if const_expr(is_f8_a): a2, a3 = lds_load_packs_k64( curr_row, col_base + 64, lds_buffer ) @@ -1063,7 +1068,7 @@ def compute_tile( up_list = list(acc_up_in) if not gate_only else None mfma_res_ty = vec4_f32 epilogue_pf = None - if prefetch_epilogue and doweight_stage1: + if const_expr(prefetch_epilogue and doweight_stage1): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * arith.index(4) ii_idx_list_pf = [ @@ -1102,7 +1107,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): static_position=[0], dynamic_position=[], ) - if not gate_only: + if const_expr(not gate_only): up_bs_i32 = up_b_scale[ku128 * num_acc_n_packed + ni] up_bs_val = vector.extract( up_bs_i32, static_position=[0], dynamic_position=[] @@ -1110,7 +1115,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for ikxdl in range_constexpr(pack_K): k_idx = ku128 * pack_K + ikxdl gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] - if not gate_only: + if const_expr(not gate_only): up_bp0, up_bp1 = up_b_tile_in[k_idx] for inxdl in range_constexpr(pack_N): ni_idx = ni * pack_N + inxdl @@ -1119,7 +1124,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): gb128 = pack_i64x4_to_i32x8( gb0, gb1, c0_i64, c0_i64 ) - if not gate_only: + if const_expr(not gate_only): ub0 = up_bp0[ni_idx] ub1 = up_bp1[ni_idx] ub128 = pack_i64x4_to_i32x8( @@ -1137,7 +1142,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(pack_M): mi_idx = mi * pack_M + imxdl _a_reg_idx = k_idx * m_repeat + mi_idx - if is_f8_a: + if const_expr(is_f8_a): a0, a1, a2, a3 = a_tile_regs[_a_reg_idx] a128 = pack_i64x4_to_i32x8( a0, a1, a2, a3 @@ -1164,7 +1169,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): ], ) ) - if not gate_only: + if const_expr(not gate_only): up_list[acc_idx] = ( rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, @@ -1189,7 +1194,7 @@ def load_a_subtile(k_idx, mi_idx, lds_buffer): mi_val = arith.constant(mi_idx * 16, index=True) curr_row = row_a_lds + mi_val a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) - if is_f8_a: + if const_expr(is_f8_a): a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer) return (a0, a1, a2, a3) else: @@ -1229,7 +1234,7 @@ def _pack(x0, x1, x2, x3): mfma_res_ty = vec4_f32 gb128 = _pack(gate_b_single[0], gate_b_single[1], c0_i64, c0_i64) - if not gate_only: + if const_expr(not gate_only): ub128 = _pack(up_b_single[0], up_b_single[1], c0_i64, c0_i64) for mi_p in range_constexpr(m_repeat_packed): @@ -1238,7 +1243,7 @@ def _pack(x0, x1, x2, x3): mi_idx = mi_p * pack_M + imxdl a_reg = all_a_tiles[k_idx * m_repeat + mi_idx] - if is_f8_a: + if const_expr(is_f8_a): a128 = _pack(a_reg[0], a_reg[1], a_reg[2], a_reg[3]) else: a128 = _pack(a_reg[0], a_reg[1], c0_i64, c0_i64) @@ -1258,7 +1263,7 @@ def _pack(x0, x1, x2, x3): gate_bs_val, ], ) - if not gate_only: + if const_expr(not gate_only): up_list[acc_idx] = ( rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, @@ -1313,9 +1318,9 @@ def _interleaved_half( # DMA A to OTHER buffer (for next half), non-blocking _abs_k_dma = k_base_idx + arith.constant(next_k_dma_py, index=True) - if use_async_copy and next_k_dma_py < int(_k_dim): + if const_expr(use_async_copy and next_k_dma_py < int(_k_dim)): prefetch_x_to_lds(_abs_k_dma, lds_write) - if not use_async_copy: + if const_expr(not use_async_copy): _x_regs = load_x_tile(_abs_k_dma) # ---- Extract previous scale values ---- @@ -1333,7 +1338,7 @@ def _interleaved_half( static_position=[0], dynamic_position=[], ) - if not gate_only: + if const_expr(not gate_only): _prev_usv = vector.extract( prev_up_bs[0], static_position=[0], @@ -1347,7 +1352,7 @@ def _interleaved_half( for _p in range_constexpr(_pipe_n_phases): # Scale VMEM loads (phase 0 only) - if _pp_has_scale[_p]: + if const_expr(_pp_has_scale[_p]): _new_as_list = [] for _mi_p in range_constexpr(m_repeat_packed): _raw_as = buffer_ops.buffer_load( @@ -1366,7 +1371,7 @@ def _interleaved_half( cache_modifier=0, ) _new_gs = _rearrange_b_scale(_new_gs) - if not gate_only: + if const_expr(not gate_only): _new_us = buffer_ops.buffer_load( sw_rsrc, _up_scale_bases[0] + _k_off, @@ -1379,7 +1384,7 @@ def _interleaved_half( # B VMEM loads for _b_j in range_constexpr(len(_pp_b_loads[_p])): _b_type, _b_ku, _b_ni = _pp_b_loads[_p][_b_j] - if _b_type == "gate": + if const_expr(_b_type == "gate"): _b_gate_all[(_b_ku, _b_ni)] = load_b_packs_k64( _bk, _b_ku, @@ -1452,12 +1457,12 @@ def _interleaved_half( g = _b_gate_all[(ku, ni)] g_packs0.append(g[0]) g_packs1.append(g[1]) - if not gate_only: + if const_expr(not gate_only): u = _b_up_all[(ku, ni)] u_packs0.append(u[0]) u_packs1.append(u[1]) cur_gate_w.append((g_packs0, g_packs1)) - if not gate_only: + if const_expr(not gate_only): cur_up_w.append((u_packs0, u_packs1)) cur_a_scale = [] @@ -1469,12 +1474,12 @@ def _interleaved_half( ) ) cur_gate_bs = [vector.from_elements(T.vec(1, T.i32), [_new_gs])] - if not gate_only: + if const_expr(not gate_only): cur_up_bs = [vector.from_elements(T.vec(1, T.i32), [_new_us])] else: cur_up_bs = None - if not use_async_copy: + if const_expr(not use_async_copy): store_x_tile_to_lds(_x_regs, lds_write) return ( @@ -1492,7 +1497,7 @@ def _interleaved_half( rocdl.sched_barrier(0) k0 = k_base_idx - if use_async_copy: + if const_expr(use_async_copy): prefetch_x_to_lds(k0, lds_x_pong) else: x_regs0 = load_x_tile(k0) @@ -1519,7 +1524,7 @@ def _interleaved_half( _k1 = k_base_idx + arith.constant(tile_k, index=True) rocdl.sched_barrier(0) - if use_async_copy: + if const_expr(use_async_copy): prefetch_x_to_lds(_k1, lds_x_ping) else: _x_regs_prime = load_x_tile(_k1) @@ -1540,7 +1545,7 @@ def _interleaved_half( odd_k_tiles = (num_k_tiles_py % 2) == 1 tail_tiles = 1 if odd_k_tiles else 2 k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) - if k_main2_py < 0: + if const_expr(k_main2_py < 0): k_main2_py = 0 gate_w_pong = gate_w0 up_w_pong = up_w0 @@ -1593,7 +1598,7 @@ def _sched_hints_stage1_gate_up(): # rocdl.sched_dswr(1) # rocdl.sched_barrier(0) - if use_async_copy: + if const_expr(use_async_copy): a_vmem_load = max(1, tile_m // 32) mfma_group = a_vmem_load rocdl.sched_vmem(a_vmem_load) @@ -1603,7 +1608,7 @@ def _sched_hints_stage1_gate_up(): b_vmem_total = k_unroll * num_acc_n * 2 vmem_count = b_vmem_total + 2 + a_vmem_load - if tile_m == 16: + if const_expr(tile_m == 16): for i in range_constexpr(2): rocdl.sched_dsrd(1) rocdl.sched_mfma(1) @@ -1619,11 +1624,11 @@ def _sched_hints_stage1_gate_up(): rocdl.sched_vmem(1) rocdl.sched_mfma(mfma_group) - if tile_m == 32: + if const_expr(tile_m == 32): for i in range_constexpr(vmem_count - a_vmem_load * 4): rocdl.sched_vmem(1) rocdl.sched_mfma(mfma_group) - elif tile_m == 64: + elif const_expr(tile_m == 64): rocdl.sched_vmem(1) rocdl.sched_mfma(1) rocdl.sched_vmem(1) @@ -1637,7 +1642,7 @@ def _sched_hints_stage1_gate_up(): rocdl.sched_barrier(0) - if k_main2_py > 0: + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): next_k_load_1 = k_iv_py + tile_k next_k_load_2 = k_iv_py + tile_k * 2 @@ -1704,7 +1709,7 @@ def _sched_hints_stage1_gate_up(): # _barrier() # scf.YieldOp([]) - if odd_k_tiles: + if const_expr(odd_k_tiles): acc_gate, acc_up, epilogue_pf = compute_tile( acc_gate, acc_up, @@ -1719,7 +1724,8 @@ def _sched_hints_stage1_gate_up(): else: _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) k_tail1 = k_base_idx + _k_tail_rel - if use_async_copy: + x_regs_ping = [] + if const_expr(use_async_copy): prefetch_x_to_lds(k_tail1, lds_x_ping) else: x_regs_ping = load_x_tile(k_tail1) @@ -1739,7 +1745,7 @@ def _sched_hints_stage1_gate_up(): gate_bs_pong, up_bs_pong, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_x_tile_to_lds(x_regs_ping, lds_x_ping) rocdl.s_waitcnt(0) _barrier() @@ -1781,7 +1787,7 @@ def _silu_mul_vec4(gate_v4, up_v4): result_elems.append(g * sig * u) return vector.from_elements(vec4_f32, result_elems) - if not _is_splitk: + if const_expr(not _is_splitk): acc = [None] * (int(num_acc_n) * int(m_repeat)) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(num_acc_n): @@ -1792,7 +1798,7 @@ def _silu_mul_vec4(gate_v4, up_v4): # Output: out[(t*topk+s) * inter_dim + col] = silu(gate) * up # For split-K: skip silu, output gate/up separately with atomic add tw_pf = None - if epilogue_pf is not None: + if const_expr(epilogue_pf is not None): _, tw_pf, _ = epilogue_pf mask24_i32 = arith.constant(0xFFFFFF) @@ -1808,7 +1814,7 @@ def _silu_mul_vec4(gate_v4, up_v4): out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) - if lds_out is None: + if const_expr(lds_out is None): raise RuntimeError("CShuffle epilogue requires lds_out") _apply_weight = doweight_stage1 and not _is_splitk @@ -1823,10 +1829,11 @@ def write_row_to_lds( col_base_local, num_acc_n: int, lds_out, + acc_v, ): - if _apply_weight: + if const_expr(_apply_weight): tw_idx = (mi * 4) + ii - if tw_pf is not None: + if const_expr(tw_pf is not None): tw = tw_pf[tw_idx] else: tw = buffer_ops.buffer_load( @@ -1836,11 +1843,11 @@ def write_row_to_lds( col_local = col_base_local + (ni * 16) acc_idx = mi * num_acc_n + ni v = vector.extract( - acc[acc_idx], static_position=[ii], dynamic_position=[] + acc_v[acc_idx], static_position=[ii], dynamic_position=[] ) - if _apply_weight: + if const_expr(_apply_weight): v = v * tw - if _need_quant: + if const_expr(_need_quant): lds_idx = row_base_lds + col_local vec1_f32 = T.vec(1, f32) v1 = vector.from_elements(vec1_f32, [v]) @@ -1882,6 +1889,32 @@ def _idx_to_llvm_ptr(idx_val, addr_space=1): ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") return llvm.inttoptr(ptr_ty, i64_raw) + def _make_write_row_to_lds(acc_v): + def _write_row_to_lds_bound( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + return write_row_to_lds( + mi=mi, + ii=ii, + row_in_tile=row_in_tile, + row=row, + row_base_lds=row_base_lds, + col_base_local=col_base_local, + num_acc_n=num_acc_n, + lds_out=lds_out, + acc_v=acc_v, + ) + + return _write_row_to_lds_bound + _e_vec = _e_vec_s1 _e_vec_sk = 2 _cshuffle_nlane = min(32, tile_n // _e_vec) @@ -1930,7 +1963,7 @@ def _f32_to_e2m1(qx_f32): e2m1 = arith.minui(rounded, _c7_i32) return (s >> _c28_i32) | e2m1 - if _need_sort: + if const_expr(_need_sort): _n32_sort = _sorted_scale_cols_i32 * _c32_i32 # Mutable slot for split-K N-offset (gate=0, up=inter_dim) @@ -1938,7 +1971,7 @@ def _f32_to_e2m1(qx_f32): def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): fused, row_byte_base = row_ctx - if _need_quant and not _is_splitk: + if const_expr(_need_quant and not _is_splitk): frag_vals = [] for i in range_constexpr(_e_vec): frag_vals.append( @@ -1984,7 +2017,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): ) out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) _pack_bytes = _e_vec // 2 - if _pack_bytes == 1: + if const_expr(_pack_bytes == 1): store_val = arith.TruncIOp(T.i8, packed_i32) store_raw = ( store_val._value @@ -1994,7 +2027,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): llvm.StoreOp( store_raw, out_ptr_v, alignment=1, nontemporal=True ) - elif _pack_bytes == 2: + elif const_expr(_pack_bytes == 2): store_val = arith.TruncIOp(T.i16, packed_i32) store_raw = ( store_val._value @@ -2014,7 +2047,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): packed_raw, out_ptr_v, alignment=4, nontemporal=True ) - if _need_sort: + if const_expr(_need_sort): col_g0_i32 = arith.index_cast(T.i32, col_g0) is_scale_writer = arith.cmpi( CmpIPredicate.eq, col_g0_i32 & _c31_i32, _c0_i32 @@ -2045,7 +2078,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): offset_is_bytes=True, ) scf.YieldOp([]) - elif _is_splitk: + elif const_expr(_is_splitk): col_idx = col_g0 + arith.constant(_sk_n_offset[0], index=True) byte_off_col = col_idx * arith.constant( out_elem_bytes, index=True @@ -2082,7 +2115,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): else (ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()) ) - if gate_only: + if const_expr(gate_only): # gate_only: single pass, by_n covers full [0, 2*inter_dim) _eff_e_vec = _e_vec_sk acc = acc_gate @@ -2107,11 +2140,11 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, + write_row_to_lds=_make_write_row_to_lds(acc), precompute_row=precompute_row, store_pair=store_pair, ) - elif _is_splitk: + elif const_expr(_is_splitk): # Two-pass epilogue: gate then up, each with atomic add _eff_e_vec = _e_vec_sk @@ -2139,7 +2172,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, + write_row_to_lds=_make_write_row_to_lds(acc), precompute_row=precompute_row, store_pair=store_pair, ) @@ -2170,7 +2203,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, + write_row_to_lds=_make_write_row_to_lds(acc), precompute_row=precompute_row, store_pair=store_pair, ) @@ -2196,7 +2229,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): n_tile_base=n_tile_base, lds_out=lds_out, frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, + write_row_to_lds=_make_write_row_to_lds(acc), precompute_row=precompute_row, store_pair=store_pair, ) @@ -2265,7 +2298,7 @@ def launch_mixed_moe_gemm1( allocator_ping.finalize() inter_in = arith.index_cast(ir.IndexType.get(), i32_inter_in.ir_value()) - if gate_only: + if const_expr(gate_only): gx = inter_in / arith.constant(tile_n, index=True) else: gx = ( @@ -2683,7 +2716,7 @@ def check_c_k_valid_gate(base_k): out_nbytes_idx = ( tokens_in * n_in * arith.constant(out_elem_bytes, index=True) ) - if not bool(accumulate): + if const_expr(not bool(accumulate)): out_nbytes_idx = ( tokens_in * arith.index(topk) @@ -2712,10 +2745,10 @@ def check_c_k_valid_gate(base_k): num_valid_idx = arith.index_cast(ir.IndexType.get(), num_valid_i32) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16_a: - sx_rsrc = None - else: - if is_f4_a: + sx_rsrc = 1 + sw_rsrc = 1 + if const_expr(not is_f16_a): + if const_expr(is_f4_a): # A2 microscale: e8m0 in sorted layout [sorted_size, K/32]. # Caller must pre-scatter a2_scale via moe_mxfp4_sort. kblk = _div_pow2(k_in, 32) @@ -2732,9 +2765,7 @@ def check_c_k_valid_gate(base_k): arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: - sw_rsrc = None - else: + if const_expr(not is_f16_b): # Weight microscale buffer (packed i32 holding e8m0 bytes). # Use an exact descriptor size so hardware OOB checking works. kblk_w = _div_pow2(k_in, 32) # K/32 @@ -2783,7 +2814,7 @@ def check_c_k_valid_gate(base_k): _c0_p = arith.constant(0, index=True) _c1_p = arith.constant(1, index=True) - if _persistent: + if const_expr(_persistent): # Expert-phase scheduling: contiguous M-tile dispatch. # grid_y = cu_num, each CTA handles a contiguous chunk of M-tiles: # [bx_persist * tiles_per_block, ..., (bx_persist+1) * tiles_per_block - 1] @@ -2812,7 +2843,7 @@ def check_c_k_valid_gate(base_k): _for_ip.__enter__() _mi_p = _for_persist.induction_variable - if _persistent: + if const_expr(_persistent): _still_active = _for_persist.inner_iter_args[0] bx = bx_persist * _tiles_per_block + _mi_p else: @@ -2835,7 +2866,7 @@ def check_c_k_valid_gate(base_k): CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) ) - if _persistent: + if const_expr(_persistent): # Absolute B-base: no cross-iteration state needed. _expert_b_base = expert_idx * arith.constant( _expert_b_stride, index=True @@ -2862,7 +2893,7 @@ def check_c_k_valid_gate(base_k): # For tile_m < 32 (pack_M < _scale_pack_m): shift a_scale i32 so the # correct bytes land at the op_sel positions we use. - if pack_M < _scale_pack_m: + if const_expr(pack_M < _scale_pack_m): _m_off = _mod_pow2(_div_pow2(bx_m, 16), _scale_pack_m) _m_scale_shift_i32 = arith.index_cast( T.i32, _m_off * arith.constant(8, index=True) @@ -2878,18 +2909,18 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16_a: - if bytes_per_thread_x % 16 != 0: + if const_expr(is_f16_a): + if const_expr(bytes_per_thread_x % 16 != 0): raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" ) x_load_bytes = 16 else: - if bytes_per_thread_x % 16 == 0: + if const_expr(bytes_per_thread_x % 16 == 0): x_load_bytes = 16 - elif bytes_per_thread_x % 8 == 0: + elif const_expr(bytes_per_thread_x % 8 == 0): x_load_bytes = 8 - elif bytes_per_thread_x % 4 == 0: + elif const_expr(bytes_per_thread_x % 4 == 0): x_load_bytes = 4 else: raise ValueError( @@ -2939,7 +2970,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): idx_elem = ( idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) @@ -3001,9 +3032,9 @@ def load_x_tile(base_k): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): parts.append(vector.bitcast(vec4_i32, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): parts.append(vector.bitcast(vec2_i32, x_vec)) else: parts.append(vector.bitcast(vec1_i32, x_vec)) @@ -3031,7 +3062,7 @@ def load_x_tile(base_k): by_n = by * arith.constant(tile_n, index=True) - if pack_N < _scale_pack_n: + if const_expr(pack_N < _scale_pack_n): _global_n_base = expert_off_idx + by_n + n_tile_base _n_off = _mod_pow2(_div_pow2(_global_n_base, 16), _scale_pack_n) _n_scale_shift_i32 = arith.index_cast( @@ -3153,7 +3184,7 @@ def load_scale(arg_scale, rsrc, scale_info, ku, mni): return vector.from_elements(T.vec(1, T.i32), [s]) def _apply_k_shift(scale_vec, k_shift_bits): - if k_shift_bits > 0: + if const_expr(k_shift_bits > 0): val = vector.extract( scale_vec, static_position=[0], dynamic_position=[] ) @@ -3212,7 +3243,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, @@ -3227,7 +3258,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): lds_store_8b_xor16( arith, vector, @@ -3294,7 +3325,7 @@ def compute_tile( a1_prefetch=None, b_hi_loader=None, ): - if b_hi_loader is not None: + if const_expr(b_hi_loader is not None): b_tile_full = [None] * k_unroll for i in range_constexpr(_b_split_ku): b_tile_full[i] = b_tile_in[i] @@ -3305,8 +3336,8 @@ def compute_tile( epilogue_pf = None bias = None - if prefetch_epilogue: - if enable_bias: + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): bias = [] for ni in range_constexpr(num_acc_n): global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 @@ -3317,7 +3348,7 @@ def compute_tile( ) ) tw_pf = None - if doweight_stage2: + if const_expr(doweight_stage2): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * arith.index(4) ii_idx_list_pf = [ @@ -3355,7 +3386,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): _pack_K_shift = (pack_K - 1).bit_length() _pack_K_mask = pack_K - 1 - if b_hi_loader is not None: + if const_expr(b_hi_loader is not None): _b_hi = b_hi_loader() for _bhi_i in range_constexpr(len(_b_hi)): b_tile_full[_b_split_ku + _bhi_i] = _b_hi[_bhi_i] @@ -3373,7 +3404,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): a_scale_val = vector.extract( a_scale_i32, static_position=[0], dynamic_position=[] ) - if _m_scale_shift_i32 is not None: + if const_expr(_m_scale_shift_i32 is not None): a_scale_val = arith.shrui( a_scale_val, _m_scale_shift_i32 ) @@ -3384,7 +3415,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): static_position=[0], dynamic_position=[], ) - if _n_scale_shift_i32 is not None: + if const_expr(_n_scale_shift_i32 is not None): b_scale_val = arith.shrui( b_scale_val, _n_scale_shift_i32 ) @@ -3395,13 +3426,13 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): mi_val = arith.constant(mi_idx * 16, index=True) curr_row_a_lds = row_a_lds + mi_val - if ( + if const_expr( (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0) ): a0, a1 = a0_prefetch - elif ( + elif const_expr( (a1_prefetch is not None) and (k_idx == 1) and (mi_idx == 0) @@ -3412,7 +3443,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): curr_row_a_lds, col_base0, lds_base ) - if is_f8_a: + if const_expr(is_f8_a): col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64( curr_row_a_lds, col_base1, lds_base @@ -3473,25 +3504,25 @@ def hot_loop_scheduler(): rocdl.sched_dsrd(2) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) - if num_acc_n < 4: + if const_expr(num_acc_n < 4): rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) # DS-write hints near the end: match total A LDS-store micro-ops per thread. dswr_tail = num_x_loads - if dswr_tail > sche_iters: + if const_expr(dswr_tail > sche_iters): dswr_tail = sche_iters dswr_start = sche_iters - dswr_tail @@ -3500,13 +3531,13 @@ def hot_loop_scheduler(): rocdl.sched_mfma(mfma_group) rocdl.sched_dsrd(1) rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: + if const_expr(sche_i >= dswr_start - 1): rocdl.sched_dswr(1) rocdl.sched_barrier(0) def _k_shift_bits(k_py): - if pack_K >= _scale_pack_k: + if const_expr(pack_K >= _scale_pack_k): return 0 return ((k_py // 128) % _scale_pack_k) * _scale_pack_m * 8 @@ -3531,7 +3562,7 @@ def _k_base(k_py): # Prologue -- B-first. k0 = arith.index(0) - if _b_split_enabled: + if const_expr(_b_split_enabled): b_cur = load_b_tile_lo(k0) else: b_cur = load_b_tile(k0) @@ -3571,7 +3602,7 @@ def _k_base(k_py): odd_k_tiles = (num_k_tiles_py % 2) == 1 tail_tiles = 1 if odd_k_tiles else 2 k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) - if k_main2_py < 0: + if const_expr(k_main2_py < 0): k_main2_py = 0 c2_tile_k = arith.constant(tile_k * 2, index=True) @@ -3586,7 +3617,7 @@ def _make_b_hi_loader(base_k): """Create a b_hi_loader callable for a given base_k.""" return lambda _bk=base_k: load_b_tile_hi(_bk) - if k_main2_py > 0: + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): k_iv = arith.index(k_iv_py) next_k1 = k_iv + tile_k @@ -3671,7 +3702,7 @@ def _make_b_hi_loader(base_k): else None ) - if odd_k_tiles: + if const_expr(odd_k_tiles): # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). acc, epilogue_pf = compute_tile( acc, @@ -3749,7 +3780,7 @@ def _make_b_hi_loader(base_k): tw_pf = None bias_pf = None - if epilogue_pf is not None: + if const_expr(epilogue_pf is not None): _, tw_pf, bias_pf = epilogue_pf mask24_i32 = arith.constant(0xFFFFFF) @@ -3767,7 +3798,7 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): ) # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). - if lds_out is None: + if const_expr(lds_out is None): raise RuntimeError( "FLIR_MOE_STAGE2_CSHUFFLE=1 but lds_out is not allocated/aliased." ) @@ -3796,9 +3827,9 @@ def write_row_to_lds( num_acc_n: int, lds_out, ): - if doweight_stage2: + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii - if tw_pf is not None: + if const_expr(tw_pf is not None): tw = tw_pf[tw_idx] else: tw = buffer_ops.buffer_load( @@ -3811,12 +3842,12 @@ def write_row_to_lds( v = vector.extract( acc[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if const_expr(is_int8): v = arith.sitofp(f32, v) - if enable_bias: + if const_expr(enable_bias): v = v + bias_pf[ni] - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -3840,7 +3871,7 @@ def precompute_row(*, row_local, row): t_idx = arith.index_cast(ir.IndexType.get(), t) s_idx = arith.index_cast(ir.IndexType.get(), s) ts_idx = t_idx * arith.constant(topk, index=True) + s_idx - if accumulate: + if const_expr(accumulate): row_byte_base = out_base_idx + t_idx * arith.constant( model_dim * out_elem_bytes, index=True ) @@ -3860,7 +3891,7 @@ def _idx_to_llvm_ptr(idx_val, addr_space=1): def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): fused, row_byte_base = row_ctx - if not bool(accumulate): + if const_expr(not bool(accumulate)): # ---- 64-bit global store path (avoids i32 offset overflow) ---- col_idx = col_g0 byte_off_col = col_idx * arith.constant( @@ -3922,7 +3953,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): _all_valid = arith.andi(blk_valid, arith.andi(exp_valid, tile_has_tokens)) - if _persistent: + if const_expr(_persistent): # Short-circuit: contiguous tiles are monotonically increasing, # so once bx_m >= num_valid_ids all remaining tiles are invalid. _cur_active = arith.andi(_still_active, blk_valid) @@ -3992,6 +4023,7 @@ def launch_mixed_moe_gemm2( n_in = arith.index_cast(T.index, i32_n_in) gx = n_in / arith.constant(tile_n, index=True) + gy = arith.constant(0, index=True) if _persistent: gy = arith.constant(_cu_num, index=True) else: diff --git a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py index 0cb4539953..0766ec00a8 100644 --- a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py +++ b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py @@ -27,7 +27,7 @@ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl.expr import arith, vector, range_constexpr +from flydsl.expr import arith, vector, range_constexpr, const_expr from flydsl.expr.typing import T, Int32 from flydsl.expr.arith import ArithValue, CmpIPredicate from flydsl.compiler.kernel_function import CompilationContext @@ -203,7 +203,7 @@ def _f32_to_e2m1(qx_f32): vec_bf16_ty = T.vec(VEC, T.bf16) vec_f32_ty = T.vec(VEC, f32) - if vec_dw == 1: + if const_expr(vec_dw == 1): vec1_i32_ty = T.vec(1, i32) gate_vec = vector.from_elements(vec1_i32_ty, [gate_raw]) up_vec = vector.from_elements(vec1_i32_ty, [up_raw]) @@ -267,12 +267,12 @@ def _f32_to_e2m1(qx_f32): ) _pack_bytes = VEC // 2 - if _pack_bytes == 1: + if const_expr(_pack_bytes == 1): store_val = arith.TruncIOp(T.i8, packed_i32) buffer_ops.buffer_store( store_val, out_rsrc, fp4_byte_off, offset_is_bytes=True ) - elif _pack_bytes == 2: + elif const_expr(_pack_bytes == 2): store_val = arith.TruncIOp(T.i16, packed_i32) buffer_ops.buffer_store( store_val, out_rsrc, fp4_byte_off, offset_is_bytes=True