From c7082a7b0c5e8bdb0b786e3b2d2eb2cf8780c721 Mon Sep 17 00:00:00 2001 From: Qiuweihong <953950914@qq.com> Date: Fri, 3 Apr 2026 10:14:23 +0800 Subject: [PATCH] implement jump table on riscv64 --- src/cmd/compile/internal/riscv64/ssa.go | 40 +++++++++++++++++++ .../compile/internal/ssa/_gen/RISCV64.rules | 2 + .../compile/internal/ssa/_gen/RISCV64Ops.go | 5 +++ src/cmd/compile/internal/ssa/opGen.go | 26 ++++++------ .../compile/internal/ssa/rewriteRISCV64.go | 13 ++++++ src/cmd/internal/obj/riscv/obj.go | 7 ++++ src/cmd/internal/sys/arch.go | 1 + test/codegen/switch.go | 6 +++ 8 files changed, 88 insertions(+), 12 deletions(-) diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index b34256c1766230..adf6b1361514d9 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -1042,6 +1042,46 @@ func ssaGenBlock(s *ssagen.State, b, next *ssa.Block) { p.From.Reg = b.Controls[0].Reg() } + case ssa.BlockRISCV64JUMPTABLE: + // Jump table: + // TMP = base + index*8 (SH3ADD if Zba else SLLI+ADD). + // Load slot into TMP, then indirect JMP through TMP. + var p *obj.Prog + if buildcfg.GORISCV64 >= 22 { + p = s.Prog(riscv.ASH3ADD) + p.From.Type = obj.TYPE_REG + p.From.Reg = b.Controls[1].Reg() + p.Reg = b.Controls[0].Reg() + p.To.Type = obj.TYPE_REG + p.To.Reg = riscv.REG_TMP + } else { + p = s.Prog(riscv.ASLLI) + p.From.Type = obj.TYPE_CONST + p.From.Offset = 3 + p.Reg = b.Controls[0].Reg() + p.To.Type = obj.TYPE_REG + p.To.Reg = riscv.REG_TMP + + p = s.Prog(riscv.AADD) + p.From.Type = obj.TYPE_REG + p.From.Reg = riscv.REG_TMP + p.Reg = b.Controls[1].Reg() + p.To.Type = obj.TYPE_REG + p.To.Reg = riscv.REG_TMP + } + + p = s.Prog(riscv.AMOV) + p.From.Type = obj.TYPE_MEM + p.From.Reg = riscv.REG_TMP + p.To.Type = obj.TYPE_REG + p.To.Reg = riscv.REG_TMP + + p = s.Prog(obj.AJMP) + p.To.Type = obj.TYPE_MEM + p.To.Reg = riscv.REG_TMP + // Save jump tables for later resolution of the target blocks. + s.JumpTables = append(s.JumpTables, b) + default: b.Fatalf("Unhandled block: %s", b.LongString()) } diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index ed7e142d0881bb..9082abeb0fd383 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -39,6 +39,8 @@ (Select1 (Sub64borrow x y c)) => (OR (SLTU x s:(SUB x y)) (SLTU s (SUB s c))) +(JumpTable idx) => (JUMPTABLE {makeJumpTableSym(b)} idx (MOVaddr {makeJumpTableSym(b)} (SB))) + // (x + y) / 2 => (x / 2) + (y / 2) + (x & y & 1) (Avg64u x y) => (ADD (ADD (SRLI [1] x) (SRLI [1] y)) (ANDI [1] (AND x y))) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index 5ce3b0e99d8715..747737423d9bcb 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -550,6 +550,11 @@ func init() { {name: "BGEZ", controls: 1}, {name: "BLTZ", controls: 1}, {name: "BGTZ", controls: 1}, + // JUMPTABLE implements jump tables. + // Aux is the symbol (an *obj.LSym) for the jump table. + // control[0] is the index into the jump table. + // control[1] is the address of the jump table (the address of the symbol stored in Aux). + {name: "JUMPTABLE", controls: 2, aux: "Sym"}, } archs = append(archs, arch{ diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index 67a594416d7d2e..9688e13bb4de58 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -151,6 +151,7 @@ const ( BlockRISCV64BGEZ BlockRISCV64BLTZ BlockRISCV64BGTZ + BlockRISCV64JUMPTABLE BlockS390XBRC BlockS390XCRJ @@ -296,18 +297,19 @@ var blockString = [...]string{ BlockPPC64FGT: "FGT", BlockPPC64FGE: "FGE", - BlockRISCV64BEQ: "BEQ", - BlockRISCV64BNE: "BNE", - BlockRISCV64BLT: "BLT", - BlockRISCV64BGE: "BGE", - BlockRISCV64BLTU: "BLTU", - BlockRISCV64BGEU: "BGEU", - BlockRISCV64BEQZ: "BEQZ", - BlockRISCV64BNEZ: "BNEZ", - BlockRISCV64BLEZ: "BLEZ", - BlockRISCV64BGEZ: "BGEZ", - BlockRISCV64BLTZ: "BLTZ", - BlockRISCV64BGTZ: "BGTZ", + BlockRISCV64BEQ: "BEQ", + BlockRISCV64BNE: "BNE", + BlockRISCV64BLT: "BLT", + BlockRISCV64BGE: "BGE", + BlockRISCV64BLTU: "BLTU", + BlockRISCV64BGEU: "BGEU", + BlockRISCV64BEQZ: "BEQZ", + BlockRISCV64BNEZ: "BNEZ", + BlockRISCV64BLEZ: "BLEZ", + BlockRISCV64BGEZ: "BGEZ", + BlockRISCV64BLTZ: "BLTZ", + BlockRISCV64BGTZ: "BGTZ", + BlockRISCV64JUMPTABLE: "JUMPTABLE", BlockS390XBRC: "BRC", BlockS390XCRJ: "CRJ", diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 43df9db6bc8445..0b9020507f46e7 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -11222,6 +11222,19 @@ func rewriteBlockRISCV64(b *Block) bool { b.resetWithControl(BlockRISCV64BNEZ, v0) return true } + case BlockJumpTable: + // match: (JumpTable idx) + // result: (JUMPTABLE {makeJumpTableSym(b)} idx (MOVaddr {makeJumpTableSym(b)} (SB))) + for { + idx := b.Controls[0] + v0 := b.NewValue0(b.Pos, OpRISCV64MOVaddr, typ.Uintptr) + v0.Aux = symToAux(makeJumpTableSym(b)) + v1 := b.NewValue0(b.Pos, OpSB, typ.Uintptr) + v0.AddArg(v1) + b.resetWithControl2(BlockRISCV64JUMPTABLE, idx, v0) + b.Aux = symToAux(makeJumpTableSym(b)) + return true + } } return false } diff --git a/src/cmd/internal/obj/riscv/obj.go b/src/cmd/internal/obj/riscv/obj.go index 50c687f722a1f8..7e46ede1f79046 100644 --- a/src/cmd/internal/obj/riscv/obj.go +++ b/src/cmd/internal/obj/riscv/obj.go @@ -5063,6 +5063,13 @@ func assemble(ctxt *obj.Link, cursym *obj.LSym, newprog obj.ProgAlloc) { } obj.MarkUnsafePoints(ctxt, cursym.Func().Text, newprog, isUnsafePoint, nil) + + // generate jump table entries. + for _, jt := range cursym.Func().JumpTables { + for i, p := range jt.Targets { + jt.Sym.WriteAddr(ctxt, int64(i)*8, 8, cursym, p.Pc) + } + } } func isUnsafePoint(p *obj.Prog) bool { diff --git a/src/cmd/internal/sys/arch.go b/src/cmd/internal/sys/arch.go index 2738fe3b54c9e5..c2c5afa332680e 100644 --- a/src/cmd/internal/sys/arch.go +++ b/src/cmd/internal/sys/arch.go @@ -239,6 +239,7 @@ var ArchRISCV64 = &Arch{ MinLC: 2, Alignment: 8, // riscv unaligned loads work, but are really slow (trap + simulated by OS) CanMergeLoads: false, + CanJumpTable: true, HasLR: true, FixedFrameSize: 8, // LR } diff --git a/test/codegen/switch.go b/test/codegen/switch.go index d59ef4f2eb78f6..1dea2f92d7cb8b 100644 --- a/test/codegen/switch.go +++ b/test/codegen/switch.go @@ -26,6 +26,8 @@ func square(x int) int { // amd64:`JMP \(.*\)\(.*\)$` // arm64:`MOVD \(R.*\)\(R.*<<3\)` `JMP \(R.*\)$` // loong64: `ALSLV` `MOVV` `JMP` + // riscv64/rva20u64:`SLLI`,`ADD`,`MOV\s\(X31\), X31`,`JALR\sX0, \(X31\)$` + // riscv64/rva22u64,riscv64/rva23u64:`SH3ADD`,`MOV\s\(X31\), X31`,`JALR\sX0, \(X31\)$` switch x { case 1: return 1 @@ -53,6 +55,8 @@ func length(x string) int { // amd64:`JMP \(.*\)\(.*\)$` // arm64:`MOVD \(R.*\)\(R.*<<3\)` `JMP \(R.*\)$` // loong64:`ALSLV` `MOVV` `JMP` + // riscv64/rva20u64:`SLLI`,`ADD`,`MOV\s\(X31\), X31`,`JALR\sX0, \(X31\)$` + // riscv64/rva22u64,riscv64/rva23u64:`SH3ADD`,`MOV\s\(X31\), X31`,`JALR\sX0, \(X31\)$` switch x { case "a": return 1 @@ -106,6 +110,8 @@ func mimetype(ext string) string { func typeSwitch(x any) int { // amd64:`JMP \(.*\)\(.*\)$` // arm64:`MOVD \(R.*\)\(R.*<<3\)` `JMP \(R.*\)$` + // riscv64/rva20u64:`SLLI`,`ADD`,`MOV\s\(X31\), X31`,`JALR\sX0, \(X31\)$` + // riscv64/rva22u64,riscv64/rva23u64:`SH3ADD`,`MOV\s\(X31\), X31`,`JALR\sX0, \(X31\)$` switch x.(type) { case int: return 0