-
Notifications
You must be signed in to change notification settings - Fork 1.8k
perf: replace atomic scatter with pair-buffer tile binning and cooperative large-splat processing #8586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: replace atomic scatter with pair-buffer tile binning and cooperative large-splat processing #8586
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -77,30 +77,34 @@ let _randomColorRaw = null; | |||||||||||
| * GSplatManager manages the rendering of splats using a work buffer, where all active splats are | ||||||||||||
| * stored and rendered from. | ||||||||||||
| * | ||||||||||||
| * GPU sorting (WebGPU only): | ||||||||||||
| * 1. [culling] Frustum cull: a fragment shader tests each bounding sphere against frustum | ||||||||||||
| * planes and writes results into a bit-packed nodeVisibilityTexture (1 bit per sphere). | ||||||||||||
| * 2. Interval compaction: operates on contiguous intervals of splats (one per octree node) | ||||||||||||
| * rather than individual pixels. A cull/count pass writes each interval's splat count | ||||||||||||
| * (or 0 if culled) into a count buffer. A prefix sum produces output offsets. A scatter | ||||||||||||
| * pass expands visible intervals into compactedSplatIds (flat list of work-buffer pixel | ||||||||||||
| * indices). The last prefix sum element gives visibleCount. | ||||||||||||
| * 3. Generate sort keys: an indirect compute dispatch (visibleCount threads) reads each | ||||||||||||
| * Shared culling + compaction (GPU sorting and compute renderer, WebGPU only): | ||||||||||||
| * Interval compaction operates on contiguous intervals of splats (one per octree node). | ||||||||||||
| * 1. Cull + count (compute): each interval's bounding sphere is tested against frustum | ||||||||||||
| * planes (or a fisheye cone). The pass writes the interval's splat count (or 0 if | ||||||||||||
| * culled) into a count buffer. | ||||||||||||
| * 2. Prefix sum: exclusive prefix sum over the count buffer produces output offsets. | ||||||||||||
| * The last element gives visibleCount. | ||||||||||||
| * 3. Scatter (compute): one workgroup per interval expands visible intervals into | ||||||||||||
| * compactedSplatIds (flat list of work-buffer pixel indices). | ||||||||||||
| * | ||||||||||||
| * Raster renderer — GPU sorting (WebGPU, {@link GSplatQuadRenderer}): | ||||||||||||
| * Uses shared steps 1-3 above, then: | ||||||||||||
| * 4. Generate sort keys: an indirect compute dispatch (visibleCount threads) reads each | ||||||||||||
| * compactedSplatIds[i] to look up the splat's depth and writes a sort key to keysBuffer. | ||||||||||||
| * 4. Radix sort: an indirect GPU radix sort over keysBuffer, with compactedSplatIds supplied | ||||||||||||
| * 5. Radix sort: an indirect GPU radix sort over keysBuffer, with compactedSplatIds supplied | ||||||||||||
| * as initial values, produces a buffer of sorted splat IDs directly. | ||||||||||||
| * 5. Render: the vertex shader reads sortedSplatIds[vertexId] → splatId. | ||||||||||||
| * 6. Render: the vertex shader reads sortedSplatIds[vertexId] → splatId. | ||||||||||||
| * | ||||||||||||
| * CPU sorting (WebGPU): | ||||||||||||
| * Raster renderer — CPU sorting (WebGPU and WebGL, {@link GSplatQuadRenderer}): | ||||||||||||
| * 1. Sort on worker: camera position and splat centers are sent to a web worker which | ||||||||||||
| * performs a radix sort and returns the sorted order as orderBuffer (storage buffer). | ||||||||||||
| * 2. Interval compaction with frustum culling (same as GPU path steps 1-2). | ||||||||||||
| * 3. Render: the vertex shader reads compactedSplatIds[vertexId] → splatId. | ||||||||||||
| * performs a counting sort and returns the sorted order as orderBuffer. | ||||||||||||
| * 2. Render: the vertex shader reads orderBuffer[vertexId] → splatId. | ||||||||||||
| * No culling or compaction is used. | ||||||||||||
| * | ||||||||||||
| * CPU sorting (WebGL): | ||||||||||||
| * 1. Sort on worker: same as the WebGPU CPU path, producing orderBuffer (texture). | ||||||||||||
| * 2. Render: the vertex shader reads orderBuffer[vertexId] → splatId directly. | ||||||||||||
| * No culling or compaction is available on WebGL. | ||||||||||||
| * Compute tiled renderer (WebGPU only, {@link GSplatComputeLocalRenderer}): | ||||||||||||
| * Uses shared steps 1-3 above, then runs a fully compute-based tiled pipeline: | ||||||||||||
| * project splats into a cache, bin into screen tiles, sort per-tile by depth, and rasterize | ||||||||||||
| * front-to-back. See {@link GSplatComputeLocalRenderer} for the full pass breakdown. | ||||||||||||
| * | ||||||||||||
| * @ignore | ||||||||||||
| */ | ||||||||||||
|
|
@@ -175,8 +179,8 @@ class GSplatManager { | |||||||||||
| indirectDrawSlot = -1; | ||||||||||||
|
|
||||||||||||
| /** | ||||||||||||
| * Indirect dispatch slot index for key gen (first of 2 consecutive slots). | ||||||||||||
| * Slot 0 = key gen (256 threads/workgroup), slot 1 = sort (1024 elements/workgroup). | ||||||||||||
| * Indirect dispatch slot index for key gen (first of 3 consecutive slots). | ||||||||||||
| * Slot +0 = key gen, slot +1 = sort, slot +2 = place-entries. | ||||||||||||
| * | ||||||||||||
| * @type {number} | ||||||||||||
| */ | ||||||||||||
|
|
@@ -1646,8 +1650,8 @@ class GSplatManager { | |||||||||||
| this.intervalCompaction.dispatchCompact(this.workBuffer.frustumCuller, numIntervals, totalActiveSplats, this.renderer.fisheyeProj.enabled); | ||||||||||||
|
|
||||||||||||
| // Extract the visible count from the prefix sum into sortElementCountBuffer. | ||||||||||||
| // writeIndirectArgs is the only path that does this; the indirect draw/dispatch | ||||||||||||
| // slots are unused by the local renderer but required by the API. | ||||||||||||
| // writeIndirectArgs is the only path that does this. The local renderer uses | ||||||||||||
| // dispatch slot +2 (place-entries) for indirect dispatch. | ||||||||||||
|
||||||||||||
| // writeIndirectArgs is the only path that does this. The local renderer uses | |
| // dispatch slot +2 (place-entries) for indirect dispatch. | |
| // writeIndirectArgs is the only path that does this. The local compute renderer | |
| // prepares its own indirect dispatch args in private buffers and does not use | |
| // a third slot in the shared indirect-dispatch buffer. |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,49 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Cooperative large-splat place-entries pass: one workgroup (256 threads) per large splat. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // The regular PlaceEntries pass skips large splats (flagged via the high bit of | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // splatPairCount). This pass picks them up and spreads each splat's pair writes across | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // 256 threads, eliminating the long tail caused by single threads looping over hundreds | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // of pairs with scattered tileEntries writes. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Reuses the same largeSplatIds buffer and indirect dispatch dimensions as | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // LargeTileCount — one workgroup per large splat. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| export const computeGsplatLocalPlaceEntriesLargeSource = /* wgsl */` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const WG_SIZE: u32 = 256u; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(0) var<storage, read> pairBuffer: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(1) var<storage, read> splatPairStart: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(2) var<storage, read> splatPairCount: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(3) var<storage, read> tileSplatCounts: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(4) var<storage, read_write> tileEntries: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(5) var<storage, read> largeSplatIds: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(6) var<storage, read> largeSplatCount: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @compute @workgroup_size(256) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fn main( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @builtin(workgroup_id) wgId: vec3u, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @builtin(num_workgroups) numWorkgroups: vec3u, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @builtin(local_invocation_index) lid: u32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let largeSplatIdx = wgId.y * numWorkgroups.x + wgId.x; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let numLarge = min(largeSplatCount[0], arrayLength(&largeSplatIds)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (largeSplatIdx >= numLarge) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let threadIdx = largeSplatIds[largeSplatIdx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let pairCount = splatPairCount[threadIdx] & 0x7FFFFFFFu; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (pairCount == 0u) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let start = splatPairStart[threadIdx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var j = lid; j < pairCount; j += WG_SIZE) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let packed = pairBuffer[start + j]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let tileIdx = packed >> 16u; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let localOff = packed & 0xFFFFu; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var j = lid; j < pairCount; j += WG_SIZE) { | |
| let packed = pairBuffer[start + j]; | |
| let tileIdx = packed >> 16u; | |
| let localOff = packed & 0xFFFFu; | |
| tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx; | |
| let pairBufferLen = arrayLength(&pairBuffer); | |
| let tileSplatCountsLen = arrayLength(&tileSplatCounts); | |
| let tileEntriesLen = arrayLength(&tileEntries); | |
| for (var j = lid; j < pairCount; j += WG_SIZE) { | |
| let pairIdx = start + j; | |
| if (pairIdx >= pairBufferLen) { | |
| continue; | |
| } | |
| let packed = pairBuffer[pairIdx]; | |
| let tileIdx = packed >> 16u; | |
| let localOff = packed & 0xFFFFu; | |
| if (tileIdx >= tileSplatCountsLen) { | |
| continue; | |
| } | |
| let entryIdx = tileSplatCounts[tileIdx] + localOff; | |
| if (entryIdx >= tileEntriesLen) { | |
| continue; | |
| } | |
| tileEntries[entryIdx] = threadIdx; |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,43 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Lightweight place-entries pass: reads (tileIdx, localOffset) pairs written by the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // count+pair-write pass and places splat indices into per-tile entry lists using the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // prefix-summed tile offsets. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Dispatched per-splat (same shape as the count pass). Each thread reads its pair range | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // via splatPairStart/splatPairCount and writes to tileEntries at deterministic positions. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| export const computeGsplatLocalPlaceEntriesSource = /* wgsl */` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(0) var<storage, read> pairBuffer: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(1) var<storage, read> splatPairStart: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(2) var<storage, read> splatPairCount: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(3) var<storage, read> tileSplatCounts: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(4) var<storage, read_write> tileEntries: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @group(0) @binding(5) var<storage, read> sortElementCount: array<u32>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @compute @workgroup_size(256) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(num_workgroups) numWorkgroups: vec3u) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let threadIdx = gid.y * (numWorkgroups.x * 256u) + gid.x; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let numVisible = sortElementCount[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (threadIdx >= numVisible) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let rawCount = splatPairCount[threadIdx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // High bit marks large splats handled by the cooperative LargePlaceEntries pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rawCount == 0u || (rawCount & 0x80000000u) != 0u) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let count = rawCount; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let start = splatPairStart[threadIdx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var j: u32 = 0u; j < count; j++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let packed = pairBuffer[start + j]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let tileIdx = packed >> 16u; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let localOff = packed & 0xFFFFu; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // tileSplatCounts has been prefix-summed, so it holds the start offset for each tile. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // localOff is the within-tile position assigned by atomicAdd during the count pass. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var j: u32 = 0u; j < count; j++) { | |
| let packed = pairBuffer[start + j]; | |
| let tileIdx = packed >> 16u; | |
| let localOff = packed & 0xFFFFu; | |
| // tileSplatCounts has been prefix-summed, so it holds the start offset for each tile. | |
| // localOff is the within-tile position assigned by atomicAdd during the count pass. | |
| tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx; | |
| let pairBufferLen = arrayLength(&pairBuffer); | |
| let tileCountsLen = arrayLength(&tileSplatCounts); | |
| let tileEntriesLen = arrayLength(&tileEntries); | |
| for (var j: u32 = 0u; j < count; j++) { | |
| let pairIndex = start + j; | |
| if (pairIndex >= pairBufferLen) { | |
| continue; | |
| } | |
| let packed = pairBuffer[pairIndex]; | |
| let tileIdx = packed >> 16u; | |
| let localOff = packed & 0xFFFFu; | |
| if (tileIdx >= tileCountsLen) { | |
| continue; | |
| } | |
| // tileSplatCounts has been prefix-summed, so it holds the start offset for each tile. | |
| // localOff is the within-tile position assigned by atomicAdd during the count pass. | |
| let entryIndex = tileSplatCounts[tileIdx] + localOff; | |
| if (entryIndex >= tileEntriesLen) { | |
| continue; | |
| } | |
| tileEntries[entryIndex] = threadIdx; |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc/comment says the indirect dispatch slot is the “first of 3 consecutive slots” (key gen, sort, place-entries), but
compute-gsplat-write-indirect-argsstill writes only two dispatch arg triplets (key gen + sort). The compute local renderer also builds its own indirect args buffers for count/place-entries rather than using the shared slot. Please update this comment to match the actual indirect-dispatch usage to avoid misleading future changes.