diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index c0b8cfd66be1..accf6488045d 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -96,12 +96,24 @@ public Predicate convert(Filter filter, boolean ignoreFailure) { } } + private boolean isNaN(Object value) { + if (value instanceof Float) { + return Float.isNaN((Float) value); + } else if (value instanceof Double) { + return Double.isNaN((Double) value); + } + return false; + } + public Predicate convert(Filter filter) { if (filter instanceof EqualTo) { EqualTo eq = (EqualTo) filter; - // TODO deal with isNaN int index = fieldIndex(eq.attribute()); Object literal = convertLiteral(index, eq.value()); + if (isNaN(literal)) { + // NaN != NaN, so equality with NaN should never match + return PredicateBuilder.alwaysFalse(); + } return builder.equal(index, literal); } else if (filter instanceof EqualNullSafe) { EqualNullSafe eq = (EqualNullSafe) filter; @@ -116,21 +128,37 @@ public Predicate convert(Filter filter) { GreaterThan gt = (GreaterThan) filter; int index = fieldIndex(gt.attribute()); Object literal = convertLiteral(index, gt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.greaterThan(index, literal); } else if (filter instanceof GreaterThanOrEqual) { GreaterThanOrEqual gt = (GreaterThanOrEqual) filter; int index = fieldIndex(gt.attribute()); Object literal = convertLiteral(index, gt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.greaterOrEqual(index, literal); } else if (filter instanceof LessThan) { LessThan lt = (LessThan) filter; int index = fieldIndex(lt.attribute()); Object literal = convertLiteral(index, lt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.lessThan(index, literal); } else if (filter instanceof LessThanOrEqual) { LessThanOrEqual lt = (LessThanOrEqual) filter; int index = fieldIndex(lt.attribute()); Object literal = convertLiteral(index, lt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.lessOrEqual(index, literal); } else if (filter instanceof In) { In in = (In) filter; diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala index 1493dfc49c76..47353b1aa426 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala @@ -44,13 +44,25 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { } } + private def isNaN(value: Any): Boolean = { + value match { + case f: Float => f.isNaN + case d: Double => d.isNaN + case _ => false + } + } + private def convert(sparkPredicate: SparkPredicate): Predicate = { sparkPredicate.name() match { case EQUAL_TO => sparkPredicate match { case BinaryPredicate(transform, literal) => - // TODO deal with isNaN - builder.equal(transform, literal) + if (isNaN(literal)) { + // NaN != NaN, so equality with NaN should never match + PredicateBuilder.alwaysFalse() + } else { + builder.equal(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -70,7 +82,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case GREATER_THAN => sparkPredicate match { case BinaryPredicate(transform, literal) => - builder.greaterThan(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.greaterThan(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -78,7 +95,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case GREATER_THAN_OR_EQUAL => sparkPredicate match { case BinaryPredicate((transform, literal)) => - builder.greaterOrEqual(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.greaterOrEqual(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -86,7 +108,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case LESS_THAN => sparkPredicate match { case BinaryPredicate(transform, literal) => - builder.lessThan(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.lessThan(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -94,7 +121,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case LESS_THAN_OR_EQUAL => sparkPredicate match { case BinaryPredicate(transform, literal) => - builder.lessOrEqual(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.lessOrEqual(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala index f7e0bba63f14..c5e714afaab0 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala @@ -127,6 +127,13 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { checkAnswer(sql(s"SELECT float_col from test_tbl WHERE $filter"), Seq(Row(1.0f))) assert(scanFilesCount(filter) == 1) + // Test NaN handling - equality with NaN should return AlwaysFalse + val nanFilter = "float_col = CAST('NaN' AS FLOAT)" + val nanPredicate = converter.convert(v2Filter(nanFilter)).get + assert( + nanPredicate.equals(PredicateBuilder.alwaysFalse()), + "NaN equality should return AlwaysFalse") + filter = "double_col = 1.0" actual = converter.convert(v2Filter(filter)).get assert(actual.equals(builder.equal(6, 1.0d))) @@ -507,6 +514,85 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { assert(filesScanned == 4, s"Expected 4 files but scanned $filesScanned files") } + test("V2Filter: EqualTo with NaN should return AlwaysFalse") { + // Test float_col = NaN should always return false (no matching rows) + val filter1 = "float_col = CAST('NaN' AS FLOAT)" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(PredicateBuilder.alwaysFalse())) + + // Verify no files are scanned (AlwaysFalse should skip all files) + val filesScanned1 = scanFilesWithPredicate(predicate1) + assert( + filesScanned1 == 0, + s"Expected 0 files for NaN equality but scanned $filesScanned1 files") + + // Test double_col = NaN should always return false + val filter2 = "double_col = CAST('NaN' AS DOUBLE)" + val predicate2 = converter.convert(v2Filter(filter2)).get + assert(predicate2.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned2 = scanFilesWithPredicate(predicate2) + assert( + filesScanned2 == 0, + s"Expected 0 files for NaN equality but scanned $filesScanned2 files") + } + + test("V2Filter: GreaterThan with NaN should return AlwaysFalse") { + // Test float_col > NaN should always return false + val filter1 = "float_col > CAST('NaN' AS FLOAT)" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned1 = scanFilesWithPredicate(predicate1) + assert( + filesScanned1 == 0, + s"Expected 0 files for NaN comparison but scanned $filesScanned1 files") + + // Test double_col > NaN should always return false + val filter2 = "double_col > CAST('NaN' AS DOUBLE)" + val predicate2 = converter.convert(v2Filter(filter2)).get + assert(predicate2.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned2 = scanFilesWithPredicate(predicate2) + assert( + filesScanned2 == 0, + s"Expected 0 files for NaN comparison but scanned $filesScanned2 files") + } + + test("V2Filter: LessThanOrEqual with NaN should return AlwaysFalse") { + // Test float_col <= NaN should always return false + val filter1 = "float_col <= CAST('NaN' AS FLOAT)" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned1 = scanFilesWithPredicate(predicate1) + assert( + filesScanned1 == 0, + s"Expected 0 files for NaN comparison but scanned $filesScanned1 files") + } + + test("V2Filter: float and double normal operations not affected by NaN handling") { + // Verify that normal float/double queries still work correctly + val filter1 = "float_col = 1.0" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(builder.equal(5, 1.0f))) + checkAnswer(sql(s"SELECT float_col FROM test_tbl WHERE $filter1"), Seq(Row(1.0f))) + + val filter2 = "double_col > 2.0" + val predicate2 = converter.convert(v2Filter(filter2)).get + assert(predicate2.equals(builder.greaterThan(6, 2.0d))) + checkAnswer( + sql(s"SELECT double_col FROM test_tbl WHERE $filter2 ORDER BY double_col"), + Seq(Row(3.0d), Row(4.0d))) + + val filter3 = "float_col <= 3.0" + val predicate3 = converter.convert(v2Filter(filter3)).get + assert(predicate3.equals(builder.lessOrEqual(5, 3.0f))) + checkAnswer( + sql(s"SELECT float_col FROM test_tbl WHERE $filter3 ORDER BY float_col"), + Seq(Row(1.0f), Row(2.0f), Row(3.0f))) + } + private def v2Filter(str: String, tableName: String = "test_tbl"): SparkPredicate = { val condition = sql(s"SELECT * FROM $tableName WHERE $str").queryExecution.optimizedPlan .collectFirst { case f: Filter => f }