diff --git a/docs/content.en/docs/release-notes/_index.md b/docs/content.en/docs/release-notes/_index.md index c7eb0480..a81ad133 100644 --- a/docs/content.en/docs/release-notes/_index.md +++ b/docs/content.en/docs/release-notes/_index.md @@ -21,6 +21,7 @@ Information about release notes of Coco App is provided here. - fix: avoid recentering when resizing to compact after leaving extension #1030 - fix: fix incorrect window position after hiding #1034 - fix: fix arrow keys not working after closing the context menu #1035 +- fix: apply local results weight to scores generated by rerank() #1036 ### ✈️ Improvements diff --git a/src-tauri/src/search/mod.rs b/src-tauri/src/search/mod.rs index a73bd6ed..4302aebd 100644 --- a/src-tauri/src/search/mod.rs +++ b/src-tauri/src/search/mod.rs @@ -304,20 +304,6 @@ async fn query_coco_fusion_multi_query_sources( }); } - /* - * Apply settings: local query source weight - */ - let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle); - // Scores remain unchanged if it is 1.0 - if local_query_source_weight != 1.0 { - for (query_source, hits) in all_hits_grouped_by_query_source.iter_mut() { - if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE { - hits.iter_mut() - .for_each(|hit| hit.score = hit.score * local_query_source_weight); - } - } - } - /* * Sort hits within each source by score (descending) in case data sources * do not sort them @@ -336,28 +322,29 @@ async fn query_coco_fusion_multi_query_sources( * 1. All sources have hits returned * 2. Query sources with many hits won't dominate */ - let mut final_hits_grouped_by_source_id: HashMap> = HashMap::new(); - let mut pruned: HashMap<&str, &[QueryHits]> = HashMap::new(); + let mut final_hits_grouped_by_query_source: HashMap> = + HashMap::new(); + let mut pruned: HashMap<&QuerySource, &[QueryHits]> = HashMap::new(); // Include at least 2 hits from each query source let max_hits_per_source = (size as usize / n_sources).max(2); for (query_source, hits) in all_hits_grouped_by_query_source.iter() { let hits_taken = if hits.len() > max_hits_per_source { - pruned.insert(&query_source.id, &hits[max_hits_per_source..]); + pruned.insert(&query_source, &hits[max_hits_per_source..]); hits[0..max_hits_per_source].to_vec() } else { hits.clone() }; - final_hits_grouped_by_source_id.insert(query_source.id.clone(), hits_taken); + final_hits_grouped_by_query_source.insert(query_source.clone(), hits_taken); } - let final_hits_len = final_hits_grouped_by_source_id + let final_hits_len = final_hits_grouped_by_query_source .iter() - .fold(0, |acc: usize, (_source_id, hits)| acc + hits.len()); + .fold(0, |acc: usize, (_source, hits)| acc + hits.len()); let pruned_len = pruned .iter() - .fold(0, |acc: usize, (_source_id, hits)| acc + hits.len()); + .fold(0, |acc: usize, (_source, hits)| acc + hits.len()); /* * If we still need more hits, take the highest-scoring from `pruned` @@ -371,8 +358,8 @@ async fn query_coco_fusion_multi_query_sources( let n_take = n_have.min(n_need); for _ in 0..n_take { - let mut highest_score_hit: Option<(&str, &QueryHits)> = None; - for (source_id, sorted_hits) in pruned.iter_mut() { + let mut highest_score_hit: Option<(&QuerySource, &QueryHits)> = None; + for (source, sorted_hits) in pruned.iter_mut() { if sorted_hits.is_empty() { continue; } @@ -387,7 +374,7 @@ async fn query_coco_fusion_multi_query_sources( }; if have_higher_score_hit { - highest_score_hit = Some((*source_id, hit)); + highest_score_hit = Some((*source, hit)); // Advance sorted_hits by 1 element, if have if sorted_hits.len() == 1 { @@ -398,24 +385,38 @@ async fn query_coco_fusion_multi_query_sources( } } - let (source_id, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set"); + let (source, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set"); - final_hits_grouped_by_source_id - .get_mut(source_id) - .expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_source_id`, so it should exist") + final_hits_grouped_by_query_source + .get_mut(source) + .expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_query_source`, so it should exist") .push(hit.clone()); } } /* - * Re-rank the final hits + * Re-rank (re-score) the final hits */ if n_sources > 1 { - boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_source_id); + boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_query_source); + } + + /* + * Apply settings "local search results weight" to the scores + */ + let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle); + // Scores remain unchanged if it is 1.0 + if local_query_source_weight != 1.0 { + for (query_source, hits) in final_hits_grouped_by_query_source.iter_mut() { + if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE { + hits.iter_mut() + .for_each(|hit| hit.score = hit.score * local_query_source_weight); + } + } } let mut final_hits = Vec::new(); - for (_source_id, hits) in final_hits_grouped_by_source_id { + for (_source, hits) in final_hits_grouped_by_query_source { final_hits.extend(hits); } @@ -451,13 +452,13 @@ use strsim::levenshtein; fn boosted_levenshtein_rerank( query: &str, - all_hits_grouped_by_source_id: &mut HashMap>, + final_hits_grouped_by_query_source: &mut HashMap>, ) { let query_lower = query.to_lowercase(); - for (source_id, hits) in all_hits_grouped_by_source_id.iter_mut() { + for (source, hits) in final_hits_grouped_by_query_source.iter_mut() { // Skip special sources like calculator - if source_id == crate::extension::built_in::calculator::DATA_SOURCE_ID { + if source.id == crate::extension::built_in::calculator::DATA_SOURCE_ID { continue; }