fix: apply local results weight to scores generated by rerank() (#1036)

* fix: apply local results weight to scores generated by rerank()

We have two kinds of scores:

1. Scores calculated by query sources
2. In the post-search phase, we re-score the documents in
   `boosted_levenshtein_rerank()`

When applying the setting "local search results weight", previous
implementation modified the first score. Its effect was not noticeable,
since we collect documents evenly from the hits returned by query sources.

This commit refactors the implementation so that the setting adjusts the
second score, making the effect obvious.

* refactor: keep O(1) lookup && update var names

* chore: one more var name to update
This commit is contained in:
SteveLauC
2025-12-30 14:26:01 +08:00
committed by GitHub
parent b4856e61e6
commit 90e787058c
2 changed files with 36 additions and 34 deletions

View File

@@ -21,6 +21,7 @@ Information about release notes of Coco App is provided here.
- fix: avoid recentering when resizing to compact after leaving extension #1030
- fix: fix incorrect window position after hiding #1034
- fix: fix arrow keys not working after closing the context menu #1035
- fix: apply local results weight to scores generated by rerank() #1036
### ✈️ Improvements

View File

@@ -304,20 +304,6 @@ async fn query_coco_fusion_multi_query_sources(
});
}
/*
* Apply settings: local query source weight
*/
let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle);
// Scores remain unchanged if it is 1.0
if local_query_source_weight != 1.0 {
for (query_source, hits) in all_hits_grouped_by_query_source.iter_mut() {
if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE {
hits.iter_mut()
.for_each(|hit| hit.score = hit.score * local_query_source_weight);
}
}
}
/*
* Sort hits within each source by score (descending) in case data sources
* do not sort them
@@ -336,28 +322,29 @@ async fn query_coco_fusion_multi_query_sources(
* 1. All sources have hits returned
* 2. Query sources with many hits won't dominate
*/
let mut final_hits_grouped_by_source_id: HashMap<String, Vec<QueryHits>> = HashMap::new();
let mut pruned: HashMap<&str, &[QueryHits]> = HashMap::new();
let mut final_hits_grouped_by_query_source: HashMap<QuerySource, Vec<QueryHits>> =
HashMap::new();
let mut pruned: HashMap<&QuerySource, &[QueryHits]> = HashMap::new();
// Include at least 2 hits from each query source
let max_hits_per_source = (size as usize / n_sources).max(2);
for (query_source, hits) in all_hits_grouped_by_query_source.iter() {
let hits_taken = if hits.len() > max_hits_per_source {
pruned.insert(&query_source.id, &hits[max_hits_per_source..]);
pruned.insert(&query_source, &hits[max_hits_per_source..]);
hits[0..max_hits_per_source].to_vec()
} else {
hits.clone()
};
final_hits_grouped_by_source_id.insert(query_source.id.clone(), hits_taken);
final_hits_grouped_by_query_source.insert(query_source.clone(), hits_taken);
}
let final_hits_len = final_hits_grouped_by_source_id
let final_hits_len = final_hits_grouped_by_query_source
.iter()
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
.fold(0, |acc: usize, (_source, hits)| acc + hits.len());
let pruned_len = pruned
.iter()
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
.fold(0, |acc: usize, (_source, hits)| acc + hits.len());
/*
* If we still need more hits, take the highest-scoring from `pruned`
@@ -371,8 +358,8 @@ async fn query_coco_fusion_multi_query_sources(
let n_take = n_have.min(n_need);
for _ in 0..n_take {
let mut highest_score_hit: Option<(&str, &QueryHits)> = None;
for (source_id, sorted_hits) in pruned.iter_mut() {
let mut highest_score_hit: Option<(&QuerySource, &QueryHits)> = None;
for (source, sorted_hits) in pruned.iter_mut() {
if sorted_hits.is_empty() {
continue;
}
@@ -387,7 +374,7 @@ async fn query_coco_fusion_multi_query_sources(
};
if have_higher_score_hit {
highest_score_hit = Some((*source_id, hit));
highest_score_hit = Some((*source, hit));
// Advance sorted_hits by 1 element, if have
if sorted_hits.len() == 1 {
@@ -398,24 +385,38 @@ async fn query_coco_fusion_multi_query_sources(
}
}
let (source_id, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set");
let (source, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set");
final_hits_grouped_by_source_id
.get_mut(source_id)
.expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_source_id`, so it should exist")
final_hits_grouped_by_query_source
.get_mut(source)
.expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_query_source`, so it should exist")
.push(hit.clone());
}
}
/*
* Re-rank the final hits
* Re-rank (re-score) the final hits
*/
if n_sources > 1 {
boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_source_id);
boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_query_source);
}
/*
* Apply settings "local search results weight" to the scores
*/
let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle);
// Scores remain unchanged if it is 1.0
if local_query_source_weight != 1.0 {
for (query_source, hits) in final_hits_grouped_by_query_source.iter_mut() {
if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE {
hits.iter_mut()
.for_each(|hit| hit.score = hit.score * local_query_source_weight);
}
}
}
let mut final_hits = Vec::new();
for (_source_id, hits) in final_hits_grouped_by_source_id {
for (_source, hits) in final_hits_grouped_by_query_source {
final_hits.extend(hits);
}
@@ -451,13 +452,13 @@ use strsim::levenshtein;
fn boosted_levenshtein_rerank(
query: &str,
all_hits_grouped_by_source_id: &mut HashMap<String, Vec<QueryHits>>,
final_hits_grouped_by_query_source: &mut HashMap<QuerySource, Vec<QueryHits>>,
) {
let query_lower = query.to_lowercase();
for (source_id, hits) in all_hits_grouped_by_source_id.iter_mut() {
for (source, hits) in final_hits_grouped_by_query_source.iter_mut() {
// Skip special sources like calculator
if source_id == crate::extension::built_in::calculator::DATA_SOURCE_ID {
if source.id == crate::extension::built_in::calculator::DATA_SOURCE_ID {
continue;
}