diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 6d59e114d39..ec0e692aaf4 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -566,12 +566,6 @@ class QueryConfig { static constexpr const char* kSparkJsonIgnoreNullFields = "spark.json_ignore_null_fields"; - /// If true, collect_list aggregate function will ignore nulls in the input. - /// Defaults to true to match Spark's default behavior. Set to false to - /// include nulls (RESPECT NULLS). Introduced in Spark 4.2 (SPARK-55256). - static constexpr const char* kSparkCollectListIgnoreNulls = - "spark.collect_list.ignore_nulls"; - /// The number of local parallel table writer operators per task. static constexpr const char* kTaskWriterCount = "task_writer_count"; @@ -1394,10 +1388,6 @@ class QueryConfig { return get(kSparkJsonIgnoreNullFields, true); } - bool sparkCollectListIgnoreNulls() const { - return get(kSparkCollectListIgnoreNulls, true); - } - bool exprTrackCpuUsage() const { return get(kExprTrackCpuUsage, false); } diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index 0ba1c32fc10..a6c1aca3279 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -54,15 +54,15 @@ General Aggregate Functions ``hash`` cannot be null. -.. spark:function:: collect_list(x) -> array<[same as x]> +.. spark:function:: collect_list(x [, ignoreNulls]) -> array<[same as x]> - Returns an array created from the input ``x`` elements. By default, - ignores null inputs and returns an empty array when all inputs are null. + Returns an array created from the input ``x`` elements. + When ``ignoreNulls`` is ``true`` (default), null inputs are excluded and + an empty array is returned when all inputs are null. - When the configuration property ``spark.collect_list.ignore_nulls`` is set - to ``false``, null values are included in the output array (RESPECT NULLS - behavior). In this mode, an all-null input produces an array of nulls - instead of an empty array. + When ``ignoreNulls`` is ``false`` (RESPECT NULLS), null values are included + in the output array. In this mode, an all-null input produces an array of + nulls instead of an empty array. .. spark:function:: collect_set(x [, ignoreNulls]) -> array<[same as x]> diff --git a/velox/exec/SimpleAggregateAdapter.h b/velox/exec/SimpleAggregateAdapter.h index 24a35955e5e..f9dd1b4ca2b 100644 --- a/velox/exec/SimpleAggregateAdapter.h +++ b/velox/exec/SimpleAggregateAdapter.h @@ -762,10 +762,12 @@ class SimpleAggregateAdapter : public Aggregate { } } + protected: + std::unique_ptr fn_; + + private: std::vector inputDecoded_; DecodedVector intermediateDecoded_; - - std::unique_ptr fn_; }; } // namespace facebook::velox::exec diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp index 4e779a0b5e8..1afa58ee141 100644 --- a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -18,6 +18,7 @@ #include "velox/exec/SimpleAggregateAdapter.h" #include "velox/functions/lib/aggregates/ValueList.h" +#include "velox/vector/ConstantVector.h" using namespace facebook::velox::aggregate; using namespace facebook::velox::exec; @@ -44,14 +45,6 @@ class CollectListAggregate { // aggregation uses the accumulator path, which correctly respects the config. bool ignoreNulls_{true}; - void initialize( - core::AggregationNode::Step /*step*/, - const std::vector& /*argTypes*/, - const TypePtr& /*resultType*/, - const core::QueryConfig& config) { - ignoreNulls_ = config.sparkCollectListIgnoreNulls(); - } - struct AccumulatorType { ValueList elements_; @@ -114,16 +107,40 @@ class CollectListAggregate { }; }; +// Adapter that overrides setConstantInputs to read the ignoreNulls flag. +class CollectListAdapter : public SimpleAggregateAdapter { + public: + using SimpleAggregateAdapter::SimpleAggregateAdapter; + + void setConstantInputs( + const std::vector& constantInputs) override { + if (constantInputs.size() >= 2 && constantInputs[1] != nullptr && + !constantInputs[1]->isNullAt(0)) { + fn_->ignoreNulls_ = + constantInputs[1]->as>()->valueAt(0); + } + } +}; + AggregateRegistrationResult registerCollectList( const std::string& name, bool withCompanionFunctions, bool overwrite) { std::vector> signatures{ + // collect_list(E) -> array(E): default ignoreNulls=true. + exec::AggregateFunctionSignatureBuilder() + .typeVariable("E") + .returnType("array(E)") + .intermediateType("array(E)") + .argumentType("E") + .build(), + // collect_list(E, ignoreNulls) -> array(E): explicit flag. exec::AggregateFunctionSignatureBuilder() .typeVariable("E") .returnType("array(E)") .intermediateType("array(E)") .argumentType("E") + .constantArgumentType("boolean") .build()}; return exec::registerAggregateFunction( name, @@ -133,9 +150,7 @@ AggregateRegistrationResult registerCollectList( const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& config) -> std::unique_ptr { - VELOX_CHECK_EQ( - argTypes.size(), 1, "{} takes at most one argument", name); - return std::make_unique>( + return std::make_unique( step, argTypes, resultType, &config); }, withCompanionFunctions, diff --git a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp index a5935df9b43..58dd4366d19 100644 --- a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp @@ -124,24 +124,28 @@ TEST_F(CollectListAggregateTest, allNullsInput) { {}); } -std::unordered_map makeConfig(bool ignoreNulls) { - return {{"spark.collect_list.ignore_nulls", ignoreNulls ? "true" : "false"}}; +TEST_F(CollectListAggregateTest, explicitIgnoreNullsTrue) { + // 2-arg form with ignoreNulls=true should behave same as 1-arg. + auto input = makeRowVector({makeNullableFlatVector( + {1, 2, std::nullopt, 4, std::nullopt, 6})}); + auto expected = + makeRowVector({makeArrayVectorFromJson({"[1, 2, 4, 6]"})}); + testAggregations( + {input}, + {}, + {"spark_collect_list(c0, true)"}, + {"array_sort(a0)"}, + {expected}); } TEST_F(CollectListAggregateTest, respectNulls) { - // When ignoreNulls is false (RESPECT NULLS), nulls should be included. + // 2-arg form with ignoreNulls=false (RESPECT NULLS). auto input = makeRowVector({makeNullableFlatVector( {1, 2, std::nullopt, 4, std::nullopt, 6})}); auto expected = makeRowVector({makeNullableArrayVector( std::vector>>{ {1, 2, std::nullopt, 4, std::nullopt, 6}})}); - std::vector expectedResult{expected}; - testAggregations( - {input}, - {}, - {"spark_collect_list(c0)"}, - expectedResult, - makeConfig(false)); + testAggregations({input}, {}, {"spark_collect_list(c0, false)"}, {expected}); } TEST_F(CollectListAggregateTest, respectNullsGroupBy) { @@ -153,30 +157,20 @@ TEST_F(CollectListAggregateTest, respectNullsGroupBy) { makeNullableArrayVector( std::vector>>{ {std::nullopt, 1}, {2, std::nullopt, 3}})}); - std::vector expectedResult{expected}; testAggregations( {data}, {"c0"}, - {"spark_collect_list(c1)"}, + {"spark_collect_list(c1, false)"}, {"c0", "a0"}, - expectedResult, - makeConfig(false)); + {expected}); } TEST_F(CollectListAggregateTest, respectNullsAllNulls) { - // When all inputs are null and ignoreNulls is false, output should be an - // array of nulls (not an empty array). auto input = makeRowVector({makeAllNullFlatVector(3)}); auto expected = makeRowVector({makeNullableArrayVector( std::vector>>{ {std::nullopt, std::nullopt, std::nullopt}})}); - std::vector expectedResult{expected}; - testAggregations( - {input}, - {}, - {"spark_collect_list(c0)"}, - expectedResult, - makeConfig(false)); + testAggregations({input}, {}, {"spark_collect_list(c0, false)"}, {expected}); } } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test diff --git a/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp b/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp index 664b2fa6d3c..e52d6e12fe0 100644 --- a/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp +++ b/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp @@ -87,6 +87,8 @@ int main(int argc, char** argv) { // Velox registers a 2-arg collect_set(T, boolean) signature that Spark // doesn't support. The fuzzer may pick this signature and fail. "collect_set", + // Same as collect_set — 2-arg signature not supported by Spark. + "collect_list", "first_ignore_null", "last_ignore_null", "regr_replacement",