mirror of
https://github.com/infinilabs/coco-app.git
synced 2025-12-28 16:06:28 +01:00
refacotr: refactoring assistant api (#195)
* refacotr: refactoring assistant api * update release notes
This commit is contained in:
169
src-tauri/src/assistant/mod.rs
Normal file
169
src-tauri/src/assistant/mod.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use crate::common::assistant::InitChatMessage;
|
||||
use crate::common::http::GetResponse;
|
||||
use crate::server::http_client::HttpClient;
|
||||
use reqwest::Response;
|
||||
use std::collections::HashMap;
|
||||
use tauri::{AppHandle, Runtime};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn new_chat<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
message: String,
|
||||
) -> Result<GetResponse, String> {
|
||||
let body = if !message.is_empty() {
|
||||
let message = InitChatMessage { message: Some(message) };
|
||||
let body = reqwest::Body::from(serde_json::to_string(&message).unwrap());
|
||||
Some(body)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let query_params: Option<HashMap<String, String>> = None;
|
||||
|
||||
let response = HttpClient::post(&server_id, "/chat/_new", query_params, body)
|
||||
.await
|
||||
.map_err(|e| format!("Error sending message: {}", e))?;
|
||||
|
||||
if response.status().as_u16() < 200 || response.status().as_u16() >= 400 {
|
||||
return Err("Failed to send message".to_string());
|
||||
}
|
||||
|
||||
let chat_response: GetResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse response JSON: {}", e))?;
|
||||
|
||||
// Check the result and status fields
|
||||
if chat_response.result != "created" {
|
||||
return Err(format!("Unexpected result: {}", chat_response.result));
|
||||
}
|
||||
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn chat_history<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
from: u32,
|
||||
size: u32,
|
||||
) -> Result<String, String> {
|
||||
let mut query_params = HashMap::new();
|
||||
if from > 0 {
|
||||
query_params.insert("from".to_string(), from.to_string());
|
||||
}
|
||||
if size > 0 {
|
||||
query_params.insert("size".to_string(), size.to_string());
|
||||
}
|
||||
|
||||
let response = HttpClient::get(&server_id, "/chat/_history", Some(query_params))
|
||||
.await
|
||||
.map_err(|e| format!("Error get sessions: {}", e))?;
|
||||
|
||||
handle_raw_response(response).await?
|
||||
}
|
||||
|
||||
async fn handle_raw_response(response: Response) -> Result<Result<String, String>, String> {
|
||||
Ok(if response.status().as_u16() < 200 || response.status().as_u16() >= 400 {
|
||||
Err("Failed to send message".to_string())
|
||||
} else {
|
||||
let body = response.text().await.map_err(|e| format!("Failed to parse response JSON: {}", e))?;
|
||||
Ok(body)
|
||||
})
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn session_chat_history<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
session_id: String,
|
||||
from: u32,
|
||||
size: u32,
|
||||
) -> Result<String, String> {
|
||||
let mut query_params = HashMap::new();
|
||||
if from > 0 {
|
||||
query_params.insert("from".to_string(), from.to_string());
|
||||
}
|
||||
if size > 0 {
|
||||
query_params.insert("size".to_string(), size.to_string());
|
||||
}
|
||||
|
||||
let path = format!("/chat/{}/_history", session_id);
|
||||
|
||||
let response = HttpClient::get(&server_id, path.as_str(), Some(query_params))
|
||||
.await
|
||||
.map_err(|e| format!("Error get session message: {}", e))?;
|
||||
|
||||
handle_raw_response(response).await?
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn open_session_chat<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
session_id: String,
|
||||
) -> Result<String, String> {
|
||||
let mut query_params = HashMap::new();
|
||||
let path = format!("/chat/{}/_open", session_id);
|
||||
|
||||
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
|
||||
.await
|
||||
.map_err(|e| format!("Error open session: {}", e))?;
|
||||
|
||||
handle_raw_response(response).await?
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn close_session_chat<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
session_id: String,
|
||||
) -> Result<String, String> {
|
||||
let mut query_params = HashMap::new();
|
||||
let path = format!("/chat/{}/_close", session_id);
|
||||
|
||||
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
|
||||
.await
|
||||
.map_err(|e| format!("Error close session: {}", e))?;
|
||||
|
||||
handle_raw_response(response).await?
|
||||
}
|
||||
#[tauri::command]
|
||||
pub async fn cancel_session_chat<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
session_id: String,
|
||||
) -> Result<String, String> {
|
||||
let mut query_params = HashMap::new();
|
||||
let path = format!("/chat/{}/_cancel", session_id);
|
||||
|
||||
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
|
||||
.await
|
||||
.map_err(|e| format!("Error cancel session: {}", e))?;
|
||||
|
||||
handle_raw_response(response).await?
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn send_message<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
server_id: String,
|
||||
session_id: String,
|
||||
websocket_id: String,
|
||||
query_params: Option<HashMap<String, String>>, //search,deep_thinking
|
||||
) -> Result<String, String> {
|
||||
let path = format!("/chat/{}/_send", session_id);
|
||||
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("WEBSOCKET-SESSION-ID".to_string(), websocket_id);
|
||||
|
||||
let response = HttpClient::advanced_post(&server_id, path.as_str(), Some(headers), query_params, None)
|
||||
.await
|
||||
.map_err(|e| format!("Error cancel session: {}", e))?;
|
||||
|
||||
handle_raw_response(response).await?
|
||||
}
|
||||
|
||||
|
||||
6
src-tauri/src/common/assistant.rs
Normal file
6
src-tauri/src/common/assistant.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InitChatMessage {
|
||||
pub message: Option<String>,
|
||||
}
|
||||
16
src-tauri/src/common/http.rs
Normal file
16
src-tauri/src/common/http.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GetResponse {
|
||||
pub _id: String,
|
||||
pub _source: Source,
|
||||
pub result: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Source {
|
||||
pub id: String,
|
||||
pub created: String,
|
||||
pub updated: String,
|
||||
pub status: String,
|
||||
}
|
||||
@@ -8,6 +8,8 @@ pub mod search;
|
||||
pub mod document;
|
||||
pub mod traits;
|
||||
pub mod register;
|
||||
pub mod assistant;
|
||||
pub mod http;
|
||||
|
||||
pub static MAIN_WINDOW_LABEL: &str = "main";
|
||||
pub static SETTINGS_WINDOW_LABEL: &str = "settings";
|
||||
pub static SETTINGS_WINDOW_LABEL: &str = "settings";
|
||||
|
||||
@@ -7,6 +7,7 @@ mod shortcut;
|
||||
mod util;
|
||||
|
||||
mod setup;
|
||||
mod assistant;
|
||||
|
||||
use crate::common::register::SearchSourceRegistry;
|
||||
// use crate::common::traits::SearchSource;
|
||||
@@ -103,7 +104,13 @@ pub fn run() {
|
||||
server::datasource::get_datasources_by_server,
|
||||
server::connector::get_connectors_by_server,
|
||||
search::query_coco_fusion,
|
||||
// server::get_user_profiles,
|
||||
assistant::chat_history,
|
||||
assistant::new_chat,
|
||||
assistant::send_message,
|
||||
assistant::session_chat_history,
|
||||
assistant::open_session_chat,
|
||||
assistant::close_session_chat,
|
||||
assistant::cancel_session_chat,
|
||||
// server::get_coco_server_datasources,
|
||||
// server::get_coco_server_connectors,
|
||||
server::websocket::connect_to_server,
|
||||
|
||||
@@ -28,7 +28,7 @@ pub async fn handle_sso_callback<R: Runtime>(
|
||||
let path = request_access_token_url(&request_id, &code);
|
||||
|
||||
// Send the request for the access token using the util::http::HttpClient::get method
|
||||
let response = HttpClient::get(&server_id, &path)
|
||||
let response = HttpClient::get(&server_id, &path, None)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to send request to the server: {}", e))?;
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ pub async fn get_connectors_from_cache_or_remote(
|
||||
|
||||
pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, String> {
|
||||
// Use the generic GET method from HttpClient
|
||||
let resp = HttpClient::get(&id, "/connector/_search")
|
||||
let resp = HttpClient::get(&id, "/connector/_search",None)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// dbg!("Error fetching connector for id {}: {}", &id, &e);
|
||||
|
||||
@@ -91,7 +91,7 @@ pub async fn get_datasources_by_server<R: Runtime>(
|
||||
) -> Result<Vec<DataSource>, String> {
|
||||
|
||||
// Perform the async HTTP request outside the cache lock
|
||||
let resp = HttpClient::get(&id, "/datasource/_search")
|
||||
let resp = HttpClient::get(&id, "/datasource/_search",None)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// dbg!("Error fetching datasource: {}", &e);
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use crate::server::servers::{get_server_by_id, get_server_token};
|
||||
use std::time::Duration;
|
||||
|
||||
use http::HeaderName;
|
||||
use once_cell::sync::Lazy;
|
||||
use reqwest::{Client, Method, RequestBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tauri::ipc::RuntimeCapability;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub static HTTP_CLIENT: Lazy<Mutex<Client>> = Lazy::new(|| {
|
||||
@@ -29,26 +31,22 @@ impl HttpClient {
|
||||
pub async fn send_raw_request(
|
||||
method: Method,
|
||||
url: &str,
|
||||
headers: Option<reqwest::header::HeaderMap>,
|
||||
query_params: Option<HashMap<String, String>>,
|
||||
headers: Option<HashMap<String, String>>,
|
||||
body: Option<reqwest::Body>,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
let request_builder = Self::get_request_builder(method, url, headers, body).await;
|
||||
let mut request_builder = Self::get_request_builder(method, url, headers, query_params, body).await;
|
||||
|
||||
// Send the request
|
||||
let response = match request_builder.send().await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
dbg!("Failed to send request: {}", &e);
|
||||
return Err(format!("Failed to send request: {}", e));
|
||||
}
|
||||
};
|
||||
let response = request_builder.send().await
|
||||
.map_err(|e| format!("Failed to send request: {}", e))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn get_request_builder(
|
||||
method: Method,
|
||||
url: &str,
|
||||
headers: Option<reqwest::header::HeaderMap>,
|
||||
headers: Option<HashMap<String, String>>,
|
||||
query_params: Option<HashMap<String, String>>, // Add query parameters
|
||||
body: Option<reqwest::Body>,
|
||||
) -> RequestBuilder {
|
||||
let client = HTTP_CLIENT.lock().await; // Acquire the lock on HTTP_CLIENT
|
||||
@@ -56,11 +54,21 @@ impl HttpClient {
|
||||
// Build the request
|
||||
let mut request_builder = client.request(method.clone(), url);
|
||||
|
||||
// Add headers if present
|
||||
|
||||
if let Some(h) = headers {
|
||||
request_builder = request_builder.headers(h);
|
||||
let mut req_headers = reqwest::header::HeaderMap::new();
|
||||
for (key, value) in h.into_iter() {
|
||||
let _ = req_headers.insert(
|
||||
HeaderName::from_bytes(key.as_bytes()).unwrap(),
|
||||
reqwest::header::HeaderValue::from_str(&value).unwrap(),
|
||||
);
|
||||
}
|
||||
request_builder = request_builder.headers(req_headers);
|
||||
}
|
||||
|
||||
if let Some(query) = query_params {
|
||||
request_builder = request_builder.query(&query);
|
||||
}
|
||||
// Add body if present
|
||||
if let Some(b) = body {
|
||||
request_builder = request_builder.body(b);
|
||||
@@ -73,6 +81,8 @@ impl HttpClient {
|
||||
server_id: &str,
|
||||
method: Method,
|
||||
path: &str,
|
||||
custom_headers: Option<HashMap<String, String>>,
|
||||
query_params: Option<HashMap<String, String>>,
|
||||
body: Option<reqwest::Body>,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
// Fetch the server using the server_id
|
||||
@@ -84,50 +94,72 @@ impl HttpClient {
|
||||
// Retrieve the token for the server (token is optional)
|
||||
let token = get_server_token(server_id).map(|t| t.access_token.clone());
|
||||
|
||||
// Create headers map (optional "X-API-TOKEN" header)
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let mut headers = if let Some(custom_headers) = custom_headers {
|
||||
custom_headers
|
||||
} else {
|
||||
let mut headers = HashMap::new();
|
||||
headers
|
||||
};
|
||||
|
||||
if let Some(t) = token {
|
||||
headers.insert(
|
||||
"X-API-TOKEN",
|
||||
reqwest::header::HeaderValue::from_str(&t).unwrap(),
|
||||
"X-API-TOKEN".to_string(),
|
||||
t,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// dbg!(&server_id);
|
||||
// dbg!(&url);
|
||||
// dbg!(&headers);
|
||||
|
||||
Self::send_raw_request(method, &url, Some(headers), body).await
|
||||
Self::send_raw_request(method, &url, query_params, Some(headers), body).await
|
||||
} else {
|
||||
Err("Server not found".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience method for GET requests (as it's the most common)
|
||||
pub async fn get(server_id: &str, path: &str) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::GET, path, None).await
|
||||
pub async fn get(server_id: &str, path: &str, query_params: Option<HashMap<String, String>>, // Add query parameters
|
||||
) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::GET, path, None, query_params, None).await
|
||||
}
|
||||
|
||||
// Convenience method for POST requests
|
||||
pub async fn post(
|
||||
server_id: &str,
|
||||
path: &str,
|
||||
body: reqwest::Body,
|
||||
query_params: Option<HashMap<String, String>>, // Add query parameters
|
||||
body: Option<reqwest::Body>,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::POST, path, Some(body)).await
|
||||
HttpClient::send_request(server_id, Method::POST, path, None, query_params, body).await
|
||||
}
|
||||
|
||||
pub async fn advanced_post(
|
||||
server_id: &str,
|
||||
path: &str,
|
||||
custom_headers: Option<HashMap<String, String>>,
|
||||
query_params: Option<HashMap<String, String>>, // Add query parameters
|
||||
body: Option<reqwest::Body>,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::POST, path, custom_headers, query_params, body).await
|
||||
}
|
||||
|
||||
// Convenience method for PUT requests
|
||||
pub async fn put(
|
||||
server_id: &str,
|
||||
path: &str,
|
||||
body: reqwest::Body,
|
||||
custom_headers: Option<HashMap<String, String>>,
|
||||
query_params: Option<HashMap<String, String>>, // Add query parameters
|
||||
body: Option<reqwest::Body>,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::PUT, path, Some(body)).await
|
||||
HttpClient::send_request(server_id, Method::PUT, path, custom_headers, query_params, body).await
|
||||
}
|
||||
|
||||
// Convenience method for DELETE requests
|
||||
pub async fn delete(server_id: &str, path: &str) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::DELETE, path, None).await
|
||||
pub async fn delete(server_id: &str, path: &str, custom_headers: Option<HashMap<String, String>>,
|
||||
query_params: Option<HashMap<String, String>>, // Add query parameters
|
||||
) -> Result<reqwest::Response, String> {
|
||||
HttpClient::send_request(server_id, Method::DELETE, path, custom_headers, query_params, None).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ pub async fn get_user_profiles<R: Runtime>(
|
||||
server_id: String,
|
||||
) -> Result<UserProfile, String> {
|
||||
// Use the generic GET method from HttpClient
|
||||
let response = HttpClient::get(&server_id, "/account/profile")
|
||||
let response = HttpClient::get(&server_id, "/account/profile", None)
|
||||
.await
|
||||
.map_err(|e| format!("Error fetching profile: {}", e))?;
|
||||
|
||||
|
||||
@@ -294,7 +294,7 @@ pub async fn refresh_coco_server_info<R: Runtime>(
|
||||
let profile = server.profile;
|
||||
|
||||
// Use the HttpClient to send the request
|
||||
let response = HttpClient::get(&id, "/provider/_info") // Assuming "/provider-info" is the endpoint
|
||||
let response = HttpClient::get(&id, "/provider/_info", None) // Assuming "/provider-info" is the endpoint
|
||||
.await
|
||||
.map_err(|e| format!("Failed to send request to the server: {}", e))?;
|
||||
|
||||
@@ -366,7 +366,7 @@ pub async fn add_coco_server<R: Runtime>(
|
||||
let url = provider_info_url(&endpoint);
|
||||
|
||||
// Use the HttpClient to fetch provider information
|
||||
let response = HttpClient::send_raw_request(Method::GET, url.as_str(), None, None)
|
||||
let response = HttpClient::send_raw_request(Method::GET, url.as_str(), None, None, None)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to send request to the server: {}", e))?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user