mirror of
https://github.com/infinilabs/coco-app.git
synced 2025-12-16 11:37:47 +01:00
refactor: tighten up Coco servers state management (#790)
* refactor: tighten up Coco servers state management * ignore unused warnings * log out if the failed request has status 401
This commit is contained in:
@@ -1,8 +1,23 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use thiserror::Error;
|
||||
|
||||
fn serialize_optional_status_code<S>(
|
||||
status_code: &Option<StatusCode>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match status_code {
|
||||
Some(code) => serializer.serialize_str(&format!("{:?}", code)),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ErrorCause {
|
||||
#[serde(default)]
|
||||
pub r#type: Option<String>,
|
||||
@@ -11,7 +26,7 @@ pub struct ErrorCause {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
#[allow(unused)]
|
||||
pub struct ErrorDetail {
|
||||
#[serde(default)]
|
||||
pub root_cause: Option<Vec<ErrorCause>>,
|
||||
@@ -24,18 +39,22 @@ pub struct ErrorDetail {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ErrorResponse {
|
||||
#[serde(default)]
|
||||
pub error: Option<ErrorDetail>,
|
||||
#[serde(default)]
|
||||
#[allow(unused)]
|
||||
pub status: Option<u16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Serialize)]
|
||||
pub enum SearchError {
|
||||
#[error("HttpError: {0}")]
|
||||
HttpError(String),
|
||||
#[error("HttpError: status code [{status_code:?}], msg [{msg}]")]
|
||||
HttpError {
|
||||
#[serde(serialize_with = "serialize_optional_status_code")]
|
||||
status_code: Option<StatusCode>,
|
||||
msg: String,
|
||||
},
|
||||
|
||||
#[error("ParseError: {0}")]
|
||||
ParseError(String),
|
||||
@@ -43,12 +62,7 @@ pub enum SearchError {
|
||||
#[error("Timeout occurred")]
|
||||
Timeout,
|
||||
|
||||
#[error("UnknownError: {0}")]
|
||||
#[allow(dead_code)]
|
||||
Unknown(String),
|
||||
|
||||
#[error("InternalError: {0}")]
|
||||
#[allow(dead_code)]
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
@@ -59,7 +73,10 @@ impl From<reqwest::Error> for SearchError {
|
||||
} else if err.is_decode() {
|
||||
SearchError::ParseError(err.to_string())
|
||||
} else {
|
||||
SearchError::HttpError(err.to_string())
|
||||
SearchError::HttpError {
|
||||
status_code: err.status(),
|
||||
msg: err.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,20 +83,6 @@ where
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn parse_search_results_with_score<T>(
|
||||
response: Response,
|
||||
) -> Result<Vec<(T, Option<f64>)>, Box<dyn Error>>
|
||||
where
|
||||
T: for<'de> Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
Ok(parse_search_hits(response)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|hit| (hit._source, hit._score))
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct SearchQuery {
|
||||
pub from: u64,
|
||||
|
||||
@@ -50,9 +50,17 @@ pub struct Server {
|
||||
pub updated: String,
|
||||
#[serde(default = "default_enabled_type")]
|
||||
pub enabled: bool,
|
||||
/// Public Coco servers can be used without signing in.
|
||||
#[serde(default = "default_bool_type")]
|
||||
pub public: bool,
|
||||
|
||||
/// A coco server is available if:
|
||||
///
|
||||
/// 1. It is still online, we check this via the `GET /base_url/provider/_info`
|
||||
/// interface.
|
||||
/// 2. A user is logged in to this Coco server, i.e., a token is stored in the
|
||||
/// `SERVER_TOKEN_LIST_CACHE`.
|
||||
/// For public Coco servers, requirement 2 is not needed.
|
||||
#[serde(default = "default_available_type")]
|
||||
pub available: bool,
|
||||
|
||||
@@ -84,7 +92,10 @@ pub struct ServerAccessToken {
|
||||
#[serde(default = "default_empty_string")] // Custom default function for empty string
|
||||
pub id: String,
|
||||
pub access_token: String,
|
||||
pub expired_at: u32, //unix timestamp in seconds
|
||||
/// Unix timestamp in seconds
|
||||
///
|
||||
/// Currently, this is UNUSED.
|
||||
pub expired_at: u32,
|
||||
}
|
||||
|
||||
impl ServerAccessToken {
|
||||
|
||||
@@ -169,7 +169,7 @@ pub fn run() {
|
||||
#[cfg(any(target_os = "macos", target_os = "windows"))]
|
||||
extension::built_in::file_search::config::set_file_system_config,
|
||||
server::synthesize::synthesize,
|
||||
util::file::get_file_icon,
|
||||
util::file::get_file_icon,
|
||||
util::app_lang::update_app_lang,
|
||||
])
|
||||
.setup(|app| {
|
||||
@@ -273,7 +273,7 @@ pub async fn init<R: Runtime>(app_handle: &AppHandle<R>) {
|
||||
log::error!("Failed to load server tokens: {}", err);
|
||||
}
|
||||
|
||||
let coco_servers = server::servers::get_all_servers();
|
||||
let coco_servers = server::servers::get_all_servers().await;
|
||||
|
||||
// Get the registry from Tauri's state
|
||||
// let registry: State<SearchSourceRegistry> = app_handle.state::<SearchSourceRegistry>();
|
||||
@@ -562,12 +562,12 @@ fn set_up_tauri_logger() -> TauriPlugin<tauri::Wry> {
|
||||
// When running the built binary, set `COCO_LOG` to `coco_lib=trace` to capture all logs
|
||||
// that come from Coco in the log file, which helps with debugging.
|
||||
if !tauri::is_dev() {
|
||||
// We have absolutely no guarantee that we (We have control over the Rust
|
||||
// code, but definitely no idea about the libc C code, all the shared objects
|
||||
// that we will link) will not concurrently read/write `envp`, so just use unsafe.
|
||||
unsafe {
|
||||
std::env::set_var("COCO_LOG", "coco_lib=trace");
|
||||
}
|
||||
// We have absolutely no guarantee that we (We have control over the Rust
|
||||
// code, but definitely no idea about the libc C code, all the shared objects
|
||||
// that we will link) will not concurrently read/write `envp`, so just use unsafe.
|
||||
unsafe {
|
||||
std::env::set_var("COCO_LOG", "coco_lib=trace");
|
||||
}
|
||||
}
|
||||
|
||||
let mut builder = tauri_plugin_log::Builder::new();
|
||||
|
||||
@@ -4,9 +4,12 @@ use crate::common::search::{
|
||||
FailedRequest, MultiSourceQueryResponse, QueryHits, QueryResponse, QuerySource, SearchQuery,
|
||||
};
|
||||
use crate::common::traits::SearchSource;
|
||||
use crate::server::servers::logout_coco_server;
|
||||
use crate::server::servers::mark_server_as_offline;
|
||||
use function_name::named;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use reqwest::StatusCode;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
@@ -14,7 +17,7 @@ use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use tauri::{AppHandle, Manager, Runtime};
|
||||
use tokio::time::error::Elapsed;
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
/// Helper function to return the Future used for querying querysources.
|
||||
///
|
||||
@@ -191,9 +194,38 @@ pub async fn query_coco_fusion<R: Runtime>(
|
||||
query_source.id,
|
||||
search_error
|
||||
);
|
||||
|
||||
let mut status_code_num: u16 = 0;
|
||||
|
||||
if let SearchError::HttpError {
|
||||
status_code: opt_status_code,
|
||||
msg: _,
|
||||
} = search_error
|
||||
{
|
||||
if let Some(status_code) = opt_status_code {
|
||||
status_code_num = status_code.as_u16();
|
||||
if status_code != StatusCode::OK {
|
||||
if status_code == StatusCode::UNAUTHORIZED {
|
||||
// This Coco server is unavailable. In addition to marking it as
|
||||
// unavailable, we need to log out because the status code is 401.
|
||||
logout_coco_server(app_handle.clone(), query_source.id.clone()).await.unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"the search request to Coco server [id {}, name {}] failed with status code {}, the login token is invalid, we are trying to log out, but failed with error [{}]",
|
||||
query_source.id, query_source.name, StatusCode::UNAUTHORIZED, e
|
||||
);
|
||||
})
|
||||
} else {
|
||||
// This Coco server is unavailable
|
||||
mark_server_as_offline(app_handle.clone(), &query_source.id)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
failed_requests.push(FailedRequest {
|
||||
source: query_source,
|
||||
status: 0,
|
||||
status: status_code_num,
|
||||
error: Some(search_error.to_string()),
|
||||
reason: None,
|
||||
});
|
||||
|
||||
@@ -45,10 +45,12 @@ pub async fn upload_attachment(
|
||||
form = form.part("files", part);
|
||||
}
|
||||
|
||||
let server = get_server_by_id(&server_id).ok_or("Server not found")?;
|
||||
let server = get_server_by_id(&server_id)
|
||||
.await
|
||||
.ok_or("Server not found")?;
|
||||
let url = HttpClient::join_url(&server.endpoint, &format!("attachment/_upload"));
|
||||
|
||||
let token = get_server_token(&server_id).await?;
|
||||
let token = get_server_token(&server_id).await;
|
||||
let mut headers = HashMap::new();
|
||||
if let Some(token) = token {
|
||||
headers.insert("X-API-TOKEN".to_string(), token.access_token);
|
||||
|
||||
@@ -20,15 +20,15 @@ pub async fn handle_sso_callback<R: Runtime>(
|
||||
code: String,
|
||||
) -> Result<(), String> {
|
||||
// Retrieve the server details using the server ID
|
||||
let server = get_server_by_id(&server_id);
|
||||
let server = get_server_by_id(&server_id).await;
|
||||
|
||||
let expire_in = 3600; // TODO, need to update to actual expire_in value
|
||||
if let Some(mut server) = server {
|
||||
// Save the access token for the server
|
||||
let access_token = ServerAccessToken::new(server_id.clone(), code.clone(), expire_in);
|
||||
// dbg!(&server_id, &request_id, &code, &token);
|
||||
save_access_token(server_id.clone(), access_token);
|
||||
persist_servers_token(&app_handle)?;
|
||||
save_access_token(server_id.clone(), access_token).await;
|
||||
persist_servers_token(&app_handle).await?;
|
||||
|
||||
// Register the server to the search source
|
||||
try_register_server_to_search_source(app_handle.clone(), &server).await;
|
||||
@@ -41,7 +41,7 @@ pub async fn handle_sso_callback<R: Runtime>(
|
||||
Ok(p) => {
|
||||
server.profile = Some(p);
|
||||
server.available = true;
|
||||
save_server(&server);
|
||||
save_server(&server).await;
|
||||
persist_servers(&app_handle).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ pub fn get_connector_by_id(server_id: &str, connector_id: &str) -> Option<Connec
|
||||
}
|
||||
|
||||
pub async fn refresh_all_connectors<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
|
||||
let servers = get_all_servers();
|
||||
let servers = get_all_servers().await;
|
||||
|
||||
// Collect all the tasks for fetching and refreshing connectors
|
||||
let mut server_map = HashMap::new();
|
||||
|
||||
@@ -34,7 +34,7 @@ pub fn get_datasources_from_cache(server_id: &str) -> Option<HashMap<String, Dat
|
||||
pub async fn refresh_all_datasources<R: Runtime>(_app_handle: &AppHandle<R>) -> Result<(), String> {
|
||||
// dbg!("Attempting to refresh all datasources");
|
||||
|
||||
let servers = get_all_servers();
|
||||
let servers = get_all_servers().await;
|
||||
|
||||
let mut server_map = HashMap::new();
|
||||
|
||||
|
||||
@@ -175,14 +175,14 @@ impl HttpClient {
|
||||
body: Option<reqwest::Body>,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
// Fetch the server using the server_id
|
||||
let server = get_server_by_id(server_id);
|
||||
let server = get_server_by_id(server_id).await;
|
||||
if let Some(s) = server {
|
||||
// Construct the URL
|
||||
let url = HttpClient::join_url(&s.endpoint, path);
|
||||
|
||||
// Retrieve the token for the server (token is optional)
|
||||
let token = get_server_token(server_id)
|
||||
.await?
|
||||
.await
|
||||
.map(|t| t.access_token.clone());
|
||||
|
||||
let mut headers = if let Some(custom_headers) = custom_headers {
|
||||
@@ -205,7 +205,7 @@ impl HttpClient {
|
||||
|
||||
Self::send_raw_request(method, &url, query_params, Some(headers), body).await
|
||||
} else {
|
||||
Err("Server not found".to_string())
|
||||
Err(format!("Server [{}] not found", server_id))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ use crate::server::http_client::HttpClient;
|
||||
use async_trait::async_trait;
|
||||
// use futures::stream::StreamExt;
|
||||
use ordered_float::OrderedFloat;
|
||||
use reqwest::StatusCode;
|
||||
use std::collections::HashMap;
|
||||
// use std::hash::Hash;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct DocumentsSizedCollector {
|
||||
@@ -44,7 +44,7 @@ impl DocumentsSizedCollector {
|
||||
}
|
||||
}
|
||||
|
||||
fn documents(self) -> impl ExactSizeIterator<Item=Document> {
|
||||
fn documents(self) -> impl ExactSizeIterator<Item = Document> {
|
||||
self.docs.into_iter().map(|(_, doc, _)| doc)
|
||||
}
|
||||
|
||||
@@ -108,7 +108,18 @@ impl SearchSource for CocoSearchSource {
|
||||
|
||||
let response = HttpClient::get(&self.server.id, &url, Some(query_params))
|
||||
.await
|
||||
.map_err(|e| SearchError::HttpError(format!("{}", e)))?;
|
||||
.map_err(|e| SearchError::HttpError {
|
||||
status_code: None,
|
||||
msg: format!("{}", e),
|
||||
})?;
|
||||
let status_code = response.status();
|
||||
|
||||
if ![StatusCode::OK, StatusCode::CREATED].contains(&status_code) {
|
||||
return Err(SearchError::HttpError {
|
||||
status_code: Some(status_code),
|
||||
msg: format!("Request failed with status code [{}]", status_code),
|
||||
});
|
||||
}
|
||||
|
||||
// Use the helper function to parse the response body
|
||||
let response_body = get_response_body_text(response)
|
||||
@@ -123,7 +134,6 @@ impl SearchSource for CocoSearchSource {
|
||||
let parsed: SearchResponse<Document> = serde_json::from_str(&response_body)
|
||||
.map_err(|e| SearchError::ParseError(format!("{}", e)))?;
|
||||
|
||||
|
||||
// Process the parsed response
|
||||
total_hits = parsed.hits.total.value as usize;
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::COCO_TAURI_STORE;
|
||||
use crate::common::http::get_response_body_text;
|
||||
use crate::common::register::SearchSourceRegistry;
|
||||
use crate::common::server::{AuthProvider, Provider, Server, ServerAccessToken, Sso, Version};
|
||||
@@ -5,68 +6,72 @@ use crate::server::connector::fetch_connectors_by_server;
|
||||
use crate::server::datasource::datasource_search;
|
||||
use crate::server::http_client::HttpClient;
|
||||
use crate::server::search::CocoSearchSource;
|
||||
use crate::COCO_TAURI_STORE;
|
||||
use lazy_static::lazy_static;
|
||||
use function_name;
|
||||
use http::StatusCode;
|
||||
use reqwest::Method;
|
||||
use serde_json::from_value;
|
||||
use serde_json::Value as JsonValue;
|
||||
use serde_json::from_value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::LazyLock;
|
||||
use tauri::Runtime;
|
||||
use tauri::{AppHandle, Manager};
|
||||
use tauri_plugin_store::StoreExt;
|
||||
// Assuming you're using serde_json
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
lazy_static! {
|
||||
static ref SERVER_CACHE: Arc<RwLock<HashMap<String, Server>>> =
|
||||
Arc::new(RwLock::new(HashMap::new()));
|
||||
static ref SERVER_TOKEN: Arc<RwLock<HashMap<String, ServerAccessToken>>> =
|
||||
Arc::new(RwLock::new(HashMap::new()));
|
||||
}
|
||||
/// Coco sever list
|
||||
static SERVER_LIST_CACHE: LazyLock<RwLock<HashMap<String, Server>>> =
|
||||
LazyLock::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn check_server_exists(id: &str) -> bool {
|
||||
let cache = SERVER_CACHE.read().unwrap(); // Acquire read lock
|
||||
cache.contains_key(id)
|
||||
}
|
||||
/// If a server has a token stored here that has not expired, it is considered logged in.
|
||||
///
|
||||
/// Since the `expire_at` field of `struct ServerAccessToken` is currently unused,
|
||||
/// all servers stored here are treated as logged in.
|
||||
static SERVER_TOKEN_LIST_CACHE: LazyLock<RwLock<HashMap<String, ServerAccessToken>>> =
|
||||
LazyLock::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
pub fn get_server_by_id(id: &str) -> Option<Server> {
|
||||
let cache = SERVER_CACHE.read().unwrap(); // Acquire read lock
|
||||
/// `SERVER_LIST_CACHE` will be stored in KV store COCO_TAURI_STORE, under this key.
|
||||
pub const COCO_SERVERS: &str = "coco_servers";
|
||||
|
||||
/// `SERVER_TOKEN_LIST_CACHE` will be stored in KV store COCO_TAURI_STORE, under this key.
|
||||
const COCO_SERVER_TOKENS: &str = "coco_server_tokens";
|
||||
|
||||
pub async fn get_server_by_id(id: &str) -> Option<Server> {
|
||||
let cache = SERVER_LIST_CACHE.read().await;
|
||||
cache.get(id).cloned()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_server_token(id: &str) -> Result<Option<ServerAccessToken>, String> {
|
||||
let cache = SERVER_TOKEN.read().map_err(|err| err.to_string())?;
|
||||
pub async fn get_server_token(id: &str) -> Option<ServerAccessToken> {
|
||||
let cache = SERVER_TOKEN_LIST_CACHE.read().await;
|
||||
|
||||
Ok(cache.get(id).cloned())
|
||||
cache.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn save_access_token(server_id: String, token: ServerAccessToken) -> bool {
|
||||
let mut cache = SERVER_TOKEN.write().unwrap();
|
||||
pub async fn save_access_token(server_id: String, token: ServerAccessToken) -> bool {
|
||||
let mut cache = SERVER_TOKEN_LIST_CACHE.write().await;
|
||||
cache.insert(server_id, token).is_none()
|
||||
}
|
||||
|
||||
fn check_endpoint_exists(endpoint: &str) -> bool {
|
||||
let cache = SERVER_CACHE.read().unwrap();
|
||||
async fn check_endpoint_exists(endpoint: &str) -> bool {
|
||||
let cache = SERVER_LIST_CACHE.read().await;
|
||||
cache.values().any(|server| server.endpoint == endpoint)
|
||||
}
|
||||
|
||||
pub fn save_server(server: &Server) -> bool {
|
||||
let mut cache = SERVER_CACHE.write().unwrap();
|
||||
cache.insert(server.id.clone(), server.clone()).is_none() // If the server id did not exist, `insert` will return `None`
|
||||
/// Return true if `server` does not exists in the server list, i.e., it is a newly-added
|
||||
/// server.
|
||||
pub async fn save_server(server: &Server) -> bool {
|
||||
let mut cache = SERVER_LIST_CACHE.write().await;
|
||||
cache.insert(server.id.clone(), server.clone()).is_none()
|
||||
}
|
||||
|
||||
fn remove_server_by_id(id: String) -> bool {
|
||||
/// Return the removed `Server` if it exists in the server list.
|
||||
async fn remove_server_by_id(id: &str) -> Option<Server> {
|
||||
log::debug!("remove server by id: {}", &id);
|
||||
let mut cache = SERVER_CACHE.write().unwrap();
|
||||
let deleted = cache.remove(id.as_str());
|
||||
deleted.is_some()
|
||||
let mut cache = SERVER_LIST_CACHE.write().await;
|
||||
cache.remove(id)
|
||||
}
|
||||
|
||||
pub async fn persist_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
|
||||
let cache = SERVER_CACHE.read().unwrap(); // Acquire a read lock, not a write lock, since you're not modifying the cache
|
||||
let cache = SERVER_LIST_CACHE.read().await;
|
||||
|
||||
// Convert HashMap to Vec for serialization (iterating over values of HashMap)
|
||||
let servers: Vec<Server> = cache.values().cloned().collect();
|
||||
@@ -86,14 +91,16 @@ pub async fn persist_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<()
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove_server_token(id: &str) -> bool {
|
||||
/// Return true if the server token of the server specified by `id` exists in
|
||||
/// the token list and gets deleted.
|
||||
pub async fn remove_server_token(id: &str) -> bool {
|
||||
log::debug!("remove server token by id: {}", &id);
|
||||
let mut cache = SERVER_TOKEN.write().unwrap();
|
||||
let mut cache = SERVER_TOKEN_LIST_CACHE.write().await;
|
||||
cache.remove(id).is_some()
|
||||
}
|
||||
|
||||
pub fn persist_servers_token<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
|
||||
let cache = SERVER_TOKEN.read().unwrap(); // Acquire a read lock, not a write lock, since you're not modifying the cache
|
||||
pub async fn persist_servers_token<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
|
||||
let cache = SERVER_TOKEN_LIST_CACHE.read().await;
|
||||
|
||||
// Convert HashMap to Vec for serialization (iterating over values of HashMap)
|
||||
let servers: Vec<ServerAccessToken> = cache.values().cloned().collect();
|
||||
@@ -173,26 +180,42 @@ pub async fn load_servers_token<R: Runtime>(
|
||||
servers.ok_or_else(|| "Failed to read servers from store: No servers found".to_string())?;
|
||||
|
||||
// Convert each item in the JsonValue array to a Server
|
||||
if let JsonValue::Array(servers_array) = servers {
|
||||
// Deserialize each JsonValue into Server, filtering out any errors
|
||||
let deserialized_tokens: Vec<ServerAccessToken> = servers_array
|
||||
.into_iter()
|
||||
.filter_map(|server_json| from_value(server_json).ok()) // Only keep valid Server instances
|
||||
.collect();
|
||||
match servers {
|
||||
JsonValue::Array(servers_array) => {
|
||||
let mut deserialized_tokens: Vec<ServerAccessToken> =
|
||||
Vec::with_capacity(servers_array.len());
|
||||
for server_json in servers_array {
|
||||
match from_value(server_json.clone()) {
|
||||
Ok(token) => {
|
||||
deserialized_tokens.push(token);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"failed to deserialize JSON [{}] to [struct ServerAccessToken], error [{}], store [{}] key [{}] is possibly corrupted!",
|
||||
server_json, e, COCO_TAURI_STORE, COCO_SERVER_TOKENS
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if deserialized_tokens.is_empty() {
|
||||
return Err("Failed to deserialize any servers from the store.".to_string());
|
||||
if deserialized_tokens.is_empty() {
|
||||
return Err("Failed to deserialize any servers from the store.".to_string());
|
||||
}
|
||||
|
||||
for server in deserialized_tokens.iter() {
|
||||
save_access_token(server.id.clone(), server.clone()).await;
|
||||
}
|
||||
|
||||
log::debug!("loaded {:?} servers's token", &deserialized_tokens.len());
|
||||
|
||||
Ok(deserialized_tokens)
|
||||
}
|
||||
|
||||
for server in deserialized_tokens.iter() {
|
||||
save_access_token(server.id.clone(), server.clone());
|
||||
_ => {
|
||||
unreachable!(
|
||||
"coco server tokens should be stored in an array under store [{}] key [{}], but it is not",
|
||||
COCO_TAURI_STORE, COCO_SERVER_TOKENS
|
||||
);
|
||||
}
|
||||
|
||||
log::debug!("loaded {:?} servers's token", &deserialized_tokens.len());
|
||||
|
||||
Ok(deserialized_tokens)
|
||||
} else {
|
||||
Err("Failed to read servers from store: Invalid format".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,26 +237,41 @@ pub async fn load_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<Vec<S
|
||||
servers.ok_or_else(|| "Failed to read servers from store: No servers found".to_string())?;
|
||||
|
||||
// Convert each item in the JsonValue array to a Server
|
||||
if let JsonValue::Array(servers_array) = servers {
|
||||
// Deserialize each JsonValue into Server, filtering out any errors
|
||||
let deserialized_servers: Vec<Server> = servers_array
|
||||
.into_iter()
|
||||
.filter_map(|server_json| from_value(server_json).ok()) // Only keep valid Server instances
|
||||
.collect();
|
||||
match servers {
|
||||
JsonValue::Array(servers_array) => {
|
||||
let mut deserialized_servers = Vec::with_capacity(servers_array.len());
|
||||
for server_json in servers_array {
|
||||
match from_value(server_json.clone()) {
|
||||
Ok(server) => {
|
||||
deserialized_servers.push(server);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"failed to deserialize JSON [{}] to [struct Server], error [{}], store [{}] key [{}] is possibly corrupted!",
|
||||
server_json, e, COCO_TAURI_STORE, COCO_SERVERS
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if deserialized_servers.is_empty() {
|
||||
return Err("Failed to deserialize any servers from the store.".to_string());
|
||||
if deserialized_servers.is_empty() {
|
||||
return Err("Failed to deserialize any servers from the store.".to_string());
|
||||
}
|
||||
|
||||
for server in deserialized_servers.iter() {
|
||||
save_server(&server).await;
|
||||
}
|
||||
|
||||
log::debug!("load servers: {:?}", &deserialized_servers);
|
||||
|
||||
Ok(deserialized_servers)
|
||||
}
|
||||
|
||||
for server in deserialized_servers.iter() {
|
||||
save_server(&server);
|
||||
_ => {
|
||||
unreachable!(
|
||||
"coco servers should be stored in an array under store [{}] key [{}], but it is not",
|
||||
COCO_TAURI_STORE, COCO_SERVERS
|
||||
);
|
||||
}
|
||||
|
||||
log::debug!("load servers: {:?}", &deserialized_servers);
|
||||
|
||||
Ok(deserialized_servers)
|
||||
} else {
|
||||
Err("Failed to read servers from store: Invalid format".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +288,7 @@ pub async fn load_or_insert_default_server<R: Runtime>(
|
||||
}
|
||||
|
||||
let default = get_default_server();
|
||||
save_server(&default);
|
||||
save_server(&default).await;
|
||||
|
||||
log::debug!("loaded default servers");
|
||||
|
||||
@@ -259,33 +297,22 @@ pub async fn load_or_insert_default_server<R: Runtime>(
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_coco_servers<R: Runtime>(
|
||||
_app_handle: AppHandle<R>,
|
||||
app_handle: AppHandle<R>,
|
||||
) -> Result<Vec<Server>, String> {
|
||||
//hard fresh all server's info, in order to get the actual health
|
||||
refresh_all_coco_server_info(_app_handle.clone()).await;
|
||||
refresh_all_coco_server_info(app_handle.clone()).await;
|
||||
|
||||
let servers: Vec<Server> = get_all_servers();
|
||||
let servers: Vec<Server> = get_all_servers().await;
|
||||
Ok(servers)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_servers_as_hashmap() -> HashMap<String, Server> {
|
||||
let cache = SERVER_CACHE.read().unwrap();
|
||||
cache.clone()
|
||||
}
|
||||
|
||||
pub fn get_all_servers() -> Vec<Server> {
|
||||
let cache = SERVER_CACHE.read().unwrap();
|
||||
pub async fn get_all_servers() -> Vec<Server> {
|
||||
let cache = SERVER_LIST_CACHE.read().await;
|
||||
cache.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// We store added Coco servers in the Tauri store using this key.
|
||||
pub const COCO_SERVERS: &str = "coco_servers";
|
||||
|
||||
const COCO_SERVER_TOKENS: &str = "coco_server_tokens";
|
||||
|
||||
pub async fn refresh_all_coco_server_info<R: Runtime>(app_handle: AppHandle<R>) {
|
||||
let servers = get_all_servers();
|
||||
let servers = get_all_servers().await;
|
||||
for server in servers {
|
||||
let _ = refresh_coco_server_info(app_handle.clone(), server.id.clone()).await;
|
||||
}
|
||||
@@ -298,7 +325,7 @@ pub async fn refresh_coco_server_info<R: Runtime>(
|
||||
) -> Result<Server, String> {
|
||||
// Retrieve the server from the cache
|
||||
let cached_server = {
|
||||
let cache = SERVER_CACHE.read().unwrap();
|
||||
let cache = SERVER_LIST_CACHE.read().await;
|
||||
cache.get(&id).cloned()
|
||||
};
|
||||
|
||||
@@ -313,19 +340,16 @@ pub async fn refresh_coco_server_info<R: Runtime>(
|
||||
let profile = server.profile;
|
||||
|
||||
// Send request to fetch updated server info
|
||||
let response = HttpClient::get(&id, "/provider/_info", None)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to contact the server: {}", e));
|
||||
|
||||
if response.is_err() {
|
||||
let _ = mark_server_as_offline(app_handle, &id).await;
|
||||
return Err(response.err().unwrap());
|
||||
}
|
||||
|
||||
let response = response?;
|
||||
let response = match HttpClient::get(&id, "/provider/_info", None).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
mark_server_as_offline(app_handle, &id).await;
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
let _ = mark_server_as_offline(app_handle, &id).await;
|
||||
mark_server_as_offline(app_handle, &id).await;
|
||||
return Err(format!("Request failed with status: {}", response.status()));
|
||||
}
|
||||
|
||||
@@ -336,19 +360,25 @@ pub async fn refresh_coco_server_info<R: Runtime>(
|
||||
let mut updated_server: Server = serde_json::from_str(&body)
|
||||
.map_err(|e| format!("Failed to deserialize the response: {}", e))?;
|
||||
|
||||
// Mark server as online
|
||||
let _ = mark_server_as_online(app_handle.clone(), &id).await;
|
||||
|
||||
// Restore local state
|
||||
updated_server.id = id.clone();
|
||||
updated_server.builtin = is_builtin;
|
||||
updated_server.enabled = is_enabled;
|
||||
updated_server.available = true;
|
||||
updated_server.available = {
|
||||
if server.public {
|
||||
// Public Coco servers are available as long as they are online.
|
||||
true
|
||||
} else {
|
||||
// For non-public Coco servers, we still need to check if it is
|
||||
// logged in, i.e., has a token stored in `SERVER_TOKEN_LIST_CACHE`.
|
||||
get_server_token(&id).await.is_some()
|
||||
}
|
||||
};
|
||||
updated_server.profile = profile;
|
||||
trim_endpoint_last_forward_slash(&mut updated_server);
|
||||
|
||||
// Save and persist
|
||||
save_server(&updated_server);
|
||||
save_server(&updated_server).await;
|
||||
persist_servers(&app_handle)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to persist servers: {}", e))?;
|
||||
@@ -371,10 +401,10 @@ pub async fn add_coco_server<R: Runtime>(
|
||||
|
||||
let endpoint = endpoint.trim_end_matches('/');
|
||||
|
||||
if check_endpoint_exists(endpoint) {
|
||||
if check_endpoint_exists(endpoint).await {
|
||||
log::debug!(
|
||||
"This Coco server has already been registered: {:?}",
|
||||
&endpoint
|
||||
"trying to register a Coco server [{}] that has already been registered",
|
||||
endpoint
|
||||
);
|
||||
return Err("This Coco server has already been registered.".into());
|
||||
}
|
||||
@@ -386,6 +416,15 @@ pub async fn add_coco_server<R: Runtime>(
|
||||
|
||||
log::debug!("Get provider info response: {:?}", &response);
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
log::debug!(
|
||||
"trying to register a Coco server [{}] that is possibly down",
|
||||
endpoint
|
||||
);
|
||||
|
||||
return Err("This Coco server is possibly down".into());
|
||||
}
|
||||
|
||||
let body = get_response_body_text(response).await?;
|
||||
|
||||
let mut server: Server = serde_json::from_str(&body)
|
||||
@@ -393,15 +432,32 @@ pub async fn add_coco_server<R: Runtime>(
|
||||
|
||||
trim_endpoint_last_forward_slash(&mut server);
|
||||
|
||||
// The JSON returned from `provider/_info` won't have this field, serde will set
|
||||
// it to an empty string during deserialization, we need to set a valid value here.
|
||||
if server.id.is_empty() {
|
||||
server.id = pizza_common::utils::uuid::Uuid::new().to_string();
|
||||
}
|
||||
|
||||
// Use the default name, if it is not set.
|
||||
if server.name.is_empty() {
|
||||
server.name = "Coco Server".to_string();
|
||||
}
|
||||
|
||||
save_server(&server);
|
||||
// Update the `available` field
|
||||
if server.public {
|
||||
// Serde already sets this to true, but just to make the code clear, do it again.
|
||||
server.available = true;
|
||||
} else {
|
||||
let opt_token = get_server_token(&server.id).await;
|
||||
assert!(
|
||||
opt_token.is_none(),
|
||||
"this Coco server is newly-added, we should have no token stored for it!"
|
||||
);
|
||||
// This is a non-public Coco server, and it is not logged in, so it is unavailable.
|
||||
server.available = false;
|
||||
}
|
||||
|
||||
save_server(&server).await;
|
||||
try_register_server_to_search_source(app_handle.clone(), &server).await;
|
||||
|
||||
persist_servers(&app_handle)
|
||||
@@ -413,6 +469,7 @@ pub async fn add_coco_server<R: Runtime>(
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
#[function_name::named]
|
||||
pub async fn remove_coco_server<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
id: String,
|
||||
@@ -420,51 +477,104 @@ pub async fn remove_coco_server<R: Runtime>(
|
||||
let registry = app_handle.state::<SearchSourceRegistry>();
|
||||
registry.remove_source(id.as_str()).await;
|
||||
|
||||
remove_server_token(id.as_str());
|
||||
remove_server_by_id(id);
|
||||
|
||||
let opt_server = remove_server_by_id(id.as_str()).await;
|
||||
let Some(server) = opt_server else {
|
||||
panic!(
|
||||
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
|
||||
function_name!(),
|
||||
id
|
||||
);
|
||||
};
|
||||
persist_servers(&app_handle)
|
||||
.await
|
||||
.expect("failed to save servers");
|
||||
persist_servers_token(&app_handle).expect("failed to save server tokens");
|
||||
|
||||
// Only non-public Coco servers require tokens
|
||||
if !server.public {
|
||||
// If is logged in, clear the token as well.
|
||||
let deleted = remove_server_token(id.as_str()).await;
|
||||
if deleted {
|
||||
persist_servers_token(&app_handle)
|
||||
.await
|
||||
.expect("failed to save server tokens");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
#[function_name::named]
|
||||
pub async fn enable_server<R: Runtime>(app_handle: AppHandle<R>, id: String) -> Result<(), ()> {
|
||||
println!("enable_server: {}", id);
|
||||
let opt_server = get_server_by_id(id.as_str()).await;
|
||||
|
||||
let server = get_server_by_id(id.as_str());
|
||||
if let Some(mut server) = server {
|
||||
server.enabled = true;
|
||||
save_server(&server);
|
||||
let Some(mut server) = opt_server else {
|
||||
panic!(
|
||||
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
|
||||
function_name!(),
|
||||
id
|
||||
);
|
||||
};
|
||||
|
||||
// Register the server to the search source
|
||||
try_register_server_to_search_source(app_handle.clone(), &server).await;
|
||||
server.enabled = true;
|
||||
save_server(&server).await;
|
||||
|
||||
persist_servers(&app_handle)
|
||||
.await
|
||||
.expect("failed to save servers");
|
||||
}
|
||||
// Register the server to the search source
|
||||
try_register_server_to_search_source(app_handle.clone(), &server).await;
|
||||
|
||||
persist_servers(&app_handle)
|
||||
.await
|
||||
.expect("failed to save servers");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
#[function_name::named]
|
||||
pub async fn disable_server<R: Runtime>(app_handle: AppHandle<R>, id: String) -> Result<(), ()> {
|
||||
let opt_server = get_server_by_id(id.as_str()).await;
|
||||
|
||||
let Some(mut server) = opt_server else {
|
||||
panic!(
|
||||
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
|
||||
function_name!(),
|
||||
id
|
||||
);
|
||||
};
|
||||
|
||||
server.enabled = false;
|
||||
|
||||
let registry = app_handle.state::<SearchSourceRegistry>();
|
||||
registry.remove_source(id.as_str()).await;
|
||||
|
||||
save_server(&server).await;
|
||||
persist_servers(&app_handle)
|
||||
.await
|
||||
.expect("failed to save servers");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// For non-public Coco servers, we add it to the search source as long as it is
|
||||
/// enabled.
|
||||
///
|
||||
/// For public Coco server, an extra token is required.
|
||||
pub async fn try_register_server_to_search_source(
|
||||
app_handle: AppHandle<impl Runtime>,
|
||||
server: &Server,
|
||||
) {
|
||||
if server.enabled {
|
||||
log::trace!(
|
||||
"Server {} is public: {} and available: {}",
|
||||
"Server [name: {}, id: {}] is public: {} and available: {}",
|
||||
&server.name,
|
||||
&server.id,
|
||||
&server.public,
|
||||
&server.available
|
||||
);
|
||||
|
||||
if !server.public {
|
||||
let token = get_server_token(&server.id).await;
|
||||
let opt_token = get_server_token(&server.id).await;
|
||||
|
||||
if !token.is_ok() || token.is_ok() && token.unwrap().is_none() {
|
||||
if opt_token.is_none() {
|
||||
log::debug!("Server {} is not public and no token was found", &server.id);
|
||||
return;
|
||||
}
|
||||
@@ -476,113 +586,110 @@ pub async fn try_register_server_to_search_source(
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn mark_server_as_online<R: Runtime>(
|
||||
app_handle: AppHandle<R>, id: &str) -> Result<(), ()> {
|
||||
// println!("server_is_offline: {}", id);
|
||||
let server = get_server_by_id(id);
|
||||
#[function_name::named]
|
||||
#[allow(unused)]
|
||||
async fn mark_server_as_online<R: Runtime>(app_handle: AppHandle<R>, id: &str) {
|
||||
let server = get_server_by_id(id).await;
|
||||
if let Some(mut server) = server {
|
||||
server.available = true;
|
||||
server.health = None;
|
||||
save_server(&server);
|
||||
save_server(&server).await;
|
||||
|
||||
try_register_server_to_search_source(app_handle.clone(), &server).await;
|
||||
} else {
|
||||
log::warn!(
|
||||
"[{}()] invoked with a server [{}] that does not exist!",
|
||||
function_name!(),
|
||||
id
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn mark_server_as_offline<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
id: &str,
|
||||
) -> Result<(), ()> {
|
||||
// println!("server_is_offline: {}", id);
|
||||
let server = get_server_by_id(id);
|
||||
#[function_name::named]
|
||||
pub(crate) async fn mark_server_as_offline<R: Runtime>(app_handle: AppHandle<R>, id: &str) {
|
||||
let server = get_server_by_id(id).await;
|
||||
if let Some(mut server) = server {
|
||||
server.available = false;
|
||||
server.health = None;
|
||||
save_server(&server);
|
||||
save_server(&server).await;
|
||||
|
||||
let registry = app_handle.state::<SearchSourceRegistry>();
|
||||
registry.remove_source(id).await;
|
||||
} else {
|
||||
log::warn!(
|
||||
"[{}()] invoked with a server [{}] that does not exist!",
|
||||
function_name!(),
|
||||
id
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn disable_server<R: Runtime>(app_handle: AppHandle<R>, id: String) -> Result<(), ()> {
|
||||
let server = get_server_by_id(id.as_str());
|
||||
if let Some(mut server) = server {
|
||||
server.enabled = false;
|
||||
|
||||
let registry = app_handle.state::<SearchSourceRegistry>();
|
||||
registry.remove_source(id.as_str()).await;
|
||||
|
||||
save_server(&server);
|
||||
persist_servers(&app_handle)
|
||||
.await
|
||||
.expect("failed to save servers");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
#[function_name::named]
|
||||
pub async fn logout_coco_server<R: Runtime>(
|
||||
app_handle: AppHandle<R>,
|
||||
id: String,
|
||||
) -> Result<(), String> {
|
||||
log::debug!("Attempting to log out server by id: {}", &id);
|
||||
|
||||
// Check if server token exists
|
||||
if let Some(_token) = get_server_token(id.as_str()).await? {
|
||||
log::debug!("Found server token for id: {}", &id);
|
||||
// Check if the server exists
|
||||
let Some(mut server) = get_server_by_id(id.as_str()).await else {
|
||||
panic!(
|
||||
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
|
||||
function_name!(),
|
||||
id
|
||||
);
|
||||
};
|
||||
|
||||
// Clear server profile
|
||||
server.profile = None;
|
||||
// Logging out from a non-public Coco server makes it unavailable
|
||||
if !server.public {
|
||||
server.available = false;
|
||||
}
|
||||
// Save the updated server data
|
||||
save_server(&server).await;
|
||||
// Persist the updated server data
|
||||
if let Err(e) = persist_servers(&app_handle).await {
|
||||
log::debug!("Failed to save server for id: {}. Error: {:?}", &id, &e);
|
||||
return Err(format!("Failed to save server: {}", &e));
|
||||
}
|
||||
|
||||
let has_token = get_server_token(id.as_str()).await.is_some();
|
||||
if server.public {
|
||||
if has_token {
|
||||
panic!("Public Coco server won't have token")
|
||||
}
|
||||
} else {
|
||||
assert!(
|
||||
has_token,
|
||||
"This is a non-public Coco server, and it is logged in, we should have a token"
|
||||
);
|
||||
// Remove the server token from cache
|
||||
remove_server_token(id.as_str());
|
||||
remove_server_token(id.as_str()).await;
|
||||
|
||||
// Persist the updated tokens
|
||||
if let Err(e) = persist_servers_token(&app_handle) {
|
||||
if let Err(e) = persist_servers_token(&app_handle).await {
|
||||
log::debug!("Failed to save tokens for id: {}. Error: {:?}", &id, &e);
|
||||
return Err(format!("Failed to save tokens: {}", &e));
|
||||
}
|
||||
} else {
|
||||
// Log the case where server token is not found
|
||||
log::debug!("No server token found for id: {}", &id);
|
||||
}
|
||||
|
||||
// Check if the server exists
|
||||
if let Some(mut server) = get_server_by_id(id.as_str()) {
|
||||
log::debug!("Found server for id: {}", &id);
|
||||
|
||||
// Clear server profile
|
||||
server.profile = None;
|
||||
let _ = mark_server_as_offline(app_handle.clone(), id.as_str()).await;
|
||||
|
||||
// Save the updated server data
|
||||
save_server(&server);
|
||||
|
||||
// Persist the updated server data
|
||||
if let Err(e) = persist_servers(&app_handle).await {
|
||||
log::debug!("Failed to save server for id: {}. Error: {:?}", &id, &e);
|
||||
return Err(format!("Failed to save server: {}", &e));
|
||||
}
|
||||
} else {
|
||||
// Log the case where server is not found
|
||||
log::debug!("No server found for id: {}", &id);
|
||||
return Err(format!("No server found for id: {}", id));
|
||||
// Remove it from the search source if it becomes unavailable
|
||||
if !server.available {
|
||||
let registry = app_handle.state::<SearchSourceRegistry>();
|
||||
registry.remove_source(id.as_str()).await;
|
||||
}
|
||||
|
||||
log::debug!("Successfully logged out server with id: {}", &id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes the trailing slash from the server's endpoint if present.
|
||||
/// Helper function to remove the trailing slash from the server's endpoint if present.
|
||||
fn trim_endpoint_last_forward_slash(server: &mut Server) {
|
||||
if server.endpoint.ends_with('/') {
|
||||
server.endpoint.pop(); // Remove the last character
|
||||
while server.endpoint.ends_with('/') {
|
||||
server.endpoint.pop();
|
||||
}
|
||||
let endpoint = &mut server.endpoint;
|
||||
while endpoint.ends_with('/') {
|
||||
endpoint.pop();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -591,42 +698,47 @@ fn provider_info_url(endpoint: &str) -> String {
|
||||
format!("{endpoint}/provider/_info")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_endpoint_last_forward_slash() {
|
||||
let mut server = Server {
|
||||
id: "test".to_string(),
|
||||
builtin: false,
|
||||
enabled: true,
|
||||
name: "".to_string(),
|
||||
endpoint: "https://example.com///".to_string(),
|
||||
provider: Provider {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_trim_endpoint_last_forward_slash() {
|
||||
let mut server = Server {
|
||||
id: "test".to_string(),
|
||||
builtin: false,
|
||||
enabled: true,
|
||||
name: "".to_string(),
|
||||
icon: "".to_string(),
|
||||
website: "".to_string(),
|
||||
eula: "".to_string(),
|
||||
privacy_policy: "".to_string(),
|
||||
banner: "".to_string(),
|
||||
description: "".to_string(),
|
||||
},
|
||||
version: Version {
|
||||
number: "".to_string(),
|
||||
},
|
||||
minimal_client_version: None,
|
||||
updated: "".to_string(),
|
||||
public: false,
|
||||
available: false,
|
||||
health: None,
|
||||
profile: None,
|
||||
auth_provider: AuthProvider {
|
||||
sso: Sso {
|
||||
url: "".to_string(),
|
||||
endpoint: "https://example.com///".to_string(),
|
||||
provider: Provider {
|
||||
name: "".to_string(),
|
||||
icon: "".to_string(),
|
||||
website: "".to_string(),
|
||||
eula: "".to_string(),
|
||||
privacy_policy: "".to_string(),
|
||||
banner: "".to_string(),
|
||||
description: "".to_string(),
|
||||
},
|
||||
},
|
||||
priority: 0,
|
||||
stats: None,
|
||||
};
|
||||
version: Version {
|
||||
number: "".to_string(),
|
||||
},
|
||||
minimal_client_version: None,
|
||||
updated: "".to_string(),
|
||||
public: false,
|
||||
available: false,
|
||||
health: None,
|
||||
profile: None,
|
||||
auth_provider: AuthProvider {
|
||||
sso: Sso {
|
||||
url: "".to_string(),
|
||||
},
|
||||
},
|
||||
priority: 0,
|
||||
stats: None,
|
||||
};
|
||||
|
||||
trim_endpoint_last_forward_slash(&mut server);
|
||||
trim_endpoint_last_forward_slash(&mut server);
|
||||
|
||||
assert_eq!(server.endpoint, "https://example.com");
|
||||
assert_eq!(server.endpoint, "https://example.com");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tauri::{AppHandle, Emitter, Runtime};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::{connect_async_tls_with_config, Connector};
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
|
||||
use tokio_tungstenite::{Connector, connect_async_tls_with_config};
|
||||
#[derive(Default)]
|
||||
pub struct WebSocketManager {
|
||||
connections: Arc<Mutex<HashMap<String, Arc<WebSocketInstance>>>>,
|
||||
@@ -53,9 +53,11 @@ pub async fn connect_to_server<R: Runtime>(
|
||||
// Disconnect old connection first
|
||||
disconnect(client_id.clone(), state.clone()).await.ok();
|
||||
|
||||
let server = get_server_by_id(&id).ok_or(format!("Server with ID {} not found", id))?;
|
||||
let server = get_server_by_id(&id)
|
||||
.await
|
||||
.ok_or(format!("Server with ID {} not found", id))?;
|
||||
let endpoint = convert_to_websocket(&server.endpoint)?;
|
||||
let token = get_server_token(&id).await?.map(|t| t.access_token.clone());
|
||||
let token = get_server_token(&id).await.map(|t| t.access_token.clone());
|
||||
|
||||
let mut request =
|
||||
tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(&endpoint)
|
||||
@@ -95,8 +97,8 @@ pub async fn connect_to_server<R: Runtime>(
|
||||
true, // disable_nagle
|
||||
Some(connector), // Connector
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("WebSocket TLS error: {:?}", e))?;
|
||||
.await
|
||||
.map_err(|e| format!("WebSocket TLS error: {:?}", e))?;
|
||||
|
||||
let (cancel_tx, mut cancel_rx) = mpsc::channel(1);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user