Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
442 changes: 392 additions & 50 deletions src/scene/gsplat-unified/gsplat-compute-local-renderer.js

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions src/scene/gsplat-unified/gsplat-local-dispatch-set.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ class GSplatLocalDispatchSet {
_countComputeFisheye = null;

/** @type {Compute} */
scatterCompute;
placeEntriesCompute;

/** @type {Compute} */
largeSplatCompute;

/** @type {Compute} */
largePlaceEntriesCompute;

/** @type {Compute} */
classifyCompute;
Expand All @@ -89,9 +95,6 @@ class GSplatLocalDispatchSet {
/** @type {StorageBuffer|null} */
_tileSplatCountsBuffer = null;

/** @type {StorageBuffer|null} */
_tileWriteCursorsBuffer = null;

/** @type {StorageBuffer|null} */
_smallTileListBuffer = null;

Expand Down Expand Up @@ -197,7 +200,6 @@ class GSplatLocalDispatchSet {
if (requiredTileSlots <= this._allocatedTileCapacity) return;

this._tileSplatCountsBuffer?.destroy();
this._tileWriteCursorsBuffer?.destroy();
this._smallTileListBuffer?.destroy();
this._largeTileListBuffer?.destroy();
this._largeTileOverflowBasesBuffer?.destroy();
Expand All @@ -209,7 +211,6 @@ class GSplatLocalDispatchSet {

this._allocatedTileCapacity = requiredTileSlots;
this._tileSplatCountsBuffer = new StorageBuffer(this.device, requiredTileSlots * 4, BUFFERUSAGE_COPY_DST | BUFFERUSAGE_COPY_SRC);
this._tileWriteCursorsBuffer = new StorageBuffer(this.device, numTiles * 4, BUFFERUSAGE_COPY_DST);
this._smallTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._largeTileListBuffer = new StorageBuffer(this.device, numTiles * 4);
this._largeTileOverflowBasesBuffer = new StorageBuffer(this.device, numTiles * 4);
Expand Down Expand Up @@ -388,7 +389,6 @@ class GSplatLocalDispatchSet {
}

this._tileSplatCountsBuffer?.destroy();
this._tileWriteCursorsBuffer?.destroy();
this._smallTileListBuffer?.destroy();
this._largeTileListBuffer?.destroy();
this._largeTileOverflowBasesBuffer?.destroy();
Expand Down
51 changes: 28 additions & 23 deletions src/scene/gsplat-unified/gsplat-manager.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,34 @@ let _randomColorRaw = null;
* GSplatManager manages the rendering of splats using a work buffer, where all active splats are
* stored and rendered from.
*
* GPU sorting (WebGPU only):
* 1. [culling] Frustum cull: a fragment shader tests each bounding sphere against frustum
* planes and writes results into a bit-packed nodeVisibilityTexture (1 bit per sphere).
* 2. Interval compaction: operates on contiguous intervals of splats (one per octree node)
* rather than individual pixels. A cull/count pass writes each interval's splat count
* (or 0 if culled) into a count buffer. A prefix sum produces output offsets. A scatter
* pass expands visible intervals into compactedSplatIds (flat list of work-buffer pixel
* indices). The last prefix sum element gives visibleCount.
* 3. Generate sort keys: an indirect compute dispatch (visibleCount threads) reads each
* Shared culling + compaction (GPU sorting and compute renderer, WebGPU only):
* Interval compaction operates on contiguous intervals of splats (one per octree node).
* 1. Cull + count (compute): each interval's bounding sphere is tested against frustum
* planes (or a fisheye cone). The pass writes the interval's splat count (or 0 if
* culled) into a count buffer.
* 2. Prefix sum: exclusive prefix sum over the count buffer produces output offsets.
* The last element gives visibleCount.
* 3. Scatter (compute): one workgroup per interval expands visible intervals into
* compactedSplatIds (flat list of work-buffer pixel indices).
*
* Raster renderer — GPU sorting (WebGPU, {@link GSplatQuadRenderer}):
* Uses shared steps 1-3 above, then:
* 4. Generate sort keys: an indirect compute dispatch (visibleCount threads) reads each
* compactedSplatIds[i] to look up the splat's depth and writes a sort key to keysBuffer.
* 4. Radix sort: an indirect GPU radix sort over keysBuffer, with compactedSplatIds supplied
* 5. Radix sort: an indirect GPU radix sort over keysBuffer, with compactedSplatIds supplied
* as initial values, produces a buffer of sorted splat IDs directly.
* 5. Render: the vertex shader reads sortedSplatIds[vertexId] → splatId.
* 6. Render: the vertex shader reads sortedSplatIds[vertexId] → splatId.
*
* CPU sorting (WebGPU):
* Raster renderer — CPU sorting (WebGPU and WebGL, {@link GSplatQuadRenderer}):
* 1. Sort on worker: camera position and splat centers are sent to a web worker which
* performs a radix sort and returns the sorted order as orderBuffer (storage buffer).
* 2. Interval compaction with frustum culling (same as GPU path steps 1-2).
* 3. Render: the vertex shader reads compactedSplatIds[vertexId] → splatId.
* performs a counting sort and returns the sorted order as orderBuffer.
* 2. Render: the vertex shader reads orderBuffer[vertexId] → splatId.
* No culling or compaction is used.
*
* CPU sorting (WebGL):
* 1. Sort on worker: same as the WebGPU CPU path, producing orderBuffer (texture).
* 2. Render: the vertex shader reads orderBuffer[vertexId] → splatId directly.
* No culling or compaction is available on WebGL.
* Compute tiled renderer (WebGPU only, {@link GSplatComputeLocalRenderer}):
* Uses shared steps 1-3 above, then runs a fully compute-based tiled pipeline:
* project splats into a cache, bin into screen tiles, sort per-tile by depth, and rasterize
* front-to-back. See {@link GSplatComputeLocalRenderer} for the full pass breakdown.
*
* @ignore
*/
Expand Down Expand Up @@ -175,8 +179,9 @@ class GSplatManager {
indirectDrawSlot = -1;

/**
* Indirect dispatch slot index for key gen (first of 2 consecutive slots).
* Slot 0 = key gen (256 threads/workgroup), slot 1 = sort (1024 elements/workgroup).
* Indirect dispatch slot index for GPU-sort indirect dispatch args.
* Slot +0 = key gen, slot +1 = sort. The compute local renderer builds
* its own indirect args in private buffers and does not use these slots.
*
* @type {number}
*/
Expand Down Expand Up @@ -1646,8 +1651,8 @@ class GSplatManager {
this.intervalCompaction.dispatchCompact(this.workBuffer.frustumCuller, numIntervals, totalActiveSplats, this.renderer.fisheyeProj.enabled);

// Extract the visible count from the prefix sum into sortElementCountBuffer.
// writeIndirectArgs is the only path that does this; the indirect draw/dispatch
// slots are unused by the local renderer but required by the API.
// writeIndirectArgs is the only path that does this. The local compute renderer
// prepares its own indirect dispatch args in private buffers.
this.allocateAndWriteIntervalIndirectArgs(numIntervals);

const ic = /** @type {GSplatIntervalCompaction} */ (this.intervalCompaction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Cooperative large-splat place-entries pass: one workgroup (256 threads) per large splat.
//
// The regular PlaceEntries pass skips large splats (flagged via the high bit of
// splatPairCount). This pass picks them up and spreads each splat's pair writes across
// 256 threads, eliminating the long tail caused by single threads looping over hundreds
// of pairs with scattered tileEntries writes.
//
// Reuses the same largeSplatIds buffer and indirect dispatch dimensions as
// LargeTileCount — one workgroup per large splat.
export const computeGsplatLocalPlaceEntriesLargeSource = /* wgsl */`

const WG_SIZE: u32 = 256u;

@group(0) @binding(0) var<storage, read> pairBuffer: array<u32>;
@group(0) @binding(1) var<storage, read> splatPairStart: array<u32>;
@group(0) @binding(2) var<storage, read> splatPairCount: array<u32>;
@group(0) @binding(3) var<storage, read> tileSplatCounts: array<u32>;
@group(0) @binding(4) var<storage, read_write> tileEntries: array<u32>;
@group(0) @binding(5) var<storage, read> largeSplatIds: array<u32>;
@group(0) @binding(6) var<storage, read> largeSplatCount: array<u32>;

@compute @workgroup_size(256)
fn main(
@builtin(workgroup_id) wgId: vec3u,
@builtin(num_workgroups) numWorkgroups: vec3u,
@builtin(local_invocation_index) lid: u32
) {
let largeSplatIdx = wgId.y * numWorkgroups.x + wgId.x;
let numLarge = min(largeSplatCount[0], arrayLength(&largeSplatIds));
if (largeSplatIdx >= numLarge) {
return;
}

let threadIdx = largeSplatIds[largeSplatIdx];
let pairCount = splatPairCount[threadIdx] & 0x7FFFFFFFu;
if (pairCount == 0u) {
return;
}

let start = splatPairStart[threadIdx];
let pairLen = arrayLength(&pairBuffer);
let tileEntriesLen = arrayLength(&tileEntries);

for (var j = lid; j < pairCount; j += WG_SIZE) {
let pairIdx = start + j;
if (pairIdx >= pairLen) { break; }

let packed = pairBuffer[pairIdx];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;

let entryIdx = tileSplatCounts[tileIdx] + localOff;
if (entryIdx < tileEntriesLen) {
tileEntries[entryIdx] = threadIdx;
}
}
}
`;
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Lightweight place-entries pass: reads (tileIdx, localOffset) pairs written by the
// count+pair-write pass and places splat indices into per-tile entry lists using the
// prefix-summed tile offsets.
//
// Dispatched per-splat (same shape as the count pass). Each thread reads its pair range
// via splatPairStart/splatPairCount and writes to tileEntries at deterministic positions.
export const computeGsplatLocalPlaceEntriesSource = /* wgsl */`

@group(0) @binding(0) var<storage, read> pairBuffer: array<u32>;
@group(0) @binding(1) var<storage, read> splatPairStart: array<u32>;
@group(0) @binding(2) var<storage, read> splatPairCount: array<u32>;
@group(0) @binding(3) var<storage, read> tileSplatCounts: array<u32>;
@group(0) @binding(4) var<storage, read_write> tileEntries: array<u32>;
@group(0) @binding(5) var<storage, read> sortElementCount: array<u32>;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3u, @builtin(num_workgroups) numWorkgroups: vec3u) {
let threadIdx = gid.y * (numWorkgroups.x * 256u) + gid.x;
let numVisible = sortElementCount[0];
if (threadIdx >= numVisible) {
return;
}

let rawCount = splatPairCount[threadIdx];
// High bit marks large splats handled by the cooperative LargePlaceEntries pass
if (rawCount == 0u || (rawCount & 0x80000000u) != 0u) {
return;
}
let count = rawCount;

let start = splatPairStart[threadIdx];
let pairLen = arrayLength(&pairBuffer);
let tileEntriesLen = arrayLength(&tileEntries);

for (var j: u32 = 0u; j < count; j++) {
let pairIdx = start + j;
if (pairIdx >= pairLen) { break; }

let packed = pairBuffer[pairIdx];
let tileIdx = packed >> 16u;
let localOff = packed & 0xFFFFu;

// tileSplatCounts has been prefix-summed, so it holds the start offset for each tile.
// localOff is the within-tile position assigned by atomicAdd during the count pass.
let entryIdx = tileSplatCounts[tileIdx] + localOff;
if (entryIdx < tileEntriesLen) {
tileEntries[entryIdx] = threadIdx;
}
}
}
`;

This file was deleted.

Loading