add adaptive batch size heuristic for filtered search#309
add adaptive batch size heuristic for filtered search#309yuejiaointel wants to merge 13 commits intomainfrom
Conversation
rfsaliev
left a comment
There was a problem hiding this comment.
Thank you for the good proposal.
Requested changes:
Please apply such improvements to
range_search()and in `vamana_index_impl.h as well
Suggestions:
There are some performance related suggestions in comments.
But during the review, I found, that: compute_filtered_batch_size() logic is prediction of further amount of processing based on previous processing results and requested amount of matches aka:
PredictFurtherProcessing(processed, hits, goal)
So, I would declare this function more generic, and move it to utilities header with more common signature and reuse in vamana_index_impl.h as well:
In such case,
% max_batch_sizeoperation should be applied outside of this function
/// @param processed - number of already processed elements (total_checked)
/// @param hits - number of matched elements (found)
/// @param goal - number of requested elements to be matched (needed)
/// @param hint - result to be returned if prediction is failed, e.g. other params == 0
size_t predict_further_processing(size_t processed, size_t hits, size_t goal, size_t hint) {
if (processed * hits * goal == 0) {
return hint;
}
// use prediction formula below
...
}| @@ -136,6 +153,8 @@ class DynamicVamanaIndexImpl { | |||
| } | |||
| } | |||
| } | |||
| batch_size = | |||
| compute_filtered_batch_size(found, k, total_checked, batch_size); | |||
There was a problem hiding this comment.
Good idea, but, from performance perspective, I would slightly change the code:
- Compute the batch size at the beginning of the
do-whileloop - it will avoid computation whenfound==k - Increment
total_checkedout-of theforloop. - It might make sense to set initial batch size the
maxofkandsearch_window_size
E.g.
| size_t total_checked = 0; | |
| auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); | |
| do { | |
| batch_size = | |
| compute_filtered_batch_size(found, k, total_checked, batch_size); | |
| iterator.next(batch_size); | |
| for (auto& neighbor : iterator.results()) { | |
| if (filter->is_member(neighbor.id())) { | |
| result.set(neighbor, i, found); | |
| found++; | |
| if (found == k) { | |
| break; | |
| } | |
| } | |
| } | |
| total_checked += iterator.size(); | |
There was a problem hiding this comment.
Thx, added these change
| double hit_rate = static_cast<double>(found) / total_checked; | ||
| return static_cast<size_t>((needed - found) / hit_rate); |
There was a problem hiding this comment.
I would also try to improve performance here:
- FP64 computation is not very performant
- Computation precision is not very important here
- There is potential issues in SVS BatchIterator in case of huge batch size
So, I would use the following formula:
hit_rate_inv = 1 / hit_rate = checked / foundresult = (needed - found) / hit_rate = (needed - found) * hit_rate_inv = needed * checked / found - checked- The formula
needed * checked / found - checkedis most precise, but there is the bigger risk of overflow for hugeneededandcheckedvalues
| double hit_rate = static_cast<double>(found) / total_checked; | |
| return static_cast<size_t>((needed - found) / hit_rate); | |
| auto hit_rate = total_checked / found + 1; // found == 0 is handled above; +1 to increase result eliminating INT precision issues | |
| return (needed - found) * hit_rate % max_batch_size; // max_batch_size - constant |
Alternative (assuming, that FP32 is fast enough):
| double hit_rate = static_cast<double>(found) / total_checked; | |
| return static_cast<size_t>((needed - found) / hit_rate); | |
| float new_batch_size = static_cast<float>(needed) * total_checked / found - total_checked; | |
| return static_cast<size_t>(new_batch_size) % max_batch_size; |
There was a problem hiding this comment.
thx added, probably need to run some benchmarks before knowing exact performance
- Rename compute_filtered_batch_size to predict_further_processing and move to svs_runtime_utils.h for reuse - Use float arithmetic instead of double for hit rate calculation - Compute batch size at loop start to avoid unnecessary computation - Use iterator.size() instead of per-element increment for total_checked - Initial batch size = max(k, search_window_size) - Apply adaptive batch size to vamana_index_impl.h filtered search
- Cap batch size with std::min instead of modulo to avoid SIGFPE - Add comments explaining adaptive batch sizing logic
769bcf5 to
ee06f00
Compare
rfsaliev
left a comment
There was a problem hiding this comment.
It seems like max_batch_size calculation issue.
bindings/cpp/src/vamana_index_impl.h
Outdated
| // Use adaptive batch sizing: start with at least k candidates, | ||
| // then adjust based on observed filter hit rate. | ||
| auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); | ||
| const auto max_batch_size = batch_size; |
There was a problem hiding this comment.
IMHO, the max_batch_size value should be (compile-time?) constant based on the generic SVS Vamana performance instead of current k or search_window_size.
For example:
k == search_window_size == 10filter->is_member()returnstruefor 10% of results- after first iteration, the
predict_function will return(10 - 1) * 10 / 1 == 90 - but next
batch_sizewill be limited tomax_batch_size == 10 - So, we will have 10 small iterations instead of 1 big enough
There was a problem hiding this comment.
thx for the suggestion and agreed, the cap is too restrictive, removed the cap entirely and added a filter_stop early exist heuristic instead :
if hit rate falls below thresold (set by user) after getting some hits, we give up and return empty, and iterator should be able to handle large batch size by growing the search buffer
There was a problem hiding this comment.
there are some discussions about this during the benchmark results discussion , pulled you in to the chat
- Remove max_batch_size cap that limited adaptive sizing effectiveness - Add filter_stop param to SearchParams (default 0 = never give up) - Add should_stop_filtered_search() helper in svs_runtime_utils.h - If hit rate falls below filter_stop after first round, return empty so caller can fall back to exact search
Verifies that search with filter_stop=0.5 gives up and returns unspecified results when hit rate (~10%) is below threshold.
| iterator.next(k); | ||
| batch_size = | ||
| predict_further_processing(total_checked, found, k, batch_size); | ||
| iterator.next(batch_size); |
There was a problem hiding this comment.
What will happen on the second iteration in case if:
- filter_stop = 0.0
- batch_size = k = 100
- found = 1
?
How big bach_size will be here?
There was a problem hiding this comment.
batch size will be (100-1)*100/1 = 9900, with 1% hit to find remaining 99 results, we need 9900 more, is that too large?
There was a problem hiding this comment.
@ibhati , can you please clarify if batch_size=9900 is suitable for Vamana BatchIterator?
Thank you.
There was a problem hiding this comment.
checked with Ishwar and was told max size should be up to the number of vectors in the index, changed, thx for this question!
Enables early exit by default so OpenSearch can test the heuristic without plumbing a new search parameter through the stack.
Batch size can never exceed the index size since there are no more vectors to check beyond that.
Add max_batch_size parameter instead of capping at each call site.
Keep early exit opt-in only. OpenSearch can set filter_stop=0.01 when ready to test the heuristic.
| // Selective search with IDSelector | ||
| auto old_sp = impl_->get_search_parameters(); | ||
| impl_->set_search_parameters(sp); | ||
| const float filter_stop = params ? params->filter_stop : 0.0f; |
There was a problem hiding this comment.
It seems like this construction enforce user to always provide proper filter_stop value in case when user want to configure SearchParams.
To eliminate this issue, I would recommend to initialize filter_stop field in SearchParams with Unspecify<float> and use set_if_specified() here:
| const float filter_stop = params ? params->filter_stop : 0.0f; | |
| float filter_stop = svs_default_filter_stop_defined_somewhere; | |
| if (params) { | |
| set_if_specified(filter_stop, params->filter_stop); | |
| } |
| // If the hit rate after the first round falls below this threshold, | ||
| // stop and return empty results (caller can fall back to exact search). | ||
| // Default 0 means never give up. | ||
| float filter_stop = 0.0f; |
There was a problem hiding this comment.
| float filter_stop = 0.0f; | |
| float filter_stop = Unspecify<float>(); |
| size_t found = 0; | ||
| size_t total_checked = 0; | ||
| auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); | ||
| const auto max_batch_size = impl_->size(); |
There was a problem hiding this comment.
This can be moved out of the search_closure
bindings/cpp/src/svs_runtime_utils.h
Outdated
| // If no hits yet, returns `hint` unchanged. | ||
| // Result is capped at `max_batch_size` (e.g., number of vectors in the index). | ||
| inline size_t predict_further_processing( | ||
| size_t processed, size_t hits, size_t goal, size_t hint, size_t max_batch_size |
There was a problem hiding this comment.
| size_t processed, size_t hits, size_t goal, size_t hint, size_t max_batch_size | |
| size_t processed, size_t hits, size_t goal, size_t hint, size_t max_value |
There was a problem hiding this comment.
thx, implemented these changes
…ax_value - Use Unspecify<float>() for filter_stop default, set_if_specified pattern - Move max_batch_size (impl size) out of search_closure - Rename max_batch_size to max_value in predict_further_processing
Currently the filtered k-NN search loop uses batch_size = k when calling iterator.next(). When the filter is restrictive (e.g., 1% of IDs pass), this results in many expensive graph traversal rounds to collect enough valid results.
This PR introduces a heuristic that adapts the batch size based on observed filter hit rate:
For example, with k=10 and a 10% filter pass rate: instead of ~100 rounds of 10 candidates, it converges in ~2 rounds.