Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
mode: BroadcastMode,
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
dataSize: SQLMetric,
buildThreads: SQLMetric): BuildSideRelation = {

val (buildKeys, isNullAware) = mode match {
case mode1: HashedRelationBroadcastMode =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,8 @@ class VeloxMetricsApi extends MetricsApi with Logging {
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"),
"buildThreads" -> SQLMetrics.createMetric(sparkContext, "build threads")
)

override def genColumnarSubqueryBroadcastMetrics(
Expand Down Expand Up @@ -667,7 +668,10 @@ class VeloxMetricsApi extends MetricsApi with Logging {
"numOutputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of output bytes"),
"loadLazyVectorTime" -> SQLMetrics.createNanoTimingMetric(
sparkContext,
"time of loading lazy vectors")
"time of loading lazy vectors"),
"buildHashTableTime" -> SQLMetrics.createTimingMetric(
sparkContext,
"time to build hash table")
)

override def genHashJoinTransformerMetricsUpdater(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
mode: BroadcastMode,
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
dataSize: SQLMetric,
buildThreads: SQLMetric): BuildSideRelation = {

val buildKeys = mode match {
case mode1: HashedRelationBroadcastMode =>
Expand Down Expand Up @@ -851,22 +852,31 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
numOutputRows += serialized.map(_.numRows).sum
dataSize += rawSize

val rawThreads =
math
.ceil(dataSize.value.toDouble / VeloxConfig.get.veloxBroadcastHashTableBuildTargetBytes)
.toInt
val buildThreadsValue = if (rawThreads < 1) 1 else rawThreads
buildThreads += buildThreadsValue

if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
UnsafeColumnarBuildSideRelation(
newOutput,
serialized.flatMap(_.offHeapData().asScala),
mode,
newBuildKeys,
offload)
offload,
buildThreadsValue)
}
} else {
ColumnarBuildSideRelation(
newOutput,
serialized.flatMap(_.onHeapData().asScala).toArray,
mode,
newBuildKeys,
offload)
offload,
buildThreadsValue)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class VeloxConfig(conf: SQLConf) extends GlutenConfig(conf) {
def enableBroadcastBuildOncePerExecutor: Boolean =
getConf(VELOX_BROADCAST_BUILD_HASHTABLE_ONCE_PER_EXECUTOR)

def veloxBroadcastHashTableBuildThreads: Int =
getConf(COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS)
def veloxBroadcastHashTableBuildTargetBytes: Long =
getConf(COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_TARGET_BYTES)

def veloxOrcScanEnabled: Boolean =
getConf(VELOX_ORC_SCAN_ENABLED)
Expand Down Expand Up @@ -206,12 +206,11 @@ object VeloxConfig extends ConfigRegistry {
.intConf
.createOptional

val COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_THREADS =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.broadcastHashTableBuildThreads")
.doc("The number of threads used to build the broadcast hash table. " +
"If not set or set to 0, it will use the default number of threads (available processors).")
.intConf
.createWithDefault(1)
val COLUMNAR_VELOX_BROADCAST_HASH_TABLE_BUILD_TARGET_BYTES =
buildStaticConf("spark.gluten.velox.broadcast.build.targetBytesPerThread")
.doc("It is used to calculate the number of hash table build threads.")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("32MB")

val COLUMNAR_VELOX_ASYNC_TIMEOUT =
buildStaticConf("spark.gluten.sql.columnar.backend.velox.asyncTimeoutOnTaskStopping")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

import io.substrait.proto.JoinRel
Expand Down Expand Up @@ -158,7 +159,7 @@ case class BroadcastHashJoinExecTransformer(
buildBroadcastTableId,
isNullAwareAntiJoin,
bloomFilterPushdownSize,
VeloxConfig.get.veloxBroadcastHashTableBuildThreads
metrics.get("buildHashTableTime")
)
val broadcastRDD = VeloxBroadcastBuildSideRDD(sparkContext, broadcast, context)
// FIXME: Do we have to make build side a RDD?
Expand All @@ -176,4 +177,4 @@ case class BroadcastHashJoinContext(
buildHashTableId: String,
isNullAwareAntiJoin: Boolean = false,
bloomFilterPushdownSize: Long,
broadcastHashTableBuildThreads: Int)
buildHashTableTimeMetric: Option[SQLMetric] = None)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ object ColumnarBuildSideRelation {
batches: Array[Array[Byte]],
mode: BroadcastMode,
newBuildKeys: Seq[Expression] = Seq.empty,
offload: Boolean = false): ColumnarBuildSideRelation = {
offload: Boolean = false,
buildThreads: Int = 1): ColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
// Bind each key to the build-side output so simple cols become BoundReference
Expand All @@ -66,7 +67,8 @@ object ColumnarBuildSideRelation {
batches,
BroadcastModeUtils.toSafe(boundMode),
newBuildKeys,
offload)
offload,
buildThreads)
}
}

Expand All @@ -75,7 +77,8 @@ case class ColumnarBuildSideRelation(
batches: Array[Array[Byte]],
safeBroadcastMode: SafeBroadcastMode,
newBuildKeys: Seq[Expression],
offload: Boolean)
offload: Boolean,
buildThreads: Int)
extends BuildSideRelation
with Logging
with KnownSizeEstimation {
Expand Down Expand Up @@ -156,6 +159,7 @@ case class ColumnarBuildSideRelation(
broadcastContext: BroadcastHashJoinContext): (Long, ColumnarBuildSideRelation) =
synchronized {
if (hashTableData == 0) {
val startTime = System.nanoTime()
val runtime = Runtimes.contextInstance(
BackendsApiManager.getBackendName,
"ColumnarBuildSideRelation#buildHashTable")
Expand Down Expand Up @@ -215,10 +219,15 @@ case class ColumnarBuildSideRelation(
SubstraitUtil.toNameStruct(newOutput).toByteArray,
broadcastContext.isNullAwareAntiJoin,
broadcastContext.bloomFilterPushdownSize,
broadcastContext.broadcastHashTableBuildThreads
buildThreads
)

jniWrapper.close(serializeHandle)

// Update build hash table time metric
val elapsedTime = System.nanoTime() - startTime
broadcastContext.buildHashTableTimeMetric.foreach(_ += elapsedTime / 1000000)

(hashTableData, this)
} else {
(HashJoinBuilder.cloneHashTable(hashTableData), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ object UnsafeColumnarBuildSideRelation {
batches: Seq[UnsafeByteArray],
mode: BroadcastMode,
newBuildKeys: Seq[Expression] = Seq.empty,
offload: Boolean = false): UnsafeColumnarBuildSideRelation = {
offload: Boolean = false,
buildThreads: Int = 1): UnsafeColumnarBuildSideRelation = {
val boundMode = mode match {
case HashedRelationBroadcastMode(keys, isNullAware) =>
// Bind each key to the build-side output so simple cols become BoundReference
Expand All @@ -71,7 +72,8 @@ object UnsafeColumnarBuildSideRelation {
batches,
BroadcastModeUtils.toSafe(boundMode),
newBuildKeys,
offload)
offload,
buildThreads)
}
}

Expand All @@ -91,7 +93,8 @@ class UnsafeColumnarBuildSideRelation(
private var batches: Seq[UnsafeByteArray],
private var safeBroadcastMode: SafeBroadcastMode,
private var newBuildKeys: Seq[Expression],
private var offload: Boolean)
private var offload: Boolean,
private var buildThreads: Int)
extends BuildSideRelation
with Externalizable
with Logging
Expand All @@ -113,7 +116,7 @@ class UnsafeColumnarBuildSideRelation(

/** needed for serialization. */
def this() = {
this(null, null, null, Seq.empty, false)
this(null, null, null, Seq.empty, false, 1)
}

private[unsafe] def getBatches(): Seq[UnsafeByteArray] = {
Expand All @@ -125,6 +128,7 @@ class UnsafeColumnarBuildSideRelation(
def buildHashTable(broadcastContext: BroadcastHashJoinContext): (Long, BuildSideRelation) =
synchronized {
if (hashTableData == 0) {
val startTime = System.nanoTime()
val runtime = Runtimes.contextInstance(
BackendsApiManager.getBackendName,
"UnsafeColumnarBuildSideRelation#buildHashTable")
Expand Down Expand Up @@ -185,10 +189,15 @@ class UnsafeColumnarBuildSideRelation(
SubstraitUtil.toNameStruct(newOutput).toByteArray,
broadcastContext.isNullAwareAntiJoin,
broadcastContext.bloomFilterPushdownSize,
broadcastContext.broadcastHashTableBuildThreads
buildThreads
)

jniWrapper.close(serializeHandle)

// Update build hash table time metric
val elapsedTime = System.nanoTime() - startTime
broadcastContext.buildHashTableTimeMetric.foreach(_ += elapsedTime / 1000000)

(hashTableData, this)
} else {
(HashJoinBuilder.cloneHashTable(hashTableData), null)
Expand All @@ -205,6 +214,7 @@ class UnsafeColumnarBuildSideRelation(
out.writeObject(batches.toArray)
out.writeObject(newBuildKeys)
out.writeBoolean(offload)
out.writeInt(buildThreads)
}

override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
Expand All @@ -213,6 +223,7 @@ class UnsafeColumnarBuildSideRelation(
kryo.writeClassAndObject(out, batches.toArray)
kryo.writeClassAndObject(out, newBuildKeys)
out.writeBoolean(offload)
out.writeInt(buildThreads)
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -221,6 +232,7 @@ class UnsafeColumnarBuildSideRelation(
batches = in.readObject().asInstanceOf[Array[UnsafeByteArray]].toSeq
newBuildKeys = in.readObject().asInstanceOf[Seq[Expression]]
offload = in.readBoolean()
buildThreads = in.readInt()
}

override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
Expand All @@ -229,6 +241,7 @@ class UnsafeColumnarBuildSideRelation(
batches = kryo.readClassAndObject(in).asInstanceOf[Array[UnsafeByteArray]].toSeq
newBuildKeys = kryo.readClassAndObject(in).asInstanceOf[Seq[Expression]]
offload = in.readBoolean()
buildThreads = in.readInt()
}

private def transformProjection: UnsafeProjection = safeBroadcastMode match {
Expand Down
33 changes: 15 additions & 18 deletions cpp/velox/jni/VeloxJniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <jni/JniCommon.h>
#include <velox/connectors/hive/PartitionIdGenerator.h>
#include <velox/exec/OperatorUtils.h>
#include <folly/futures/Future.h>
#include <folly/executors/CPUThreadPoolExecutor.h>

#include <exception>
#include "JniUdf.h"
Expand Down Expand Up @@ -946,7 +948,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
jbyteArray namedStruct,
jboolean isNullAwareAntiJoin,
jlong bloomFilterPushdownSize,
jint broadcastHashTableBuildThreads) {
jint numThreads) {
JNI_METHOD_START
const auto hashTableId = jStringToCString(env, tableId);

Expand Down Expand Up @@ -985,17 +987,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
cb.push_back(ObjectStore::retrieve<ColumnarBatch>(handle));
}

size_t maxThreads = broadcastHashTableBuildThreads > 0
? std::min((size_t)broadcastHashTableBuildThreads, (size_t)32)
: std::min((size_t)std::thread::hardware_concurrency(), (size_t)32);

// Heuristic: Each thread should process at least a certain number of batches to justify parallelism overhead.
// 32 batches is roughly 128k rows, which is a reasonable granularity for a single thread.
constexpr size_t kMinBatchesPerThread = 32;
size_t numThreads = std::min(maxThreads, (handleCount + kMinBatchesPerThread - 1) / kMinBatchesPerThread);
numThreads = std::max((size_t)1, numThreads);

if (numThreads <= 1) {
if (numThreads == 1) {
auto builder = nativeHashTableBuild(
hashJoinKeys,
names,
Expand All @@ -1020,16 +1012,20 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
return gluten::getHashTableObjStore()->save(builder);
}

std::vector<std::thread> threads;

// Use thread pool (executor) instead of creating threads directly
auto executor = VeloxBackend::get()->executor();

std::vector<std::shared_ptr<gluten::HashTableBuilder>> hashTableBuilders(numThreads);
std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>> otherTables(numThreads);
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(numThreads);

for (size_t t = 0; t < numThreads; ++t) {
size_t start = (handleCount * t) / numThreads;
size_t end = (handleCount * (t + 1)) / numThreads;

threads.emplace_back([&, t, start, end]() {
// Submit task to thread pool
auto future = folly::via(executor, [&, t, start, end]() {
std::vector<std::shared_ptr<gluten::ColumnarBatch>> threadBatches;
for (size_t i = start; i < end; ++i) {
threadBatches.push_back(cb[i]);
Expand All @@ -1050,11 +1046,12 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
hashTableBuilders[t] = std::move(builder);
otherTables[t] = std::move(hashTableBuilders[t]->uniqueTable());
});

futures.push_back(std::move(future));
}

for (auto& thread : threads) {
thread.join();
}
// Wait for all tasks to complete
folly::collectAll(futures).wait();

auto mainTable = std::move(otherTables[0]);
std::vector<std::unique_ptr<facebook::velox::exec::BaseHashTable>> tables;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ trait SparkPlanExecApi {
mode: BroadcastMode,
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation
dataSize: SQLMetric,
buildThreads: SQLMetric = null): BuildSideRelation

def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
mode.canonicalized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
mode,
child,
longMetric("numOutputRows"),
longMetric("dataSize"))
longMetric("dataSize"),
longMetric("buildThreads"))
}

val broadcasted = GlutenTimeMetric.millis(longMetric("broadcastTime")) {
Expand Down
Loading