From c471a83821e14919578e4e06a4dc78cd7184460a Mon Sep 17 00:00:00 2001 From: ayangweb <75017711+ayangweb@users.noreply.github.com> Date: Fri, 30 May 2025 17:18:52 +0800 Subject: [PATCH] feat: support third party extensions (#572) * refactor: support third party extensions * fix tests * fix: assistant_get error * aaa * bbb * ccc * ddd * fix: aa * fix: aa * sss * fix:asds * eee * refactor: loosen restriction of query string length * fix: input auto * feat: add ai overview trigger condition configuration * refactor: continue chatting to select the corresponding mini-helper * chore: settings width height * aaa --------- Co-authored-by: Steve Lau Co-authored-by: rain <15911122312@163.com> --- docs/content.en/docs/release-notes/_index.md | 2 + package.json | 1 + pnpm-lock.yaml | 3 + public/assets/fonts/icons/extension.js | 1 + src-tauri/Cargo.lock | 57 +- src-tauri/Cargo.toml | 4 + .../assets/extension/AIOverview/plugin.json | 8 + .../assets/extension/Applications/plugin.json | 9 + .../assets/extension/Calculator/plugin.json | 9 + .../extension/QuickAIAccess/plugin.json | 8 + src-tauri/rust-toolchain.toml | 2 +- src-tauri/src/common/document.rs | 70 ++ src-tauri/src/common/traits.rs | 2 - .../src/extension/built_in/ai_overview.rs | 1 + .../built_in}/application/mod.rs | 0 .../built_in}/application/with_feature.rs | 137 +-- .../built_in}/application/without_feature.rs | 46 +- .../built_in}/calculator.rs | 6 +- .../src/extension/built_in/file_system.rs | 1 + src-tauri/src/extension/built_in/mod.rs | 310 +++++++ .../built_in/pizza_engine_runtime.rs | 51 ++ .../src/extension/built_in/quick_ai_access.rs | 1 + src-tauri/src/extension/mod.rs | 825 ++++++++++++++++++ src-tauri/src/extension/third_party.rs | 733 ++++++++++++++++ src-tauri/src/lib.rs | 39 +- src-tauri/src/local/mod.rs | 164 ---- .../src/{local/file_system.rs => mod.rs} | 0 src-tauri/src/search/mod.rs | 27 +- src-tauri/src/server/search.rs | 32 +- src-tauri/src/util/mod.rs | 1 - src-tauri/tauri.conf.json | 4 +- src/components/Assistant/AssistantList.tsx | 21 +- src/components/ChatMessage/MessageActions.tsx | 39 +- src/components/ChatMessage/index.tsx | 11 +- src/components/Common/Icons/AiSummaryIcon.tsx | 49 -- src/components/Search/AiOverview.tsx | 91 ++ src/components/Search/AiSummary.tsx | 49 -- src/components/Search/AskAi.tsx | 26 +- src/components/Search/AssistantManager.tsx | 13 +- src/components/Search/AutoResizeTextarea.tsx | 103 ++- src/components/Search/ContextMenu.tsx | 209 ++--- src/components/Search/DocumentList.tsx | 15 +- src/components/Search/DropdownList.tsx | 32 +- src/components/Search/DropdownListItem.tsx | 26 +- src/components/Search/InputBox.tsx | 48 +- src/components/Search/InputControls.tsx | 49 +- src/components/Search/Search.tsx | 9 +- src/components/Search/SearchIcons.tsx | 4 +- src/components/Search/SearchSource.tsx | 6 +- .../Extensions/components/Content/index.tsx | 321 ++++--- .../components/Details/AiOverview/index.tsx | 91 ++ .../components/Details/Application/index.tsx | 37 +- .../components/Details/Applications/index.tsx | 2 +- .../{QuickAiAccess => SharedAi}/index.tsx | 91 +- .../Extensions/components/Details/index.tsx | 69 +- src/components/Settings/Extensions/index.tsx | 253 ++---- src/components/Settings/SettingsSelectPro.tsx | 8 +- src/components/Settings/SettingsToggle.tsx | 28 +- src/constants/index.ts | 2 - src/hooks/useKeyboardNavigation.ts | 14 +- src/hooks/useScript.ts | 15 +- src/hooks/useSearch.ts | 217 +++-- src/hooks/useStreamChat.ts | 135 +++ src/hooks/useSyncStore.ts | 30 +- src/main.css | 12 +- src/routes/layout.tsx | 16 + src/stores/extensionsStore.ts | 33 + src/stores/searchStore.ts | 11 + src/types/platform.ts | 4 +- src/types/search.ts | 5 +- src/utils/index.ts | 4 + src/utils/platform.ts | 12 + src/utils/tauriAdapter.ts | 5 +- tailwind.config.js | 2 +- tsup.config.ts | 2 +- 75 files changed, 3674 insertions(+), 1099 deletions(-) create mode 100644 public/assets/fonts/icons/extension.js create mode 100644 src-tauri/assets/extension/AIOverview/plugin.json create mode 100644 src-tauri/assets/extension/Applications/plugin.json create mode 100644 src-tauri/assets/extension/Calculator/plugin.json create mode 100644 src-tauri/assets/extension/QuickAIAccess/plugin.json create mode 100644 src-tauri/src/extension/built_in/ai_overview.rs rename src-tauri/src/{local => extension/built_in}/application/mod.rs (100%) rename src-tauri/src/{local => extension/built_in}/application/with_feature.rs (91%) rename src-tauri/src/{local => extension/built_in}/application/without_feature.rs (77%) rename src-tauri/src/{local => extension/built_in}/calculator.rs (97%) create mode 100644 src-tauri/src/extension/built_in/file_system.rs create mode 100644 src-tauri/src/extension/built_in/mod.rs create mode 100644 src-tauri/src/extension/built_in/pizza_engine_runtime.rs create mode 100644 src-tauri/src/extension/built_in/quick_ai_access.rs create mode 100644 src-tauri/src/extension/mod.rs create mode 100644 src-tauri/src/extension/third_party.rs delete mode 100644 src-tauri/src/local/mod.rs rename src-tauri/src/{local/file_system.rs => mod.rs} (100%) delete mode 100644 src/components/Common/Icons/AiSummaryIcon.tsx create mode 100644 src/components/Search/AiOverview.tsx delete mode 100644 src/components/Search/AiSummary.tsx create mode 100644 src/components/Settings/Extensions/components/Details/AiOverview/index.tsx rename src/components/Settings/Extensions/components/Details/{QuickAiAccess => SharedAi}/index.tsx (58%) create mode 100644 src/hooks/useStreamChat.ts diff --git a/docs/content.en/docs/release-notes/_index.md b/docs/content.en/docs/release-notes/_index.md index 8cef906b..7c69dfa5 100644 --- a/docs/content.en/docs/release-notes/_index.md +++ b/docs/content.en/docs/release-notes/_index.md @@ -111,6 +111,8 @@ Information about release notes of Coco Server is provided here. - feat: data sources support displaying customized icons #432 - feat: add shortcut key conflict hint and reset function #442 - feat: updated to include error message #465 +- feat: support third party extensions #572 +- feat: support ai overview #572 ### Bug fix diff --git a/package.json b/package.json index 5b2057f4..d77f8cc6 100644 --- a/package.json +++ b/package.json @@ -62,6 +62,7 @@ "tauri-plugin-macos-permissions-api": "^2.3.0", "tauri-plugin-screenshots-api": "^2.2.0", "tauri-plugin-windows-version-api": "^2.0.0", + "type-fest": "^4.41.0", "use-debounce": "^10.0.4", "uuid": "^11.1.0", "wavesurfer.js": "^7.9.5", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4dbaef35..6e940e55 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -140,6 +140,9 @@ importers: tauri-plugin-windows-version-api: specifier: ^2.0.0 version: 2.0.0 + type-fest: + specifier: ^4.41.0 + version: 4.41.0 use-debounce: specifier: ^10.0.4 version: 10.0.4(react@18.3.1) diff --git a/public/assets/fonts/icons/extension.js b/public/assets/fonts/icons/extension.js new file mode 100644 index 00000000..1983928c --- /dev/null +++ b/public/assets/fonts/icons/extension.js @@ -0,0 +1 @@ +window._iconfont_svg_string_4934333='',(t=>{var l=(a=(a=document.getElementsByTagName("script"))[a.length-1]).getAttribute("data-injectcss"),a=a.getAttribute("data-disable-injectsvg");if(!a){var c,h,i,F,e,o=function(l,a){a.parentNode.insertBefore(l,a)};if(l&&!t.__iconfont__svg__cssinject__){t.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(l){console&&console.log(l)}}c=function(){var l,a=document.createElement("div");a.innerHTML=t._iconfont_svg_string_4934333,(a=a.getElementsByTagName("svg")[0])&&(a.setAttribute("aria-hidden","true"),a.style.position="absolute",a.style.width=0,a.style.height=0,a.style.overflow="hidden",a=a,(l=document.body).firstChild?o(a,l.firstChild):l.appendChild(a))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(c,0):(h=function(){document.removeEventListener("DOMContentLoaded",h,!1),c()},document.addEventListener("DOMContentLoaded",h,!1)):document.attachEvent&&(i=c,F=t.document,e=!1,d(),F.onreadystatechange=function(){"complete"==F.readyState&&(F.onreadystatechange=null,p())})}function p(){e||(e=!0,i())}function d(){try{F.documentElement.doScroll("left")}catch(l){return void setTimeout(d,50)}p()}})(window); \ No newline at end of file diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 708e7f19..ab6bfb02 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -823,13 +823,16 @@ dependencies = [ name = "coco" version = "0.4.0" dependencies = [ + "anyhow", "applications", "async-trait", "base64 0.13.1", "chinese-number", "chrono", + "derive_more 2.0.1", "dirs 5.0.1", "enigo", + "function_name", "futures", "futures-util", "hostname", @@ -847,6 +850,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "serde_plain", "strsim 0.10.0", "tauri", "tauri-build", @@ -1291,6 +1295,27 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", + "unicode-xid", +] + [[package]] name = "digest" version = "0.10.7" @@ -1826,6 +1851,21 @@ dependencies = [ "libc", ] +[[package]] +name = "function_name" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1ab577a896d09940b5fe12ec5ae71f9d8211fff62c919c03a3750a9901e98a7" +dependencies = [ + "function_name-proc-macro", +] + +[[package]] +name = "function_name-proc-macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673464e1e314dd67a0fd9544abc99e8eb28d0c7e3b69b033bcff9b2d00b87333" + [[package]] name = "funty" version = "2.0.0" @@ -5328,7 +5368,7 @@ checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe" dependencies = [ "bitflags 1.3.2", "cssparser", - "derive_more", + "derive_more 0.99.20", "fxhash", "log", "matches", @@ -5403,6 +5443,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_repr" version = "0.1.20" @@ -7040,6 +7089,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 7029c030..23ce256c 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -93,6 +93,10 @@ chinese-number = "0.7" num2words = "1" tauri-plugin-log = "2" chrono = "0.4.41" +serde_plain = "1.0.2" +derive_more = { version = "2.0.1", features = ["display"] } +anyhow = "1.0.98" +function_name = "0.3.0" [target."cfg(target_os = \"macos\")".dependencies] tauri-nspanel = { git = "https://github.com/ahkohd/tauri-nspanel", branch = "v2" } diff --git a/src-tauri/assets/extension/AIOverview/plugin.json b/src-tauri/assets/extension/AIOverview/plugin.json new file mode 100644 index 00000000..54e27027 --- /dev/null +++ b/src-tauri/assets/extension/AIOverview/plugin.json @@ -0,0 +1,8 @@ +{ + "id": "AIOverview", + "title": "AI Overview", + "description": "...", + "icon": "font_a-AIOverview", + "type": "ai_extension", + "enabled": true +} diff --git a/src-tauri/assets/extension/Applications/plugin.json b/src-tauri/assets/extension/Applications/plugin.json new file mode 100644 index 00000000..61977cc0 --- /dev/null +++ b/src-tauri/assets/extension/Applications/plugin.json @@ -0,0 +1,9 @@ +{ + "id": "Applications", + "platforms": ["macos", "linux", "windows"], + "title": "Applications", + "description": "...", + "icon": "font_Application", + "type": "group", + "enabled": true +} diff --git a/src-tauri/assets/extension/Calculator/plugin.json b/src-tauri/assets/extension/Calculator/plugin.json new file mode 100644 index 00000000..8d8ab0da --- /dev/null +++ b/src-tauri/assets/extension/Calculator/plugin.json @@ -0,0 +1,9 @@ +{ + "id": "Calculator", + "title": "Calculator", + "platforms": ["macos", "linux", "windows"], + "description": "...", + "icon": "font_Calculator", + "type": "calculator", + "enabled": true +} diff --git a/src-tauri/assets/extension/QuickAIAccess/plugin.json b/src-tauri/assets/extension/QuickAIAccess/plugin.json new file mode 100644 index 00000000..d3c3494b --- /dev/null +++ b/src-tauri/assets/extension/QuickAIAccess/plugin.json @@ -0,0 +1,8 @@ +{ + "id": "QuickAIAccess", + "title": "Quick AI Access", + "description": "...", + "icon": "font_a-QuickAIAccess", + "type": "ai_extension", + "enabled": true +} diff --git a/src-tauri/rust-toolchain.toml b/src-tauri/rust-toolchain.toml index af7a52c6..62665c35 100644 --- a/src-tauri/rust-toolchain.toml +++ b/src-tauri/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-10-29" \ No newline at end of file +channel = "nightly-2025-02-28" \ No newline at end of file diff --git a/src-tauri/src/common/document.rs b/src-tauri/src/common/document.rs index 663f50ae..cf2d94d3 100644 --- a/src-tauri/src/common/document.rs +++ b/src-tauri/src/common/document.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use crate::hide_coco; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RichLabel { pub label: Option, @@ -29,6 +31,72 @@ pub struct EditorInfo { pub timestamp: Option, } +/// Defines the action that would be performed when a document gets opened. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) enum OnOpened { + /// Launch the application + Application { app_path: String }, + /// Open the URL. + Document { url: String }, + /// Spawn a child process to run the `CommandAction`. + Command { + action: crate::extension::CommandAction, + }, +} + +impl OnOpened { + pub(crate) fn url(&self) -> String { + match self { + Self::Application { app_path } => app_path.clone(), + Self::Document { url } => url.clone(), + Self::Command { action } => { + const WHITESPACE: &str = " "; + let mut ret = action.exec.clone(); + ret.push_str(WHITESPACE); + ret.push_str(action.args.join(WHITESPACE).as_str()); + + ret + } + } + } +} + +#[tauri::command] +pub(crate) async fn open(on_opened: OnOpened) -> Result<(), String> { + log::debug!("open({})", on_opened.url()); + + use crate::util::open as homemade_tauri_shell_open; + use crate::GLOBAL_TAURI_APP_HANDLE; + use std::process::Command; + + let global_tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + + match on_opened { + OnOpened::Application { app_path } => { + homemade_tauri_shell_open(global_tauri_app_handle.clone(), app_path).await? + } + OnOpened::Document { url } => { + homemade_tauri_shell_open(global_tauri_app_handle.clone(), url).await? + } + OnOpened::Command { action } => { + let mut cmd = Command::new(action.exec); + cmd.args(action.args); + let output = cmd.output().map_err(|e| e.to_string())?; + if !output.status.success() { + return Err(format!( + "Command failed, stderr [{}]", + String::from_utf8_lossy(&output.stderr) + )); + } + } + } + + hide_coco(global_tauri_app_handle.clone()).await; + Ok(()) +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Document { pub id: String, @@ -48,6 +116,8 @@ pub struct Document { pub thumbnail: Option, pub cover: Option, pub tags: Option>, + /// What will happen if we open this document. + pub on_opened: Option, pub url: Option, pub size: Option, pub metadata: Option>, diff --git a/src-tauri/src/common/traits.rs b/src-tauri/src/common/traits.rs index a4a1cca2..84683075 100644 --- a/src-tauri/src/common/traits.rs +++ b/src-tauri/src/common/traits.rs @@ -1,5 +1,4 @@ use crate::common::error::SearchError; -// use std::{future::Future, pin::Pin}; use crate::common::search::SearchQuery; use crate::common::search::{QueryResponse, QuerySource}; use async_trait::async_trait; @@ -10,4 +9,3 @@ pub trait SearchSource: Send + Sync { async fn search(&self, query: SearchQuery) -> Result; } - diff --git a/src-tauri/src/extension/built_in/ai_overview.rs b/src-tauri/src/extension/built_in/ai_overview.rs new file mode 100644 index 00000000..692c787d --- /dev/null +++ b/src-tauri/src/extension/built_in/ai_overview.rs @@ -0,0 +1 @@ +pub(super) const EXTENSION_ID: &str = "AIOverview"; \ No newline at end of file diff --git a/src-tauri/src/local/application/mod.rs b/src-tauri/src/extension/built_in/application/mod.rs similarity index 100% rename from src-tauri/src/local/application/mod.rs rename to src-tauri/src/extension/built_in/application/mod.rs diff --git a/src-tauri/src/local/application/with_feature.rs b/src-tauri/src/extension/built_in/application/with_feature.rs similarity index 91% rename from src-tauri/src/local/application/with_feature.rs rename to src-tauri/src/extension/built_in/application/with_feature.rs index 7e40e0c6..95b4e100 100644 --- a/src-tauri/src/local/application/with_feature.rs +++ b/src-tauri/src/extension/built_in/application/with_feature.rs @@ -1,13 +1,14 @@ -use super::super::SearchSourceState; -use super::super::Task; -use super::super::RUNTIME_TX; -use super::AppEntry; +use super::super::pizza_engine_runtime::SearchSourceState; +use super::super::pizza_engine_runtime::Task; +use super::super::pizza_engine_runtime::RUNTIME_TX; +use super::super::Extension; use super::AppMetadata; -use crate::common::document::{DataSourceReference, Document}; +use crate::common::document::{DataSourceReference, Document, OnOpened}; use crate::common::error::SearchError; use crate::common::search::{QueryResponse, QuerySource, SearchQuery}; use crate::common::traits::SearchSource; -use crate::local::LOCAL_QUERY_SOURCE_TYPE; +use crate::extension::ExtensionType; +use crate::extension::LOCAL_QUERY_SOURCE_TYPE; use crate::util::open; use crate::GLOBAL_TAURI_APP_HANDLE; use applications::{App, AppTrait}; @@ -326,7 +327,7 @@ impl Task for SearchApplicationsTask { async fn exec(&mut self, state: &mut Option>) { let callback = self.callback.take().unwrap(); - let disabled_app_list = get_disabled_app_list(self.tauri_app_handle.clone()); + let disabled_app_list = get_disabled_app_list(&self.tauri_app_handle); // TODO: search via alias, implement this when Pizza engine supports update let dsl = format!( @@ -551,19 +552,24 @@ fn pizza_engine_hits_to_coco_hits( FieldValue::Text(string) => string, _ => unreachable!("field icon is of type Text"), }; + let on_opened = OnOpened::Application { + app_path: app_path.clone(), + }; + let url = on_opened.url(); let coco_document = Document { source: Some(DataSourceReference { r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()), name: Some(QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME.into()), id: Some(QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME.into()), - icon: None, + icon: Some(String::from("font_Application")), }), id: app_path.clone(), category: Some("Application".to_string()), title: Some(app_name.clone()), - url: Some(app_path), icon: Some(app_icon_path), + on_opened: Some(on_opened), + url: Some(url), ..Default::default() }; @@ -574,12 +580,7 @@ fn pizza_engine_hits_to_coco_hits( coco_hits } -#[tauri::command] -pub async fn set_app_alias( - tauri_app_handle: AppHandle, - app_path: String, - alias: String, -) { +pub fn set_app_alias(tauri_app_handle: &AppHandle, app_path: &str, alias: &str) { let store = tauri_app_handle .store(TAURI_STORE_APP_ALIAS) .unwrap_or_else(|_| panic!("store [{}] not found/loaded", TAURI_STORE_APP_ALIAS)); @@ -649,42 +650,42 @@ fn register_app_hotkey_upon_start( Ok(()) } -#[tauri::command] -pub async fn register_app_hotkey( - tauri_app_handle: AppHandle, - app_path: String, - hotkey: String, +pub fn register_app_hotkey( + tauri_app_handle: &AppHandle, + app_path: &str, + hotkey: &str, ) -> Result<(), String> { + // Ignore the error as it may not be registered + unregister_app_hotkey(tauri_app_handle, app_path)?; + let app_hotkey_store = tauri_app_handle .store(TAURI_STORE_APP_HOTKEY) .unwrap_or_else(|_| panic!("store [{}] not found/loaded", TAURI_STORE_APP_HOTKEY)); - app_hotkey_store.set(app_path.clone(), hotkey.as_str()); + app_hotkey_store.set(app_path, hotkey); tauri_app_handle .global_shortcut() - .on_shortcut(hotkey.as_str(), app_hotkey_handler(app_path)) + .on_shortcut(hotkey, app_hotkey_handler(app_path.into())) .map_err(|e| e.to_string())?; Ok(()) } -#[tauri::command] -pub async fn unregister_app_hotkey( - tauri_app_handle: AppHandle, - app_path: String, +pub fn unregister_app_hotkey( + tauri_app_handle: &AppHandle, + app_path: &str, ) -> Result<(), String> { let app_hotkey_store = tauri_app_handle .store(TAURI_STORE_APP_HOTKEY) .unwrap_or_else(|_| panic!("store [{}] not found/loaded", TAURI_STORE_APP_HOTKEY)); - let Some(hotkey) = app_hotkey_store.get(app_path.as_str()) else { - let error_msg = format!( + let Some(hotkey) = app_hotkey_store.get(app_path) else { + warn!( "unregister an Application hotkey that does not exist app: [{}]", app_path, ); - warn!("{}", error_msg); - return Err(error_msg); + return Ok(()); }; let hotkey = match hotkey { @@ -692,11 +693,18 @@ pub async fn unregister_app_hotkey( _ => unreachable!("hotkey should be stored in a string"), }; - let deleted = app_hotkey_store.delete(app_path.as_str()); + let deleted = app_hotkey_store.delete(app_path); if !deleted { return Err("failed to delete application hotkey from store".into()); } + if !tauri_app_handle + .global_shortcut() + .is_registered(hotkey.as_str()) + { + panic!("inconsistent state, tauri store a hotkey is stored in the tauri store but it is not registered"); + } + tauri_app_handle .global_shortcut() .unregister(hotkey.as_str()) @@ -705,7 +713,7 @@ pub async fn unregister_app_hotkey( Ok(()) } -fn get_disabled_app_list(tauri_app_handle: AppHandle) -> Vec { +fn get_disabled_app_list(tauri_app_handle: &AppHandle) -> Vec { let store = tauri_app_handle .store(TAURI_STORE_DISABLED_APP_LIST_AND_SEARCH_PATH) .unwrap_or_else(|_| { @@ -732,10 +740,19 @@ fn get_disabled_app_list(tauri_app_handle: AppHandle) -> Vec( - tauri_app_handle: AppHandle, - app_path: String, +pub fn is_app_search_enabled(app_path: &str) -> bool { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + + let disabled_app_list = get_disabled_app_list(tauri_app_handle); + + disabled_app_list.iter().all(|path| path != app_path) +} + +pub fn disable_app_search( + tauri_app_handle: &AppHandle, + app_path: &str, ) -> Result<(), String> { let store = tauri_app_handle .store(TAURI_STORE_DISABLED_APP_LIST_AND_SEARCH_PATH) @@ -748,24 +765,26 @@ pub async fn disable_app_search( let mut disabled_app_list = get_disabled_app_list(tauri_app_handle); - if disabled_app_list.contains(&app_path) { + if disabled_app_list + .iter() + .any(|disabled_app| disabled_app == app_path) + { return Err(format!( "trying to disable an app that is disabled [{}]", app_path )); } - disabled_app_list.push(app_path); + disabled_app_list.push(app_path.into()); store.set(TAURI_STORE_KEY_DISABLED_APP_LIST, disabled_app_list); Ok(()) } -#[tauri::command] -pub async fn enable_app_search( - tauri_app_handle: AppHandle, - app_path: String, +pub fn enable_app_search( + tauri_app_handle: &AppHandle, + app_path: &str, ) -> Result<(), String> { let store = tauri_app_handle .store(TAURI_STORE_DISABLED_APP_LIST_AND_SEARCH_PATH) @@ -879,7 +898,7 @@ pub async fn get_app_search_path(tauri_app_handle: AppHandle) -> #[tauri::command] pub async fn get_app_list( tauri_app_handle: AppHandle, -) -> Result, String> { +) -> Result, String> { let search_paths = get_app_search_path(tauri_app_handle.clone()).await; let apps = list_app_in(search_paths)?; @@ -910,14 +929,12 @@ pub async fn get_app_list( let store = tauri_app_handle .store(TAURI_STORE_APP_HOTKEY) .unwrap_or_else(|_| panic!("store [{}] not found/loaded", TAURI_STORE_APP_HOTKEY)); - let opt_string = store.get(&path).map(|json| match json { + store.get(&path).map(|json| match json { Json::String(s) => s, _ => unreachable!("app hotkey should be stored in a string"), - }); - - opt_string.unwrap_or(String::new()) + }) }; - let is_disabled = { + let enabled = { let store = tauri_app_handle .store(TAURI_STORE_DISABLED_APP_LIST_AND_SEARCH_PATH) .unwrap_or_else(|_| panic!("store [{}] not found/loaded", TAURI_STORE_APP_HOTKEY)); @@ -942,16 +959,26 @@ pub async fn get_app_list( _ => unreachable!("disabled app list should be stored in an array"), }; - disabled_app_list.contains(&path) + !disabled_app_list.contains(&path) }; - let app_entry = AppEntry { - path, - name, - icon_path, - alias, + let app_entry = Extension { + id: path, + title: name, + platforms: None, + // Leave it empty as it won't be used + description: String::new(), + icon: icon_path, + r#type: ExtensionType::Application, + action: None, + quick_link: None, + commands: None, + scripts: None, + quick_links: None, + alias: Some(alias), hotkey, - is_disabled, + enabled, + settings: None, }; app_entries.push(app_entry); diff --git a/src-tauri/src/local/application/without_feature.rs b/src-tauri/src/extension/built_in/application/without_feature.rs similarity index 77% rename from src-tauri/src/local/application/without_feature.rs rename to src-tauri/src/extension/built_in/application/without_feature.rs index 1ca683c2..1538d209 100644 --- a/src-tauri/src/local/application/without_feature.rs +++ b/src-tauri/src/extension/built_in/application/without_feature.rs @@ -1,11 +1,11 @@ +use super::super::Extension; +use super::AppMetadata; use crate::common::error::SearchError; use crate::common::search::{QueryResponse, QuerySource, SearchQuery}; use crate::common::traits::SearchSource; -use crate::local::LOCAL_QUERY_SOURCE_TYPE; +use crate::extension::LOCAL_QUERY_SOURCE_TYPE; use async_trait::async_trait; use tauri::{AppHandle, Runtime}; -use super::AppEntry; -use super::AppMetadata; pub(crate) const QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME: &str = "Applications"; @@ -39,46 +39,45 @@ impl SearchSource for ApplicationSearchSource { } } -#[tauri::command] -pub async fn set_app_alias(_app_path: String, _alias: String) -> Result<(), String> { +pub fn set_app_alias(_tauri_app_handle: &AppHandle, _app_path: &str, _alias: &str) { unreachable!("app list should be empty, there is no way this can be invoked") } -#[tauri::command] -pub async fn register_app_hotkey( - _tauri_app_handle: AppHandle, - _app_path: String, - _hotkey: String, +pub fn register_app_hotkey( + _tauri_app_handle: &AppHandle, + _app_path: &str, + _hotkey: &str, ) -> Result<(), String> { unreachable!("app list should be empty, there is no way this can be invoked") } -#[tauri::command] -pub async fn unregister_app_hotkey( - _tauri_app_handle: AppHandle, - _app_path: String, +pub fn unregister_app_hotkey( + _tauri_app_handle: &AppHandle, + _app_path: &str, ) -> Result<(), String> { unreachable!("app list should be empty, there is no way this can be invoked") } -#[tauri::command] -pub async fn disable_app_search( - _tauri_app_handle: AppHandle, - _app_path: String, +pub fn disable_app_search( + _tauri_app_handle: &AppHandle, + _app_path: &str, ) -> Result<(), String> { // no-op Ok(()) } -#[tauri::command] -pub async fn enable_app_search( - _tauri_app_handle: AppHandle, - _app_path: String, +pub fn enable_app_search( + _tauri_app_handle: &AppHandle, + _app_path: &str, ) -> Result<(), String> { // no-op Ok(()) } +pub fn is_app_search_enabled(_app_path: &str) -> bool { + false +} + #[tauri::command] pub async fn add_app_search_path( _tauri_app_handle: AppHandle, @@ -103,11 +102,10 @@ pub async fn get_app_search_path(_tauri_app_handle: AppHandle) -> Vec::new() } - #[tauri::command] pub async fn get_app_list( _tauri_app_handle: AppHandle, -) -> Result, String> { +) -> Result, String> { // Return an empty list Ok(Vec::new()) } diff --git a/src-tauri/src/local/calculator.rs b/src-tauri/src/extension/built_in/calculator.rs similarity index 97% rename from src-tauri/src/local/calculator.rs rename to src-tauri/src/extension/built_in/calculator.rs index 15654298..b8e978a4 100644 --- a/src-tauri/src/local/calculator.rs +++ b/src-tauri/src/extension/built_in/calculator.rs @@ -1,4 +1,4 @@ -use super::LOCAL_QUERY_SOURCE_TYPE; +use super::super::LOCAL_QUERY_SOURCE_TYPE; use crate::common::{ document::{DataSourceReference, Document}, error::SearchError, @@ -116,7 +116,7 @@ impl SearchSource for CalculatorSource { }); }; - // Trim the leading and tailing whitespace so that our later if condition + // Trim the leading and tailing whitespace so that our later if condition // will only be evaluated against non-whitespace characters. let query_string = query_string.trim(); @@ -146,7 +146,7 @@ impl SearchSource for CalculatorSource { r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()), name: Some(DATA_SOURCE_ID.into()), id: Some(DATA_SOURCE_ID.into()), - icon: None, + icon: Some(String::from("font_Calculator")), }), ..Default::default() }; diff --git a/src-tauri/src/extension/built_in/file_system.rs b/src-tauri/src/extension/built_in/file_system.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src-tauri/src/extension/built_in/file_system.rs @@ -0,0 +1 @@ + diff --git a/src-tauri/src/extension/built_in/mod.rs b/src-tauri/src/extension/built_in/mod.rs new file mode 100644 index 00000000..a5c0854b --- /dev/null +++ b/src-tauri/src/extension/built_in/mod.rs @@ -0,0 +1,310 @@ +//! Built-in extensions and related stuff. + +pub mod ai_overview; +pub mod application; +pub mod calculator; +pub mod file_system; +pub mod pizza_engine_runtime; +pub mod quick_ai_access; + +use super::Extension; +use crate::extension::{alter_extension_json_file, load_extension_from_json_file}; +use crate::{SearchSourceRegistry, GLOBAL_TAURI_APP_HANDLE}; +use std::path::PathBuf; +use std::sync::LazyLock; +use tauri::path::BaseDirectory; +use tauri::Manager; + +pub(crate) static BUILT_IN_EXTENSION_DIRECTORY: LazyLock = LazyLock::new(|| { + let mut resource_dir = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set") + .path() + .resolve("assets", BaseDirectory::Resource) + .expect( + "User home directory not found, which should be impossible on desktop environments", + ); + resource_dir.push("extension"); + + resource_dir +}); + +pub(super) async fn init_built_in_extension( + extension: &Extension, + search_source_registry: &SearchSourceRegistry, +) { + log::trace!("initializing built-in extensions"); + + if extension.id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { + search_source_registry + .register_source(application::ApplicationSearchSource) + .await; + log::debug!("built-in extension [{}] initialized", extension.id); + } + + if extension.id == calculator::DATA_SOURCE_ID { + let calculator_search = calculator::CalculatorSource::new(2000f64); + search_source_registry + .register_source(calculator_search) + .await; + log::debug!("built-in extension [{}] initialized", extension.id); + } +} + +pub(crate) fn is_extension_built_in(extension_id: &str) -> bool { + if extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { + return true; + } + + if extension_id.starts_with(&format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + )) { + return true; + } + + if extension_id == calculator::DATA_SOURCE_ID { + return true; + } + + if extension_id == quick_ai_access::EXTENSION_ID { + return true; + } + + if extension_id == ai_overview::EXTENSION_ID { + return true; + } + + false +} + +pub(crate) async fn enable_built_in_extension(extension_id: &str) -> Result<(), String> { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let search_source_registry_tauri_state = tauri_app_handle.state::(); + + let update_extension = |extension: &mut Extension| -> Result<(), String> { + extension.enabled = true; + Ok(()) + }; + + if extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { + search_source_registry_tauri_state + .register_source(application::ApplicationSearchSource) + .await; + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + + return Ok(()); + } + + // Check if this is an application + let application_prefix = format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + ); + if extension_id.starts_with(&application_prefix) { + let app_path = &extension_id[application_prefix.len()..]; + application::enable_app_search(tauri_app_handle, app_path)?; + return Ok(()); + } + + if extension_id == calculator::DATA_SOURCE_ID { + let calculator_search = calculator::CalculatorSource::new(2000f64); + search_source_registry_tauri_state + .register_source(calculator_search) + .await; + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + return Ok(()); + } + + if extension_id == quick_ai_access::EXTENSION_ID { + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + return Ok(()); + } + + if extension_id == ai_overview::EXTENSION_ID { + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + return Ok(()); + } + + Ok(()) +} + +pub(crate) async fn disable_built_in_extension(extension_id: &str) -> Result<(), String> { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let search_source_registry_tauri_state = tauri_app_handle.state::(); + + let update_extension = |extension: &mut Extension| -> Result<(), String> { + extension.enabled = false; + Ok(()) + }; + + if extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { + search_source_registry_tauri_state + .remove_source(extension_id) + .await; + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + return Ok(()); + } + + // Check if this is an application + let application_prefix = format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + ); + if extension_id.starts_with(&application_prefix) { + let app_path = &extension_id[application_prefix.len()..]; + application::disable_app_search(tauri_app_handle, app_path)?; + return Ok(()); + } + + if extension_id == calculator::DATA_SOURCE_ID { + search_source_registry_tauri_state + .remove_source(extension_id) + .await; + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + return Ok(()); + } + + if extension_id == quick_ai_access::EXTENSION_ID { + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + + return Ok(()); + } + + if extension_id == ai_overview::EXTENSION_ID { + alter_extension_json_file( + &BUILT_IN_EXTENSION_DIRECTORY.as_path(), + extension_id, + update_extension, + )?; + + return Ok(()); + } + + Ok(()) +} + +pub(crate) fn set_built_in_extension_alias(extension_id: &str, alias: &str) { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + + let application_prefix = format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + ); + if extension_id.starts_with(&application_prefix) { + let app_path = &extension_id[application_prefix.len()..]; + application::set_app_alias(tauri_app_handle, app_path, alias); + } +} + +pub(crate) fn register_built_in_extension_hotkey( + extension_id: &str, + hotkey: &str, +) -> Result<(), String> { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let application_prefix = format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + ); + if extension_id.starts_with(&application_prefix) { + let app_path = &extension_id[application_prefix.len()..]; + application::register_app_hotkey(&tauri_app_handle, app_path, hotkey)?; + } + Ok(()) +} + +pub(crate) fn unregister_built_in_extension_hotkey(extension_id: &str) -> Result<(), String> { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let application_prefix = format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + ); + if extension_id.starts_with(&application_prefix) { + let app_path = &extension_id[application_prefix.len()..]; + application::unregister_app_hotkey(&tauri_app_handle, app_path)?; + } + Ok(()) +} + +pub(crate) async fn is_built_in_extension_enabled(extension_id: &str) -> Result { + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let search_source_registry_tauri_state = tauri_app_handle.state::(); + + if extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { + return Ok(search_source_registry_tauri_state + .get_source(extension_id) + .await + .is_some()); + } + + // Check if this is an application + let application_prefix = format!( + "{}.", + application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME + ); + if extension_id.starts_with(&application_prefix) { + let app_path = &extension_id[application_prefix.len()..]; + return Ok(application::is_app_search_enabled(app_path)); + } + + if extension_id == calculator::DATA_SOURCE_ID { + return Ok(search_source_registry_tauri_state + .get_source(extension_id) + .await + .is_some()); + } + + if extension_id == quick_ai_access::EXTENSION_ID { + let extension = + load_extension_from_json_file(&BUILT_IN_EXTENSION_DIRECTORY.as_path(), extension_id)?; + return Ok(extension.enabled); + } + + if extension_id == ai_overview::EXTENSION_ID { + let extension = + load_extension_from_json_file(&BUILT_IN_EXTENSION_DIRECTORY.as_path(), extension_id)?; + return Ok(extension.enabled); + } + + unreachable!("extension [{}] is not a built-in extension", extension_id) +} diff --git a/src-tauri/src/extension/built_in/pizza_engine_runtime.rs b/src-tauri/src/extension/built_in/pizza_engine_runtime.rs new file mode 100644 index 00000000..fd1974f0 --- /dev/null +++ b/src-tauri/src/extension/built_in/pizza_engine_runtime.rs @@ -0,0 +1,51 @@ +//! We use Pizza Engine to index applications and local files. The engine will be +//! run in the thread/runtime defined in this file. +//! +//! # Why such a thread/runtime is needed +//! +//! Generally, Tokio async runtime requires all the async tasks running on it to be +//! `Send` and `Sync`, but the async tasks created by Pizza Engine are not, +//! which forces us to create a dedicated thread/runtime to execute them. + +use std::any::Any; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::OnceLock; + +pub(crate) trait SearchSourceState { + #[cfg_attr(not(feature = "use_pizza_engine"), allow(unused))] + fn as_mut_any(&mut self) -> &mut dyn Any; +} + +#[async_trait::async_trait(?Send)] +pub(crate) trait Task: Send + Sync { + fn search_source_id(&self) -> &'static str; + + async fn exec(&mut self, state: &mut Option>); +} + +pub(crate) static RUNTIME_TX: OnceLock>> = + OnceLock::new(); + +pub(crate) fn start_pizza_engine_runtime() { + std::thread::spawn(|| { + let rt = tokio::runtime::Runtime::new().unwrap(); + + let main = async { + let mut states: HashMap>> = HashMap::new(); + + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + RUNTIME_TX.set(tx).unwrap(); + + while let Some(mut task) = rx.recv().await { + let opt_search_source_state = match states.entry(task.search_source_id().into()) { + Entry::Occupied(o) => o.into_mut(), + Entry::Vacant(v) => v.insert(None), + }; + task.exec(opt_search_source_state).await; + } + }; + + rt.block_on(main); + }); +} diff --git a/src-tauri/src/extension/built_in/quick_ai_access.rs b/src-tauri/src/extension/built_in/quick_ai_access.rs new file mode 100644 index 00000000..67750d97 --- /dev/null +++ b/src-tauri/src/extension/built_in/quick_ai_access.rs @@ -0,0 +1 @@ +pub(super) const EXTENSION_ID: &str = "QuickAIAccess"; diff --git a/src-tauri/src/extension/mod.rs b/src-tauri/src/extension/mod.rs new file mode 100644 index 00000000..d71bdaf0 --- /dev/null +++ b/src-tauri/src/extension/mod.rs @@ -0,0 +1,825 @@ +pub(crate) mod built_in; +mod third_party; + +use crate::common::document::OnOpened; +use crate::{common::register::SearchSourceRegistry, GLOBAL_TAURI_APP_HANDLE}; +use anyhow::Context; +use derive_more::Display; +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value as Json; +use std::collections::HashSet; +use std::ffi::OsStr; +use std::path::Path; +use tauri::Manager; +use third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE; + +pub const LOCAL_QUERY_SOURCE_TYPE: &str = "local"; +const PLUGIN_JSON_FILE_NAME: &str = "plugin.json"; +const ASSETS_DIRECTORY_FILE_NAME: &str = "assets"; + +#[derive(Debug, Deserialize, Serialize, Copy, Clone, Hash, PartialEq, Eq, Display)] +#[serde(rename_all(serialize = "lowercase", deserialize = "lowercase"))] +enum Platform { + #[display("macOS")] + Macos, + #[display("Linux")] + Linux, + #[display("windows")] + Windows, +} + +/// Helper function to determine the current platform. +fn current_platform() -> Platform { + let os_str = std::env::consts::OS; + serde_plain::from_str(os_str).unwrap_or_else(|_e| { + panic!("std::env::consts::OS is [{}], which is not a valid value for [enum Platform], valid values: ['macos', 'linux', 'windows']", os_str) + }) +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct Extension { + /// Unique extension identifier. + id: String, + /// Extension name. + title: String, + /// Platforms supported by this extension. + /// + /// If `None`, then this extension can be used on all the platforms. + #[serde(skip_serializing_if = "Option::is_none")] + platforms: Option>, + /// Extension description. + description: String, + //// Specify the icon for this extension, multi options are available: + /// + /// 1. It can be a path to the icon file, the path can be + /// + /// * relative (relative to the "assets" directory) + /// * absolute + /// 2. It can be a font class code, e.g., 'font_coco', if you want to use + /// Coco's built-in icons. + /// + /// In cases where your icon file is named similarly to a font class code, Coco + /// will treat it as an icon file if it exists, i.e., if file `/assets/font_coco` + /// exists, then Coco will use this file rather than the built-in 'font_coco' icon. + icon: String, + r#type: ExtensionType, + /// If this is a Command extension, then action defines the operation to execute + /// when the it is triggered. + #[serde(skip_serializing_if = "Option::is_none")] + action: Option, + /// The link to open if this is a QuickLink extension. + #[serde(skip_serializing_if = "Option::is_none")] + quick_link: Option, + + // If this extension is of type Group or Extension, then it behaves like a + // directory, i.e., it could contain sub items. + commands: Option>, + scripts: Option>, + quick_links: Option>, + + /// The alias of the extension. + /// + /// Extension of type Group and Extension cannot have alias. + /// + #[serde(skip_serializing_if = "Option::is_none")] + alias: Option, + /// The hotkey of the extension. + /// + /// Extension of type Group and Extension cannot have hotkey. + #[serde(skip_serializing_if = "Option::is_none")] + hotkey: Option, + + /// Is this extension enabled. + enabled: bool, + + /// Extension settings + #[serde(skip_serializing_if = "Option::is_none")] + settings: Option, +} + +impl Extension { + /// Whether this extension could be searched. + pub(crate) fn searchable(&self) -> bool { + self.on_opened().is_some() + } + /// Return what will happen when we open this extension. + /// + /// `None` if it cannot be opened. + pub(crate) fn on_opened(&self) -> Option { + match self.r#type { + ExtensionType::Group => None, + ExtensionType::Extension => None, + ExtensionType::Command => Some(OnOpened::Command { + action: self.action.clone().unwrap_or_else(|| { + panic!( + "Command extension [{}]'s [action] field is not set, something wrong with your extension validity check", self.id + ) + }), + }), + ExtensionType::Application => Some(OnOpened::Application { + app_path: self.id.clone(), + }), + ExtensionType::Script => todo!("not supported yet"), + ExtensionType::Quicklink => todo!("not supported yet"), + ExtensionType::Setting => todo!("not supported yet"), + ExtensionType::Calculator => None, + ExtensionType::AiExtension => None, + } + } + + /// Perform `how` against the extension specified by `extension_id`. + /// + /// Please note that `extension_id` could point to a sub extension. + pub(crate) fn modify( + &mut self, + extension_id: &str, + how: impl FnOnce(&mut Self) -> Result<(), String>, + ) -> Result<(), String> { + let (parent_extension_id, opt_sub_extension_id) = split_extension_id(extension_id); + assert_eq!( + parent_extension_id, self.id, + "modify() should be invoked against a parent extension" + ); + + let Some(sub_extension_id) = opt_sub_extension_id else { + how(self)?; + return Ok(()); + }; + + // Search in commands + if let Some(ref mut commands) = self.commands { + if let Some(command) = commands.iter_mut().find(|cmd| cmd.id == sub_extension_id) { + how(command)?; + return Ok(()); + } + } + + // Search in scripts + if let Some(ref mut scripts) = self.scripts { + if let Some(script) = scripts.iter_mut().find(|scr| scr.id == sub_extension_id) { + how(script)?; + return Ok(()); + } + } + + // Search in quick_links + if let Some(ref mut quick_links) = self.quick_links { + if let Some(link) = quick_links + .iter_mut() + .find(|lnk| lnk.id == sub_extension_id) + { + how(link)?; + return Ok(()); + } + } + + Err(format!( + "extension [{}] not found in {:?}", + extension_id, self + )) + } + + /// Get the extension specified by `extension_id`. + /// + /// Please note that `extension_id` could point to a sub extension. + pub(crate) fn get_extension_mut(&mut self, extension_id: &str) -> Option<&mut Self> { + let (parent_extension_id, opt_sub_extension_id) = split_extension_id(extension_id); + if parent_extension_id != self.id { + return None; + } + + let Some(sub_extension_id) = opt_sub_extension_id else { + return Some(self); + }; + + self.get_sub_extension_mut(sub_extension_id) + } + + pub(crate) fn get_sub_extension_mut(&mut self, sub_extension_id: &str) -> Option<&mut Self> { + if !self.r#type.contains_sub_items() { + return None; + } + + if let Some(ref mut commands) = self.commands { + if let Some(sub_ext) = commands.iter_mut().find(|cmd| cmd.id == sub_extension_id) { + return Some(sub_ext); + } + } + if let Some(ref mut scripts) = self.scripts { + if let Some(sub_ext) = scripts + .iter_mut() + .find(|script| script.id == sub_extension_id) + { + return Some(sub_ext); + } + } + if let Some(ref mut quick_links) = self.quick_links { + if let Some(sub_ext) = quick_links + .iter_mut() + .find(|link| link.id == sub_extension_id) + { + return Some(sub_ext); + } + } + + None + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub(crate) struct CommandAction { + pub(crate) exec: String, + pub(crate) args: Vec, +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct QuickLink { + link: String, +} + +#[derive(Debug, PartialEq, Deserialize, Serialize, Clone, Display)] +#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +pub enum ExtensionType { + #[display("Group")] + Group, + #[display("Extension")] + Extension, + #[display("Command")] + Command, + #[display("Application")] + Application, + #[display("Script")] + Script, + #[display("Quicklink")] + Quicklink, + #[display("Setting")] + Setting, + #[display("Calculator")] + Calculator, + #[display("AI Extension")] + AiExtension, +} + +impl ExtensionType { + pub(crate) fn contains_sub_items(&self) -> bool { + self == &Self::Group || self == &Self::Extension + } +} + +fn canonicalize_relative_icon_path( + extension_dir: &Path, + extension: &mut Extension, +) -> Result<(), String> { + fn _canonicalize_relative_icon_path( + extension_dir: &Path, + extension: &mut Extension, + ) -> Result<(), String> { + let icon_str = &extension.icon; + let icon_path = Path::new(icon_str); + + if icon_path.is_relative() { + let absolute_icon_path = { + let mut assets_directory = extension_dir.join(ASSETS_DIRECTORY_FILE_NAME); + assets_directory.push(icon_path); + + assets_directory + }; + + if absolute_icon_path.try_exists().map_err(|e| e.to_string())? { + extension.icon = absolute_icon_path + .into_os_string() + .into_string() + .expect("path should be UTF-8 encoded"); + } + } + + Ok(()) + } + + _canonicalize_relative_icon_path(extension_dir, extension)?; + + if let Some(commands) = &mut extension.commands { + for command in commands { + _canonicalize_relative_icon_path(extension_dir, command)?; + } + } + + if let Some(scripts) = &mut extension.scripts { + for script in scripts { + _canonicalize_relative_icon_path(extension_dir, script)?; + } + } + + if let Some(quick_links) = &mut extension.quick_links { + for quick_link in quick_links { + _canonicalize_relative_icon_path(extension_dir, quick_link)?; + } + } + + Ok(()) +} + +fn list_extensions_under_directory(directory: &Path) -> Result<(bool, Vec), String> { + let mut found_invalid_extensions = false; + + let extension_directory = std::fs::read_dir(&directory).map_err(|e| e.to_string())?; + let current_platform = current_platform(); + + let mut extensions = Vec::new(); + for res_extension_dir in extension_directory { + let extension_dir = res_extension_dir.map_err(|e| e.to_string())?; + let file_type = extension_dir.file_type().map_err(|e| e.to_string())?; + if !file_type.is_dir() { + found_invalid_extensions = true; + log::warn!( + "invalid extension [{}]: a valid extension should be a directory, but it is not", + extension_dir.file_name().display() + ); + + // Skip invalid extension + continue; + } + + let plugin_json_file_path = { + let mut path = extension_dir.path(); + path.push(PLUGIN_JSON_FILE_NAME); + + path + }; + + if !plugin_json_file_path.is_file() { + found_invalid_extensions = true; + log::warn!( + "invalid extension: [{}]: extension file [{}] should be a JSON file, but it is not", + extension_dir.file_name().display(), + plugin_json_file_path.display() + ); + + // Skip invalid extension + continue; + } + + let mut extension = match serde_json::from_reader::<_, Extension>( + std::fs::File::open(&plugin_json_file_path).map_err(|e| e.to_string())?, + ) { + Ok(extension) => extension, + Err(e) => { + found_invalid_extensions = true; + log::warn!( + "invalid extension: [{}]: extension file [{}] is invalid, error: '{}'", + extension_dir.file_name().display(), + plugin_json_file_path.display(), + e + ); + continue; + } + }; + + // Turn it into an absolute path if it is a valid relative path because frontend code need this. + canonicalize_relative_icon_path(&extension_dir.path(), &mut extension)?; + + if !validate_extension( + &extension, + &extension_dir.file_name(), + &extensions, + current_platform, + ) { + found_invalid_extensions = true; + // Skip invalid extension + continue; + } + + extensions.push(extension); + } + + log::debug!( + "loaded extensions: {:?}", + extensions + .iter() + .map(|ext| ext.id.as_str()) + .collect::>() + ); + + Ok((found_invalid_extensions, extensions)) +} + +/// Return value: +/// +/// * boolean: indicates if we found any invalid extensions +/// * Vec: loaded extensions +#[tauri::command] +pub(crate) async fn list_extensions() -> Result<(bool, Vec), String> { + log::trace!("loading extensions"); + + let third_party_dir = third_party::THIRD_PARTY_EXTENSION_DIRECTORY.as_path(); + if !third_party_dir.try_exists().map_err(|e| e.to_string())? { + tokio::fs::create_dir_all(third_party_dir) + .await + .map_err(|e| e.to_string())?; + } + let (third_party_found_invalid_extension, mut third_party_extensions) = + list_extensions_under_directory(third_party_dir)?; + + let built_in_dir = built_in::BUILT_IN_EXTENSION_DIRECTORY.as_path(); + let (built_in_found_invalid_extension, built_in_extensions) = + list_extensions_under_directory(built_in_dir)?; + + let found_invalid_extension = + third_party_found_invalid_extension || built_in_found_invalid_extension; + let extensions = { + third_party_extensions.extend(built_in_extensions); + + third_party_extensions + }; + + Ok((found_invalid_extension, extensions)) +} + +/// Helper function to validate `extension`, return `true` if it is valid. +fn validate_extension( + extension: &Extension, + extension_dir_name: &OsStr, + listed_extensions: &[Extension], + current_platform: Platform, +) -> bool { + if OsStr::new(&extension.id) != extension_dir_name { + log::warn!( + "invalid extension []: id [{}] and extension directory name [{}] do not match", + extension.id, + extension_dir_name.display() + ); + return false; + } + + // Extension ID should be unique + if listed_extensions.iter().any(|ext| ext.id == extension.id) { + log::warn!( + "invalid extension []: extension with id [{}] already exists", + extension.id, + ); + return false; + } + + if !validate_extension_or_sub_item(extension) { + return false; + } + + // Extension is incompatible + if let Some(ref platforms) = extension.platforms { + if !platforms.contains(¤t_platform) { + log::warn!("extension [{}] is not compatible with the current platform [{}], it is available to {:?}", extension.id, current_platform, platforms.iter().map(|os|os.to_string()).collect::>()); + return false; + } + } + + if let Some(ref commands) = extension.commands { + if !validate_sub_items(&extension.id, commands) { + return false; + } + } + + if let Some(ref scripts) = extension.scripts { + if !validate_sub_items(&extension.id, scripts) { + return false; + } + } + + if let Some(ref quick_links) = extension.quick_links { + if !validate_sub_items(&extension.id, quick_links) { + return false; + } + } + + true +} + +/// Checks that can be performed against an extension or a sub item. +fn validate_extension_or_sub_item(extension: &Extension) -> bool { + // Only + // + // 1. letters + // 2. hyphens + // 3. numbers + // + // are allowed in the ID. + if !extension + .id + .chars() + .all(|c| c.is_ascii_alphabetic() || c == '-') + { + log::warn!( + "invalid extension [{}], [id] should contain only letters, numbers, or hyphens", + extension.id + ); + return false; + } + + // If field `action` is Some, then it should be a Command + if extension.action.is_some() && extension.r#type != ExtensionType::Command { + log::warn!( + "invalid extension [{}], [action] is set for a non-Command extension", + extension.id + ); + return false; + } + + if extension.r#type == ExtensionType::Command && extension.action.is_none() { + log::warn!( + "invalid extension [{}], [action] should be set for a Command extension", + extension.id + ); + return false; + } + + // If field `quick_link` is Some, then it should be a QuickLink + if extension.quick_link.is_some() && extension.r#type != ExtensionType::Quicklink { + log::warn!( + "invalid extension [{}], [quick_link] is set for a non-QuickLink extension", + extension.id + ); + return false; + } + + if extension.r#type == ExtensionType::Quicklink && extension.quick_link.is_none() { + log::warn!( + "invalid extension [{}], [quick_link] should be set for a QuickLink extension", + extension.id + ); + return false; + } + + // Group and Extension cannot have alias + if extension.alias.is_some() { + if extension.r#type == ExtensionType::Group || extension.r#type == ExtensionType::Extension + { + log::warn!( + "invalid extension [{}], extension of type [{:?}] cannot have alias", + extension.id, + extension.r#type + ); + return false; + } + } + + // Group and Extension cannot have hotkey + if extension.hotkey.is_some() { + if extension.r#type == ExtensionType::Group || extension.r#type == ExtensionType::Extension + { + log::warn!( + "invalid extension [{}], extension of type [{:?}] cannot have hotkey", + extension.id, + extension.r#type + ); + return false; + } + } + + if extension.commands.is_some() + || extension.scripts.is_some() + || extension.quick_links.is_some() + { + if extension.r#type != ExtensionType::Group && extension.r#type != ExtensionType::Extension + { + log::warn!( + "invalid extension [{}], only extension of type [Group] and [Extension] can have sub-items", + extension.id, + ); + return false; + } + } + + true +} + +/// Helper function to check sub-items. +fn validate_sub_items(extension_id: &str, sub_items: &[Extension]) -> bool { + for (sub_item_index, sub_item) in sub_items.iter().enumerate() { + // If field `action` is Some, then it should be a Command + if sub_item.action.is_some() && sub_item.r#type != ExtensionType::Command { + log::warn!( + "invalid extension sub-item [{}-{}]: [action] is set for a non-Command extension", + extension_id, + sub_item.id + ); + return false; + } + + if sub_item.r#type == ExtensionType::Group || sub_item.r#type == ExtensionType::Extension { + log::warn!( + "invalid extension sub-item [{}-{}]: sub-item should not be of type [Group] or [Extension]", + extension_id, sub_item.id + ); + return false; + } + + let sub_item_with_same_id_count = sub_items + .iter() + .enumerate() + .filter(|(_idx, ext)| ext.id == sub_item.id) + .filter(|(idx, _ext)| *idx != sub_item_index) + .count(); + if sub_item_with_same_id_count != 0 { + log::warn!( + "invalid extension [{}]: found more than one sub-items with the same ID [{}]", + extension_id, + sub_item.id + ); + return false; + } + + if !validate_extension_or_sub_item(sub_item) { + return false; + } + + if sub_item.platforms.is_some() { + log::warn!( + "invalid extension [{}]: key [platforms] should not be set in sub-items", + extension_id, + ); + return false; + } + } + + true +} + +pub(crate) async fn init_extensions(mut extensions: Vec) -> Result<(), String> { + log::trace!("initializing extensions"); + + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let search_source_registry_tauri_state = tauri_app_handle.state::(); + + built_in::application::ApplicationSearchSource::init(tauri_app_handle.clone()).await?; + + // Init the built-in enabled extensions + for built_in_extension in extensions + .extract_if(.., |ext| built_in::is_extension_built_in(&ext.id)) + .filter(|ext| ext.enabled) + { + built_in::init_built_in_extension(&built_in_extension, &search_source_registry_tauri_state) + .await; + } + + // Now the third-party extensions + let third_party_search_source = third_party::ThirdPartyExtensionsSearchSource::new(extensions); + third_party_search_source + .restore_extensions_hotkey() + .await?; + let third_party_search_source_clone = third_party_search_source.clone(); + // Set the global search source so that we can access it in `#[tauri::command]`s + // ignore the result because this function will be invoked twice, which + // means this global variable will be set twice. + let _ = THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.set(third_party_search_source_clone); + search_source_registry_tauri_state + .register_source(third_party_search_source) + .await; + + Ok(()) +} + +#[tauri::command] +pub(crate) async fn enable_extension(extension_id: String) -> Result<(), String> { + println!("enable_extension: {}", extension_id); + + if built_in::is_extension_built_in(&extension_id) { + built_in::enable_built_in_extension(&extension_id).await?; + return Ok(()); + } + + third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").enable_extension(&extension_id).await +} + +#[tauri::command] +pub(crate) async fn disable_extension(extension_id: String) -> Result<(), String> { + println!("disable_extension: {}", extension_id); + + if built_in::is_extension_built_in(&extension_id) { + built_in::disable_built_in_extension(&extension_id).await?; + return Ok(()); + } + third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").disable_extension(&extension_id).await +} + +#[tauri::command] +pub(crate) async fn set_extension_alias(extension_id: String, alias: String) -> Result<(), String> { + if built_in::is_extension_built_in(&extension_id) { + built_in::set_built_in_extension_alias(&extension_id, &alias); + return Ok(()); + } + third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").set_extension_alias(&extension_id, &alias).await +} + +#[tauri::command] +pub(crate) async fn register_extension_hotkey( + extension_id: String, + hotkey: String, +) -> Result<(), String> { + println!("register_extension_hotkey: {}, {}", extension_id, hotkey); + + if built_in::is_extension_built_in(&extension_id) { + built_in::register_built_in_extension_hotkey(&extension_id, &hotkey)?; + return Ok(()); + } + third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").register_extension_hotkey(&extension_id, &hotkey).await +} + +/// NOTE: this function won't error out if the extension specified by `extension_id` +/// has no hotkey set because we need it to behave like this. +#[tauri::command] +pub(crate) async fn unregister_extension_hotkey(extension_id: String) -> Result<(), String> { + if built_in::is_extension_built_in(&extension_id) { + built_in::unregister_built_in_extension_hotkey(&extension_id)?; + return Ok(()); + } + third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").unregister_extension_hotkey(&extension_id).await?; + + Ok(()) +} + +#[tauri::command] +pub(crate) async fn is_extension_enabled(extension_id: String) -> Result { + if built_in::is_extension_built_in(&extension_id) { + return built_in::is_built_in_extension_enabled(&extension_id).await; + } + third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").is_extension_enabled(&extension_id).await +} + +fn split_extension_id(extension_id: &str) -> (&str, Option<&str>) { + match extension_id.find('.') { + Some(idx) => (&extension_id[..idx], Some(&extension_id[idx + 1..])), + None => (extension_id, None), + } +} + +fn load_extension_from_json_file( + extension_directory: &Path, + extension_id: &str, +) -> Result { + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + let json_file_path = { + let mut extension_directory_path = extension_directory.join(parent_extension_id); + extension_directory_path.push(PLUGIN_JSON_FILE_NAME); + + extension_directory_path + }; + + let mut extension = serde_json::from_reader::<_, Extension>( + std::fs::File::open(&json_file_path) + .with_context(|| { + format!( + "the [{}] file for extension [{}] is missing or broken", + PLUGIN_JSON_FILE_NAME, parent_extension_id + ) + }) + .map_err(|e| e.to_string())?, + ) + .map_err(|e| e.to_string())?; + + canonicalize_relative_icon_path(extension_directory, &mut extension)?; + + Ok(extension) +} + +fn alter_extension_json_file( + extension_directory: &Path, + extension_id: &str, + how: impl Fn(&mut Extension) -> Result<(), String>, +) -> Result<(), String> { + log::debug!( + "altering extension JSON file for extension [{}]", + extension_id + ); + + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + let json_file_path = { + let mut extension_directory_path = extension_directory.join(parent_extension_id); + extension_directory_path.push(PLUGIN_JSON_FILE_NAME); + + extension_directory_path + }; + + let mut extension = serde_json::from_reader::<_, Extension>( + std::fs::File::open(&json_file_path) + .with_context(|| { + format!( + "the [{}] file for extension [{}] is missing or broken", + PLUGIN_JSON_FILE_NAME, parent_extension_id + ) + }) + .map_err(|e| e.to_string())?, + ) + .map_err(|e| e.to_string())?; + + extension.modify(extension_id, how)?; + + std::fs::write( + &json_file_path, + serde_json::to_string_pretty(&extension).map_err(|e| e.to_string())?, + ) + .map_err(|e| e.to_string())?; + + Ok(()) +} diff --git a/src-tauri/src/extension/third_party.rs b/src-tauri/src/extension/third_party.rs new file mode 100644 index 00000000..37b690dc --- /dev/null +++ b/src-tauri/src/extension/third_party.rs @@ -0,0 +1,733 @@ +use super::alter_extension_json_file; +use super::Extension; +use super::LOCAL_QUERY_SOURCE_TYPE; +use crate::common::document::open; +use crate::common::document::DataSourceReference; +use crate::common::document::Document; +use crate::common::error::SearchError; +use crate::common::search::QueryResponse; +use crate::common::search::QuerySource; +use crate::common::search::SearchQuery; +use crate::common::traits::SearchSource; +use crate::extension::split_extension_id; +use crate::GLOBAL_TAURI_APP_HANDLE; +use async_trait::async_trait; +use function_name::named; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::LazyLock; +use std::sync::OnceLock; +use tauri::async_runtime; +use tauri::Manager; +use tauri_plugin_global_shortcut::GlobalShortcutExt; +use tauri_plugin_global_shortcut::ShortcutState; +use tokio::sync::RwLock; + +pub(crate) static THIRD_PARTY_EXTENSION_DIRECTORY: LazyLock = LazyLock::new(|| { + let mut app_data_dir = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set") + .path() + .app_data_dir() + .expect( + "User home directory not found, which should be impossible on desktop environments", + ); + app_data_dir.push("extension"); + + app_data_dir +}); + +/// All the third-party extensions will be registered as one search source. +/// +/// Since some `#[tauri::command]`s need to access it, we store it in a global +/// static variable as well. +#[derive(Debug, Clone)] +pub(super) struct ThirdPartyExtensionsSearchSource { + inner: Arc, +} + +impl ThirdPartyExtensionsSearchSource { + pub(super) fn new(extensions: Vec) -> Self { + Self { + inner: Arc::new(ThirdPartyExtensionsSearchSourceInner { + extensions: RwLock::new(extensions), + }), + } + } + + #[named] + pub(super) async fn enable_extension(&self, extension_id: &str) -> Result<(), String> { + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + + let mut extensions_write_lock = self.inner.extensions.write().await; + let opt_index = extensions_write_lock + .iter() + .position(|ext| ext.id == parent_extension_id); + + let Some(index) = opt_index else { + return Err(format!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + )); + }; + + let extension = extensions_write_lock + .get_mut(index) + .expect("just checked this extension exists"); + + let update_extension = |ext: &mut Extension| -> Result<(), String> { + if ext.enabled { + return Err(format!( + "{} invoked with an extension that is already enabled [{}]", + function_name!(), + extension_id + )); + } + ext.enabled = true; + + Ok(()) + }; + + extension.modify(extension_id, update_extension)?; + alter_extension_json_file( + &THIRD_PARTY_EXTENSION_DIRECTORY, + extension_id, + update_extension, + )?; + + Ok(()) + } + + #[named] + pub(super) async fn disable_extension(&self, extension_id: &str) -> Result<(), String> { + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + + let mut extensions_write_lock = self.inner.extensions.write().await; + let opt_index = extensions_write_lock + .iter() + .position(|ext| ext.id == parent_extension_id); + + let Some(index) = opt_index else { + return Err(format!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + )); + }; + + let extension = extensions_write_lock + .get_mut(index) + .expect("just checked this extension exists"); + + let update_extension = |ext: &mut Extension| -> Result<(), String> { + if !ext.enabled { + return Err(format!( + "{} invoked with an extension that is already enabled [{}]", + function_name!(), + extension_id + )); + } + ext.enabled = false; + + Ok(()) + }; + + extension.modify(extension_id, update_extension)?; + alter_extension_json_file( + &THIRD_PARTY_EXTENSION_DIRECTORY, + extension_id, + update_extension, + )?; + + Ok(()) + } + + #[named] + pub(super) async fn set_extension_alias( + &self, + extension_id: &str, + alias: &str, + ) -> Result<(), String> { + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + + let mut extensions_write_lock = self.inner.extensions.write().await; + let opt_index = extensions_write_lock + .iter() + .position(|ext| ext.id == parent_extension_id); + + let Some(index) = opt_index else { + log::warn!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + ); + return Ok(()); + }; + + let extension = extensions_write_lock + .get_mut(index) + .expect("just checked this extension exists"); + + let update_extension = |ext: &mut Extension| -> Result<(), String> { + ext.alias = Some(alias.to_string()); + Ok(()) + }; + + extension.modify(extension_id, update_extension)?; + alter_extension_json_file( + &THIRD_PARTY_EXTENSION_DIRECTORY, + extension_id, + update_extension, + )?; + + Ok(()) + } + + pub(super) async fn restore_extensions_hotkey(&self) -> Result<(), String> { + fn set_up_hotkey( + tauri_app_handle: &tauri::AppHandle, + extension: &Extension, + ) -> Result<(), String> { + if let Some(ref hotkey) = extension.hotkey { + let on_opened = extension.on_opened().unwrap_or_else(|| panic!( "extension has hotkey, but on_open() returns None, extension ID [{}], extension type [{:?}]", extension.id, extension.r#type)); + + let extension_id_clone = extension.id.clone(); + + tauri_app_handle + .global_shortcut() + .on_shortcut(hotkey.as_str(), move |_tauri_app_handle, _hotkey, event| { + let on_opened_clone = on_opened.clone(); + let extension_id_clone = extension_id_clone.clone(); + if event.state() == ShortcutState::Pressed { + async_runtime::spawn(async move { + let result = open(on_opened_clone).await; + if let Err(msg) = result { + log::warn!( + "failed to open extension [{}], error [{}]", + extension_id_clone, + msg + ); + } + }); + } + }) + .map_err(|e| e.to_string())?; + } + + Ok(()) + } + + let extensions_read_lock = self.inner.extensions.read().await; + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + + for extension in extensions_read_lock.iter() { + if extension.r#type.contains_sub_items() { + if let Some(commands) = &extension.commands { + for command in commands.iter().filter(|cmd| cmd.enabled) { + set_up_hotkey(tauri_app_handle, command)?; + } + } + + if let Some(scripts) = &extension.scripts { + for script in scripts.iter().filter(|script| script.enabled) { + set_up_hotkey(tauri_app_handle, script)?; + } + } + + if let Some(quick_links) = &extension.quick_links { + for quick_link in quick_links.iter().filter(|link| link.enabled) { + set_up_hotkey(tauri_app_handle, quick_link)?; + } + } + } else { + set_up_hotkey(tauri_app_handle, extension)?; + } + } + + Ok(()) + } + + #[named] + pub(super) async fn register_extension_hotkey( + &self, + extension_id: &str, + hotkey: &str, + ) -> Result<(), String> { + self.unregister_extension_hotkey(extension_id).await?; + + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + let mut extensions_write_lock = self.inner.extensions.write().await; + let opt_index = extensions_write_lock + .iter() + .position(|ext| ext.id == parent_extension_id); + + let Some(index) = opt_index else { + return Err(format!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + )); + }; + + let mut extension = extensions_write_lock + .get_mut(index) + .expect("just checked this extension exists"); + + let update_extension = |ext: &mut Extension| -> Result<(), String> { + ext.hotkey = Some(hotkey.into()); + Ok(()) + }; + + // Update extension (memory and file) + extension.modify(extension_id, update_extension)?; + alter_extension_json_file( + &THIRD_PARTY_EXTENSION_DIRECTORY, + extension_id, + update_extension, + )?; + + // To make borrow checker happy + let extension_dbg_string = format!("{:?}", extension); + extension = match extension.get_extension_mut(extension_id) { + Some(ext) => ext, + None => { + panic!( + "extension [{}] should be found in {}", + extension_id, extension_dbg_string + ) + } + }; + + // Set hotkey + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + let on_opened = extension.on_opened().unwrap_or_else(|| panic!( + "setting hotkey for an extension that cannot be opened, extension ID [{}], extension type [{:?}]", extension_id, extension.r#type, + )); + + let extension_id_clone = extension_id.to_string(); + tauri_app_handle + .global_shortcut() + .on_shortcut(hotkey, move |_tauri_app_handle, _hotkey, event| { + let on_opened_clone = on_opened.clone(); + let extension_id_clone = extension_id_clone.clone(); + if event.state() == ShortcutState::Pressed { + async_runtime::spawn(async move { + let result = open(on_opened_clone).await; + if let Err(msg) = result { + log::warn!( + "failed to open extension [{}], error [{}]", + extension_id_clone, + msg + ); + } + }); + } + }) + .map_err(|e| e.to_string())?; + + Ok(()) + } + + /// NOTE: this function won't error out if the extension specified by `extension_id` + /// has no hotkey set because we need it to behave like this. + #[named] + pub(super) async fn unregister_extension_hotkey( + &self, + extension_id: &str, + ) -> Result<(), String> { + let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id); + + let mut extensions_write_lock = self.inner.extensions.write().await; + let opt_index = extensions_write_lock + .iter() + .position(|ext| ext.id == parent_extension_id); + + let Some(index) = opt_index else { + return Err(format!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + )); + }; + + let parent_extension = extensions_write_lock + .get_mut(index) + .expect("just checked this extension exists"); + let Some(extension) = parent_extension.get_extension_mut(extension_id) else { + return Err(format!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + )); + }; + + let Some(hotkey) = extension.hotkey.clone() else { + log::warn!( + "extension [{}] has no hotkey set, but we are trying to unregister it", + extension_id + ); + return Ok(()); + }; + + let update_extension = |extension: &mut Extension| -> Result<(), String> { + extension.hotkey = None; + Ok(()) + }; + + parent_extension.modify(extension_id, update_extension)?; + alter_extension_json_file( + &THIRD_PARTY_EXTENSION_DIRECTORY, + extension_id, + update_extension, + )?; + + // Set hotkey + let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE + .get() + .expect("global tauri app handle not set"); + tauri_app_handle + .global_shortcut() + .unregister(hotkey.as_str()) + .map_err(|e| e.to_string())?; + + Ok(()) + } + + #[named] + pub(super) async fn is_extension_enabled(&self, extension_id: &str) -> Result { + let (parent_extension_id, opt_sub_extension_id) = split_extension_id(extension_id); + + let extensions_read_lock = self.inner.extensions.read().await; + let opt_index = extensions_read_lock + .iter() + .position(|ext| ext.id == parent_extension_id); + + let Some(index) = opt_index else { + return Err(format!( + "{} invoked with an extension that does not exist [{}]", + function_name!(), + extension_id + )); + }; + + let extension = extensions_read_lock + .get(index) + .expect("just checked this extension exists"); + + if let Some(sub_extension_id) = opt_sub_extension_id { + // For a sub-extension, it is enabled iff: + // + // 1. Its parent extension is enabled, and + // 2. It is enabled + if !extension.enabled { + return Ok(false); + } + + if let Some(ref commands) = extension.commands { + if let Some(sub_ext) = commands.iter().find(|cmd| cmd.id == sub_extension_id) { + return Ok(sub_ext.enabled); + } + } + if let Some(ref scripts) = extension.scripts { + if let Some(sub_ext) = scripts.iter().find(|script| script.id == sub_extension_id) { + return Ok(sub_ext.enabled); + } + } + if let Some(ref commands) = extension.commands { + if let Some(sub_ext) = commands + .iter() + .find(|quick_link| quick_link.id == sub_extension_id) + { + return Ok(sub_ext.enabled); + } + } + + Err(format!( + "{} invoked with a sub-extension that does not exist [{}/{}]", + function_name!(), + parent_extension_id, + sub_extension_id + )) + } else { + Ok(extension.enabled) + } + } +} + +pub(super) static THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE: OnceLock = + OnceLock::new(); + +#[derive(Debug)] +struct ThirdPartyExtensionsSearchSourceInner { + extensions: RwLock>, +} + +#[async_trait] +impl SearchSource for ThirdPartyExtensionsSearchSource { + fn get_type(&self) -> QuerySource { + QuerySource { + r#type: LOCAL_QUERY_SOURCE_TYPE.into(), + name: hostname::get() + .unwrap_or("My Computer".into()) + .to_string_lossy() + .into(), + id: "extensions".into(), + } + } + + async fn search(&self, query: SearchQuery) -> Result { + let Some(query_string) = query.query_strings.get("query") else { + return Ok(QueryResponse { + source: self.get_type(), + hits: Vec::new(), + total_hits: 0, + }); + }; + + let mut hits = Vec::new(); + let extensions_read_lock = self.inner.extensions.read().await; + let query_lower = query_string.to_lowercase(); + + for extension in extensions_read_lock.iter().filter(|ext| ext.enabled) { + if extension.r#type.contains_sub_items() { + if let Some(ref commands) = extension.commands { + for command in commands.iter().filter(|cmd| cmd.enabled) { + if let Some(hit) = extension_to_hit(command, &query_lower) { + hits.push(hit); + } + } + } + + if let Some(ref scripts) = extension.scripts { + for script in scripts.iter().filter(|script| script.enabled) { + if let Some(hit) = extension_to_hit(script, &query_lower) { + hits.push(hit); + } + } + } + + if let Some(ref quick_links) = extension.quick_links { + for quick_link in quick_links.iter().filter(|link| link.enabled) { + if let Some(hit) = extension_to_hit(quick_link, &query_lower) { + hits.push(hit); + } + } + } + } else { + if let Some(hit) = extension_to_hit(extension, &query_lower) { + hits.push(hit); + } + } + } + + let total_hits = hits.len(); + + Ok(QueryResponse { + source: self.get_type(), + hits, + total_hits, + }) + } +} + +fn extension_to_hit(extension: &Extension, query_lower: &str) -> Option<(Document, f64)> { + if !extension.searchable() { + return None; + } + + let mut total_score = 0.0; + + // Score based on title match + // Title is considered more important, so it gets a higher weight. + if let Some(title_score) = + calculate_text_similarity(&query_lower, &extension.title.to_lowercase()) + { + total_score += title_score * 1.0; // Weight for title + } + + // Score based on alias match if available + // Alias is considered less important than title, so it gets a lower weight. + if let Some(alias) = &extension.alias { + if let Some(alias_score) = calculate_text_similarity(&query_lower, &alias.to_lowercase()) { + total_score += alias_score * 0.7; // Weight for alias + } + } + + // Only include if there's some relevance (score is meaningfully positive) + if total_score > 0.01 { + let on_opened = extension.on_opened().unwrap_or_else(|| { + panic!( + "extension (id [{}], type [{:?}]) is searchable, and should have a valid on_opened", + extension.id, extension.r#type + ) + }); + let url = on_opened.url(); + + let document = Document { + id: extension.id.clone(), + title: Some(extension.title.clone()), + icon: Some(extension.icon.clone()), + on_opened: Some(on_opened), + url: Some(url), + category: Some(extension.r#type.to_string()), + source: Some(DataSourceReference { + id: Some(format!("{:?}", extension.r#type)), + name: Some(format!("{:?}", extension.r#type)), + icon: None, + r#type: Some(format!("{:?}", extension.r#type)), + }), + + ..Default::default() + }; + + Some((document, total_score)) + } else { + None + } +} + +// Calculates a similarity score between a query and a text, aiming for a [0, 1] range. +// Assumes query and text are already lowercased. +fn calculate_text_similarity(query: &str, text: &str) -> Option { + if query.is_empty() || text.is_empty() { + return None; + } + + if text == query { + return Some(1.0); // Perfect match + } + + let query_len = query.len() as f64; + let text_len = text.len() as f64; + let ratio = query_len / text_len; + let mut score: f64 = 0.0; + + // Case 1: Text starts with the query (prefix match) + // Score: base 0.5, bonus up to 0.4 for how much of `text` is covered by `query`. Max 0.9. + if text.starts_with(query) { + score = score.max(0.5 + 0.4 * ratio); + } + + // Case 2: Text contains the query (substring match, not necessarily prefix) + // Score: base 0.3, bonus up to 0.3. Max 0.6. + // `score.max` ensures that if it's both a prefix and contains, the higher score (prefix) is taken. + if text.contains(query) { + score = score.max(0.3 + 0.3 * ratio); + } + + // Case 3: Fallback for "all query characters exist in text" (order-independent) + if score < 0.2 { + if query.chars().all(|c_q| text.contains(c_q)) { + score = score.max(0.15); // Fixed low score for this weaker match type + } + } + + if score > 0.0 { + // Cap non-perfect matches slightly below 1.0 to make perfect (1.0) distinct. + Some(score.min(0.95)) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Helper function for approximate floating point comparison + fn approx_eq(a: f64, b: f64) -> bool { + (a - b).abs() < 1e-10 + } + + #[test] + fn test_empty_strings() { + assert_eq!(calculate_text_similarity("", "text"), None); + assert_eq!(calculate_text_similarity("query", ""), None); + assert_eq!(calculate_text_similarity("", ""), None); + } + + #[test] + fn test_perfect_match() { + assert_eq!(calculate_text_similarity("text", "text"), Some(1.0)); + assert_eq!(calculate_text_similarity("a", "a"), Some(1.0)); + } + + #[test] + fn test_prefix_match() { + // For "te" and "text": + // score = 0.5 + 0.4 * (2/4) = 0.5 + 0.2 = 0.7 + let score = calculate_text_similarity("te", "text").unwrap(); + assert!(approx_eq(score, 0.7)); + + // For "tex" and "text": + // score = 0.5 + 0.4 * (3/4) = 0.5 + 0.3 = 0.8 + let score = calculate_text_similarity("tex", "text").unwrap(); + assert!(approx_eq(score, 0.8)); + } + + #[test] + fn test_substring_match() { + // For "ex" and "text": + // score = 0.3 + 0.3 * (2/4) = 0.3 + 0.15 = 0.45 + let score = calculate_text_similarity("ex", "text").unwrap(); + assert!(approx_eq(score, 0.45)); + + // Prefix should score higher than substring + assert!( + calculate_text_similarity("te", "text").unwrap() + > calculate_text_similarity("ex", "text").unwrap() + ); + } + + #[test] + fn test_character_presence() { + // Characters present but not in sequence + // "tac" in "contact" - not a substring, but all chars exist + let score = calculate_text_similarity("tac", "contact").unwrap(); + assert!(approx_eq(0.3 + 0.3 * (3.0 / 7.0), score)); + + assert!(calculate_text_similarity("ac", "contact").is_some()); + + // Should not apply if some characters are missing + assert_eq!(calculate_text_similarity("xyz", "contact"), None); + } + + #[test] + fn test_combined_scenarios() { + // Test that character presence fallback doesn't override higher scores + // "tex" is a prefix of "text" with score 0.8 + let score = calculate_text_similarity("tex", "text").unwrap(); + assert!(approx_eq(score, 0.8)); + + // Test a case where the characters exist but it's already a substring + // "act" is a substring of "contact" with score > 0.2, so fallback won't apply + let expected_score = 0.3 + 0.3 * (3.0 / 7.0); + let actual_score = calculate_text_similarity("act", "contact").unwrap(); + assert!(approx_eq(actual_score, expected_score)); + } + + #[test] + fn test_no_similarity() { + assert_eq!(calculate_text_similarity("xyz", "test"), None); + } + + #[test] + fn test_score_capping() { + // Use a long query that's a prefix of a slightly longer text + let long_text = "abcdefghijklmnopqrstuvwxyz"; + let long_prefix = "abcdefghijklmnopqrstuvwxy"; // All but last letter + + // Expected score would be 0.5 + 0.4 * (25/26) = 0.5 + 0.385 = 0.885 + let expected_score = 0.5 + 0.4 * (25.0 / 26.0); + let actual_score = calculate_text_similarity(long_prefix, long_text).unwrap(); + assert!(approx_eq(actual_score, expected_score)); + + // Verify that non-perfect matches are capped at 0.95 + assert!(calculate_text_similarity("almost", "almost perfect").unwrap() <= 0.95); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index af7ec2e2..06e1e15d 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,7 +1,7 @@ mod assistant; mod autostart; mod common; -mod local; +mod extension; mod search; mod server; mod settings; @@ -142,25 +142,24 @@ pub fn run() { server::attachment::get_attachment, server::attachment::delete_attachment, server::transcription::transcription, - util::open, server::system_settings::get_system_settings, simulate_mouse_click, - local::get_disabled_local_query_sources, - local::enable_local_query_source, - local::disable_local_query_source, - local::application::get_app_list, - local::application::get_app_search_path, - local::application::get_app_metadata, - local::application::set_app_alias, - local::application::register_app_hotkey, - local::application::unregister_app_hotkey, - local::application::disable_app_search, - local::application::enable_app_search, - local::application::add_app_search_path, - local::application::remove_app_search_path, + extension::built_in::application::get_app_list, + extension::built_in::application::get_app_search_path, + extension::built_in::application::get_app_metadata, + extension::built_in::application::add_app_search_path, + extension::built_in::application::remove_app_search_path, + extension::list_extensions, + extension::enable_extension, + extension::disable_extension, + extension::set_extension_alias, + extension::register_extension_hotkey, + extension::unregister_extension_hotkey, + extension::is_extension_enabled, settings::set_allow_self_signature, settings::get_allow_self_signature, - assistant::ask_ai + assistant::ask_ai, + crate::common::document::open, ]) .setup(|app| { let app_handle = app.handle().clone(); @@ -262,7 +261,7 @@ pub async fn init(app_handle: &AppHandle) { .await; } - local::start_pizza_engine_runtime(); + extension::built_in::pizza_engine_runtime::start_pizza_engine_runtime(); } #[tauri::command] @@ -418,7 +417,11 @@ fn open_settings(app: &tauri::AppHandle) { #[tauri::command] async fn get_app_search_source(app_handle: AppHandle) -> Result<(), String> { - local::init_local_search_source(&app_handle).await?; + let (_found_invalid_extensions, extensions) = extension::list_extensions() + .await + .map_err(|e| e.to_string())?; + extension::init_extensions(extensions).await?; + let _ = server::connector::refresh_all_connectors(&app_handle).await; let _ = server::datasource::refresh_all_datasources(&app_handle).await; diff --git a/src-tauri/src/local/mod.rs b/src-tauri/src/local/mod.rs deleted file mode 100644 index 6fb16c29..00000000 --- a/src-tauri/src/local/mod.rs +++ /dev/null @@ -1,164 +0,0 @@ -pub mod application; -pub mod calculator; -pub mod file_system; - -use std::any::Any; -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::sync::OnceLock; - -use crate::common::register::SearchSourceRegistry; -use serde_json::Value as Json; -use tauri::{AppHandle, Manager, Runtime}; -use tauri_plugin_store::StoreExt; - -pub const LOCAL_QUERY_SOURCE_TYPE: &str = "local"; -pub const TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE: &str = "local_query_source_enabled_state"; - -trait SearchSourceState { - #[cfg_attr(not(feature = "use_pizza_engine"), allow(unused))] - fn as_mut_any(&mut self) -> &mut dyn Any; -} - -#[async_trait::async_trait(?Send)] -trait Task: Send + Sync { - fn search_source_id(&self) -> &'static str; - - async fn exec(&mut self, state: &mut Option>); -} - -static RUNTIME_TX: OnceLock>> = OnceLock::new(); - -pub(crate) fn start_pizza_engine_runtime() { - std::thread::spawn(|| { - let rt = tokio::runtime::Runtime::new().unwrap(); - - let main = async { - let mut states: HashMap>> = HashMap::new(); - - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); - RUNTIME_TX.set(tx).unwrap(); - - while let Some(mut task) = rx.recv().await { - let opt_search_source_state = match states.entry(task.search_source_id().into()) { - Entry::Occupied(o) => o.into_mut(), - Entry::Vacant(v) => v.insert(None), - }; - task.exec(opt_search_source_state).await; - } - }; - - rt.block_on(main); - }); -} - -pub(crate) async fn init_local_search_source( - app_handle: &AppHandle, -) -> Result<(), String> { - let enabled_status_store = app_handle - .store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE) - .map_err(|e| e.to_string())?; - if enabled_status_store.is_empty() { - enabled_status_store.set( - application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME, - Json::Bool(true), - ); - enabled_status_store.set(calculator::DATA_SOURCE_ID, Json::Bool(true)); - } - let registry = app_handle.state::(); - - application::ApplicationSearchSource::init(app_handle.clone()).await?; - - for (id, enabled) in enabled_status_store.entries() { - let enabled = match enabled { - Json::Bool(b) => b, - _ => unreachable!("enabled state should be stored as a boolean"), - }; - - if enabled { - if id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { - registry - .register_source(application::ApplicationSearchSource) - .await; - } - - if id == calculator::DATA_SOURCE_ID { - let calculator_search = calculator::CalculatorSource::new(2000f64); - registry.register_source(calculator_search).await; - } - } - } - - Ok(()) -} - -#[tauri::command] -pub async fn get_disabled_local_query_sources(app_handle: AppHandle) -> Vec { - let enabled_status_store = app_handle - .store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE) - .unwrap_or_else(|e| { - panic!( - "tauri store [{}] should exist and be loaded, but that's not true due to error [{}]", - TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e - ) - }); - let mut disabled_local_query_sources = Vec::new(); - - for (id, enabled) in enabled_status_store.entries() { - let enabled = match enabled { - Json::Bool(b) => b, - _ => unreachable!("enabled state should be stored as a boolean"), - }; - - if !enabled { - disabled_local_query_sources.push(id); - } - } - - disabled_local_query_sources -} - -#[tauri::command] -pub async fn enable_local_query_source( - app_handle: AppHandle, - query_source_id: String, -) { - let registry = app_handle.state::(); - if query_source_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME { - let application_search = application::ApplicationSearchSource; - registry.register_source(application_search).await; - } - if query_source_id == calculator::DATA_SOURCE_ID { - let calculator_search = calculator::CalculatorSource::new(2000f64); - registry.register_source(calculator_search).await; - } - - let enabled_status_store = app_handle - .store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE) - .unwrap_or_else(|e| { - panic!( - "tauri store [{}] should exist and be loaded, but that's not true due to error [{}]", - TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e - ) - }); - enabled_status_store.set(query_source_id, Json::Bool(true)); -} - -#[tauri::command] -pub async fn disable_local_query_source( - app_handle: AppHandle, - query_source_id: String, -) { - let registry = app_handle.state::(); - registry.remove_source(&query_source_id).await; - - let enabled_status_store = app_handle - .store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE) - .unwrap_or_else(|e| { - panic!( - "tauri store [{}] should exist and be loaded, but that's not true due to error [{}]", - TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e - ) - }); - enabled_status_store.set(query_source_id, Json::Bool(false)); -} diff --git a/src-tauri/src/local/file_system.rs b/src-tauri/src/mod.rs similarity index 100% rename from src-tauri/src/local/file_system.rs rename to src-tauri/src/mod.rs diff --git a/src-tauri/src/search/mod.rs b/src-tauri/src/search/mod.rs index 95aa1a67..b90a6db9 100644 --- a/src-tauri/src/search/mod.rs +++ b/src-tauri/src/search/mod.rs @@ -3,7 +3,6 @@ use crate::common::register::SearchSourceRegistry; use crate::common::search::{ FailedRequest, MultiSourceQueryResponse, QueryHits, QuerySource, SearchQuery, }; -use crate::local; use futures::stream::FuturesUnordered; use futures::StreamExt; use std::cmp::Reverse; @@ -20,7 +19,10 @@ pub async fn query_coco_fusion( query_strings: HashMap, query_timeout: u64, ) -> Result { - let query_keyword = query_strings.get("query").unwrap_or(&"".to_string()).clone(); + let query_keyword = query_strings + .get("query") + .unwrap_or(&"".to_string()) + .clone(); let query_source_to_search = query_strings.get("querysource"); @@ -28,7 +30,6 @@ pub async fn query_coco_fusion( let sources_future = search_sources.get_sources(); let mut futures = FuturesUnordered::new(); - let mut sources = HashMap::new(); let sources_list = sources_future.await; @@ -52,8 +53,6 @@ pub async fn query_coco_fusion( } } - sources.insert(query_source_type.id.clone(), query_source_type); - let query = SearchQuery::new(from, size, query_strings.clone()); let query_source_clone = query_source.clone(); // Clone Arc to avoid ownership issues @@ -62,7 +61,7 @@ pub async fn query_coco_fusion( timeout(timeout_duration, async { query_source_clone.search(query).await }) - .await + .await })); } @@ -159,23 +158,22 @@ pub async fn query_coco_fusion( let mut unique_sources = HashSet::new(); for hit in &final_hits { if let Some(source) = &hit.source { - if source.id != local::calculator::DATA_SOURCE_ID { + if source.id != crate::extension::built_in::calculator::DATA_SOURCE_ID { unique_sources.insert(&source.id); } } } log::debug!( - "Multiple sources found: {:?}, no rerank needed", - unique_sources - ); + "Multiple sources found: {:?}, no rerank needed", + unique_sources + ); if unique_sources.len() < 1 { need_rerank = false; // If we have hits from multiple sources, we don't need to rerank } if need_rerank && final_hits.len() > 1 { - // Precollect (index, title) let titles_to_score: Vec<(usize, &str)> = final_hits .iter() @@ -184,7 +182,7 @@ pub async fn query_coco_fusion( let source = hit.source.as_ref()?; let title = hit.document.title.as_deref()?; - if source.id != local::calculator::DATA_SOURCE_ID { + if source.id != crate::extension::built_in::calculator::DATA_SOURCE_ID { Some((idx, title)) } else { None @@ -203,7 +201,8 @@ pub async fn query_coco_fusion( for (idx, score) in scored_hits.into_iter().take(size as usize) { final_hits[idx].score = score; } - } else if final_hits.len() < size as usize { // If we still need more hits, take the highest-scoring remaining ones + } else if final_hits.len() < size as usize { + // If we still need more hits, take the highest-scoring remaining ones let remaining_needed = size as usize - final_hits.len(); @@ -275,4 +274,4 @@ fn boosted_levenshtein_rerank(query: &str, titles: Vec<(usize, &str)>) -> Vec<(u (idx, score.min(1.0) as f64) }) .collect() -} \ No newline at end of file +} diff --git a/src-tauri/src/server/search.rs b/src-tauri/src/server/search.rs index be480f98..c2afc508 100644 --- a/src-tauri/src/server/search.rs +++ b/src-tauri/src/server/search.rs @@ -1,4 +1,4 @@ -use crate::common::document::Document; +use crate::common::document::{Document, OnOpened}; use crate::common::error::SearchError; use crate::common::http::get_response_body_text; use crate::common::search::{QueryHits, QueryResponse, QuerySource, SearchQuery, SearchResponse}; @@ -45,7 +45,7 @@ impl DocumentsSizedCollector { } } - fn documents(self) -> impl ExactSizeIterator { + fn documents(self) -> impl ExactSizeIterator { self.docs.into_iter().map(|(_, doc, _)| doc) } @@ -103,11 +103,7 @@ impl SearchSource for CocoSearchSource { query_args.insert(key, JsonValue::String(value)); } - let response = HttpClient::get( - &self.server.id, - &url, - Some(query_args), - ) + let response = HttpClient::get(&self.server.id, &url, Some(query_args)) .await .map_err(|e| SearchError::HttpError(format!("{}", e)))?; @@ -116,7 +112,6 @@ impl SearchSource for CocoSearchSource { .await .map_err(|e| SearchError::ParseError(e))?; - // Check if the response body is empty if !response_body.is_empty() { // Parse the search response from the body text @@ -125,14 +120,21 @@ impl SearchSource for CocoSearchSource { // Process the parsed response total_hits = parsed.hits.total.value as usize; - hits = parsed - .hits - .hits - .into_iter() - .map(|hit| (hit._source, hit._score.unwrap_or(0.0))) // Default _score to 0.0 if None - .collect(); - } + for hit in parsed.hits.hits { + let mut document = hit._source; + // Default _score to 0.0 if None + let score = hit._score.unwrap_or(0.0); + let on_opened = document + .url + .as_ref() + .map(|url| OnOpened::Document { url: url.clone() }); + // Set the `on_opened` field as it won't be returned from Coco server + document.on_opened = on_opened; + + hits.push((document, score)); + } + } // Return the final result Ok(QueryResponse { diff --git a/src-tauri/src/util/mod.rs b/src-tauri/src/util/mod.rs index a0c78d45..30f85aa8 100644 --- a/src-tauri/src/util/mod.rs +++ b/src-tauri/src/util/mod.rs @@ -67,7 +67,6 @@ fn get_linux_desktop_environment() -> Option { // // tauri_plugin_shell::open() is deprecated, but we still use it. #[allow(deprecated)] -#[tauri::command] pub async fn open(app_handle: AppHandle, path: String) -> Result<(), String> { if cfg!(target_os = "linux") { let borrowed_path = Path::new(&path); diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 3b10b8ed..c4ed0afd 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -42,6 +42,8 @@ "url": "/ui/settings", "width": 1000, "height": 700, + "minHeight": 700, + "minWidth": 1000, "center": true, "transparent": true, "maximizable": false, @@ -105,7 +107,7 @@ } } }, - "resources": ["assets", "icons"] + "resources": ["assets/**/*", "icons"] }, "plugins": { "features": { diff --git a/src/components/Assistant/AssistantList.tsx b/src/components/Assistant/AssistantList.tsx index 5da62886..9b05d374 100644 --- a/src/components/Assistant/AssistantList.tsx +++ b/src/components/Assistant/AssistantList.tsx @@ -1,4 +1,4 @@ -import { useState, useRef, useCallback } from "react"; +import { useState, useRef, useCallback, useEffect } from "react"; import { ChevronDownIcon, RefreshCw } from "lucide-react"; import { useTranslation } from "react-i18next"; import { isNil } from "lodash-es"; @@ -16,6 +16,7 @@ import PopoverInput from "@/components/Common/PopoverInput"; import { AssistantFetcher } from "./AssistantFetcher"; import AssistantItem from "./AssistantItem"; import Pagination from "@/components/Common/Pagination"; +import { useSearchStore } from "@/stores/searchStore"; interface AssistantListProps { assistantIDs?: string[]; @@ -37,6 +38,11 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) { const searchInputRef = useRef(null); const [keyword, setKeyword] = useState(""); const debounceKeyword = useDebounce(keyword, { wait: 500 }); + const askAiAssistantId = useSearchStore((state) => state.askAiAssistantId); + const setAskAiAssistantId = useSearchStore((state) => { + return state.setAskAiAssistantId; + }); + const assistantList = useConnectStore((state) => state.assistantList); const { fetchAssistant } = AssistantFetcher({ debounceKeyword, @@ -62,6 +68,19 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) { const [highlightIndex, setHighlightIndex] = useState(-1); const [isKeyboardActive, setIsKeyboardActive] = useState(false); + useEffect(() => { + if (!askAiAssistantId || assistantList.length === 0) return; + + const matched = assistantList.find((item) => { + return item._id === askAiAssistantId; + }); + + if (!matched) return; + + setCurrentAssistant(matched); + setAskAiAssistantId(void 0); + }, [assistantList, askAiAssistantId]); + useKeyPress( ["uparrow", "downarrow", "enter"], (event, key) => { diff --git a/src/components/ChatMessage/MessageActions.tsx b/src/components/ChatMessage/MessageActions.tsx index 493a4334..f6e2f68a 100644 --- a/src/components/ChatMessage/MessageActions.tsx +++ b/src/components/ChatMessage/MessageActions.tsx @@ -1,5 +1,6 @@ import { COPY_BUTTON_ID } from "@/constants"; import { useSearchStore } from "@/stores/searchStore"; +import clsx from "clsx"; import { Check, Copy, @@ -14,6 +15,8 @@ interface MessageActionsProps { id: string; content: string; question?: string; + actionClassName?: string; + actionIconSize?: number; onResend?: () => void; } @@ -23,6 +26,8 @@ export const MessageActions = ({ id, content, question, + actionClassName, + actionIconSize, onResend, }: MessageActionsProps) => { const [copied, setCopied] = useState(false); @@ -89,7 +94,7 @@ export const MessageActions = ({ const goAskAi = useSearchStore((state) => state.goAskAi); return ( -
+
{!isRefreshOnly && ( )} @@ -116,6 +133,10 @@ export const MessageActions = ({ ? "text-[#1990FF] dark:text-[#1990FF]" : "text-[#666666] dark:text-[#A3A3A3]" }`} + style={{ + width: actionIconSize, + height: actionIconSize, + }} /> )} @@ -132,6 +153,10 @@ export const MessageActions = ({ ? "text-[#1990FF] dark:text-[#1990FF]" : "text-[#666666] dark:text-[#A3A3A3]" }`} + style={{ + width: actionIconSize, + height: actionIconSize, + }} /> )} @@ -146,6 +171,10 @@ export const MessageActions = ({ ? "text-[#1990FF] dark:text-[#1990FF]" : "text-[#666666] dark:text-[#A3A3A3]" }`} + style={{ + width: actionIconSize, + height: actionIconSize, + }} /> )} @@ -162,6 +191,10 @@ export const MessageActions = ({ ? "text-[#1990FF] dark:text-[#1990FF]" : "text-[#666666] dark:text-[#A3A3A3]" }`} + style={{ + width: actionIconSize, + height: actionIconSize, + }} /> )} diff --git a/src/components/ChatMessage/index.tsx b/src/components/ChatMessage/index.tsx index dcc89be8..53b92644 100644 --- a/src/components/ChatMessage/index.tsx +++ b/src/components/ChatMessage/index.tsx @@ -30,6 +30,9 @@ interface ChatMessageProps { onResend?: (value: string) => void; loadingStep?: Record; hide_assistant?: boolean; + rootClassName?: string; + actionClassName?: string; + actionIconSize?: number; } export const ChatMessage = memo(function ChatMessage({ @@ -45,6 +48,9 @@ export const ChatMessage = memo(function ChatMessage({ onResend, loadingStep, hide_assistant = false, + rootClassName, + actionClassName, + actionIconSize, }: ChatMessageProps) { const { t } = useTranslation(); @@ -144,6 +150,8 @@ export const ChatMessage = memo(function ChatMessage({ id={message._id} content={messageContent || response?.message_chunk || ""} question={question} + actionClassName={actionClassName} + actionIconSize={actionIconSize} onResend={() => { onResend && onResend(question); }} @@ -166,7 +174,8 @@ export const ChatMessage = memo(function ChatMessage({ [isAssistant ? "justify-start" : "justify-end"], { hidden: visibleStartPage, - } + }, + rootClassName )} >
= (props) => { - const { size = 16, color } = props; - - return ( - - 编组 3 - - - - - - - - - - - - ); -}; - -export default AiSummaryIcon; diff --git a/src/components/Search/AiOverview.tsx b/src/components/Search/AiOverview.tsx new file mode 100644 index 00000000..4e51120f --- /dev/null +++ b/src/components/Search/AiOverview.tsx @@ -0,0 +1,91 @@ +import { ChevronUp, Sparkles } from "lucide-react"; +import { FC, useState } from "react"; +import clsx from "clsx"; +import { useStreamChat } from "@/hooks/useStreamChat"; +import { useExtensionsStore } from "@/stores/extensionsStore"; +import { ChatMessage } from "../ChatMessage"; + +interface AiSummaryProps { + message: string; +} + +const AiOverview: FC = (props) => { + const { message } = props; + const aiOverviewServer = useExtensionsStore((state) => { + return state.aiOverviewServer; + }); + const aiOverviewAssistant = useExtensionsStore((state) => { + return state.aiOverviewAssistant; + }); + + const [expand, setExpand] = useState(true); + const [visible, setVisible] = useState(false); + + const { isTyping, chunkData, loadingStep } = useStreamChat({ + message, + clientId: "ai-overview-client-id", + server: aiOverviewServer, + assistant: aiOverviewAssistant, + setVisible, + }); + + return ( +
+
{ + setExpand(!expand); + }} + > + +
+ +
+ + AI Overview +
+ +
+
+ +
+
+ +
+
+ ); +}; + +export default AiOverview; diff --git a/src/components/Search/AiSummary.tsx b/src/components/Search/AiSummary.tsx deleted file mode 100644 index 63661111..00000000 --- a/src/components/Search/AiSummary.tsx +++ /dev/null @@ -1,49 +0,0 @@ -import { ChevronUp, Copy, SquareArrowOutUpRight, Volume2 } from "lucide-react"; -import { useState } from "react"; -import AiSummaryIcon from "../Common/Icons/AiSummaryIcon"; -import clsx from "clsx"; -import Markdown from "../ChatMessage/Markdown"; - -const AiSummary = () => { - const [expand, setExpand] = useState(true); - - return ( -
-
{ - setExpand(!expand); - }} - > - -
- -
- - AI Summarize -
- -
- -
- -
- - - - - -
-
- ); -}; - -export default AiSummary; diff --git a/src/components/Search/AskAi.tsx b/src/components/Search/AskAi.tsx index 5abaac60..08d842f6 100644 --- a/src/components/Search/AskAi.tsx +++ b/src/components/Search/AskAi.tsx @@ -9,7 +9,7 @@ import { useEffect, useRef, useState } from "react"; import { noop } from "lodash-es"; import { ChatMessage } from "../ChatMessage"; -import { ASK_AI_CLIENT_ID, COPY_BUTTON_ID } from "@/constants"; +import { COPY_BUTTON_ID } from "@/constants"; import { useSearchStore } from "@/stores/searchStore"; import platformAdapter from "@/utils/platformAdapter"; import useMessageChunkData from "@/hooks/useMessageChunkData"; @@ -75,6 +75,9 @@ const AskAi = () => { return state.setAskAiServerId; }); const state = useReactive({}); + const setAskAiAssistantId = useSearchStore((state) => { + return state.setAskAiAssistantId; + }); useEffect(() => { if (state.serverId) return; @@ -97,12 +100,10 @@ const AskAi = () => { useMount(async () => { try { unlisten.current = await platformAdapter.listenEvent( - ASK_AI_CLIENT_ID, + "quick-ai-access-client-id", ({ payload }) => { console.log("ask_ai", JSON.parse(payload)); - setIsTyping(true); - const chunkData = JSON.parse(payload); if (chunkData?._id) { @@ -115,6 +116,13 @@ const AskAi = () => { return; } + // If the chunk data does not contain a message_chunk, we ignore it + if (!chunkData.message_chunk) { + return; + } + + setIsTyping(true); + setLoadingStep(() => ({ query_intent: false, tools: false, @@ -164,15 +172,12 @@ const AskAi = () => { const { serverId, assistantId } = state; - console.log("serverId", serverId); - console.log("assistantId", assistantId); - try { await platformAdapter.invokeBackend("ask_ai", { message: askAiMessage, serverId, assistantId, - clientId: ASK_AI_CLIENT_ID, + clientId: "quick-ai-access-client-id", }); } catch (error) { addError(String(error)); @@ -184,7 +189,7 @@ const AskAi = () => { if (isTyping) return; - const { serverId } = state; + const { serverId, assistantId } = state; if ((isMac && metaKey) || (!isMac && ctrlKey)) { await platformAdapter.commands("open_session_chat", { @@ -195,7 +200,8 @@ const AskAi = () => { platformAdapter.emitEvent("toggle-to-chat-mode"); setAskAiServerId(serverId); - return setAskAiSessionId(sessionIdRef.current); + setAskAiSessionId(sessionIdRef.current); + return setAskAiAssistantId(assistantId); } const copyButton = document.getElementById(COPY_BUTTON_ID); diff --git a/src/components/Search/AssistantManager.tsx b/src/components/Search/AssistantManager.tsx index ad4833a8..439fad6f 100644 --- a/src/components/Search/AssistantManager.tsx +++ b/src/components/Search/AssistantManager.tsx @@ -38,16 +38,16 @@ export function useAssistantManager({ const [assistantDetail, setAssistantDetail] = useState({}); const assistant_get = useCallback(async () => { + if (!askAI?.id) return; if (isTauri) { + if (!askAI?.querySource?.id) return; const res = await platformAdapter.commands("assistant_get", { serverId: askAI?.querySource?.id, assistantId: askAI?.id, }); setAssistantDetail(res); } else { - const [error, res]: any = await Get(`/assistant/${askAI?.id}`, { - id: askAI?.id, - }); + const [error, res]: any = await Get(`/assistant/${askAI?.id}`); if (error) { console.error("assistant", error); return; @@ -57,6 +57,8 @@ export function useAssistantManager({ }, [askAI]); const handleAskAi = (event: React.KeyboardEvent) => { + if (!isTauri) return; + askAIRef.current = cloneDeep(askAI); if (!askAIRef.current) return; @@ -67,7 +69,6 @@ export function useAssistantManager({ if (!selectedAssistant && isEmpty(value)) return; - assistant_get(); changeInput(""); setAskAiMessage(!goAskAi && selectedAssistant ? "" : value); setGoAskAi(true); @@ -84,7 +85,9 @@ export function useAssistantManager({ return setGoAskAi(false); } - if (key === "Tab" && !isChatMode) { + if (key === "Tab" && !isChatMode && isTauri) { + assistant_get(); + return handleAskAi(e); } diff --git a/src/components/Search/AutoResizeTextarea.tsx b/src/components/Search/AutoResizeTextarea.tsx index 74042e7d..c87dc5a7 100644 --- a/src/components/Search/AutoResizeTextarea.tsx +++ b/src/components/Search/AutoResizeTextarea.tsx @@ -1,4 +1,4 @@ -import { useBoolean } from "ahooks"; +import { useBoolean, useDebounceFn } from "ahooks"; import { useRef, useImperativeHandle, @@ -8,6 +8,10 @@ import { } from "react"; import { useTranslation } from "react-i18next"; +const LINE_HEIGHT = 24; // 1.5rem +const MAX_FIRST_LINE_WIDTH = 470; // Width in pixels for first line +const MAX_HEIGHT = 240; // 15rem + interface AutoResizeTextareaProps { input: string; setInput: (value: string) => void; @@ -37,6 +41,77 @@ const AutoResizeTextarea = forwardRef< const textareaRef = useRef(null); const [isComposition, { setTrue, setFalse }] = useBoolean(); + // Memoize resize logic + const { run: debouncedResize } = useDebounceFn( + () => { + const textarea = textareaRef.current; + if (!textarea) return; + + // Reset height to auto to get the correct scrollHeight + textarea.style.height = "auto"; + + // Create a hidden span to measure first line width + const span = document.createElement("span"); + span.style.visibility = "hidden"; + span.style.position = "absolute"; + span.style.whiteSpace = "pre"; + span.style.font = window.getComputedStyle(textarea).font; + + // Get first line content + const content = textarea.value; + const firstLineEnd = + content.indexOf("\n") === -1 ? content.length : content.indexOf("\n"); + span.textContent = content.slice(0, firstLineEnd); + document.body.appendChild(span); + + // Calculate lines based on first line width + const firstLineWidth = span.offsetWidth; + document.body.removeChild(span); + + // Start with 1 line + let lines = 1; + + // Add a line if first line exceeds max width + if (firstLineWidth > MAX_FIRST_LINE_WIDTH) { + lines += 1; + } + + // Add lines based on scrollHeight for remaining content + const scrollHeight = textarea.scrollHeight; + const remainingLines = Math.floor( + (scrollHeight - LINE_HEIGHT) / LINE_HEIGHT + ); + lines += Math.max(0, remainingLines); + + // Calculate final height + const newHeight = Math.min(lines * LINE_HEIGHT, MAX_HEIGHT); + + // Only update if height actually changed + if (textarea.style.height !== `${newHeight}px`) { + textarea.style.height = `${newHeight}px`; + onLineCountChange?.(lines); + } + }, + { wait: 100 } + ); + + // Handle input changes and initial setup + useEffect(() => { + if (textareaRef.current) { + debouncedResize(); + } + }, [input, debouncedResize]); + + useEffect(() => { + if (textareaRef.current) { + requestAnimationFrame(() => { + // Set cursor position to end + const length = textareaRef.current?.value.length || 0; + textareaRef.current?.setSelectionRange(length, length); + }); + } + }, [lineCount]); + // Expose methods to the parent via ref useImperativeHandle(ref, () => ({ reset: () => { @@ -47,13 +122,6 @@ const AutoResizeTextarea = forwardRef< }, })); - useEffect(() => { - if (textareaRef.current) { - const length = textareaRef.current.value.length; - textareaRef.current.setSelectionRange(length, length); - } - }, [lineCount]); - const handleKeyPress = (event: KeyboardEvent) => { if (isComposition) { return event.stopPropagation(); @@ -62,18 +130,6 @@ const AutoResizeTextarea = forwardRef< handleKeyDown?.(event); }; - useEffect(() => { - if (textareaRef.current) { - textareaRef.current.style.height = "auto"; - const newHeight = Math.min(textareaRef.current.scrollHeight, 15 * 16); // 15rem ≈ 15 * 16px - textareaRef.current.style.height = `${newHeight}px`; - - const lineHeight = 24; // 1.5rem = 24px - const lineCount = Math.ceil(newHeight / lineHeight); - onLineCountChange?.(lineCount); - } - }, [input]); - return (