diff --git a/src/docset.rs b/src/docset.rs index c02bbbfc35..206ac096ae 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -138,6 +138,31 @@ pub trait DocSet: Send { buffer.len() } + /// Fills a given mutable buffer with the next doc ids smaller than `horizon`. + /// + /// Unlike [`DocSet::fill_buffer`], this method must not advance past a doc id greater than or + /// equal to `horizon`. + fn fill_buffer_up_to( + &mut self, + horizon: DocId, + buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + if self.doc() == TERMINATED { + return 0; + } + for (pos, buffer_val) in buffer.iter_mut().enumerate() { + let doc = self.doc(); + if doc >= horizon { + return pos; + } + *buffer_val = doc; + if self.advance() == TERMINATED { + return pos + 1; + } + } + buffer.len() + } + /// Returns the current document /// Right after creating a new `DocSet`, the docset points to the first document. /// @@ -251,6 +276,14 @@ impl DocSet for &mut dyn DocSet { (**self).fill_buffer(buffer) } + fn fill_buffer_up_to( + &mut self, + horizon: DocId, + buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + (**self).fill_buffer_up_to(horizon, buffer) + } + fn fill_bitset_block( &mut self, min_doc: DocId, diff --git a/src/postings/block_segment_postings.rs b/src/postings/block_segment_postings.rs index 61a9681312..66df3b2de9 100644 --- a/src/postings/block_segment_postings.rs +++ b/src/postings/block_segment_postings.rs @@ -240,6 +240,42 @@ impl BlockSegmentPostings { self.freq_decoder.output_array() } + pub(crate) fn copy_docs_and_term_freqs( + &self, + block_offset: usize, + horizon: DocId, + docs: &mut [DocId], + term_freqs: &mut [u32], + ) -> usize { + debug_assert_eq!(docs.len(), term_freqs.len()); + let block_docs = self.docs(); + let remaining_docs_in_block = block_docs.len().saturating_sub(block_offset); + let max_len = remaining_docs_in_block.min(docs.len()); + if max_len == 0 { + return 0; + } + + let source_docs = &block_docs[block_offset..block_offset + max_len]; + let len = if source_docs[max_len - 1] < horizon { + max_len + } else { + source_docs + .iter() + .position(|&doc| doc >= horizon) + .unwrap_or(max_len) + }; + + docs[..len].copy_from_slice(&source_docs[..len]); + + let block_freqs = self.freq_output_array(); + if block_freqs.len() >= block_offset + len { + term_freqs[..len].copy_from_slice(&block_freqs[block_offset..block_offset + len]); + } else { + term_freqs[..len].fill(1); + } + len + } + /// Return the frequency at index `idx` of the block. #[inline] pub fn freq(&self, idx: usize) -> u32 { diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 13b6761cfa..6b41ca0f1f 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -532,6 +532,16 @@ pub(crate) mod tests { fn score(&mut self) -> Score { self.0.score() } + + #[inline] + fn can_score_doc(&self) -> bool { + self.0.can_score_doc() + } + + #[inline] + fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score { + self.0.score_doc(doc, term_freq) + } } pub fn test_skip_against_unoptimized Box>( diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index e8928b90dc..1f60116e15 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -1,6 +1,6 @@ use common::HasLen; -use crate::docset::DocSet; +use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN}; use crate::fastfield::AliveBitSet; use crate::positions::PositionReader; use crate::postings::compression::COMPRESSION_BLOCK_SIZE; @@ -151,6 +151,34 @@ impl SegmentPostings { position_reader, } } + + pub(crate) fn fill_buffer_up_to_with_term_freqs( + &mut self, + horizon: DocId, + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + let mut num_elems = 0; + while num_elems < COLLECT_BLOCK_BUFFER_LEN && self.doc() < horizon { + let copied = self.block_cursor.copy_docs_and_term_freqs( + self.cur, + horizon, + &mut docs[num_elems..], + &mut term_freqs[num_elems..], + ); + if copied == 0 { + break; + } + num_elems += copied; + self.cur += copied; + + if self.cur == COMPRESSION_BLOCK_SIZE { + self.cur = 0; + self.block_cursor.advance(); + } + } + num_elems + } } impl DocSet for SegmentPostings { diff --git a/src/query/all_query.rs b/src/query/all_query.rs index 5431a3a1bb..ea6c2c686e 100644 --- a/src/query/all_query.rs +++ b/src/query/all_query.rs @@ -109,6 +109,16 @@ impl Scorer for AllScorer { fn score(&mut self) -> Score { 1.0 } + + #[inline] + fn can_score_doc(&self) -> bool { + true + } + + #[inline] + fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score { + 1.0 + } } #[cfg(test)] diff --git a/src/query/bm25.rs b/src/query/bm25.rs index d662f0eb56..ec931de954 100644 --- a/src/query/bm25.rs +++ b/src/query/bm25.rs @@ -1,5 +1,9 @@ +use std::cell::RefCell; +use std::num::NonZeroUsize; use std::sync::Arc; +use lru::LruCache; + use crate::fieldnorm::FieldNormReader; use crate::query::Explanation; use crate::schema::Field; @@ -59,7 +63,9 @@ fn cached_tf_component(fieldnorm: u32, average_fieldnorm: Score) -> Score { K1 * (1.0 - B + B * fieldnorm as Score / average_fieldnorm) } -fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> { +const BM25_TF_CACHE_CAPACITY: usize = 64; + +fn compute_tf_cache_uncached(average_fieldnorm: Score) -> Arc<[Score; 256]> { let mut cache: [Score; 256] = [0.0; 256]; for (fieldnorm_id, cache_mut) in cache.iter_mut().enumerate() { let fieldnorm = FieldNormReader::id_to_fieldnorm(fieldnorm_id as u8); @@ -68,6 +74,36 @@ fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> { Arc::new(cache) } +thread_local! { + static TF_CACHES: RefCell>> = RefCell::new(LruCache::new( + NonZeroUsize::new(BM25_TF_CACHE_CAPACITY).unwrap(), + )); +} + +/// The cache is shared across all [Bm25Weight] with the same average fieldnorm on the same thread. +/// It is stored in a thread local LRU cache. +/// +/// On one query all terms on the same field will share the same average fieldnorm, and thus the +/// same cache. This will lower cache pressure. +/// +/// Even between queries (on the same thread), the cache will be reused, which allows the cache to +/// better learn the memory address of the cache and access patterns. +/// +/// Thread local is used in order to be defensive about potential contention on the cache. +fn compute_tf_cache(average_fieldnorm: Score) -> Arc<[Score; 256]> { + let cache_key = average_fieldnorm.to_bits(); + TF_CACHES.with(|cache_by_average_fieldnorm| { + let mut cache_by_average_fieldnorm = cache_by_average_fieldnorm.borrow_mut(); + if let Some(cache) = cache_by_average_fieldnorm.get(&cache_key) { + return cache.clone(); + } + + let cache = compute_tf_cache_uncached(average_fieldnorm); + cache_by_average_fieldnorm.put(cache_key, cache.clone()); + cache + }) +} + /// A struct used for computing BM25 scores. #[derive(Clone)] pub struct Bm25Weight { @@ -229,7 +265,7 @@ impl Bm25Weight { #[cfg(test)] mod tests { - use super::idf; + use super::{idf, Bm25Weight}; use crate::{assert_nearly_equals, Score}; #[test] @@ -237,4 +273,12 @@ mod tests { let score: Score = 2.0; assert_nearly_equals!(idf(1, 2), score.ln()); } + + #[test] + fn test_bm25_tf_cache_is_shared_for_same_average_fieldnorm() { + let weight1 = Bm25Weight::for_one_term(1, 10, 3.0); + let weight2 = Bm25Weight::for_one_term(2, 10, 3.0); + + assert!(std::sync::Arc::ptr_eq(&weight1.cache, &weight2.cache)); + } } diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index f62cffb57e..03515223d0 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -91,10 +91,14 @@ fn into_box_scorer( num_docs: u32, ) -> Box { match scorer { - SpecializedScorer::TermUnion(term_scorers) => { - let union_scorer = - BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs); - Box::new(union_scorer) + SpecializedScorer::TermUnion(mut term_scorers) => { + if term_scorers.len() == 1 { + Box::new(term_scorers.pop().unwrap()) + } else { + let union_scorer = + BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs); + Box::new(union_scorer) + } } SpecializedScorer::TermIntersection(term_scorers) => { let boxed_scorers: Vec> = term_scorers diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index 69847d7507..4391ee145a 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -112,6 +112,14 @@ impl DocSet for BoostScorer { self.underlying.fill_buffer(buffer) } + fn fill_buffer_up_to( + &mut self, + horizon: DocId, + buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + self.underlying.fill_buffer_up_to(horizon, buffer) + } + fn doc(&self) -> u32 { self.underlying.doc() } @@ -138,6 +146,27 @@ impl Scorer for BoostScorer { fn score(&mut self) -> Score { self.underlying.score() * self.boost } + + #[inline] + fn can_score_doc(&self) -> bool { + self.underlying.can_score_doc() + } + + #[inline] + fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score { + self.underlying.score_doc(doc, term_freq) * self.boost + } + + #[inline] + fn fill_buffer_up_to_with_term_freqs( + &mut self, + horizon: DocId, + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + self.underlying + .fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs) + } } #[cfg(test)] diff --git a/src/query/const_score_query.rs b/src/query/const_score_query.rs index d07e6a96f4..87c016f9a1 100644 --- a/src/query/const_score_query.rs +++ b/src/query/const_score_query.rs @@ -141,6 +141,16 @@ impl Scorer for ConstScorer { fn score(&mut self) -> Score { self.score } + + #[inline] + fn can_score_doc(&self) -> bool { + true + } + + #[inline] + fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score { + self.score + } } #[cfg(test)] diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs index 2b4b54c008..dbb1b8aab0 100644 --- a/src/query/disjunction.rs +++ b/src/query/disjunction.rs @@ -315,6 +315,20 @@ mod tests { fn score(&mut self) -> Score { self.foo.get(self.cursor).map(|x| x.1).unwrap_or(0.0) } + + #[inline] + fn can_score_doc(&self) -> bool { + true + } + + #[inline] + fn score_doc(&mut self, doc: DocId, _term_freq: u32) -> Score { + self.foo + .iter() + .find(|(candidate_doc, _)| *candidate_doc == doc) + .map(|(_, score)| *score) + .unwrap_or(0.0) + } } #[test] diff --git a/src/query/empty_query.rs b/src/query/empty_query.rs index 2fa1772bdc..1a817270bd 100644 --- a/src/query/empty_query.rs +++ b/src/query/empty_query.rs @@ -59,6 +59,16 @@ impl Scorer for EmptyScorer { fn score(&mut self) -> Score { 0.0 } + + #[inline] + fn can_score_doc(&self) -> bool { + true + } + + #[inline] + fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score { + 0.0 + } } #[cfg(test)] diff --git a/src/query/score_combiner.rs b/src/query/score_combiner.rs index 2fe760c3d8..6ba7fcf86c 100644 --- a/src/query/score_combiner.rs +++ b/src/query/score_combiner.rs @@ -1,5 +1,40 @@ +use crate::docset::{DocSet, TERMINATED}; use crate::query::Scorer; -use crate::Score; +use crate::{DocId, Score}; + +struct ScoreOnlyScorer { + doc: DocId, + score: Score, +} + +impl DocSet for ScoreOnlyScorer { + fn advance(&mut self) -> DocId { + self.doc = TERMINATED; + TERMINATED + } + + fn doc(&self) -> DocId { + self.doc + } + + fn size_hint(&self) -> u32 { + 1 + } +} + +impl Scorer for ScoreOnlyScorer { + fn score(&mut self) -> Score { + self.score + } + + fn can_score_doc(&self) -> bool { + true + } + + fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score { + self.score + } +} /// The `ScoreCombiner` trait defines how to compute /// an overall score given a list of scores. @@ -10,6 +45,17 @@ pub trait ScoreCombiner: Default + Clone + Send + Copy + 'static { /// or not. fn update(&mut self, scorer: &mut TScorer); + /// Aggregates the score combiner with an already computed score. + fn update_score(&mut self, doc: DocId, score: Score) { + let mut scorer = ScoreOnlyScorer { doc, score }; + self.update(&mut scorer); + } + + /// Returns true if this combiner needs scorer scores to compute its state. + fn requires_scoring() -> bool { + true + } + /// Clears the score combiner state back to its initial state. fn clear(&mut self); @@ -27,6 +73,12 @@ pub struct DoNothingCombiner; impl ScoreCombiner for DoNothingCombiner { fn update(&mut self, _scorer: &mut TScorer) {} + fn update_score(&mut self, _doc: DocId, _score: Score) {} + + fn requires_scoring() -> bool { + false + } + fn clear(&mut self) {} #[inline] @@ -42,10 +94,16 @@ pub struct SumCombiner { } impl ScoreCombiner for SumCombiner { + #[inline] fn update(&mut self, scorer: &mut TScorer) { self.score += scorer.score(); } + #[inline] + fn update_score(&mut self, _doc: DocId, score: Score) { + self.score += score; + } + fn clear(&mut self) { self.score = 0.0; } @@ -77,12 +135,19 @@ impl DisjunctionMaxCombiner { } impl ScoreCombiner for DisjunctionMaxCombiner { + #[inline] fn update(&mut self, scorer: &mut TScorer) { let score = scorer.score(); self.max = Score::max(score, self.max); self.sum += score; } + #[inline] + fn update_score(&mut self, _doc: DocId, score: Score) { + self.max = Score::max(score, self.max); + self.sum += score; + } + fn clear(&mut self) { self.max = 0.0; self.sum = 0.0; diff --git a/src/query/scorer.rs b/src/query/scorer.rs index e91fc2fbce..9ffd69892e 100644 --- a/src/query/scorer.rs +++ b/src/query/scorer.rs @@ -2,8 +2,8 @@ use std::ops::DerefMut; use downcast_rs::impl_downcast; -use crate::docset::DocSet; -use crate::Score; +use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN}; +use crate::{DocId, Score}; /// Scored set of documents matching a query within a specific segment. /// @@ -13,6 +13,36 @@ pub trait Scorer: downcast_rs::Downcast + DocSet + 'static { /// /// This method will perform a bit of computation and is not cached. fn score(&mut self) -> Score; + + /// Returns true if [`Scorer::score_doc`] can score buffered docs without + /// repositioning the scorer. + /// + /// Scorers whose [`Scorer::score_doc`] needs term frequencies must also override + /// [`Scorer::fill_buffer_up_to_with_term_freqs`]. + fn can_score_doc(&self) -> bool { + false + } + + /// Returns the score for `doc` with its term frequency. + fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score { + panic!( + "score_doc is not supported by this scorer. You need check can_score_doc() before \ + calling this method." + ) + } + + /// Fills docs up to `horizon`. + /// + /// The default implementation does not fill `term_freqs`. Scorers whose + /// [`Scorer::score_doc`] reads term frequencies must override this method. + fn fill_buffer_up_to_with_term_freqs( + &mut self, + horizon: DocId, + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + _term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + DocSet::fill_buffer_up_to(self, horizon, docs) + } } impl_downcast!(Scorer); @@ -22,4 +52,25 @@ impl Scorer for Box { fn score(&mut self) -> Score { self.deref_mut().score() } + + #[inline] + fn can_score_doc(&self) -> bool { + self.as_ref().can_score_doc() + } + + #[inline] + fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score { + self.deref_mut().score_doc(doc, term_freq) + } + + #[inline] + fn fill_buffer_up_to_with_term_freqs( + &mut self, + horizon: DocId, + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + self.deref_mut() + .fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs) + } } diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index 20512f7b42..492bb83e8e 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -1,4 +1,4 @@ -use crate::docset::DocSet; +use crate::docset::{DocSet, COLLECT_BLOCK_BUFFER_LEN}; use crate::fieldnorm::FieldNormReader; use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings}; use crate::query::bm25::Bm25Weight; @@ -147,6 +147,27 @@ impl Scorer for TermScorer { let term_freq = self.term_freq(); self.similarity_weight.score(fieldnorm_id, term_freq) } + + #[inline] + fn can_score_doc(&self) -> bool { + true + } + + #[inline] + fn score_doc(&mut self, doc: DocId, term_freq: u32) -> Score { + let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc); + self.similarity_weight.score(fieldnorm_id, term_freq) + } + + fn fill_buffer_up_to_with_term_freqs( + &mut self, + horizon: DocId, + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + ) -> usize { + self.postings + .fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs) + } } #[cfg(test)] diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index e4cfe0ba38..f2cb0bcc46 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -10,23 +10,7 @@ use crate::{DocId, Score}; // of upcoming document IDs (the "horizon"). const HORIZON_NUM_TINYBITSETS: usize = HORIZON as usize / 64; const HORIZON: u32 = 64u32 * 64u32; - -// `drain_filter` is not stable yet. -// This function is similar except that it does is not unstable, and -// it does not keep the original vector ordering. -// -// Elements are dropped and not yielded. -fn unordered_drain_filter(v: &mut Vec, mut predicate: P) -where P: FnMut(&mut T) -> bool { - let mut i = 0; - while i < v.len() { - if predicate(&mut v[i]) { - v.swap_remove(i); - } else { - i += 1; - } - } -} +const GROUPED_INSERT_MAX_BUCKET_SPAN: u32 = 2; /// Creates a `DocSet` that iterate through the union of two or more `DocSet`s. pub struct BufferedUnionScorer { @@ -53,31 +37,213 @@ pub struct BufferedUnionScorer { score: Score, /// Number of documents in the segment. num_docs: u32, + /// Scratch buffer for block-based refill. + refill_docs: [DocId; COLLECT_BLOCK_BUFFER_LEN], + /// Scratch buffer for term frequencies matching `refill_docs`. + refill_term_freqs: [u32; COLLECT_BLOCK_BUFFER_LEN], + /// Whether all children support scoring buffered docs after advancing. + use_score_doc_refill: bool, } -fn refill( - scorers: &mut Vec, +#[inline] +fn union_bucket( + bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], + bucket_pos: u32, + tinyset: TinySet, +) { + debug_assert!((bucket_pos as usize) < HORIZON_NUM_TINYBITSETS); + // `bucket` comes from a doc delta below `HORIZON`; there are exactly + // `HORIZON / 64` buckets in the refill window. + bitsets[bucket_pos as usize] = bitsets[bucket_pos as usize].union(tinyset); +} + +#[inline] +fn insert_delta(bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], delta: DocId) { + debug_assert!(delta < HORIZON); + // `delta < HORIZON`, so `delta / 64` is in the bitset array. The bit + // offset is reduced modulo 64 before being inserted in the TinySet. + bitsets[delta as usize / 64].insert_mut(delta % 64u32); +} + +fn insert_and_score_full_buffer( + scorer: &mut TScorer, + docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN], bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], score_combiner: &mut [TScoreCombiner; HORIZON as usize], min_doc: DocId, ) { - unordered_drain_filter(scorers, |scorer| { - let horizon = min_doc + HORIZON; - loop { - let doc = scorer.doc(); - if doc >= horizon { - return false; + debug_assert!(docs.windows(2).all(|pair| pair[0] < pair[1])); + debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc < HORIZON); + + let first_delta = docs[0] - min_doc; + let last_delta = docs[COLLECT_BLOCK_BUFFER_LEN - 1] - min_doc; + let first_bucket = first_delta / 64; + let last_bucket = last_delta / 64; + + // Common for very dense scorers: 64 distinct doc ids in one 64-doc bucket + // means all bits in that bucket are present. + if first_bucket == last_bucket { + union_bucket(bitsets, first_bucket, TinySet::full()); + score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc); + return; + } + + // 64 sorted distinct integers spanning exactly 64 values are consecutive. + // If they cross a TinySet boundary, this is just the suffix of the first + // bucket plus the prefix of the second bucket. + if last_delta - first_delta == COLLECT_BLOCK_BUFFER_LEN as u32 - 1 { + union_bucket( + bitsets, + first_bucket, + TinySet::range_greater_or_equal(first_delta % 64u32), + ); + union_bucket( + bitsets, + last_bucket, + TinySet::range_lower((last_delta + 1) % 64u32), + ); + score_full_buffer(scorer, docs, term_freqs, score_combiner, min_doc); + return; + } + + // Grouping wins only for very dense buffers that hit the same TinySet many + // times. Once the 64 docs are spread farther, a straight pass is cheaper. + if last_bucket - first_bucket <= GROUPED_INSERT_MAX_BUCKET_SPAN { + let mut bucket = first_bucket; + let mut tinyset = TinySet::empty(); + for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) { + let delta = doc - min_doc; + let delta_bucket = delta / 64; + if delta_bucket != bucket { + union_bucket(bitsets, bucket, tinyset); + bucket = delta_bucket; + tinyset = TinySet::empty(); } - // add this document + tinyset.insert_mut(delta % 64u32); + let score = scorer.score_doc(doc, term_freq); + update_score_combiner(score_combiner, delta, doc, score); + } + union_bucket(bitsets, bucket, tinyset); + } else { + for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) { let delta = doc - min_doc; - bitsets[(delta / 64) as usize].insert_mut(delta % 64u32); - score_combiner[delta as usize].update(scorer); - if scorer.advance() == TERMINATED { - // remove the docset, it has been entirely consumed. - return true; + insert_delta(bitsets, delta); + // TODO: score_doc access the field_norm reader for each _term_, instead of once per + // doc. We could optimize this by caching the field norm for the doc, and + // reusing it for all terms in the doc. + let score = scorer.score_doc(doc, term_freq); + update_score_combiner(score_combiner, delta, doc, score); + } + } +} + +#[inline] +fn update_score_combiner( + score_combiner: &mut [TScoreCombiner; HORIZON as usize], + delta: DocId, + doc: DocId, + score: Score, +) { + debug_assert!(delta < HORIZON); + // Full and partial refill only buffer docs below `horizon`, so their + // deltas are always in the score-combiner window. + score_combiner[delta as usize].update_score(doc, score); +} + +fn score_full_buffer( + scorer: &mut TScorer, + docs: &[DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &[u32; COLLECT_BLOCK_BUFFER_LEN], + score_combiner: &mut [TScoreCombiner; HORIZON as usize], + min_doc: DocId, +) { + for (&doc, &term_freq) in docs.iter().zip(term_freqs.iter()) { + let score = scorer.score_doc(doc, term_freq); + update_score_combiner(score_combiner, doc - min_doc, doc, score); + } +} + +fn refill_scorer_with_score_docs( + scorer: &mut TScorer, + bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], + score_combiner: &mut [TScoreCombiner; HORIZON as usize], + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + min_doc: DocId, + horizon: DocId, +) { + loop { + let len = scorer.fill_buffer_up_to_with_term_freqs(horizon, docs, term_freqs); + if len == COLLECT_BLOCK_BUFFER_LEN { + debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] != TERMINATED); + debug_assert!(docs[COLLECT_BLOCK_BUFFER_LEN - 1] < horizon); + insert_and_score_full_buffer( + scorer, + docs, + term_freqs, + bitsets, + score_combiner, + min_doc, + ); + } else { + for (&doc, &term_freq) in docs[..len].iter().zip(term_freqs[..len].iter()) { + let delta = doc - min_doc; + insert_delta(bitsets, delta); + let score = scorer.score_doc(doc, term_freq); + update_score_combiner(score_combiner, delta, doc, score); } + break; + } + } +} + +fn refill_scorer_from_current_doc( + scorer: &mut TScorer, + bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], + score_combiner: &mut [TScoreCombiner; HORIZON as usize], + min_doc: DocId, + horizon: DocId, +) { + loop { + let doc = scorer.doc(); + if doc >= horizon { + break; } - }); + let delta = doc - min_doc; + insert_delta(bitsets, delta); + debug_assert!(delta < HORIZON); + score_combiner[delta as usize].update(scorer); + scorer.advance(); + } +} + +fn refill( + scorers: &mut Vec, + bitsets: &mut [TinySet; HORIZON_NUM_TINYBITSETS], + score_combiner: &mut [TScoreCombiner; HORIZON as usize], + docs: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN], + term_freqs: &mut [u32; COLLECT_BLOCK_BUFFER_LEN], + min_doc: DocId, + use_score_doc_refill: bool, +) { + let horizon = min_doc + HORIZON; + for scorer in scorers.iter_mut() { + if use_score_doc_refill { + refill_scorer_with_score_docs( + scorer, + bitsets, + score_combiner, + docs, + term_freqs, + min_doc, + horizon, + ); + } else { + refill_scorer_from_current_doc(scorer, bitsets, score_combiner, min_doc, horizon); + } + } + scorers.retain(|scorer| scorer.doc() != TERMINATED); } impl BufferedUnionScorer { @@ -87,6 +253,8 @@ impl BufferedUnionScorer TScoreCombiner, num_docs: u32, ) -> BufferedUnionScorer { + let use_score_doc_refill = + TScoreCombiner::requires_scoring() && docsets.iter().all(Scorer::can_score_doc); let non_empty_docsets: Vec = docsets .into_iter() .filter(|docset| docset.doc() != TERMINATED) @@ -100,6 +268,9 @@ impl BufferedUnionScorer BufferedUnionScorer= to the target. - unordered_drain_filter(&mut self.docsets, |docset| { + for docset in &mut self.docsets { if docset.doc() < target { docset.seek(target); } - docset.doc() == TERMINATED - }); + } + self.docsets.retain(|docset| docset.doc() != TERMINATED); // at this point all of the docsets // are positioned on a doc >= to the target. diff --git a/src/query/union/mod.rs b/src/query/union/mod.rs index 825ee219bc..1f11aaf3db 100644 --- a/src/query/union/mod.rs +++ b/src/query/union/mod.rs @@ -10,6 +10,8 @@ pub use simple_union::SimpleUnion; mod tests { use std::collections::BTreeSet; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use common::BitSet; @@ -18,8 +20,8 @@ mod tests { use crate::postings::tests::test_skip_against_unoptimized; use crate::query::score_combiner::DoNothingCombiner; use crate::query::union::bitset_union::BitSetPostingUnion; - use crate::query::{BitSetDocSet, ConstScorer, VecDocSet}; - use crate::{tests, DocId}; + use crate::query::{BitSetDocSet, ConstScorer, Scorer, VecDocSet}; + use crate::{tests, DocId, Score}; fn vec_doc_set_from_docs_list( docs_list: &[Vec], @@ -66,6 +68,61 @@ mod tests { } BitSetDocSet::from(doc_bitset) } + + struct CountingScorer { + docset: VecDocSet, + score_calls: Arc, + score_doc_calls: Arc, + } + + impl CountingScorer { + fn new( + doc_ids: Vec, + score_calls: Arc, + score_doc_calls: Arc, + ) -> Self { + CountingScorer { + docset: VecDocSet::from(doc_ids), + score_calls, + score_doc_calls, + } + } + } + + impl DocSet for CountingScorer { + fn advance(&mut self) -> DocId { + self.docset.advance() + } + + fn seek(&mut self, target: DocId) -> DocId { + self.docset.seek(target) + } + + fn doc(&self) -> DocId { + self.docset.doc() + } + + fn size_hint(&self) -> u32 { + self.docset.size_hint() + } + } + + impl Scorer for CountingScorer { + fn score(&mut self) -> Score { + self.score_calls.fetch_add(1, Ordering::SeqCst); + 1.0 + } + + fn can_score_doc(&self) -> bool { + true + } + + fn score_doc(&mut self, _doc: DocId, _term_freq: u32) -> Score { + self.score_doc_calls.fetch_add(1, Ordering::SeqCst); + 1.0 + } + } + fn aux_test_union(docs_list: &[Vec]) { for constructor in [ posting_list_union_from_docs_list, @@ -168,6 +225,22 @@ mod tests { ]); } + #[test] + fn test_do_nothing_combiner_does_not_score_buffered_docs() { + let score_calls = Arc::new(AtomicUsize::new(0)); + let score_doc_calls = Arc::new(AtomicUsize::new(0)); + let scorers = vec![ + CountingScorer::new(vec![1, 3, 5], score_calls.clone(), score_doc_calls.clone()), + CountingScorer::new(vec![2, 3, 6], score_calls.clone(), score_doc_calls.clone()), + ]; + + let mut union = BufferedUnionScorer::build(scorers, DoNothingCombiner::default, 10); + + assert_eq!(union.count_including_deleted(), 5); + assert_eq!(score_calls.load(Ordering::SeqCst), 0); + assert_eq!(score_doc_calls.load(Ordering::SeqCst), 0); + } + fn test_aux_union_skip(docs_list: &[Vec], skip_targets: Vec) { for constructor in [ posting_list_union_from_docs_list,