Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benches/agg_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ fn terms_7(index: &Index) {
});
execute_agg(index, agg_req);
}

fn terms_all_unique(index: &Index) {
let agg_req = json!({
"my_texts": { "terms": { "field": "text_all_unique_terms" } },
Expand Down
1 change: 1 addition & 0 deletions src/aggregation/bucket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod histogram;
mod range;
mod term_agg;
mod term_missing_agg;
mod term_ord_to_str_cache;

use std::collections::HashMap;
use std::fmt;
Expand Down
212 changes: 175 additions & 37 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use columnar::{
NumericalValue, StrColumn,
};
use common::{BitSet, TinySet};
use rustc_hash::FxHashMap;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};

use super::term_ord_to_str_cache::{StringArena, StringRef, TermOrdToStrCache};
use super::{CustomOrder, Order, OrderTarget};
use crate::aggregation::agg_data::{
build_segment_agg_collectors, AggRefNode, AggregationsSegmentCtx,
Expand Down Expand Up @@ -397,6 +398,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
} else if is_top_level && max_term_id < MAX_NUM_TERMS_FOR_VEC {
Expand All @@ -408,6 +410,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
} else if max_term_id < 8_000_000 && is_top_level {
Expand All @@ -422,6 +425,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
} else {
Expand All @@ -435,6 +439,7 @@ pub(crate) fn build_segment_term_collector(
bucket_id_provider,
max_term_id,
terms_req_data,
term_ord_cache: None,
};
Ok(Box::new(collector))
}
Expand Down Expand Up @@ -470,6 +475,9 @@ trait TermAggregationMap: Clone + Debug + 'static {
/// Returns the term aggregation as a vector of (term_id, bucket) pairs,
/// in any order.
fn into_vec(self) -> Vec<(u64, Bucket)>;

/// Collects all term ordinals present in this map into the given set.
fn collect_term_ords(&self, set: &mut FxHashSet<u64>);
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -622,6 +630,20 @@ impl TermAggregationMap for PagedTermMap {

Self { pages, mem_usage }
}

fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
for (page_idx, page_opt) in self.pages.iter().enumerate() {
if let Some(page) = page_opt {
let base_term_id = (page_idx << PAGE_SHIFT) as u64;
for (bucket_pos, &tiny_set) in page.presence.iter().enumerate() {
let base_offset = bucket_pos * 64;
for bit in tiny_set.into_iter() {
set.insert(base_term_id + (base_offset + bit as usize) as u64);
}
}
}
}
}
}

impl TermAggregationMap for HashMapTermBuckets {
Expand All @@ -648,6 +670,10 @@ impl TermAggregationMap for HashMapTermBuckets {
fn new(_max_term_id: u64, _bucket_id_provider: &mut BucketIdProvider) -> Self {
Self::default()
}

fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
set.extend(self.bucket_map.keys().copied());
}
}

/// An optimized term map implementation for a compact set of term ordinals.
Expand Down Expand Up @@ -704,6 +730,14 @@ impl TermAggregationMap for VecTermBucketsNoAgg {
.collect(),
}
}

fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
for (term_id, &count) in self.buckets.iter().enumerate() {
if count > 0 {
set.insert(term_id as u64);
}
}
}
}

/// An optimized term map implementation for a compact set of term ordinals.
Expand Down Expand Up @@ -753,6 +787,63 @@ impl TermAggregationMap for VecTermBuckets {
.collect(),
}
}

fn collect_term_ords(&self, set: &mut FxHashSet<u64>) {
for (term_id, bucket) in self.buckets.iter().enumerate() {
if bucket.count > 0 {
set.insert(term_id as u64);
}
}
}
}

fn build_term_ord_cache<TermMap: TermAggregationMap>(
parent_buckets: &[TermMap],
dictionary: &Dictionary,
term_req: &TermsAggReqData,
) -> std::io::Result<TermOrdToStrCache> {
let capacity: usize = parent_buckets.len() * 64;
let mut term_ords_set: FxHashSet<u64> =
FxHashSet::with_capacity_and_hasher(capacity, FxBuildHasher);
for bucket in parent_buckets.iter() {
bucket.collect_term_ords(&mut term_ords_set);
}

if let Some(missing_sentinel) = term_req.missing_value_for_accessor {
term_ords_set.remove(&missing_sentinel);
}

let mut term_ords: Vec<u64> = term_ords_set.into_iter().collect();
term_ords.sort_unstable();

term_ords.pop_if(|highest_term_ord| *highest_term_ord >= dictionary.num_terms() as u64);

let mut string_arena = StringArena::default();
let mut string_refs: Vec<StringRef> = Vec::with_capacity(term_ords.len());
let all_found: bool = dictionary.sorted_ords_to_term_cb(&term_ords, |term_bytes| {
let term_str = std::str::from_utf8(term_bytes).expect("could not convert to str");
string_refs.push(string_arena.register_str(term_str));
})?;
assert!(all_found);

let missing_key: Option<IntermediateKey> =
term_req
.req
.missing
.as_ref()
.map(|missing_value| match missing_value {
Key::Str(s) => IntermediateKey::Str(s.clone()),
Key::F64(v) => IntermediateKey::F64(*v),
Key::U64(v) => IntermediateKey::U64(*v),
Key::I64(v) => IntermediateKey::I64(*v),
});

Ok(TermOrdToStrCache::new(
term_ords,
string_refs,
string_arena,
missing_key,
))
}

/// The collector puts values from the fast field into the correct buckets and does a conversion to
Expand All @@ -765,6 +856,7 @@ struct SegmentTermCollector<TermMap: TermAggregationMap, B: SubAggBuffer> {
bucket_id_provider: BucketIdProvider,
max_term_id: u64,
terms_req_data: TermsAggReqData,
term_ord_cache: Option<TermOrdToStrCache>,
}

pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
Expand All @@ -783,6 +875,17 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
) -> crate::Result<()> {
// TODO: avoid prepare_max_bucket here and handle empty buckets.
self.prepare_max_bucket(bucket, agg_data)?;

if self.terms_req_data.column_type == ColumnType::Str && self.term_ord_cache.is_none() {
if let Some(str_dict_column) = &self.terms_req_data.str_dict_column {
self.term_ord_cache = Some(build_term_ord_cache(
&self.parent_buckets,
str_dict_column.dictionary(),
&self.terms_req_data,
)?);
}
}

let bucket = std::mem::replace(
&mut self.parent_buckets[bucket as usize],
TermMap::new(0, &mut self.bucket_id_provider),
Expand All @@ -797,6 +900,7 @@ impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentAggregationCollector
.map(BufferedSubAggs::get_sub_agg_collector),
bucket,
agg_data,
self.term_ord_cache.as_ref(),
)?;
results.push(name, IntermediateAggregationResult::Bucket(bucket))?;
Ok(())
Expand Down Expand Up @@ -957,6 +1061,7 @@ where
mut sub_agg_collector: Option<&mut dyn SegmentAggregationCollector>,
term_buckets: TermMap,
agg_data: &AggregationsSegmentCtx,
term_ord_cache: Option<&TermOrdToStrCache>,
) -> crate::Result<IntermediateBucketResult> {
let mut entries: Vec<(u64, Bucket)> = term_buckets.into_vec();

Expand Down Expand Up @@ -1005,43 +1110,76 @@ where
.map(|el| el.dictionary())
.unwrap_or_else(|| &fallback_dict);

if let Some((intermediate_key, bucket)) = extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(intermediate_key, intermediate_entry);
}

// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);

let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();
if let Some(cache) = term_ord_cache {
// Use cached term resolution: missing value is handled via the cache.
if let Some(missing_sentinel) = term_req.missing_value_for_accessor {
if let Some(pos) = entries.iter().position(|(tid, _)| *tid == missing_sentinel)
{
let (_tid, bucket) = entries.swap_remove(pos);
if let Some(missing_key) = cache.missing_key() {
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(missing_key.clone(), intermediate_entry);
}
}
}

let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
for (term_ord, bucket) in entries {
if let Some(term_str) = cache.get(term_ord) {
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)?;
dict.insert(
IntermediateKey::Str(term_str.to_string()),
intermediate_entry,
);
}
}
} else {
if let Some((intermediate_key, bucket)) =
extract_missing_value(&mut entries, term_req)
{
let intermediate_entry = into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;

let mut intermediate_entry_it = intermediate_entries.into_iter();

term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
})?;
)?;
dict.insert(intermediate_key, intermediate_entry);
}

// Sort by term ord
entries.sort_unstable_by_key(|bucket| bucket.0);

let (term_ids, buckets): (Vec<u64>, Vec<Bucket>) = entries.into_iter().unzip();

let intermediate_entries: Vec<IntermediateTermBucketEntry> = buckets
.into_iter()
.map(|bucket| {
into_intermediate_bucket_entry(
bucket,
reborrow_opt_collector(&mut sub_agg_collector),
agg_data,
)
})
.collect::<crate::Result<_>>()?;

let mut intermediate_entry_it = intermediate_entries.into_iter();

term_dict.sorted_ords_to_term_cb(&term_ids[..], |term| {
let intermediate_entry = intermediate_entry_it.next().unwrap();
dict.insert(
IntermediateKey::Str(
String::from_utf8(term.to_vec()).expect("could not convert to String"),
),
intermediate_entry,
);
})?;
}

if term_req.req.min_doc_count == 0 {
// TODO: Handle rev streaming for descending sorting by keys
Expand Down Expand Up @@ -1162,13 +1300,13 @@ where
impl<TermMap: TermAggregationMap, B: SubAggBuffer> SegmentTermCollector<TermMap, B> {
#[inline]
fn collect_terms_with_docs(
iter: impl Iterator<Item = (crate::DocId, u64)>,
doc_term_ord_iter: impl Iterator<Item = (crate::DocId, u64)>,
term_buckets: &mut TermMap,
bucket_id_provider: &mut BucketIdProvider,
sub_agg: &mut BufferedSubAggs<B>,
) {
for (doc, term_id) in iter {
let bucket_id = term_buckets.term_entry(term_id, bucket_id_provider);
for (doc, term_ord) in doc_term_ord_iter {
let bucket_id = term_buckets.term_entry(term_ord, bucket_id_provider);
sub_agg.push(bucket_id, doc);
}
}
Expand Down
Loading
Loading