From 14fbf2ac5d76d751fc4dda18130312242340b345 Mon Sep 17 00:00:00 2001 From: SteveLauC Date: Thu, 17 Jul 2025 15:08:14 +0800 Subject: [PATCH] refactor: do status code check before deserializing response (#767) * refactor: do status code check before deserializing response This commit adds a status code check to the following requests, only when this check passes, we deserialize the response JSON body: - get_connectors_by_server - mcp_server_search - datasource_search A helper function `status_code_check(response, allowed_status_codes)` is added to make refactoring easier. * chore: release notes --- docs/content.en/docs/release-notes/_index.md | 1 + src-tauri/src/server/connector.rs | 4 ++- src-tauri/src/server/datasource.rs | 7 +++-- src-tauri/src/server/http_client.rs | 29 +++++++++++++++++++- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/content.en/docs/release-notes/_index.md b/docs/content.en/docs/release-notes/_index.md index 133132f9..12c8520b 100644 --- a/docs/content.en/docs/release-notes/_index.md +++ b/docs/content.en/docs/release-notes/_index.md @@ -39,6 +39,7 @@ Information about release notes of Coco Server is provided here. - chore: make optional fields optional #758 - chore: search-chat components add formatUrl & think data & icons url #765 - chore: Coco app http request headers #744 +- refactor: do status code check before deserializing response #767 ## 0.6.0 (2025-06-29) diff --git a/src-tauri/src/server/connector.rs b/src-tauri/src/server/connector.rs index 8e736d9d..25864bc0 100644 --- a/src-tauri/src/server/connector.rs +++ b/src-tauri/src/server/connector.rs @@ -1,7 +1,8 @@ use crate::common::connector::Connector; use crate::common::search::parse_search_results; -use crate::server::http_client::HttpClient; +use crate::server::http_client::{HttpClient, status_code_check}; use crate::server::servers::get_all_servers; +use http::StatusCode; use lazy_static::lazy_static; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -107,6 +108,7 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result, Stri // dbg!("Error fetching connector for id {}: {}", &id, &e); format!("Error fetching connector: {}", e) })?; + status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?; // Parse the search results directly from the response body let datasource: Vec = parse_search_results(resp) diff --git a/src-tauri/src/server/datasource.rs b/src-tauri/src/server/datasource.rs index ed9ec9e6..b7ce7db8 100644 --- a/src-tauri/src/server/datasource.rs +++ b/src-tauri/src/server/datasource.rs @@ -1,8 +1,9 @@ use crate::common::datasource::DataSource; use crate::common::search::parse_search_results; use crate::server::connector::get_connector_by_id; -use crate::server::http_client::HttpClient; +use crate::server::http_client::{HttpClient, status_code_check}; use crate::server::servers::get_all_servers; +use http::StatusCode; use lazy_static::lazy_static; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -25,7 +26,7 @@ pub fn save_datasource_to_cache(server_id: &str, datasources: Vec) { #[allow(dead_code)] pub fn get_datasources_from_cache(server_id: &str) -> Option> { let cache = DATASOURCE_CACHE.read().unwrap(); // Acquire read lock - // dbg!("cache: {:?}", &cache); + // dbg!("cache: {:?}", &cache); let server_cache = cache.get(server_id)?; // Get the server's cache Some(server_cache.clone()) } @@ -95,6 +96,7 @@ pub async fn datasource_search( let resp = HttpClient::post(id, "/datasource/_search", query_params, None) .await .map_err(|e| format!("Error fetching datasource: {}", e))?; + status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?; // Parse the search results from the response let datasources: Vec = parse_search_results(resp).await.map_err(|e| { @@ -117,6 +119,7 @@ pub async fn mcp_server_search( let resp = HttpClient::post(id, "/mcp_server/_search", query_params, None) .await .map_err(|e| format!("Error fetching datasource: {}", e))?; + status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?; // Parse the search results from the response let mcp_server: Vec = parse_search_results(resp).await.map_err(|e| { diff --git a/src-tauri/src/server/http_client.rs b/src-tauri/src/server/http_client.rs index b1ea1ad6..33b96e06 100644 --- a/src-tauri/src/server/http_client.rs +++ b/src-tauri/src/server/http_client.rs @@ -1,7 +1,7 @@ use crate::server::servers::{get_server_by_id, get_server_token}; use crate::util::app_lang::get_app_lang; use crate::util::platform::Platform; -use http::{HeaderName, HeaderValue}; +use http::{HeaderName, HeaderValue, StatusCode}; use once_cell::sync::Lazy; use reqwest::{Client, Method, RequestBuilder}; use std::collections::HashMap; @@ -285,3 +285,30 @@ impl HttpClient { .await } } + +/// Helper function to check status code. +/// +/// If the status code is not in the `allowed_status_codes` list, return an error. +pub(crate) fn status_code_check( + response: &reqwest::Response, + allowed_status_codes: &[StatusCode], +) -> Result<(), String> { + let status_code = response.status(); + + if !allowed_status_codes.contains(&status_code) { + let msg = format!( + "Response of request [{}] status code failed: status code [{}], which is not in the 'allow' list {:?}", + response.url(), + status_code, + allowed_status_codes + .iter() + .map(|status| status.to_string()) + .collect::>() + ); + log::warn!("{}", msg); + + Err(msg) + } else { + Ok(()) + } +}