From a07e6f6f81953fa8881b258c1cb3df23a1ad9989 Mon Sep 17 00:00:00 2001 From: John Kew Date: Wed, 22 Apr 2026 14:33:40 -0700 Subject: [PATCH 1/2] Fix CouponList deserialization count validation bug --- hll/include/CouponList-internal.hpp | 17 +++++++- hll/test/CouponListTest.cpp | 60 +++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/hll/include/CouponList-internal.hpp b/hll/include/CouponList-internal.hpp index c92820e2..cef28caf 100644 --- a/hll/include/CouponList-internal.hpp +++ b/hll/include/CouponList-internal.hpp @@ -99,6 +99,13 @@ CouponList* CouponList::newList(const void* bytes, size_t len, const A& al const bool emptyFlag = ((data[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false); const uint32_t couponCount = data[hll_constants::LIST_COUNT_BYTE]; + // Reject LIST counts larger than the fixed LIST capacity. + const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE; + if (couponCount > listCapacity) { + throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount > capacity. Found couponCount: " + + std::to_string(couponCount) + + ", capacity: " + std::to_string(listCapacity)); + } const uint32_t couponsInArray = (compact ? couponCount : (1 << HllUtil::computeLgArrInts(LIST, couponCount, lgK))); const size_t expectedLength = hll_constants::LIST_INT_ARR_START + (couponsInArray * sizeof(uint32_t)); if (len < expectedLength) { @@ -146,11 +153,19 @@ CouponList* CouponList::newList(std::istream& is, const A& allocator) { const bool oooFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::OUT_OF_ORDER_FLAG_MASK) ? true : false); const bool emptyFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false); + const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE]; + // Reject LIST counts larger than the fixed LIST capacity. + const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE; + if (couponCount > listCapacity) { + throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount > capacity. Found couponCount: " + + std::to_string(couponCount) + + ", capacity: " + std::to_string(listCapacity)); + } + ClAlloc cla(allocator); CouponList* sketch = new (cla.allocate(1)) CouponList(lgK, tgtHllType, mode, allocator); using coupon_list_ptr = std::unique_ptr, std::function*)>>; coupon_list_ptr ptr(sketch, sketch->get_deleter()); - const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE]; sketch->couponCount_ = couponCount; sketch->putOutOfOrderFlag(oooFlag); // should always be false for LIST diff --git a/hll/test/CouponListTest.cpp b/hll/test/CouponListTest.cpp index 81a2bff2..ec5ae852 100644 --- a/hll/test/CouponListTest.cpp +++ b/hll/test/CouponListTest.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include "hll.hpp" #include "CouponList.hpp" @@ -180,4 +181,63 @@ TEST_CASE("coupon list: check corrupt stream data", "[coupon_list]") { REQUIRE_THROWS_AS(hll_sketch::deserialize(ss), std::invalid_argument); } +TEST_CASE("coupon list: rejects malformed coupon count in bytes", "[coupon_list]") { + // Reject declared LIST counts beyond the fixed LIST capacity. + uint8_t lgK = 8; + hll_sketch sk1(lgK); + sk1.update(1); + sk1.update(2); + auto sketchBytes = sk1.serialize_compact(); + + const uint8_t malformedCount = 10; + // Keep the input long enough to validate the declared LIST count. + const size_t requiredLen = hll_constants::LIST_INT_ARR_START + + malformedCount * sizeof(uint32_t); + if (sketchBytes.size() < requiredLen) { + sketchBytes.resize(requiredLen, 0); + } + sketchBytes[hll_constants::LIST_COUNT_BYTE] = malformedCount; + + REQUIRE_THROWS_AS( + hll_sketch::deserialize(sketchBytes.data(), sketchBytes.size()), + std::invalid_argument); + REQUIRE_THROWS_AS( + CouponList>::newList( + sketchBytes.data(), sketchBytes.size(), std::allocator()), + std::invalid_argument); +} + +TEST_CASE("coupon list: rejects malformed coupon count in stream", "[coupon_list]") { + uint8_t lgK = 8; + hll_sketch sk1(lgK); + sk1.update(1); + sk1.update(2); + std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary); + sk1.serialize_compact(ss); + + const uint8_t malformedCount = 10; + const size_t requiredLen = hll_constants::LIST_INT_ARR_START + + malformedCount * sizeof(uint32_t); + // Keep the stream long enough to validate the declared LIST count. + ss.seekp(0, std::ios::end); + const auto endPos = static_cast(ss.tellp()); + if (endPos < requiredLen) { + std::vector pad(requiredLen - endPos, 0); + ss.write(pad.data(), pad.size()); + } + + ss.seekp(hll_constants::LIST_COUNT_BYTE); + ss.put(static_cast(malformedCount)); + + ss.clear(); + ss.seekg(0); + REQUIRE_THROWS_AS(hll_sketch::deserialize(ss), std::invalid_argument); + + ss.clear(); + ss.seekg(0); + REQUIRE_THROWS_AS( + CouponList>::newList(ss, std::allocator()), + std::invalid_argument); +} + } /* namespace datasketches */ From 805eaf3c5d96bf6263298880be10da30a604faec Mon Sep 17 00:00:00 2001 From: jihuayu Date: Fri, 5 Jun 2026 21:16:33 +0800 Subject: [PATCH 2/2] Reject full CouponList during deserialization --- hll/include/CouponList-internal.hpp | 12 ++++++------ hll/test/CouponListTest.cpp | 12 ++++++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/hll/include/CouponList-internal.hpp b/hll/include/CouponList-internal.hpp index cef28caf..5b067a59 100644 --- a/hll/include/CouponList-internal.hpp +++ b/hll/include/CouponList-internal.hpp @@ -99,10 +99,10 @@ CouponList* CouponList::newList(const void* bytes, size_t len, const A& al const bool emptyFlag = ((data[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false); const uint32_t couponCount = data[hll_constants::LIST_COUNT_BYTE]; - // Reject LIST counts larger than the fixed LIST capacity. + // Reject LIST counts at or above the fixed LIST capacity. const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE; - if (couponCount > listCapacity) { - throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount > capacity. Found couponCount: " + if (couponCount >= listCapacity) { + throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount >= capacity. Found couponCount: " + std::to_string(couponCount) + ", capacity: " + std::to_string(listCapacity)); } @@ -154,10 +154,10 @@ CouponList* CouponList::newList(std::istream& is, const A& allocator) { const bool emptyFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false); const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE]; - // Reject LIST counts larger than the fixed LIST capacity. + // Reject LIST counts at or above the fixed LIST capacity. const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE; - if (couponCount > listCapacity) { - throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount > capacity. Found couponCount: " + if (couponCount >= listCapacity) { + throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount >= capacity. Found couponCount: " + std::to_string(couponCount) + ", capacity: " + std::to_string(listCapacity)); } diff --git a/hll/test/CouponListTest.cpp b/hll/test/CouponListTest.cpp index ec5ae852..7ad4f700 100644 --- a/hll/test/CouponListTest.cpp +++ b/hll/test/CouponListTest.cpp @@ -182,14 +182,16 @@ TEST_CASE("coupon list: check corrupt stream data", "[coupon_list]") { } TEST_CASE("coupon list: rejects malformed coupon count in bytes", "[coupon_list]") { - // Reject declared LIST counts beyond the fixed LIST capacity. + // Reject declared LIST counts at or beyond the fixed LIST capacity. uint8_t lgK = 8; hll_sketch sk1(lgK); sk1.update(1); sk1.update(2); - auto sketchBytes = sk1.serialize_compact(); + const auto validBytes = sk1.serialize_compact(); - const uint8_t malformedCount = 10; + const uint8_t listCapacity = static_cast(1u << hll_constants::LG_INIT_LIST_SIZE); + const uint8_t malformedCount = listCapacity; + auto sketchBytes = validBytes; // Keep the input long enough to validate the declared LIST count. const size_t requiredLen = hll_constants::LIST_INT_ARR_START + malformedCount * sizeof(uint32_t); @@ -212,10 +214,12 @@ TEST_CASE("coupon list: rejects malformed coupon count in stream", "[coupon_list hll_sketch sk1(lgK); sk1.update(1); sk1.update(2); + + const uint8_t listCapacity = static_cast(1u << hll_constants::LG_INIT_LIST_SIZE); + const uint8_t malformedCount = listCapacity; std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary); sk1.serialize_compact(ss); - const uint8_t malformedCount = 10; const size_t requiredLen = hll_constants::LIST_INT_ARR_START + malformedCount * sizeof(uint32_t); // Keep the stream long enough to validate the declared LIST count.