mirror of
https://github.com/infinilabs/coco-app.git
synced 2025-12-29 00:24:46 +01:00
refactor: refactoring search response (#119)
* fix: fix paging with from * refactor: refactoring search response
This commit is contained in:
56
src-tauri/src/common/document.rs
Normal file
56
src-tauri/src/common/document.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RichLabel {
|
||||
pub label: Option<String>,
|
||||
pub key: Option<String>,
|
||||
pub icon: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DataSourceReference {
|
||||
pub r#type: Option<String>,
|
||||
pub name: Option<String>,
|
||||
pub id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UserInfo {
|
||||
pub avatar: Option<String>,
|
||||
pub username: Option<String>,
|
||||
pub userid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EditorInfo {
|
||||
pub user: UserInfo,
|
||||
pub timestamp: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Document {
|
||||
pub id: String,
|
||||
pub created: Option<String>,
|
||||
pub updated: Option<String>,
|
||||
pub source: Option<DataSourceReference>,
|
||||
pub r#type: Option<String>,
|
||||
pub category: Option<String>,
|
||||
pub subcategory: Option<String>,
|
||||
pub categories: Option<Vec<String>>,
|
||||
pub rich_categories: Option<Vec<RichLabel>>,
|
||||
pub title: Option<String>,
|
||||
pub summary: Option<String>,
|
||||
pub lang: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub icon: Option<String>,
|
||||
pub thumbnail: Option<String>,
|
||||
pub cover: Option<String>,
|
||||
pub tags: Option<Vec<String>>,
|
||||
pub url: Option<String>,
|
||||
pub size: Option<i64>,
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
pub payload: Option<HashMap<String, serde_json::Value>>,
|
||||
pub owner: Option<UserInfo>,
|
||||
pub last_updated_by: Option<EditorInfo>,
|
||||
}
|
||||
@@ -5,3 +5,4 @@ pub mod auth;
|
||||
pub mod datasource;
|
||||
pub mod connector;
|
||||
pub mod search_response;
|
||||
pub mod document;
|
||||
|
||||
@@ -40,41 +40,37 @@ pub struct SearchHit<T> {
|
||||
pub _score: Option<f32>,
|
||||
pub _source: T, // This will hold the type we pass in (e.g., DataSource)
|
||||
}
|
||||
|
||||
pub async fn parse_search_results<T>(response: Response) -> Result<Vec<T>, Box<dyn Error>>
|
||||
pub async fn parse_search_hits<T>(
|
||||
response: Response,
|
||||
) -> Result<Vec<SearchHit<T>>, Box<dyn Error>>
|
||||
where
|
||||
T: for<'de> Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
// Log the response status and headers
|
||||
// dbg!(&response.status());
|
||||
// dbg!(&response.headers());
|
||||
|
||||
// Parse the response body to a serde::Value
|
||||
let body = response
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse JSON: {}", e))?;
|
||||
|
||||
// Log the raw body before further processing
|
||||
// dbg!(&body);
|
||||
|
||||
// Deserialize into the generic search response
|
||||
let search_response: SearchResponse<T> = serde_json::from_value(body)
|
||||
.map_err(|e| format!("Failed to deserialize search response: {}", e))?;
|
||||
|
||||
// Log the deserialized search response
|
||||
// dbg!(&search_response);
|
||||
Ok(search_response.hits.hits)
|
||||
}
|
||||
|
||||
// Collect the _source part from all hits
|
||||
let results: Vec<T> = search_response
|
||||
.hits
|
||||
.hits
|
||||
.into_iter()
|
||||
.map(|hit| hit._source)
|
||||
.collect();
|
||||
pub async fn parse_search_results<T>(
|
||||
response: Response,
|
||||
) -> Result<Vec<T>, Box<dyn Error>>
|
||||
where
|
||||
T: for<'de> Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
Ok(parse_search_hits(response).await?.into_iter().map(|hit| hit._source).collect())
|
||||
}
|
||||
|
||||
// Log the final results before returning
|
||||
// dbg!(&results);
|
||||
|
||||
Ok(results)
|
||||
pub async fn parse_search_results_with_score<T>(
|
||||
response: Response,
|
||||
) -> Result<Vec<(T, Option<f32>)>, Box<dyn Error>>
|
||||
where
|
||||
T: for<'de> Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
Ok(parse_search_hits(response).await?.into_iter().map(|hit| (hit._source, hit._score)).collect())
|
||||
}
|
||||
@@ -126,7 +126,7 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, Stri
|
||||
})?;
|
||||
|
||||
// Log the parsed results
|
||||
dbg!("Parsed connectors: {:?}", &datasources);
|
||||
// dbg!("Parsed connectors: {:?}", &datasources);
|
||||
|
||||
// Save the connectors to the cache
|
||||
save_connectors_to_cache(&id, datasources.clone());
|
||||
|
||||
@@ -1,75 +1,103 @@
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use ordered_float::OrderedFloat;
|
||||
use reqwest::Method;
|
||||
use serde::Serialize;
|
||||
use tauri::{ AppHandle, Runtime};
|
||||
use tauri::{AppHandle, Runtime};
|
||||
use serde_json::Map as JsonMap;
|
||||
use serde_json::Value as Json;
|
||||
use crate::server::http_client::{HttpClient, HTTP_CLIENT};
|
||||
use crate::server::servers::{get_all_servers, get_server_token, get_servers_as_hashmap};
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use crate::common::document::Document;
|
||||
use crate::common::search_response::parse_search_results_with_score;
|
||||
use crate::common::server::Server;
|
||||
|
||||
struct DocumentsSizedCollector {
|
||||
size: u64,
|
||||
/// Documents and socres
|
||||
/// Documents and scores
|
||||
///
|
||||
/// Sorted by score, in descending order.
|
||||
docs: Vec<(JsonMap<String, Json>, OrderedFloat<f64>)>,
|
||||
/// Sorted by score, in descending order. (Server ID, Document, Score)
|
||||
docs: Vec<(String, Document, OrderedFloat<f64>)>,
|
||||
}
|
||||
|
||||
impl DocumentsSizedCollector {
|
||||
fn new(size: u64) -> Self {
|
||||
// there will be size + 1 documents in docs at max
|
||||
let docs = Vec::with_capacity((size + 1).try_into().expect("overflow"));
|
||||
let docs = Vec::with_capacity((size + 1) as usize);
|
||||
|
||||
Self { size, docs }
|
||||
}
|
||||
|
||||
fn push(&mut self, item: JsonMap<String, Json>, score: f64) {
|
||||
fn push(&mut self, server_id: String, item: Document, score: f64) {
|
||||
let score = OrderedFloat(score);
|
||||
let insert_idx = match self.docs.binary_search_by(|(_doc, s)| score.cmp(s)) {
|
||||
let insert_idx = match self.docs.binary_search_by(|(_, _, s)| score.cmp(s)) {
|
||||
Ok(idx) => idx,
|
||||
Err(idx) => idx,
|
||||
};
|
||||
|
||||
self.docs.insert(insert_idx, (item, score));
|
||||
self.docs.insert(insert_idx, (server_id, item, score));
|
||||
|
||||
// cast usize to u64 is safe
|
||||
// Ensure we do not exceed `size`
|
||||
if self.docs.len() as u64 > self.size {
|
||||
self.docs.truncate(self.size.try_into().expect(
|
||||
"self.size < a number of type usize, it can be expressed using usize, we are safe",
|
||||
));
|
||||
self.docs.truncate(self.size as usize);
|
||||
}
|
||||
}
|
||||
|
||||
fn documents(self) -> impl ExactSizeIterator<Item = JsonMap<String, Json>> {
|
||||
self.docs.into_iter().map(|(doc, _score)| doc)
|
||||
fn documents(self) -> impl ExactSizeIterator<Item=Document> {
|
||||
self.docs.into_iter().map(|(_, doc, _)| doc)
|
||||
}
|
||||
|
||||
// New function to return documents grouped by server_id
|
||||
fn documents_by_server_id(self, x: &HashMap<String, Server>) -> Vec<QueryHits> {
|
||||
let mut grouped_docs: Vec<QueryHits> = Vec::new();
|
||||
|
||||
for (server_id, doc, _) in self.docs.into_iter() {
|
||||
let source= QuerySource {
|
||||
r#type: Some("coco-server".to_string()),
|
||||
name: Some(x.get(&server_id).map(|s| s.name.clone()).unwrap_or_default()),
|
||||
id: Some(server_id.clone()),
|
||||
};
|
||||
|
||||
grouped_docs.push(QueryHits {
|
||||
source,
|
||||
document: doc,
|
||||
});
|
||||
}
|
||||
|
||||
grouped_docs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct QuerySource{
|
||||
pub r#type: Option<String>, //coco-server/local/ etc.
|
||||
pub name: Option<String>, //coco server's name, local computer name, etc.
|
||||
pub id: Option<String>, //coco server's id
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct QueryHits {
|
||||
pub source: QuerySource,
|
||||
pub document: Document,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FailedRequest{
|
||||
pub source: QuerySource,
|
||||
pub status: u16,
|
||||
pub error: Option<String>,
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct QueryResponse {
|
||||
failed_coco_servers: Vec<String>,
|
||||
documents: Vec<JsonMap<String, Json>>,
|
||||
total_hits: u64,
|
||||
failed: Vec<FailedRequest>,
|
||||
hits: Vec<QueryHits>,
|
||||
total_hits: usize,
|
||||
}
|
||||
|
||||
|
||||
fn get_name(provider_info: &JsonMap<String, Json>) -> &str {
|
||||
provider_info
|
||||
.get("name")
|
||||
.expect("provider info does not have a [name] field")
|
||||
.as_str()
|
||||
.expect("field [name] should be a string")
|
||||
}
|
||||
|
||||
fn get_public(provider_info: &JsonMap<String, Json>) -> bool {
|
||||
provider_info
|
||||
.get("public")
|
||||
.expect("provider info does not have a [public] field")
|
||||
.as_bool()
|
||||
.expect("field [public] should be a string")
|
||||
}
|
||||
#[tauri::command]
|
||||
pub async fn query_coco_servers<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
@@ -86,7 +114,7 @@ pub async fn query_coco_servers<R: Runtime>(
|
||||
let mut futures = FuturesUnordered::new();
|
||||
let size_for_each_request = (from + size).to_string();
|
||||
|
||||
for (_,server) in coco_servers {
|
||||
for (_, server) in &coco_servers {
|
||||
let url = HttpClient::join_url(&server.endpoint, "/query/_search");
|
||||
let client = HTTP_CLIENT.lock().await; // Acquire the lock on HTTP_CLIENT
|
||||
let mut request_builder = client.request(Method::GET, url);
|
||||
@@ -98,85 +126,76 @@ pub async fn query_coco_servers<R: Runtime>(
|
||||
}
|
||||
let query_strings_cloned = query_strings.clone(); // Clone for each iteration
|
||||
|
||||
let size=size_for_each_request.clone();
|
||||
let from = from.to_string();
|
||||
let size = size_for_each_request.clone();
|
||||
let future = async move {
|
||||
let response = request_builder
|
||||
.query(&[("from", "0"), ("size", size.as_str())])
|
||||
.query(&[("from", from.as_str()), ("size", size.as_str())])
|
||||
.query(&query_strings_cloned) // Use cloned instance
|
||||
.send()
|
||||
.await;
|
||||
(server.id, response)
|
||||
(server.id.clone(), response)
|
||||
};
|
||||
|
||||
futures.push(future);
|
||||
}
|
||||
|
||||
let mut total_hits = 0;
|
||||
let mut failed_coco_servers = Vec::new();
|
||||
let mut failed_requests:Vec<FailedRequest> = Vec::new();
|
||||
let mut docs_collector = DocumentsSizedCollector::new(size);
|
||||
|
||||
while let Some((name, res_response)) = futures.next().await {
|
||||
// Helper function to create failed request
|
||||
fn create_failed_request(server_id: &str, coco_servers: &HashMap<String,Server>, error: &str, status: u16) -> FailedRequest {
|
||||
FailedRequest {
|
||||
source: QuerySource {
|
||||
r#type: Some("coco-server".to_string()),
|
||||
name: Some(coco_servers.get(server_id).map(|s| s.name.clone()).unwrap_or_default()),
|
||||
id: Some(server_id.to_string()),
|
||||
},
|
||||
status,
|
||||
error: Some(error.to_string()),
|
||||
reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over the stream of futures
|
||||
while let Some((server_id, res_response)) = futures.next().await {
|
||||
match res_response {
|
||||
Ok(response) => {
|
||||
if let Ok(mut body) = response.json::<JsonMap<String, Json>>().await {
|
||||
if let Some(Json::Object(mut hits)) = body.remove("hits") {
|
||||
if let Some(Json::Number(hits_total_value)) = hits.get("total").and_then(|t| t.get("value")) {
|
||||
if let Some(hits_total) = hits_total_value.as_u64() {
|
||||
total_hits += hits_total;
|
||||
}
|
||||
}
|
||||
if let Some(Json::Array(hits_hits)) = hits.remove("hits") {
|
||||
for hit in hits_hits.into_iter().filter_map(|h| h.as_object().cloned()) {
|
||||
if let (Some(Json::Number(score)), Some(Json::Object(source))) = (hit.get("_score"), hit.get("_source")) {
|
||||
if let Some(score_value) = score.as_f64() {
|
||||
docs_collector.push(source.clone(), score_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
let status_code = response.status().as_u16();
|
||||
match parse_search_results_with_score(response).await {
|
||||
Ok(documents) => {
|
||||
total_hits += documents.len(); // No need for `&` here, as `len` is `usize`
|
||||
for (doc, score) in documents {
|
||||
let score = score.unwrap_or(0.0) as f64;
|
||||
docs_collector.push(server_id.clone(), doc, score);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
failed_requests.push(create_failed_request(&server_id, &coco_servers, &err.to_string(), status_code));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => failed_coco_servers.push(name),
|
||||
Err(err) => {
|
||||
failed_requests.push(create_failed_request(&server_id,&coco_servers, &err.to_string(), 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let docs=docs_collector.documents().collect();
|
||||
let docs = docs_collector.documents_by_server_id(&coco_servers);
|
||||
|
||||
// dbg!(&total_hits);
|
||||
// dbg!(&failed_coco_servers);
|
||||
// dbg!(&failed_requests);
|
||||
// dbg!(&docs);
|
||||
|
||||
Ok(QueryResponse {
|
||||
failed_coco_servers,
|
||||
let query_response = QueryResponse {
|
||||
failed: failed_requests,
|
||||
hits: docs,
|
||||
total_hits,
|
||||
documents:docs ,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_docs_collector() {
|
||||
let mut collector = DocumentsSizedCollector::new(3);
|
||||
|
||||
for i in 0..10 {
|
||||
collector.push(JsonMap::new(), i as f64);
|
||||
}
|
||||
|
||||
assert_eq!(collector.docs.len(), 3);
|
||||
assert!(collector
|
||||
.docs
|
||||
.into_iter()
|
||||
.map(|(_doc, score)| score)
|
||||
.eq(vec![
|
||||
OrderedFloat(9.0),
|
||||
OrderedFloat(8.0),
|
||||
OrderedFloat(7.0)
|
||||
]));
|
||||
}
|
||||
};
|
||||
|
||||
//print to json
|
||||
// println!("{}", serde_json::to_string_pretty(&query_response).unwrap());
|
||||
|
||||
Ok(query_response)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user