refactor: refactoring search response (#119)

* fix: fix paging with from

* refactor: refactoring search response
This commit is contained in:
Medcl
2025-02-06 17:09:04 +08:00
committed by GitHub
parent aa6d17e942
commit fb37da5f6c
8 changed files with 197 additions and 122 deletions

View File

@@ -0,0 +1,56 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Serialize, Deserialize)]
pub struct RichLabel {
pub label: Option<String>,
pub key: Option<String>,
pub icon: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DataSourceReference {
pub r#type: Option<String>,
pub name: Option<String>,
pub id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserInfo {
pub avatar: Option<String>,
pub username: Option<String>,
pub userid: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EditorInfo {
pub user: UserInfo,
pub timestamp: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Document {
pub id: String,
pub created: Option<String>,
pub updated: Option<String>,
pub source: Option<DataSourceReference>,
pub r#type: Option<String>,
pub category: Option<String>,
pub subcategory: Option<String>,
pub categories: Option<Vec<String>>,
pub rich_categories: Option<Vec<RichLabel>>,
pub title: Option<String>,
pub summary: Option<String>,
pub lang: Option<String>,
pub content: Option<String>,
pub icon: Option<String>,
pub thumbnail: Option<String>,
pub cover: Option<String>,
pub tags: Option<Vec<String>>,
pub url: Option<String>,
pub size: Option<i64>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
pub payload: Option<HashMap<String, serde_json::Value>>,
pub owner: Option<UserInfo>,
pub last_updated_by: Option<EditorInfo>,
}

View File

@@ -5,3 +5,4 @@ pub mod auth;
pub mod datasource;
pub mod connector;
pub mod search_response;
pub mod document;

View File

@@ -40,41 +40,37 @@ pub struct SearchHit<T> {
pub _score: Option<f32>,
pub _source: T, // This will hold the type we pass in (e.g., DataSource)
}
pub async fn parse_search_results<T>(response: Response) -> Result<Vec<T>, Box<dyn Error>>
pub async fn parse_search_hits<T>(
response: Response,
) -> Result<Vec<SearchHit<T>>, Box<dyn Error>>
where
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
// Log the response status and headers
// dbg!(&response.status());
// dbg!(&response.headers());
// Parse the response body to a serde::Value
let body = response
.json::<Value>()
.await
.map_err(|e| format!("Failed to parse JSON: {}", e))?;
// Log the raw body before further processing
// dbg!(&body);
// Deserialize into the generic search response
let search_response: SearchResponse<T> = serde_json::from_value(body)
.map_err(|e| format!("Failed to deserialize search response: {}", e))?;
// Log the deserialized search response
// dbg!(&search_response);
Ok(search_response.hits.hits)
}
// Collect the _source part from all hits
let results: Vec<T> = search_response
.hits
.hits
.into_iter()
.map(|hit| hit._source)
.collect();
pub async fn parse_search_results<T>(
response: Response,
) -> Result<Vec<T>, Box<dyn Error>>
where
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
Ok(parse_search_hits(response).await?.into_iter().map(|hit| hit._source).collect())
}
// Log the final results before returning
// dbg!(&results);
Ok(results)
pub async fn parse_search_results_with_score<T>(
response: Response,
) -> Result<Vec<(T, Option<f32>)>, Box<dyn Error>>
where
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
Ok(parse_search_hits(response).await?.into_iter().map(|hit| (hit._source, hit._score)).collect())
}

View File

@@ -126,7 +126,7 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, Stri
})?;
// Log the parsed results
dbg!("Parsed connectors: {:?}", &datasources);
// dbg!("Parsed connectors: {:?}", &datasources);
// Save the connectors to the cache
save_connectors_to_cache(&id, datasources.clone());

View File

@@ -1,75 +1,103 @@
use std::collections::HashMap;
use std::hash::Hash;
use ordered_float::OrderedFloat;
use reqwest::Method;
use serde::Serialize;
use tauri::{ AppHandle, Runtime};
use tauri::{AppHandle, Runtime};
use serde_json::Map as JsonMap;
use serde_json::Value as Json;
use crate::server::http_client::{HttpClient, HTTP_CLIENT};
use crate::server::servers::{get_all_servers, get_server_token, get_servers_as_hashmap};
use futures::stream::{FuturesUnordered, StreamExt};
use crate::common::document::Document;
use crate::common::search_response::parse_search_results_with_score;
use crate::common::server::Server;
struct DocumentsSizedCollector {
size: u64,
/// Documents and socres
/// Documents and scores
///
/// Sorted by score, in descending order.
docs: Vec<(JsonMap<String, Json>, OrderedFloat<f64>)>,
/// Sorted by score, in descending order. (Server ID, Document, Score)
docs: Vec<(String, Document, OrderedFloat<f64>)>,
}
impl DocumentsSizedCollector {
fn new(size: u64) -> Self {
// there will be size + 1 documents in docs at max
let docs = Vec::with_capacity((size + 1).try_into().expect("overflow"));
let docs = Vec::with_capacity((size + 1) as usize);
Self { size, docs }
}
fn push(&mut self, item: JsonMap<String, Json>, score: f64) {
fn push(&mut self, server_id: String, item: Document, score: f64) {
let score = OrderedFloat(score);
let insert_idx = match self.docs.binary_search_by(|(_doc, s)| score.cmp(s)) {
let insert_idx = match self.docs.binary_search_by(|(_, _, s)| score.cmp(s)) {
Ok(idx) => idx,
Err(idx) => idx,
};
self.docs.insert(insert_idx, (item, score));
self.docs.insert(insert_idx, (server_id, item, score));
// cast usize to u64 is safe
// Ensure we do not exceed `size`
if self.docs.len() as u64 > self.size {
self.docs.truncate(self.size.try_into().expect(
"self.size < a number of type usize, it can be expressed using usize, we are safe",
));
self.docs.truncate(self.size as usize);
}
}
fn documents(self) -> impl ExactSizeIterator<Item = JsonMap<String, Json>> {
self.docs.into_iter().map(|(doc, _score)| doc)
fn documents(self) -> impl ExactSizeIterator<Item=Document> {
self.docs.into_iter().map(|(_, doc, _)| doc)
}
// New function to return documents grouped by server_id
fn documents_by_server_id(self, x: &HashMap<String, Server>) -> Vec<QueryHits> {
let mut grouped_docs: Vec<QueryHits> = Vec::new();
for (server_id, doc, _) in self.docs.into_iter() {
let source= QuerySource {
r#type: Some("coco-server".to_string()),
name: Some(x.get(&server_id).map(|s| s.name.clone()).unwrap_or_default()),
id: Some(server_id.clone()),
};
grouped_docs.push(QueryHits {
source,
document: doc,
});
}
grouped_docs
}
}
#[derive(Debug, Serialize)]
pub struct QuerySource{
pub r#type: Option<String>, //coco-server/local/ etc.
pub name: Option<String>, //coco server's name, local computer name, etc.
pub id: Option<String>, //coco server's id
}
#[derive(Debug, Serialize)]
pub struct QueryHits {
pub source: QuerySource,
pub document: Document,
}
#[derive(Debug, Serialize)]
pub struct FailedRequest{
pub source: QuerySource,
pub status: u16,
pub error: Option<String>,
pub reason: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct QueryResponse {
failed_coco_servers: Vec<String>,
documents: Vec<JsonMap<String, Json>>,
total_hits: u64,
failed: Vec<FailedRequest>,
hits: Vec<QueryHits>,
total_hits: usize,
}
fn get_name(provider_info: &JsonMap<String, Json>) -> &str {
provider_info
.get("name")
.expect("provider info does not have a [name] field")
.as_str()
.expect("field [name] should be a string")
}
fn get_public(provider_info: &JsonMap<String, Json>) -> bool {
provider_info
.get("public")
.expect("provider info does not have a [public] field")
.as_bool()
.expect("field [public] should be a string")
}
#[tauri::command]
pub async fn query_coco_servers<R: Runtime>(
app_handle: AppHandle<R>,
@@ -86,7 +114,7 @@ pub async fn query_coco_servers<R: Runtime>(
let mut futures = FuturesUnordered::new();
let size_for_each_request = (from + size).to_string();
for (_,server) in coco_servers {
for (_, server) in &coco_servers {
let url = HttpClient::join_url(&server.endpoint, "/query/_search");
let client = HTTP_CLIENT.lock().await; // Acquire the lock on HTTP_CLIENT
let mut request_builder = client.request(Method::GET, url);
@@ -98,85 +126,76 @@ pub async fn query_coco_servers<R: Runtime>(
}
let query_strings_cloned = query_strings.clone(); // Clone for each iteration
let size=size_for_each_request.clone();
let from = from.to_string();
let size = size_for_each_request.clone();
let future = async move {
let response = request_builder
.query(&[("from", "0"), ("size", size.as_str())])
.query(&[("from", from.as_str()), ("size", size.as_str())])
.query(&query_strings_cloned) // Use cloned instance
.send()
.await;
(server.id, response)
(server.id.clone(), response)
};
futures.push(future);
}
let mut total_hits = 0;
let mut failed_coco_servers = Vec::new();
let mut failed_requests:Vec<FailedRequest> = Vec::new();
let mut docs_collector = DocumentsSizedCollector::new(size);
while let Some((name, res_response)) = futures.next().await {
// Helper function to create failed request
fn create_failed_request(server_id: &str, coco_servers: &HashMap<String,Server>, error: &str, status: u16) -> FailedRequest {
FailedRequest {
source: QuerySource {
r#type: Some("coco-server".to_string()),
name: Some(coco_servers.get(server_id).map(|s| s.name.clone()).unwrap_or_default()),
id: Some(server_id.to_string()),
},
status,
error: Some(error.to_string()),
reason: None,
}
}
// Iterate over the stream of futures
while let Some((server_id, res_response)) = futures.next().await {
match res_response {
Ok(response) => {
if let Ok(mut body) = response.json::<JsonMap<String, Json>>().await {
if let Some(Json::Object(mut hits)) = body.remove("hits") {
if let Some(Json::Number(hits_total_value)) = hits.get("total").and_then(|t| t.get("value")) {
if let Some(hits_total) = hits_total_value.as_u64() {
total_hits += hits_total;
}
}
if let Some(Json::Array(hits_hits)) = hits.remove("hits") {
for hit in hits_hits.into_iter().filter_map(|h| h.as_object().cloned()) {
if let (Some(Json::Number(score)), Some(Json::Object(source))) = (hit.get("_score"), hit.get("_source")) {
if let Some(score_value) = score.as_f64() {
docs_collector.push(source.clone(), score_value);
}
}
}
let status_code = response.status().as_u16();
match parse_search_results_with_score(response).await {
Ok(documents) => {
total_hits += documents.len(); // No need for `&` here, as `len` is `usize`
for (doc, score) in documents {
let score = score.unwrap_or(0.0) as f64;
docs_collector.push(server_id.clone(), doc, score);
}
}
Err(err) => {
failed_requests.push(create_failed_request(&server_id, &coco_servers, &err.to_string(), status_code));
}
}
}
Err(_) => failed_coco_servers.push(name),
Err(err) => {
failed_requests.push(create_failed_request(&server_id,&coco_servers, &err.to_string(), 0));
}
}
}
let docs=docs_collector.documents().collect();
let docs = docs_collector.documents_by_server_id(&coco_servers);
// dbg!(&total_hits);
// dbg!(&failed_coco_servers);
// dbg!(&failed_requests);
// dbg!(&docs);
Ok(QueryResponse {
failed_coco_servers,
let query_response = QueryResponse {
failed: failed_requests,
hits: docs,
total_hits,
documents:docs ,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_docs_collector() {
let mut collector = DocumentsSizedCollector::new(3);
for i in 0..10 {
collector.push(JsonMap::new(), i as f64);
}
assert_eq!(collector.docs.len(), 3);
assert!(collector
.docs
.into_iter()
.map(|(_doc, score)| score)
.eq(vec![
OrderedFloat(9.0),
OrderedFloat(8.0),
OrderedFloat(7.0)
]));
}
};
//print to json
// println!("{}", serde_json::to_string_pretty(&query_response).unwrap());
Ok(query_response)
}