Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion hll/include/CouponList-internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ CouponList<A>* CouponList<A>::newList(const void* bytes, size_t len, const A& al
const bool emptyFlag = ((data[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false);

const uint32_t couponCount = data[hll_constants::LIST_COUNT_BYTE];
// Reject LIST counts at or above the fixed LIST capacity.
const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE;
if (couponCount >= listCapacity) {
throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount >= capacity. Found couponCount: "
+ std::to_string(couponCount)
+ ", capacity: " + std::to_string(listCapacity));
}
Comment thread
tisonkun marked this conversation as resolved.
const uint32_t couponsInArray = (compact ? couponCount : (1 << HllUtil<A>::computeLgArrInts(LIST, couponCount, lgK)));
const size_t expectedLength = hll_constants::LIST_INT_ARR_START + (couponsInArray * sizeof(uint32_t));
if (len < expectedLength) {
Expand Down Expand Up @@ -146,11 +153,19 @@ CouponList<A>* CouponList<A>::newList(std::istream& is, const A& allocator) {
const bool oooFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::OUT_OF_ORDER_FLAG_MASK) ? true : false);
const bool emptyFlag = ((listHeader[hll_constants::FLAGS_BYTE] & hll_constants::EMPTY_FLAG_MASK) ? true : false);

const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE];
// Reject LIST counts at or above the fixed LIST capacity.
const uint32_t listCapacity = 1u << hll_constants::LG_INIT_LIST_SIZE;
if (couponCount >= listCapacity) {
throw std::invalid_argument("Attempt to deserialize invalid CouponList with couponCount >= capacity. Found couponCount: "
+ std::to_string(couponCount)
+ ", capacity: " + std::to_string(listCapacity));
}
Comment thread
tisonkun marked this conversation as resolved.

ClAlloc cla(allocator);
CouponList<A>* sketch = new (cla.allocate(1)) CouponList<A>(lgK, tgtHllType, mode, allocator);
using coupon_list_ptr = std::unique_ptr<CouponList<A>, std::function<void(HllSketchImpl<A>*)>>;
coupon_list_ptr ptr(sketch, sketch->get_deleter());
const uint32_t couponCount = listHeader[hll_constants::LIST_COUNT_BYTE];
sketch->couponCount_ = couponCount;
sketch->putOutOfOrderFlag(oooFlag); // should always be false for LIST

Expand Down
64 changes: 64 additions & 0 deletions hll/test/CouponListTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <string>
#include <exception>
#include <stdexcept>
#include <vector>

#include "hll.hpp"
#include "CouponList.hpp"
Expand Down Expand Up @@ -180,4 +181,67 @@ TEST_CASE("coupon list: check corrupt stream data", "[coupon_list]") {
REQUIRE_THROWS_AS(hll_sketch::deserialize(ss), std::invalid_argument);
}

TEST_CASE("coupon list: rejects malformed coupon count in bytes", "[coupon_list]") {
// Reject declared LIST counts at or beyond the fixed LIST capacity.
uint8_t lgK = 8;
hll_sketch sk1(lgK);
sk1.update(1);
sk1.update(2);
const auto validBytes = sk1.serialize_compact();

const uint8_t listCapacity = static_cast<uint8_t>(1u << hll_constants::LG_INIT_LIST_SIZE);
const uint8_t malformedCount = listCapacity;
Comment thread
tisonkun marked this conversation as resolved.
auto sketchBytes = validBytes;
// Keep the input long enough to validate the declared LIST count.
const size_t requiredLen = hll_constants::LIST_INT_ARR_START
+ malformedCount * sizeof(uint32_t);
if (sketchBytes.size() < requiredLen) {
sketchBytes.resize(requiredLen, 0);
}
sketchBytes[hll_constants::LIST_COUNT_BYTE] = malformedCount;

REQUIRE_THROWS_AS(
hll_sketch::deserialize(sketchBytes.data(), sketchBytes.size()),
std::invalid_argument);
REQUIRE_THROWS_AS(
CouponList<std::allocator<uint8_t>>::newList(
sketchBytes.data(), sketchBytes.size(), std::allocator<uint8_t>()),
std::invalid_argument);
}

TEST_CASE("coupon list: rejects malformed coupon count in stream", "[coupon_list]") {
uint8_t lgK = 8;
hll_sketch sk1(lgK);
sk1.update(1);
sk1.update(2);

const uint8_t listCapacity = static_cast<uint8_t>(1u << hll_constants::LG_INIT_LIST_SIZE);
const uint8_t malformedCount = listCapacity;
std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary);
sk1.serialize_compact(ss);

const size_t requiredLen = hll_constants::LIST_INT_ARR_START
+ malformedCount * sizeof(uint32_t);
// Keep the stream long enough to validate the declared LIST count.
ss.seekp(0, std::ios::end);
const auto endPos = static_cast<size_t>(ss.tellp());
if (endPos < requiredLen) {
std::vector<char> pad(requiredLen - endPos, 0);
ss.write(pad.data(), pad.size());
Comment thread
tisonkun marked this conversation as resolved.
}

ss.seekp(hll_constants::LIST_COUNT_BYTE);
ss.put(static_cast<char>(malformedCount));

ss.clear();
ss.seekg(0);
REQUIRE_THROWS_AS(hll_sketch::deserialize(ss), std::invalid_argument);

ss.clear();
ss.seekg(0);
REQUIRE_THROWS_AS(
CouponList<std::allocator<uint8_t>>::newList(ss, std::allocator<uint8_t>()),
std::invalid_argument);
}

} /* namespace datasketches */
Loading