diff --git a/hll/include/CouponList-internal.hpp b/hll/include/CouponList-internal.hpp
index c92820e2..5b067a59 100644
--- a/hll/include/CouponList-internal.hpp
+++ b/hll/include/CouponList-internal.hpp
@@ -99,6 +99,13 @@ CouponList* CouponList::newList(const void* bytes, size_t len, const A& al
const bool emptyFlag = ((data[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false);
const uint32_t couponCount = data[hll_constants::LIST_COUNT_BYTE];
+ // Reject LIST counts at or above the fixed LIST capacity.
+ const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE;
+ if (couponCount >= listCapacity) {
+ throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount >= capacity. Found couponCount: "
+ + std::to_string(couponCount)
+ + ", capacity: " + std::to_string(listCapacity));
+ }
const uint32_t couponsInArray = (compact ? couponCount : (1 << HllUtil::computeLgArrInts(LIST, couponCount, lgK)));
const size_t expectedLength = hll_constants::LIST_INT_ARR_START + (couponsInArray * sizeof(uint32_t));
if (len < expectedLength) {
@@ -146,11 +153,19 @@ CouponList* CouponList::newList(std::istream& is, const A& allocator) {
const bool oooFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::OUT_OF_ORDER_FLAG_MASK) ? true : false);
const bool emptyFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false);
+ const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE];
+ // Reject LIST counts at or above the fixed LIST capacity.
+ const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE;
+ if (couponCount >= listCapacity) {
+ throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount >= capacity. Found couponCount: "
+ + std::to_string(couponCount)
+ + ", capacity: " + std::to_string(listCapacity));
+ }
+
ClAlloc cla(allocator);
CouponList* sketch = new (cla.allocate(1)) CouponList(lgK, tgtHllType, mode, allocator);
using coupon_list_ptr = std::unique_ptr, std::function*)>>;
coupon_list_ptr ptr(sketch, sketch->get_deleter());
- const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE];
sketch->couponCount_ = couponCount;
sketch->putOutOfOrderFlag(oooFlag); // should always be false for LIST
diff --git a/hll/test/CouponListTest.cpp b/hll/test/CouponListTest.cpp
index 81a2bff2..7ad4f700 100644
--- a/hll/test/CouponListTest.cpp
+++ b/hll/test/CouponListTest.cpp
@@ -24,6 +24,7 @@
#include
#include
#include
+#include
#include "hll.hpp"
#include "CouponList.hpp"
@@ -180,4 +181,67 @@ TEST_CASE("coupon list: check corrupt stream data", "[coupon_list]") {
REQUIRE_THROWS_AS(hll_sketch::deserialize(ss), std::invalid_argument);
}
+TEST_CASE("coupon list: rejects malformed coupon count in bytes", "[coupon_list]") {
+ // Reject declared LIST counts at or beyond the fixed LIST capacity.
+ uint8_t lgK = 8;
+ hll_sketch sk1(lgK);
+ sk1.update(1);
+ sk1.update(2);
+ const auto validBytes = sk1.serialize_compact();
+
+ const uint8_t listCapacity = static_cast(1u << hll_constants::LG_INIT_LIST_SIZE);
+ const uint8_t malformedCount = listCapacity;
+ auto sketchBytes = validBytes;
+ // Keep the input long enough to validate the declared LIST count.
+ const size_t requiredLen = hll_constants::LIST_INT_ARR_START
+ + malformedCount * sizeof(uint32_t);
+ if (sketchBytes.size() < requiredLen) {
+ sketchBytes.resize(requiredLen, 0);
+ }
+ sketchBytes[hll_constants::LIST_COUNT_BYTE] = malformedCount;
+
+ REQUIRE_THROWS_AS(
+ hll_sketch::deserialize(sketchBytes.data(), sketchBytes.size()),
+ std::invalid_argument);
+ REQUIRE_THROWS_AS(
+ CouponList>::newList(
+ sketchBytes.data(), sketchBytes.size(), std::allocator()),
+ std::invalid_argument);
+}
+
+TEST_CASE("coupon list: rejects malformed coupon count in stream", "[coupon_list]") {
+ uint8_t lgK = 8;
+ hll_sketch sk1(lgK);
+ sk1.update(1);
+ sk1.update(2);
+
+ const uint8_t listCapacity = static_cast(1u << hll_constants::LG_INIT_LIST_SIZE);
+ const uint8_t malformedCount = listCapacity;
+ std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary);
+ sk1.serialize_compact(ss);
+
+ const size_t requiredLen = hll_constants::LIST_INT_ARR_START
+ + malformedCount * sizeof(uint32_t);
+ // Keep the stream long enough to validate the declared LIST count.
+ ss.seekp(0, std::ios::end);
+ const auto endPos = static_cast(ss.tellp());
+ if (endPos < requiredLen) {
+ std::vector pad(requiredLen - endPos, 0);
+ ss.write(pad.data(), pad.size());
+ }
+
+ ss.seekp(hll_constants::LIST_COUNT_BYTE);
+ ss.put(static_cast(malformedCount));
+
+ ss.clear();
+ ss.seekg(0);
+ REQUIRE_THROWS_AS(hll_sketch::deserialize(ss), std::invalid_argument);
+
+ ss.clear();
+ ss.seekg(0);
+ REQUIRE_THROWS_AS(
+ CouponList>::newList(ss, std::allocator()),
+ std::invalid_argument);
+}
+
} /* namespace datasketches */