refacotr: refactoring assistant api (#195)

* refacotr: refactoring assistant api

* update release notes
This commit is contained in:
Medcl
2025-02-25 15:01:32 +08:00
committed by GitHub
parent 0e645a32a3
commit 7c88e7374b
12 changed files with 270 additions and 36 deletions

View File

@@ -0,0 +1,169 @@
use crate::common::assistant::InitChatMessage;
use crate::common::http::GetResponse;
use crate::server::http_client::HttpClient;
use reqwest::Response;
use std::collections::HashMap;
use tauri::{AppHandle, Runtime};
#[tauri::command]
pub async fn new_chat<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
message: String,
) -> Result<GetResponse, String> {
let body = if !message.is_empty() {
let message = InitChatMessage { message: Some(message) };
let body = reqwest::Body::from(serde_json::to_string(&message).unwrap());
Some(body)
} else {
None
};
let query_params: Option<HashMap<String, String>> = None;
let response = HttpClient::post(&server_id, "/chat/_new", query_params, body)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
if response.status().as_u16() < 200 || response.status().as_u16() >= 400 {
return Err("Failed to send message".to_string());
}
let chat_response: GetResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse response JSON: {}", e))?;
// Check the result and status fields
if chat_response.result != "created" {
return Err(format!("Unexpected result: {}", chat_response.result));
}
Ok(chat_response)
}
#[tauri::command]
pub async fn chat_history<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
from: u32,
size: u32,
) -> Result<String, String> {
let mut query_params = HashMap::new();
if from > 0 {
query_params.insert("from".to_string(), from.to_string());
}
if size > 0 {
query_params.insert("size".to_string(), size.to_string());
}
let response = HttpClient::get(&server_id, "/chat/_history", Some(query_params))
.await
.map_err(|e| format!("Error get sessions: {}", e))?;
handle_raw_response(response).await?
}
async fn handle_raw_response(response: Response) -> Result<Result<String, String>, String> {
Ok(if response.status().as_u16() < 200 || response.status().as_u16() >= 400 {
Err("Failed to send message".to_string())
} else {
let body = response.text().await.map_err(|e| format!("Failed to parse response JSON: {}", e))?;
Ok(body)
})
}
#[tauri::command]
pub async fn session_chat_history<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
session_id: String,
from: u32,
size: u32,
) -> Result<String, String> {
let mut query_params = HashMap::new();
if from > 0 {
query_params.insert("from".to_string(), from.to_string());
}
if size > 0 {
query_params.insert("size".to_string(), size.to_string());
}
let path = format!("/chat/{}/_history", session_id);
let response = HttpClient::get(&server_id, path.as_str(), Some(query_params))
.await
.map_err(|e| format!("Error get session message: {}", e))?;
handle_raw_response(response).await?
}
#[tauri::command]
pub async fn open_session_chat<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
session_id: String,
) -> Result<String, String> {
let mut query_params = HashMap::new();
let path = format!("/chat/{}/_open", session_id);
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error open session: {}", e))?;
handle_raw_response(response).await?
}
#[tauri::command]
pub async fn close_session_chat<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
session_id: String,
) -> Result<String, String> {
let mut query_params = HashMap::new();
let path = format!("/chat/{}/_close", session_id);
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error close session: {}", e))?;
handle_raw_response(response).await?
}
#[tauri::command]
pub async fn cancel_session_chat<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
session_id: String,
) -> Result<String, String> {
let mut query_params = HashMap::new();
let path = format!("/chat/{}/_cancel", session_id);
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
handle_raw_response(response).await?
}
#[tauri::command]
pub async fn send_message<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
session_id: String,
websocket_id: String,
query_params: Option<HashMap<String, String>>, //search,deep_thinking
) -> Result<String, String> {
let path = format!("/chat/{}/_send", session_id);
let mut headers = HashMap::new();
headers.insert("WEBSOCKET-SESSION-ID".to_string(), websocket_id);
let response = HttpClient::advanced_post(&server_id, path.as_str(), Some(headers), query_params, None)
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
handle_raw_response(response).await?
}

View File

@@ -0,0 +1,6 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitChatMessage {
pub message: Option<String>,
}

View File

@@ -0,0 +1,16 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct GetResponse {
pub _id: String,
pub _source: Source,
pub result: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Source {
pub id: String,
pub created: String,
pub updated: String,
pub status: String,
}

View File

@@ -8,6 +8,8 @@ pub mod search;
pub mod document;
pub mod traits;
pub mod register;
pub mod assistant;
pub mod http;
pub static MAIN_WINDOW_LABEL: &str = "main";
pub static SETTINGS_WINDOW_LABEL: &str = "settings";
pub static SETTINGS_WINDOW_LABEL: &str = "settings";

View File

@@ -7,6 +7,7 @@ mod shortcut;
mod util;
mod setup;
mod assistant;
use crate::common::register::SearchSourceRegistry;
// use crate::common::traits::SearchSource;
@@ -103,7 +104,13 @@ pub fn run() {
server::datasource::get_datasources_by_server,
server::connector::get_connectors_by_server,
search::query_coco_fusion,
// server::get_user_profiles,
assistant::chat_history,
assistant::new_chat,
assistant::send_message,
assistant::session_chat_history,
assistant::open_session_chat,
assistant::close_session_chat,
assistant::cancel_session_chat,
// server::get_coco_server_datasources,
// server::get_coco_server_connectors,
server::websocket::connect_to_server,

View File

@@ -28,7 +28,7 @@ pub async fn handle_sso_callback<R: Runtime>(
let path = request_access_token_url(&request_id, &code);
// Send the request for the access token using the util::http::HttpClient::get method
let response = HttpClient::get(&server_id, &path)
let response = HttpClient::get(&server_id, &path, None)
.await
.map_err(|e| format!("Failed to send request to the server: {}", e))?;

View File

@@ -96,7 +96,7 @@ pub async fn get_connectors_from_cache_or_remote(
pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, String> {
// Use the generic GET method from HttpClient
let resp = HttpClient::get(&id, "/connector/_search")
let resp = HttpClient::get(&id, "/connector/_search",None)
.await
.map_err(|e| {
// dbg!("Error fetching connector for id {}: {}", &id, &e);

View File

@@ -91,7 +91,7 @@ pub async fn get_datasources_by_server<R: Runtime>(
) -> Result<Vec<DataSource>, String> {
// Perform the async HTTP request outside the cache lock
let resp = HttpClient::get(&id, "/datasource/_search")
let resp = HttpClient::get(&id, "/datasource/_search",None)
.await
.map_err(|e| {
// dbg!("Error fetching datasource: {}", &e);

View File

@@ -1,8 +1,10 @@
use crate::server::servers::{get_server_by_id, get_server_token};
use std::time::Duration;
use http::HeaderName;
use once_cell::sync::Lazy;
use reqwest::{Client, Method, RequestBuilder};
use std::collections::HashMap;
use std::time::Duration;
use tauri::ipc::RuntimeCapability;
use tokio::sync::Mutex;
pub static HTTP_CLIENT: Lazy<Mutex<Client>> = Lazy::new(|| {
@@ -29,26 +31,22 @@ impl HttpClient {
pub async fn send_raw_request(
method: Method,
url: &str,
headers: Option<reqwest::header::HeaderMap>,
query_params: Option<HashMap<String, String>>,
headers: Option<HashMap<String, String>>,
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
let request_builder = Self::get_request_builder(method, url, headers, body).await;
let mut request_builder = Self::get_request_builder(method, url, headers, query_params, body).await;
// Send the request
let response = match request_builder.send().await {
Ok(resp) => resp,
Err(e) => {
dbg!("Failed to send request: {}", &e);
return Err(format!("Failed to send request: {}", e));
}
};
let response = request_builder.send().await
.map_err(|e| format!("Failed to send request: {}", e))?;
Ok(response)
}
pub async fn get_request_builder(
method: Method,
url: &str,
headers: Option<reqwest::header::HeaderMap>,
headers: Option<HashMap<String, String>>,
query_params: Option<HashMap<String, String>>, // Add query parameters
body: Option<reqwest::Body>,
) -> RequestBuilder {
let client = HTTP_CLIENT.lock().await; // Acquire the lock on HTTP_CLIENT
@@ -56,11 +54,21 @@ impl HttpClient {
// Build the request
let mut request_builder = client.request(method.clone(), url);
// Add headers if present
if let Some(h) = headers {
request_builder = request_builder.headers(h);
let mut req_headers = reqwest::header::HeaderMap::new();
for (key, value) in h.into_iter() {
let _ = req_headers.insert(
HeaderName::from_bytes(key.as_bytes()).unwrap(),
reqwest::header::HeaderValue::from_str(&value).unwrap(),
);
}
request_builder = request_builder.headers(req_headers);
}
if let Some(query) = query_params {
request_builder = request_builder.query(&query);
}
// Add body if present
if let Some(b) = body {
request_builder = request_builder.body(b);
@@ -73,6 +81,8 @@ impl HttpClient {
server_id: &str,
method: Method,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<HashMap<String, String>>,
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
// Fetch the server using the server_id
@@ -84,50 +94,72 @@ impl HttpClient {
// Retrieve the token for the server (token is optional)
let token = get_server_token(server_id).map(|t| t.access_token.clone());
// Create headers map (optional "X-API-TOKEN" header)
let mut headers = reqwest::header::HeaderMap::new();
let mut headers = if let Some(custom_headers) = custom_headers {
custom_headers
} else {
let mut headers = HashMap::new();
headers
};
if let Some(t) = token {
headers.insert(
"X-API-TOKEN",
reqwest::header::HeaderValue::from_str(&t).unwrap(),
"X-API-TOKEN".to_string(),
t,
);
}
// dbg!(&server_id);
// dbg!(&url);
// dbg!(&headers);
Self::send_raw_request(method, &url, Some(headers), body).await
Self::send_raw_request(method, &url, query_params, Some(headers), body).await
} else {
Err("Server not found".to_string())
}
}
// Convenience method for GET requests (as it's the most common)
pub async fn get(server_id: &str, path: &str) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::GET, path, None).await
pub async fn get(server_id: &str, path: &str, query_params: Option<HashMap<String, String>>, // Add query parameters
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::GET, path, None, query_params, None).await
}
// Convenience method for POST requests
pub async fn post(
server_id: &str,
path: &str,
body: reqwest::Body,
query_params: Option<HashMap<String, String>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::POST, path, Some(body)).await
HttpClient::send_request(server_id, Method::POST, path, None, query_params, body).await
}
pub async fn advanced_post(
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<HashMap<String, String>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::POST, path, custom_headers, query_params, body).await
}
// Convenience method for PUT requests
pub async fn put(
server_id: &str,
path: &str,
body: reqwest::Body,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<HashMap<String, String>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::PUT, path, Some(body)).await
HttpClient::send_request(server_id, Method::PUT, path, custom_headers, query_params, body).await
}
// Convenience method for DELETE requests
pub async fn delete(server_id: &str, path: &str) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::DELETE, path, None).await
pub async fn delete(server_id: &str, path: &str, custom_headers: Option<HashMap<String, String>>,
query_params: Option<HashMap<String, String>>, // Add query parameters
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::DELETE, path, custom_headers, query_params, None).await
}
}

View File

@@ -8,7 +8,7 @@ pub async fn get_user_profiles<R: Runtime>(
server_id: String,
) -> Result<UserProfile, String> {
// Use the generic GET method from HttpClient
let response = HttpClient::get(&server_id, "/account/profile")
let response = HttpClient::get(&server_id, "/account/profile", None)
.await
.map_err(|e| format!("Error fetching profile: {}", e))?;

View File

@@ -294,7 +294,7 @@ pub async fn refresh_coco_server_info<R: Runtime>(
let profile = server.profile;
// Use the HttpClient to send the request
let response = HttpClient::get(&id, "/provider/_info") // Assuming "/provider-info" is the endpoint
let response = HttpClient::get(&id, "/provider/_info", None) // Assuming "/provider-info" is the endpoint
.await
.map_err(|e| format!("Failed to send request to the server: {}", e))?;
@@ -366,7 +366,7 @@ pub async fn add_coco_server<R: Runtime>(
let url = provider_info_url(&endpoint);
// Use the HttpClient to fetch provider information
let response = HttpClient::send_raw_request(Method::GET, url.as_str(), None, None)
let response = HttpClient::send_raw_request(Method::GET, url.as_str(), None, None, None)
.await
.map_err(|e| format!("Failed to send request to the server: {}", e))?;