Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
428 changes: 380 additions & 48 deletions src/scene/gsplat-unified/gsplat-compute-local-renderer.js

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions src/scene/gsplat-unified/gsplat-local-dispatch-set.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ class GSplatLocalDispatchSet {
_countComputeFisheye = null;

/** @type {Compute} */
scatterCompute;
placeEntriesCompute;

/** @type {Compute} */
largeSplatCompute;

/** @type {Compute} */
largePlaceEntriesCompute;

/** @type {Compute} */
classifyCompute;
Expand All @@ -89,9 +95,6 @@ class GSplatLocalDispatchSet {
/** @type {StorageBuffer|null} */
_tileSplatCountsBuffer = null;

/** @type {StorageBuffer|null} */
_tileWriteCursorsBuffer = null;

/** @type {StorageBuffer|null} */
_smallTileListBuffer = null;

Expand Down Expand Up @@ -197,7 +200,6 @@ class GSplatLocalDispatchSet {
if (requiredTileSlots <= this._allocatedTileCapacity) return;

this._tileSplatCountsBuffer?.destroy();
this._tileWriteCursorsBuffer?.destroy();
this._smallTileListBuffer?.destroy();
this._largeTileListBuffer?.destroy();
this._largeTileOverflowBasesBuffer?.destroy();
Expand All @@ -209,7 +211,6 @@ class GSplatLocalDispatchSet {

this._allocatedTileCapacity = requiredTileSlots;
this._tileSplatCountsBuffer = new StorageBuffer(this.device, requiredTileSlots * 4, BUFFERUSAGE_COPY_DST | BUFFERUSAGE_COPY_SRC);
this._tileWriteCursorsBuffer = new StorageBuffer(this.device, numTiles * 4, BUFFERUSAGE_COPY_DST);
this._smallTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._largeTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._largeTileOverflowBasesBuffer = new StorageBuffer(this.device, numTiles * 4);
Expand Down Expand Up @@ -388,7 +389,6 @@ class GSplatLocalDispatchSet {
}

this._tileSplatCountsBuffer?.destroy();
this._tileWriteCursorsBuffer?.destroy();
this._smallTileListBuffer?.destroy();
this._largeTileListBuffer?.destroy();
this._largeTileOverflowBasesBuffer?.destroy();
Expand Down
50 changes: 27 additions & 23 deletions src/scene/gsplat-unified/gsplat-manager.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,34 @@ let _randomColorRaw = null;
* GSplatManager manages the rendering of splats using a work buffer, where all active splats are
* stored and rendered from.
*
* GPU sorting (WebGPU only):
* 1. [culling] Frustum cull: a fragment shader tests each bounding sphere against frustum
* planes and writes results into a bit-packed nodeVisibilityTexture (1 bit per sphere).
* 2. Interval compaction: operates on contiguous intervals of splats (one per octree node)
* rather than individual pixels. A cull/count pass writes each interval's splat count
* (or 0 if culled) into a count buffer. A prefix sum produces output offsets. A scatter
* pass expands visible intervals into compactedSplatIds (flat list of work-buffer pixel
* indices). The last prefix sum element gives visibleCount.
* 3. Generate sort keys: an indirect compute dispatch (visibleCount threads) reads each
* Shared culling + compaction (GPU sorting and compute renderer, WebGPU only):
* Interval compaction operates on contiguous intervals of splats (one per octree node).
* 1. Cull + count (compute): each interval's bounding sphere is tested against frustum
* planes (or a fisheye cone). The pass writes the interval's splat count (or 0 if
* culled) into a count buffer.
* 2. Prefix sum: exclusive prefix sum over the count buffer produces output offsets.
* The last element gives visibleCount.
* 3. Scatter (compute): one workgroup per interval expands visible intervals into
* compactedSplatIds (flat list of work-buffer pixel indices).
*
* Raster renderer — GPU sorting (WebGPU, {@link GSplatQuadRenderer}):
* Uses shared steps 1-3 above, then:
* 4. Generate sort keys: an indirect compute dispatch (visibleCount threads) reads each
* compactedSplatIds[i] to look up the splat's depth and writes a sort key to keysBuffer.
* 4. Radix sort: an indirect GPU radix sort over keysBuffer, with compactedSplatIds supplied
* 5. Radix sort: an indirect GPU radix sort over keysBuffer, with compactedSplatIds supplied
* as initial values, produces a buffer of sorted splat IDs directly.
* 5. Render: the vertex shader reads sortedSplatIds[vertexId] → splatId.
* 6. Render: the vertex shader reads sortedSplatIds[vertexId] → splatId.
*
* CPU sorting (WebGPU):
* Raster renderer — CPU sorting (WebGPU and WebGL, {@link GSplatQuadRenderer}):
* 1. Sort on worker: camera position and splat centers are sent to a web worker which
* performs a radix sort and returns the sorted order as orderBuffer (storage buffer).
* 2. Interval compaction with frustum culling (same as GPU path steps 1-2).
* 3. Render: the vertex shader reads compactedSplatIds[vertexId] → splatId.
* performs a counting sort and returns the sorted order as orderBuffer.
* 2. Render: the vertex shader reads orderBuffer[vertexId] → splatId.
* No culling or compaction is used.
*
* CPU sorting (WebGL):
* 1. Sort on worker: same as the WebGPU CPU path, producing orderBuffer (texture).
* 2. Render: the vertex shader reads orderBuffer[vertexId] → splatId directly.
* No culling or compaction is available on WebGL.
* Compute tiled renderer (WebGPU only, {@link GSplatComputeLocalRenderer}):
* Uses shared steps 1-3 above, then runs a fully compute-based tiled pipeline:
* project splats into a cache, bin into screen tiles, sort per-tile by depth, and rasterize
* front-to-back. See {@link GSplatComputeLocalRenderer} for the full pass breakdown.
*
* @ignore
*/
Expand Down Expand Up @@ -175,8 +179,8 @@ class GSplatManager {
indirectDrawSlot = -1;

/**
* Indirect dispatch slot index for key gen (first of 2 consecutive slots).
* Slot 0 = key gen (256 threads/workgroup), slot 1 = sort (1024 elements/workgroup).
* Indirect dispatch slot index for key gen (first of 3 consecutive slots).
* Slot +0 = key gen, slot +1 = sort, slot +2 = place-entries.
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc/comment says the indirect dispatch slot is the “first of 3 consecutive slots” (key gen, sort, place-entries), but compute-gsplat-write-indirect-args still writes only two dispatch arg triplets (key gen + sort). The compute local renderer also builds its own indirect args buffers for count/place-entries rather than using the shared slot. Please update this comment to match the actual indirect-dispatch usage to avoid misleading future changes.

Suggested change
* Indirect dispatch slot index for key gen (first of 3 consecutive slots).
* Slot +0 = key gen, slot +1 = sort, slot +2 = place-entries.
* Indirect dispatch slot index for GPU-sort indirect dispatch args.
* Slot +0 = key gen, slot +1 = sort.
* Place-entries/count indirect args are built separately by the compute
* local renderer and do not use this shared slot.

Copilot uses AI. Check for mistakes.
*
* @type {number}
*/
Expand Down Expand Up @@ -1646,8 +1650,8 @@ class GSplatManager {
this.intervalCompaction.dispatchCompact(this.workBuffer.frustumCuller, numIntervals, totalActiveSplats, this.renderer.fisheyeProj.enabled);

// Extract the visible count from the prefix sum into sortElementCountBuffer.
// writeIndirectArgs is the only path that does this; the indirect draw/dispatch
// slots are unused by the local renderer but required by the API.
// writeIndirectArgs is the only path that does this. The local renderer uses
// dispatch slot +2 (place-entries) for indirect dispatch.
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment states the local compute renderer uses “dispatch slot +2 (place-entries) for indirect dispatch”, but the compute local renderer now generates indirect dispatch args in its own private buffers (PlaceEntryPrep/LargeSplatPrep) and does not rely on a third slot in the shared indirect-dispatch buffer. Please adjust the comment so it matches the current implementation.

Suggested change
// writeIndirectArgs is the only path that does this. The local renderer uses
// dispatch slot +2 (place-entries) for indirect dispatch.
// writeIndirectArgs is the only path that does this. The local compute renderer
// prepares its own indirect dispatch args in private buffers and does not use
// a third slot in the shared indirect-dispatch buffer.

Copilot uses AI. Check for mistakes.
this.allocateAndWriteIntervalIndirectArgs(numIntervals);

const ic = /** @type {GSplatIntervalCompaction} */ (this.intervalCompaction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Cooperative large-splat place-entries pass: one workgroup (256 threads) per large splat.
//
// The regular PlaceEntries pass skips large splats (flagged via the high bit of
// splatPairCount). This pass picks them up and spreads each splat's pair writes across
// 256 threads, eliminating the long tail caused by single threads looping over hundreds
// of pairs with scattered tileEntries writes.
//
// Reuses the same largeSplatIds buffer and indirect dispatch dimensions as
// LargeTileCount — one workgroup per large splat.
export const computeGsplatLocalPlaceEntriesLargeSource = /* wgsl */`

const WG_SIZE: u32 = 256u;

@group(0) @binding(0) var<storage, read> pairBuffer: array<u32>;
@group(0) @binding(1) var<storage, read> splatPairStart: array<u32>;
@group(0) @binding(2) var<storage, read> splatPairCount: array<u32>;
@group(0) @binding(3) var<storage, read> tileSplatCounts: array<u32>;
@group(0) @binding(4) var<storage, read_write> tileEntries: array<u32>;
@group(0) @binding(5) var<storage, read> largeSplatIds: array<u32>;
@group(0) @binding(6) var<storage, read> largeSplatCount: array<u32>;

@compute @workgroup_size(256)
fn main(
@builtin(workgroup_id) wgId: vec3u,
@builtin(num_workgroups) numWorkgroups: vec3u,
@builtin(local_invocation_index) lid: u32
) {
let largeSplatIdx = wgId.y * numWorkgroups.x + wgId.x;
let numLarge = min(largeSplatCount[0], arrayLength(&largeSplatIds));
if (largeSplatIdx >= numLarge) {
return;
}

let threadIdx = largeSplatIds[largeSplatIdx];
let pairCount = splatPairCount[threadIdx] & 0x7FFFFFFFu;
if (pairCount == 0u) {
return;
}

let start = splatPairStart[threadIdx];

for (var j = lid; j < pairCount; j += WG_SIZE) {
let packed = pairBuffer[start + j];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;
tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx;
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same overflow risk as the non-cooperative PlaceEntries pass: if start + j is out of bounds for pairBuffer (or the computed tileEntries index exceeds its capacity), robust buffer access can yield zeros and then write incorrect entries into tile 0. Add bounds checks using arrayLength(&pairBuffer) / arrayLength(&tileEntries) (or a maxEntries uniform) so large-splat processing skips overflowed pairs instead of corrupting output.

Suggested change
for (var j = lid; j < pairCount; j += WG_SIZE) {
let packed = pairBuffer[start + j];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;
tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx;
let pairBufferLen = arrayLength(&pairBuffer);
let tileSplatCountsLen = arrayLength(&tileSplatCounts);
let tileEntriesLen = arrayLength(&tileEntries);
for (var j = lid; j < pairCount; j += WG_SIZE) {
let pairIdx = start + j;
if (pairIdx >= pairBufferLen) {
continue;
}
let packed = pairBuffer[pairIdx];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;
if (tileIdx >= tileSplatCountsLen) {
continue;
}
let entryIdx = tileSplatCounts[tileIdx] + localOff;
if (entryIdx >= tileEntriesLen) {
continue;
}
tileEntries[entryIdx] = threadIdx;

Copilot uses AI. Check for mistakes.
}
}
`;
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Lightweight place-entries pass: reads (tileIdx, localOffset) pairs written by the
// count+pair-write pass and places splat indices into per-tile entry lists using the
// prefix-summed tile offsets.
//
// Dispatched per-splat (same shape as the count pass). Each thread reads its pair range
// via splatPairStart/splatPairCount and writes to tileEntries at deterministic positions.
export const computeGsplatLocalPlaceEntriesSource = /* wgsl */`

@group(0) @binding(0) var<storage, read> pairBuffer: array<u32>;
@group(0) @binding(1) var<storage, read> splatPairStart: array<u32>;
@group(0) @binding(2) var<storage, read> splatPairCount: array<u32>;
@group(0) @binding(3) var<storage, read> tileSplatCounts: array<u32>;
@group(0) @binding(4) var<storage, read_write> tileEntries: array<u32>;
@group(0) @binding(5) var<storage, read> sortElementCount: array<u32>;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(num_workgroups) numWorkgroups: vec3u) {
let threadIdx = gid.y * (numWorkgroups.x * 256u) + gid.x;
let numVisible = sortElementCount[0];
if (threadIdx >= numVisible) {
return;
}

let rawCount = splatPairCount[threadIdx];
// High bit marks large splats handled by the cooperative LargePlaceEntries pass
if (rawCount == 0u || (rawCount & 0x80000000u) != 0u) {
return;
}
let count = rawCount;

let start = splatPairStart[threadIdx];

for (var j: u32 = 0u; j < count; j++) {
let packed = pairBuffer[start + j];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;

// tileSplatCounts has been prefix-summed, so it holds the start offset for each tile.
// localOff is the within-tile position assigned by atomicAdd during the count pass.
tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx;
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass assumes start + count stays within pairBuffer and that the resulting tileEntries[tileOffset + localOff] index is in-bounds. If globalPairCounter ever allocates past pairBuffer capacity (or if tile prefix-sum total exceeds tileEntries capacity), WebGPU’s robust buffer access can turn out-of-bounds pairBuffer reads into zeros, which then corrupts tile 0’s entries instead of just dropping work. Add explicit bounds checks using arrayLength(&pairBuffer) / arrayLength(&tileEntries) (or pass a maxEntries uniform like the old scatter path) so overflow degrades by skipping writes rather than writing incorrect entries.

Suggested change
for (var j: u32 = 0u; j < count; j++) {
let packed = pairBuffer[start + j];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;
// tileSplatCounts has been prefix-summed, so it holds the start offset for each tile.
// localOff is the within-tile position assigned by atomicAdd during the count pass.
tileEntries[tileSplatCounts[tileIdx] + localOff] = threadIdx;
let pairBufferLen = arrayLength(&pairBuffer);
let tileCountsLen = arrayLength(&tileSplatCounts);
let tileEntriesLen = arrayLength(&tileEntries);
for (var j: u32 = 0u; j < count; j++) {
let pairIndex = start + j;
if (pairIndex >= pairBufferLen) {
continue;
}
let packed = pairBuffer[pairIndex];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;
if (tileIdx >= tileCountsLen) {
continue;
}
// tileSplatCounts has been prefix-summed, so it holds the start offset for each tile.
// localOff is the within-tile position assigned by atomicAdd during the count pass.
let entryIndex = tileSplatCounts[tileIdx] + localOff;
if (entryIndex >= tileEntriesLen) {
continue;
}
tileEntries[entryIndex] = threadIdx;

Copilot uses AI. Check for mistakes.
}
}
`;

This file was deleted.

Loading