diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 4352716ebd76..deb3e9e3fbe4 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -49,6 +49,7 @@ add_arrow_test(utility-test SOURCES align_util_test.cc atfork_test.cc + base64_test.cc byte_size_test.cc byte_stream_split_test.cc cache_test.cc diff --git a/cpp/src/arrow/util/base64.h b/cpp/src/arrow/util/base64.h index 5b80e19d896b..a575fee45132 100644 --- a/cpp/src/arrow/util/base64.h +++ b/cpp/src/arrow/util/base64.h @@ -20,6 +20,7 @@ #include #include +#include "arrow/result.h" #include "arrow/util/visibility.h" namespace arrow { @@ -29,7 +30,7 @@ ARROW_EXPORT std::string base64_encode(std::string_view s); ARROW_EXPORT -std::string base64_decode(std::string_view s); +arrow::Result base64_decode(std::string_view s); } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/base64_test.cc b/cpp/src/arrow/util/base64_test.cc new file mode 100644 index 000000000000..38f99ea5e6a1 --- /dev/null +++ b/cpp/src/arrow/util/base64_test.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/base64.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace util { + +TEST(Base64DecodeTest, ValidInputs) { + ASSERT_OK_AND_ASSIGN(auto empty, arrow::util::base64_decode("")); + EXPECT_EQ(empty, ""); + + ASSERT_OK_AND_ASSIGN(auto two_paddings, arrow::util::base64_decode("Zg==")); + EXPECT_EQ(two_paddings, "f"); + + ASSERT_OK_AND_ASSIGN(auto one_padding, arrow::util::base64_decode("Zm8=")); + EXPECT_EQ(one_padding, "fo"); + + ASSERT_OK_AND_ASSIGN(auto no_padding, arrow::util::base64_decode("Zm9v")); + EXPECT_EQ(no_padding, "foo"); + + ASSERT_OK_AND_ASSIGN(auto multiblock, arrow::util::base64_decode("SGVsbG8gd29ybGQ=")); + EXPECT_EQ(multiblock, "Hello world"); +} + +TEST(Base64DecodeTest, BinaryOutput) { + // 'A' maps to index 0 — same zero value used for padding slots + // verifies the 'A' bug is not present + ASSERT_OK_AND_ASSIGN(auto all_A, arrow::util::base64_decode("AAAA")); + EXPECT_EQ(all_A, std::string("\x00\x00\x00", 3)); + + // Arbitrary non-ASCII output bytes + ASSERT_OK_AND_ASSIGN(auto binary, arrow::util::base64_decode("AP8A")); + EXPECT_EQ(binary, std::string("\x00\xff\x00", 3)); +} + +TEST(Base64DecodeTest, InvalidLength) { + ASSERT_RAISES_WITH_MESSAGE( + Invalid, "Invalid: Invalid base64 input: length is not a multiple of 4", + arrow::util::base64_decode("abc")); +} + +TEST(Base64DecodeTest, InvalidCharacters) { + ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab$=")); + + // Non-ASCII byte + std::string non_ascii = std::string("abc") + static_cast(0xFF); + ASSERT_RAISES(Invalid, arrow::util::base64_decode(non_ascii)); + + // Corruption mid-string across multiple blocks + ASSERT_RAISES(Invalid, arrow::util::base64_decode("aGVs$G8gd29ybGQ=")); +} + +TEST(Base64DecodeTest, InvalidPadding) { + // Padding in wrong position within block + ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab=c")); + + // 3 padding characters — exceeds maximum of 2 + ASSERT_RAISES(Invalid, arrow::util::base64_decode("a===")); + + // 4 padding characters + ASSERT_RAISES(Invalid, arrow::util::base64_decode("====")); + + // Padding in non-final block across multiple blocks + ASSERT_RAISES(Invalid, arrow::util::base64_decode("Zm8=Zm8=")); +} + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/vendored/base64.cpp b/cpp/src/arrow/vendored/base64.cpp index 6f53c0524e71..d36f3e21eec1 100644 --- a/cpp/src/arrow/vendored/base64.cpp +++ b/cpp/src/arrow/vendored/base64.cpp @@ -30,7 +30,9 @@ */ #include "arrow/util/base64.h" +#include "arrow/result.h" #include +#include namespace arrow { namespace util { @@ -40,11 +42,6 @@ static const std::string base64_chars = "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; - -static inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - static std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) { std::string ret; int i = 0; @@ -93,38 +90,67 @@ std::string base64_encode(std::string_view string_to_encode) { return base64_encode(bytes_to_encode, in_len); } -std::string base64_decode(std::string_view encoded_string) { +Result base64_decode(std::string_view encoded_string) { size_t in_len = encoded_string.size(); int i = 0; - int j = 0; int in_ = 0; + int padding_count = 0; + int block_padding = 0; + bool padding_started = false; unsigned char char_array_4[4], char_array_3[3]; std::string ret; - while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i ==4) { - for (i = 0; i <4; i++) - char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff; + if (encoded_string.size() % 4 != 0) { + return Status::Invalid("Invalid base64 input: length is not a multiple of 4"); + } - char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + while (in_len--) { + unsigned char c = encoded_string[in_]; - for (i = 0; (i < 3); i++) - ret += char_array_3[i]; - i = 0; + if (c == '=') { + padding_started = true; + padding_count++; + + if (padding_count > 2) { + return Status::Invalid("Invalid base64 input: too many padding characters"); + } + + char_array_4[i++] = 0; + } else { + if (padding_started) { + return Status::Invalid("Invalid base64 input: padding characters must be at the end"); + } + + if (base64_chars.find(c) == std::string::npos) { + return Status::Invalid( + "Invalid base64 input: contains non-base64 byte at position " + + std::to_string(in_)); + } + + char_array_4[i++] = c; } - } - if (i) { - for (j = 0; j < i; j++) - char_array_4[j] = base64_chars.find(char_array_4[j]) & 0xff; + in_++; + + if (i == 4) { + for (i = 0; i < 4; i++) { + if (char_array_4[i] != 0) { + char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff; + } + } + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + block_padding = padding_count; - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + for (i = 0; i < 3 - block_padding; i++) { + ret += char_array_3[i]; + } - for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + i = 0; + } } return ret;