mirror of
https://github.com/infinilabs/coco-app.git
synced 2026-05-18 13:14:53 +02:00
fix: apply local results weight to scores generated by rerank() (#1036)
* fix: apply local results weight to scores generated by rerank() We have two kinds of scores: 1. Scores calculated by query sources 2. In the post-search phase, we re-score the documents in `boosted_levenshtein_rerank()` When applying the setting "local search results weight", previous implementation modified the first score. Its effect was not noticeable, since we collect documents evenly from the hits returned by query sources. This commit refactors the implementation so that the setting adjusts the second score, making the effect obvious. * refactor: keep O(1) lookup && update var names * chore: one more var name to update
This commit is contained in:
@@ -21,6 +21,7 @@ Information about release notes of Coco App is provided here.
|
||||
- fix: avoid recentering when resizing to compact after leaving extension #1030
|
||||
- fix: fix incorrect window position after hiding #1034
|
||||
- fix: fix arrow keys not working after closing the context menu #1035
|
||||
- fix: apply local results weight to scores generated by rerank() #1036
|
||||
|
||||
### ✈️ Improvements
|
||||
|
||||
|
||||
@@ -304,20 +304,6 @@ async fn query_coco_fusion_multi_query_sources(
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
* Apply settings: local query source weight
|
||||
*/
|
||||
let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle);
|
||||
// Scores remain unchanged if it is 1.0
|
||||
if local_query_source_weight != 1.0 {
|
||||
for (query_source, hits) in all_hits_grouped_by_query_source.iter_mut() {
|
||||
if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE {
|
||||
hits.iter_mut()
|
||||
.for_each(|hit| hit.score = hit.score * local_query_source_weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Sort hits within each source by score (descending) in case data sources
|
||||
* do not sort them
|
||||
@@ -336,28 +322,29 @@ async fn query_coco_fusion_multi_query_sources(
|
||||
* 1. All sources have hits returned
|
||||
* 2. Query sources with many hits won't dominate
|
||||
*/
|
||||
let mut final_hits_grouped_by_source_id: HashMap<String, Vec<QueryHits>> = HashMap::new();
|
||||
let mut pruned: HashMap<&str, &[QueryHits]> = HashMap::new();
|
||||
let mut final_hits_grouped_by_query_source: HashMap<QuerySource, Vec<QueryHits>> =
|
||||
HashMap::new();
|
||||
let mut pruned: HashMap<&QuerySource, &[QueryHits]> = HashMap::new();
|
||||
|
||||
// Include at least 2 hits from each query source
|
||||
let max_hits_per_source = (size as usize / n_sources).max(2);
|
||||
for (query_source, hits) in all_hits_grouped_by_query_source.iter() {
|
||||
let hits_taken = if hits.len() > max_hits_per_source {
|
||||
pruned.insert(&query_source.id, &hits[max_hits_per_source..]);
|
||||
pruned.insert(&query_source, &hits[max_hits_per_source..]);
|
||||
hits[0..max_hits_per_source].to_vec()
|
||||
} else {
|
||||
hits.clone()
|
||||
};
|
||||
|
||||
final_hits_grouped_by_source_id.insert(query_source.id.clone(), hits_taken);
|
||||
final_hits_grouped_by_query_source.insert(query_source.clone(), hits_taken);
|
||||
}
|
||||
|
||||
let final_hits_len = final_hits_grouped_by_source_id
|
||||
let final_hits_len = final_hits_grouped_by_query_source
|
||||
.iter()
|
||||
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
|
||||
.fold(0, |acc: usize, (_source, hits)| acc + hits.len());
|
||||
let pruned_len = pruned
|
||||
.iter()
|
||||
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
|
||||
.fold(0, |acc: usize, (_source, hits)| acc + hits.len());
|
||||
|
||||
/*
|
||||
* If we still need more hits, take the highest-scoring from `pruned`
|
||||
@@ -371,8 +358,8 @@ async fn query_coco_fusion_multi_query_sources(
|
||||
let n_take = n_have.min(n_need);
|
||||
|
||||
for _ in 0..n_take {
|
||||
let mut highest_score_hit: Option<(&str, &QueryHits)> = None;
|
||||
for (source_id, sorted_hits) in pruned.iter_mut() {
|
||||
let mut highest_score_hit: Option<(&QuerySource, &QueryHits)> = None;
|
||||
for (source, sorted_hits) in pruned.iter_mut() {
|
||||
if sorted_hits.is_empty() {
|
||||
continue;
|
||||
}
|
||||
@@ -387,7 +374,7 @@ async fn query_coco_fusion_multi_query_sources(
|
||||
};
|
||||
|
||||
if have_higher_score_hit {
|
||||
highest_score_hit = Some((*source_id, hit));
|
||||
highest_score_hit = Some((*source, hit));
|
||||
|
||||
// Advance sorted_hits by 1 element, if have
|
||||
if sorted_hits.len() == 1 {
|
||||
@@ -398,24 +385,38 @@ async fn query_coco_fusion_multi_query_sources(
|
||||
}
|
||||
}
|
||||
|
||||
let (source_id, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set");
|
||||
let (source, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set");
|
||||
|
||||
final_hits_grouped_by_source_id
|
||||
.get_mut(source_id)
|
||||
.expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_source_id`, so it should exist")
|
||||
final_hits_grouped_by_query_source
|
||||
.get_mut(source)
|
||||
.expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_query_source`, so it should exist")
|
||||
.push(hit.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Re-rank the final hits
|
||||
* Re-rank (re-score) the final hits
|
||||
*/
|
||||
if n_sources > 1 {
|
||||
boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_source_id);
|
||||
boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_query_source);
|
||||
}
|
||||
|
||||
/*
|
||||
* Apply settings "local search results weight" to the scores
|
||||
*/
|
||||
let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle);
|
||||
// Scores remain unchanged if it is 1.0
|
||||
if local_query_source_weight != 1.0 {
|
||||
for (query_source, hits) in final_hits_grouped_by_query_source.iter_mut() {
|
||||
if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE {
|
||||
hits.iter_mut()
|
||||
.for_each(|hit| hit.score = hit.score * local_query_source_weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut final_hits = Vec::new();
|
||||
for (_source_id, hits) in final_hits_grouped_by_source_id {
|
||||
for (_source, hits) in final_hits_grouped_by_query_source {
|
||||
final_hits.extend(hits);
|
||||
}
|
||||
|
||||
@@ -451,13 +452,13 @@ use strsim::levenshtein;
|
||||
|
||||
fn boosted_levenshtein_rerank(
|
||||
query: &str,
|
||||
all_hits_grouped_by_source_id: &mut HashMap<String, Vec<QueryHits>>,
|
||||
final_hits_grouped_by_query_source: &mut HashMap<QuerySource, Vec<QueryHits>>,
|
||||
) {
|
||||
let query_lower = query.to_lowercase();
|
||||
|
||||
for (source_id, hits) in all_hits_grouped_by_source_id.iter_mut() {
|
||||
for (source, hits) in final_hits_grouped_by_query_source.iter_mut() {
|
||||
// Skip special sources like calculator
|
||||
if source_id == crate::extension::built_in::calculator::DATA_SOURCE_ID {
|
||||
if source.id == crate::extension::built_in::calculator::DATA_SOURCE_ID {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user