refactor: do status code check before deserializing response (#767)

* refactor: do status code check before deserializing response

This commit adds a status code check to the following requests, only when
this check passes, we deserialize the response JSON body:

- get_connectors_by_server
- mcp_server_search
- datasource_search

A helper function `status_code_check(response, allowed_status_codes)`
is added to make refactoring easier.

* chore: release notes
This commit is contained in:
SteveLauC
2025-07-17 15:08:14 +08:00
committed by GitHub
parent 494e2f0d8a
commit 14fbf2ac5d
4 changed files with 37 additions and 4 deletions

View File

@@ -39,6 +39,7 @@ Information about release notes of Coco Server is provided here.
- chore: make optional fields optional #758
- chore: search-chat components add formatUrl & think data & icons url #765
- chore: Coco app http request headers #744
- refactor: do status code check before deserializing response #767
## 0.6.0 (2025-06-29)

View File

@@ -1,7 +1,8 @@
use crate::common::connector::Connector;
use crate::common::search::parse_search_results;
use crate::server::http_client::HttpClient;
use crate::server::http_client::{HttpClient, status_code_check};
use crate::server::servers::get_all_servers;
use http::StatusCode;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
@@ -107,6 +108,7 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, Stri
// dbg!("Error fetching connector for id {}: {}", &id, &e);
format!("Error fetching connector: {}", e)
})?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results directly from the response body
let datasource: Vec<Connector> = parse_search_results(resp)

View File

@@ -1,8 +1,9 @@
use crate::common::datasource::DataSource;
use crate::common::search::parse_search_results;
use crate::server::connector::get_connector_by_id;
use crate::server::http_client::HttpClient;
use crate::server::http_client::{HttpClient, status_code_check};
use crate::server::servers::get_all_servers;
use http::StatusCode;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
@@ -25,7 +26,7 @@ pub fn save_datasource_to_cache(server_id: &str, datasources: Vec<DataSource>) {
#[allow(dead_code)]
pub fn get_datasources_from_cache(server_id: &str) -> Option<HashMap<String, DataSource>> {
let cache = DATASOURCE_CACHE.read().unwrap(); // Acquire read lock
// dbg!("cache: {:?}", &cache);
// dbg!("cache: {:?}", &cache);
let server_cache = cache.get(server_id)?; // Get the server's cache
Some(server_cache.clone())
}
@@ -95,6 +96,7 @@ pub async fn datasource_search(
let resp = HttpClient::post(id, "/datasource/_search", query_params, None)
.await
.map_err(|e| format!("Error fetching datasource: {}", e))?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results from the response
let datasources: Vec<DataSource> = parse_search_results(resp).await.map_err(|e| {
@@ -117,6 +119,7 @@ pub async fn mcp_server_search(
let resp = HttpClient::post(id, "/mcp_server/_search", query_params, None)
.await
.map_err(|e| format!("Error fetching datasource: {}", e))?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results from the response
let mcp_server: Vec<DataSource> = parse_search_results(resp).await.map_err(|e| {

View File

@@ -1,7 +1,7 @@
use crate::server::servers::{get_server_by_id, get_server_token};
use crate::util::app_lang::get_app_lang;
use crate::util::platform::Platform;
use http::{HeaderName, HeaderValue};
use http::{HeaderName, HeaderValue, StatusCode};
use once_cell::sync::Lazy;
use reqwest::{Client, Method, RequestBuilder};
use std::collections::HashMap;
@@ -285,3 +285,30 @@ impl HttpClient {
.await
}
}
/// Helper function to check status code.
///
/// If the status code is not in the `allowed_status_codes` list, return an error.
pub(crate) fn status_code_check(
response: &reqwest::Response,
allowed_status_codes: &[StatusCode],
) -> Result<(), String> {
let status_code = response.status();
if !allowed_status_codes.contains(&status_code) {
let msg = format!(
"Response of request [{}] status code failed: status code [{}], which is not in the 'allow' list {:?}",
response.url(),
status_code,
allowed_status_codes
.iter()
.map(|status| status.to_string())
.collect::<Vec<String>>()
);
log::warn!("{}", msg);
Err(msg)
} else {
Ok(())
}
}