mirror of
https://github.com/infinilabs/coco-app.git
synced 2025-12-28 16:06:28 +01:00
refactor: refactoring application search (#134)
This commit is contained in:
97
src-tauri/Cargo.lock
generated
97
src-tauri/Cargo.lock
generated
@@ -492,16 +492,6 @@ dependencies = [
|
||||
"alloc-stdlib",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bstr"
|
||||
version = "1.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.16.0"
|
||||
@@ -671,6 +661,7 @@ dependencies = [
|
||||
"base64 0.13.1",
|
||||
"dirs 5.0.1",
|
||||
"futures",
|
||||
"fuzzy_prefix_search",
|
||||
"hostname",
|
||||
"lazy_static",
|
||||
"log",
|
||||
@@ -680,7 +671,6 @@ dependencies = [
|
||||
"pizza-common",
|
||||
"plist",
|
||||
"reqwest",
|
||||
"rust_search",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tauri",
|
||||
@@ -886,25 +876,6 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.20"
|
||||
@@ -984,7 +955,7 @@ dependencies = [
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.11.1",
|
||||
"strsim",
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
@@ -1638,6 +1609,12 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fuzzy_prefix_search"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a904d7ec1d39e73f21e8446175cfcd3d4265313919000044e300ecf8d9967dec"
|
||||
|
||||
[[package]]
|
||||
name = "fxhash"
|
||||
version = "0.2.1"
|
||||
@@ -1898,19 +1875,6 @@ dependencies = [
|
||||
"x11-dl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "globset"
|
||||
version = "0.4.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15f1ce686646e7f1e19bf7d5533fe443a45dbfb990e00629110797578b42fb19"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"bstr",
|
||||
"log",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gobject-sys"
|
||||
version = "0.18.0"
|
||||
@@ -2258,22 +2222,6 @@ dependencies = [
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ignore"
|
||||
version = "0.4.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d89fd380afde86567dfba715db065673989d6253f42b88179abd3eae47bda4b"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"globset",
|
||||
"log",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"same-file",
|
||||
"walkdir",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.25.5"
|
||||
@@ -2847,16 +2795,6 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
|
||||
dependencies = [
|
||||
"hermit-abi 0.3.9",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_enum"
|
||||
version = "0.7.3"
|
||||
@@ -4010,19 +3948,6 @@ dependencies = [
|
||||
"trim-in-place",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_search"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d27d7be20245d289c9dde663f06521de08663d73cbaefc45785aa65d02022378"
|
||||
dependencies = [
|
||||
"dirs 4.0.0",
|
||||
"ignore",
|
||||
"num_cpus",
|
||||
"regex",
|
||||
"strsim 0.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.24"
|
||||
@@ -4560,12 +4485,6 @@ dependencies = [
|
||||
"quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
|
||||
@@ -54,7 +54,7 @@ hostname = "0.3"
|
||||
plist = "1.7"
|
||||
base64 = "0.13"
|
||||
walkdir = "2"
|
||||
rust_search = "2.0.0"
|
||||
fuzzy_prefix_search = "0.2"
|
||||
|
||||
[profile.dev]
|
||||
incremental = true # Compile your binary in smaller steps.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use crate::common::document::Document;
|
||||
use reqwest::Response;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use crate::common::document::Document;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SearchResponse<T> {
|
||||
pub took: u64,
|
||||
@@ -64,7 +64,7 @@ pub async fn parse_search_hits<T>(
|
||||
where
|
||||
T: for<'de> Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
let response=parse_search_response(response).await?;
|
||||
let response = parse_search_response(response).await?;
|
||||
|
||||
Ok(response.hits.hits)
|
||||
}
|
||||
@@ -87,7 +87,7 @@ where
|
||||
Ok(parse_search_hits(response).await?.into_iter().map(|hit| (hit._source, hit._score)).collect())
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone,Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct SearchQuery {
|
||||
pub from: u64,
|
||||
pub size: u64,
|
||||
@@ -104,35 +104,36 @@ impl SearchQuery {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone, Serialize)]
|
||||
pub struct QuerySource{
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct QuerySource {
|
||||
pub r#type: String, //coco-server/local/ etc.
|
||||
pub id: String, //coco server's id
|
||||
pub name: String, //coco server's name, local computer name, etc.
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct QueryHits {
|
||||
pub source: Option<QuerySource>,
|
||||
pub score: f64,
|
||||
pub document: Document,
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone, Serialize)]
|
||||
pub struct FailedRequest{
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct FailedRequest {
|
||||
pub source: QuerySource,
|
||||
pub status: u16,
|
||||
pub error: Option<String>,
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct QueryResponse {
|
||||
pub source: QuerySource,
|
||||
pub hits: Vec<(Document,f64)>,
|
||||
pub hits: Vec<(Document, f64)>,
|
||||
pub total_hits: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct MultiSourceQueryResponse {
|
||||
pub failed: Vec<FailedRequest>,
|
||||
pub hits: Vec<QueryHits>,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use crate::common::document::{DataSourceReference, Document};
|
||||
use crate::common::search::{QueryResponse, QuerySource, SearchQuery};
|
||||
use crate::common::traits::{SearchError, SearchSource};
|
||||
use crate::local::LOCAL_QUERY_SOURCE_TYPE;
|
||||
use async_trait::async_trait;
|
||||
use base64::encode;
|
||||
use dirs::data_dir;
|
||||
use fuzzy_prefix_search::Trie;
|
||||
use hostname;
|
||||
use plist::Value;
|
||||
use rust_search::SearchBuilder;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Stdio};
|
||||
@@ -17,7 +18,7 @@ pub struct ApplicationSearchSource {
|
||||
base_score: f64,
|
||||
app_dirs: Vec<PathBuf>,
|
||||
icons: HashMap<String, PathBuf>, // Map app names to their icon paths
|
||||
search_locations: Vec<String>, // Cached search locations
|
||||
application_paths: fuzzy_prefix_search::Trie<String>, // Cached search locations
|
||||
}
|
||||
|
||||
/// Extracts the app icon from the `.app` bundle or system icons and converts it to PNG format.
|
||||
@@ -139,10 +140,12 @@ fn convert_icns_to_png(app_dir: &Path, icns_path: &Path, app_data_folder: &Path)
|
||||
let icon_storage_dir = app_data_folder.join("coco-appIcons");
|
||||
fs::create_dir_all(&icon_storage_dir).ok();
|
||||
|
||||
// dbg!("app_name:", &app_name);
|
||||
|
||||
let output_png_path = icon_storage_dir.join(format!("{}.png", app_name));
|
||||
|
||||
if output_png_path.exists() {
|
||||
return Some(output_png_path);
|
||||
}
|
||||
|
||||
// dbg!("Converting ICNS to PNG:", &output_png_path);
|
||||
|
||||
// Run the `sips` command to convert the ICNS to PNG
|
||||
@@ -204,10 +207,7 @@ impl ApplicationSearchSource {
|
||||
let mut icons = HashMap::new();
|
||||
|
||||
// Collect search locations as strings
|
||||
let search_locations: Vec<String> = app_dirs
|
||||
.iter()
|
||||
.map(|dir| dir.to_string_lossy().to_string())
|
||||
.collect();
|
||||
let mut applications = Trie::new();
|
||||
|
||||
// Iterate over the directories to find .app files and extract icons
|
||||
for app_dir in &app_dirs {
|
||||
@@ -216,10 +216,21 @@ impl ApplicationSearchSource {
|
||||
let file_path = entry.path();
|
||||
if file_path.is_dir() && file_path.extension() == Some("app".as_ref()) {
|
||||
if let Some(app_data_folder) = data_dir() {
|
||||
// dbg!(&file_path);
|
||||
let file_path_str = file_path.to_string_lossy().to_string(); // Convert to owned String if needed
|
||||
if file_path.parent().unwrap().to_str().unwrap().contains(".app/Contents/") {
|
||||
continue;
|
||||
}
|
||||
let search_word = file_path.file_name()
|
||||
.unwrap() // unwrap() might panic if there's no file name
|
||||
.to_str()
|
||||
.unwrap() // unwrap() might panic if it's not valid UTF-8
|
||||
.trim_end_matches(".app")
|
||||
.to_lowercase(); // to_lowercase returns a String, which is owned
|
||||
|
||||
let search_word_ref = search_word.as_str(); // Get a reference to the string slice
|
||||
applications.insert(search_word_ref, file_path_str.clone());
|
||||
if let Some(icon_path) = extract_icon_from_app_bundle(&file_path, &app_data_folder) {
|
||||
// dbg!("Icon found for:", &file_path,&icon_path);
|
||||
icons.insert(file_path.to_string_lossy().to_string(), icon_path);
|
||||
icons.insert(file_path_str, icon_path);
|
||||
} else {
|
||||
dbg!("No icon found for:", &file_path);
|
||||
}
|
||||
@@ -232,7 +243,7 @@ impl ApplicationSearchSource {
|
||||
base_score,
|
||||
app_dirs,
|
||||
icons,
|
||||
search_locations, // Cached search locations
|
||||
application_paths: applications,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,9 +260,9 @@ fn clean_app_name(path: &Path) -> Option<String> {
|
||||
impl SearchSource for ApplicationSearchSource {
|
||||
fn get_type(&self) -> QuerySource {
|
||||
QuerySource {
|
||||
r#type: "Local".into(),
|
||||
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
|
||||
name: hostname::get().unwrap_or("My Computer".into()).to_string_lossy().into(),
|
||||
id: "local_app_1".into(),
|
||||
id: "local_applications".into(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,54 +281,50 @@ impl SearchSource for ApplicationSearchSource {
|
||||
});
|
||||
}
|
||||
|
||||
// Use cached search locations directly
|
||||
if self.search_locations.is_empty() {
|
||||
return Ok(QueryResponse {
|
||||
source: self.get_type(),
|
||||
hits: Vec::new(),
|
||||
total_hits: 0,
|
||||
});
|
||||
}
|
||||
let more_locations = self.search_locations[1..].to_vec();
|
||||
|
||||
// Use rust_search to find matching .app files
|
||||
let results = SearchBuilder::default()
|
||||
.search_input(&query_string)
|
||||
.location(&self.search_locations[0]) // First location
|
||||
.more_locations(more_locations) // Remaining locations
|
||||
.depth(3) // Set search depth
|
||||
.ext("app") // Only look for .app files
|
||||
.limit(query.size as usize) // Limit results
|
||||
.ignore_case()
|
||||
.build()
|
||||
.collect::<HashSet<String>>();
|
||||
|
||||
let mut total_hits = results.len();
|
||||
let mut total_hits = 0;
|
||||
let mut hits = Vec::new();
|
||||
|
||||
for path in results {
|
||||
let file_name_str = clean_app_name(Path::new(&path)).unwrap_or_else(|| path.clone());
|
||||
let mut results = self.application_paths.search_within_distance_scored(&query_string, 3);
|
||||
|
||||
let mut doc = Document::new(
|
||||
Some(DataSourceReference {
|
||||
r#type: Some("Local".into()),
|
||||
name: Some(path.clone()),
|
||||
id: Some(file_name_str.clone()),
|
||||
}),
|
||||
path.clone(),
|
||||
"Application".to_string(),
|
||||
file_name_str.clone(),
|
||||
path.clone(),
|
||||
);
|
||||
|
||||
// Attach icon if available
|
||||
if let Some(icon_path) = self.icons.get(&path) {
|
||||
if let Ok(icon_data) = read_icon_and_encode(icon_path) {
|
||||
doc.icon = Some(format!("data:image/png;base64,{}", icon_data));
|
||||
}
|
||||
// Check for NaN or extreme score values and handle them properly
|
||||
results.sort_by(|a, b| {
|
||||
// If either score is NaN, consider them equal (you can customize this logic as needed)
|
||||
if a.score.is_nan() || b.score.is_nan() {
|
||||
std::cmp::Ordering::Equal
|
||||
} else {
|
||||
// Otherwise, compare the scores as usual
|
||||
b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
});
|
||||
|
||||
hits.push((doc, self.base_score));
|
||||
if !results.is_empty() {
|
||||
for result in results {
|
||||
let file_name_str = result.word;
|
||||
let file_path_str = result.data.get(0).unwrap().to_string();
|
||||
let file_path = PathBuf::from(file_path_str.clone());
|
||||
let cleaned_file_name = clean_app_name(&file_path).unwrap();
|
||||
total_hits += 1;
|
||||
let mut doc = Document::new(
|
||||
Some(DataSourceReference {
|
||||
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
|
||||
name: Some("Applications".into()),
|
||||
id: Some(file_name_str.clone()),
|
||||
}),
|
||||
file_path_str.clone(),
|
||||
"Application".to_string(),
|
||||
cleaned_file_name,
|
||||
file_path_str.clone(),
|
||||
);
|
||||
|
||||
// Attach icon if available
|
||||
if let Some(icon_path) = self.icons.get(file_path_str.as_str()) {
|
||||
if let Ok(icon_data) = read_icon_and_encode(icon_path) {
|
||||
doc.icon = Some(format!("data:image/png;base64,{}", icon_data));
|
||||
}
|
||||
}
|
||||
|
||||
hits.push((doc, self.base_score + result.score as f64));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(QueryResponse {
|
||||
@@ -335,4 +342,4 @@ fn read_icon_and_encode(icon_path: &Path) -> Result<String, std::io::Error> {
|
||||
|
||||
// Encode the data to base64
|
||||
Ok(encode(&icon_data))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
pub mod application;
|
||||
pub mod file_system;
|
||||
pub mod file_system;
|
||||
|
||||
pub const LOCAL_QUERY_SOURCE_TYPE: &str = "local";
|
||||
@@ -1,11 +1,10 @@
|
||||
use crate::common::register::SearchSourceRegistry;
|
||||
use crate::common::search::{FailedRequest, MultiSourceQueryResponse, QuerySource, SearchQuery};
|
||||
use crate::common::search::{FailedRequest, MultiSourceQueryResponse, QueryHits, QuerySource, SearchQuery};
|
||||
use crate::common::traits::{SearchError, SearchSource};
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
use tauri::{AppHandle, Manager, Runtime};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn query_coco_fusion<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
@@ -15,11 +14,11 @@ pub async fn query_coco_fusion<R: Runtime>(
|
||||
) -> Result<MultiSourceQueryResponse, SearchError> {
|
||||
let search_sources = app_handle.state::<SearchSourceRegistry>();
|
||||
|
||||
let sources_future = search_sources.get_sources(); // Don't await yet
|
||||
let sources_future = search_sources.get_sources();
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let mut sources = HashMap::new();
|
||||
|
||||
let sources_list = sources_future.await; // Now we await
|
||||
let sources_list = sources_future.await;
|
||||
|
||||
for query_source in sources_list {
|
||||
let query_source_type = query_source.get_type().clone();
|
||||
@@ -33,17 +32,29 @@ pub async fn query_coco_fusion<R: Runtime>(
|
||||
}));
|
||||
}
|
||||
|
||||
let mut docs_collector = crate::server::search::DocumentsSizedCollector::new(size);
|
||||
let mut total_hits = 0;
|
||||
let mut failed_requests = Vec::new();
|
||||
let mut all_hits: Vec<(String, QueryHits, f64)> = Vec::new();
|
||||
let mut hits_per_source: HashMap<String, Vec<(QueryHits, f64)>> = HashMap::new();
|
||||
|
||||
while let Some(result) = futures.next().await {
|
||||
match result {
|
||||
Ok(Ok(response)) => {
|
||||
total_hits += response.total_hits;
|
||||
let source_id = response.source.id.clone();
|
||||
|
||||
for (doc, score) in response.hits {
|
||||
// dbg!("Found hit:", &doc.title, &score);
|
||||
docs_collector.push(response.source.id.clone(), doc, score);
|
||||
let query_hit = QueryHits {
|
||||
source: Some(response.source.clone()),
|
||||
score,
|
||||
document: doc,
|
||||
};
|
||||
|
||||
all_hits.push((source_id.clone(), query_hit.clone(), score));
|
||||
|
||||
hits_per_source.entry(source_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((query_hit, score));
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
@@ -73,13 +84,57 @@ pub async fn query_coco_fusion<R: Runtime>(
|
||||
}
|
||||
}
|
||||
|
||||
let all_hits = docs_collector.documents_with_sources(&sources);
|
||||
// Sort hits within each source by score (descending)
|
||||
for hits in hits_per_source.values_mut() {
|
||||
hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
}
|
||||
|
||||
// dbg!(&all_hits);
|
||||
let total_sources = hits_per_source.len();
|
||||
let max_hits_per_source = if total_sources > 0 { size as usize / total_sources } else { size as usize };
|
||||
|
||||
let mut final_hits = Vec::new();
|
||||
let mut seen_docs = std::collections::HashSet::new(); // To track documents we've already added
|
||||
|
||||
// Distribute hits fairly across sources
|
||||
for (_source_id, hits) in &mut hits_per_source {
|
||||
let take_count = hits.len().min(max_hits_per_source);
|
||||
for (doc, _) in hits.drain(0..take_count) {
|
||||
if !seen_docs.contains(&doc.document.id) {
|
||||
seen_docs.insert(doc.document.id.clone());
|
||||
final_hits.push(doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we still need more hits, take the highest-scoring remaining ones
|
||||
if final_hits.len() < size as usize {
|
||||
let remaining_needed = size as usize - final_hits.len();
|
||||
|
||||
// Sort all hits by score descending, removing duplicates by document ID
|
||||
all_hits.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let extra_hits = all_hits.into_iter()
|
||||
.filter(|(source_id, _, _)| hits_per_source.contains_key(source_id)) // Only take from known sources
|
||||
.filter_map(|(_, doc, _)| {
|
||||
if !seen_docs.contains(&doc.document.id) {
|
||||
seen_docs.insert(doc.document.id.clone());
|
||||
Some(doc)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.take(remaining_needed)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
final_hits.extend(extra_hits);
|
||||
}
|
||||
|
||||
// **Sort final hits by score descending**
|
||||
final_hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
Ok(MultiSourceQueryResponse {
|
||||
failed: failed_requests,
|
||||
hits: all_hits,
|
||||
hits: final_hits,
|
||||
total_hits,
|
||||
})
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::common::document::Document;
|
||||
use crate::common::search::{parse_search_response, QueryHits, QueryResponse, QuerySource, SearchQuery};
|
||||
use crate::common::server::Server;
|
||||
use crate::common::traits::{SearchError, SearchSource};
|
||||
use crate::server::http_client::HttpClient;
|
||||
@@ -9,8 +10,6 @@ use ordered_float::OrderedFloat;
|
||||
use reqwest::{Client, Method, RequestBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::pin::Pin;
|
||||
use crate::common::search::{parse_search_response, QueryHits, QueryResponse, QuerySource, SearchQuery};
|
||||
pub(crate) struct DocumentsSizedCollector {
|
||||
size: u64,
|
||||
/// Documents and scores
|
||||
@@ -50,13 +49,14 @@ impl DocumentsSizedCollector {
|
||||
pub(crate) fn documents_with_sources(self, x: &HashMap<String, QuerySource>) -> Vec<QueryHits> {
|
||||
let mut grouped_docs: Vec<QueryHits> = Vec::new();
|
||||
|
||||
for (source_id, doc, _) in self.docs.into_iter() {
|
||||
for (source_id, doc, score) in self.docs.into_iter() {
|
||||
// Try to get the source from the hashmap
|
||||
let source = x.get(&source_id).cloned();
|
||||
|
||||
// Push the document and source into the result
|
||||
grouped_docs.push(QueryHits {
|
||||
source,
|
||||
score: score.into_inner(),
|
||||
document: doc,
|
||||
});
|
||||
}
|
||||
@@ -95,8 +95,6 @@ impl CocoSearchSource {
|
||||
.query(query_strings)
|
||||
}
|
||||
}
|
||||
use futures::future::join_all;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[async_trait]
|
||||
impl SearchSource for CocoSearchSource {
|
||||
|
||||
Reference in New Issue
Block a user