2 Commits

Author SHA1 Message Date
ayang
9ee6b9a6c9 feat: add file upload failure handling and alert message 2025-05-16 14:32:11 +08:00
ayang
24b1758b11 refactor: enabling the InputExtra component 2025-05-15 15:50:03 +08:00
357 changed files with 12257 additions and 34530 deletions

2
.env
View File

@@ -1,3 +1,5 @@
COCO_SERVER_URL=http://localhost:9000 #https://coco.infini.cloud #http://localhost:9000
COCO_WEBSOCKET_URL=ws://localhost:9000/ws #wss://coco.infini.cloud/ws #ws://localhost:9000/ws
#TAURI_DEV_HOST=0.0.0.0

View File

@@ -1,18 +0,0 @@
name: Enforce no dependency pizza-engine
on:
pull_request:
jobs:
main:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name:
working-directory: ./src-tauri
run: |
# if cargo remove pizza-engine succeeds, then it is in our dependency list, fail the CI pipeline.
if cargo remove pizza-engine; then exit 1; fi

View File

@@ -1,70 +0,0 @@
name: Frontend Code Check
on:
pull_request:
# Only run it when Frontend code changes
paths:
- 'src/**'
- 'tsup.config.ts'
- 'package.json'
jobs:
check:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}
fetch-depth: 0
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
# No need to pass the version arg as it is specified by "packageManager" in package.json
- name: Install pnpm
uses: pnpm/action-setup@v4
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Switch platformAdapter to Web adapter
shell: bash
run: >
node -e "const fs=require('fs');const f='src/utils/platformAdapter.ts';
let s=fs.readFileSync(f,'utf8');
s=s.replace(/import\\s*\\{\\s*createTauriAdapter\\s*\\}\\s*from\\s*\\\"\\.\\/tauriAdapter\\\";/,'import { createWebAdapter } from \\\"./webAdapter\\\";');
s=s.replace(/let\\s+platformAdapter\\s*=\\s*createTauriAdapter\\(\\);/,'let platformAdapter = createWebAdapter();');
fs.writeFileSync(f,s);"
- name: Build web (Tauri dependency check)
run: pnpm build:web
- name: Verify no Tauri refs in web output
shell: bash
run: |
if grep -R -n -E '@tauri-apps|tauri-plugin' out/search-chat; then
echo 'Tauri references found in web build output';
exit 1;
else
echo 'No Tauri references found';
fi
- name: Restore platformAdapter to Tauri adapter
shell: bash
run: >
node -e "const fs=require('fs');const f='src/utils/platformAdapter.ts';
let s=fs.readFileSync(f,'utf8');
s=s.replace(/import\\s*\\{\\s*createWebAdapter\\s*\\}\\s*from\\s*\\\"\\.\\/webAdapter\\\";/,'import { createTauriAdapter } from \\\"./tauriAdapter\\\";');
s=s.replace(/let\\s+platformAdapter\\s*=\\s*createWebAdapter\\(\\);/,'let platformAdapter = createTauriAdapter();');
fs.writeFileSync(f,s);"
- name: Build frontend
run: pnpm build

View File

@@ -9,16 +9,10 @@ on:
jobs:
create-release:
runs-on: ubuntu-latest
outputs:
APP_VERSION: ${{ steps.get-version.outputs.APP_VERSION }}
RELEASE_BODY: ${{ steps.get-changelog.outputs.RELEASE_BODY }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set output
id: vars
run: echo "tag=${GITHUB_REF#refs/*/}" >> $GITHUB_OUTPUT
@@ -28,28 +22,11 @@ jobs:
with:
node-version: 20
- name: Get build version
shell: bash
id: get-version
run: |
PACKAGE_VERSION=$(jq -r '.version' package.json)
CARGO_VERSION=$(grep -m 1 '^version =' src-tauri/Cargo.toml | sed -E 's/.*"([^"]+)".*/\1/')
if [ "$PACKAGE_VERSION" != "$CARGO_VERSION" ]; then
echo "::error::Version mismatch!"
else
echo "Version match: $PACKAGE_VERSION"
fi
echo "APP_VERSION=$PACKAGE_VERSION" >> $GITHUB_OUTPUT
- name: Generate changelog
id: get-changelog
run: |
CHANGELOG_BODY=$(npx changelogithub --draft --name ${{ steps.vars.outputs.tag }})
echo "RELEASE_BODY<<EOF" >> $GITHUB_OUTPUT
echo "$CHANGELOG_BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
id: create_release
run: npx changelogithub --draft --name ${{ steps.vars.outputs.tag }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
build-app:
needs: create-release
@@ -75,23 +52,11 @@ jobs:
target: "x86_64-unknown-linux-gnu"
- platform: "ubuntu-22.04-arm"
target: "aarch64-unknown-linux-gnu"
env:
APP_VERSION: ${{ needs.create-release.outputs.APP_VERSION }}
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Checkout dependency repository
uses: actions/checkout@v4
with:
repository: 'infinilabs/pizza'
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
submodules: recursive
ref: main
path: pizza
- name: Setup node
uses: actions/setup-node@v4
with:
@@ -100,41 +65,17 @@ jobs:
with:
version: latest
- name: Install rust target
run: rustup target add ${{ matrix.target }}
- name: Install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-22.04')
run: |
sudo apt-get update
sudo apt-get install -y libwebkit2gtk-4.1-dev libappindicator3-dev librsvg2-dev patchelf xdg-utils libtracker-sparql-3.0-dev
sudo apt-get install -y libwebkit2gtk-4.1-dev libappindicator3-dev librsvg2-dev patchelf xdg-utils
# On Windows, we need to generate bindings for 'searchapi.h' using bindgen.
# And bindgen relies on 'libclang'
# https://rust-lang.github.io/rust-bindgen/requirements.html#windows
#
# We don't need to install it because it is already included in GitHub
# Action runner image:
# https://github.com/actions/runner-images/blob/main/images/windows/Windows2025-Readme.md#language-and-runtime
- name: Add Rust build target
working-directory: src-tauri
shell: bash
run: |
rustup target add ${{ matrix.target }} || true
- name: Add pizza engine as a dependency
working-directory: src-tauri
shell: bash
run: |
BUILD_ARGS="--target ${{ matrix.target }}"
if [[ "${{matrix.target }}" != "i686-pc-windows-msvc" ]]; then
echo "Adding pizza engine as a dependency for ${{matrix.platform }}-${{matrix.target }}"
( cargo add --path ../pizza/lib/engine --features query_string_parser,persistence )
BUILD_ARGS+=" --features use_pizza_engine"
else
echo "Skipping pizza engine dependency for ${{matrix.platform }}-${{matrix.target }}"
fi
echo "BUILD_ARGS=${BUILD_ARGS}" >> $GITHUB_ENV
- name: Install Rust stable
run: rustup toolchain install stable
- name: Rust cache
uses: swatinem/rust-cache@v2
@@ -149,8 +90,8 @@ jobs:
- name: Install app dependencies and build web
run: pnpm install --frozen-lockfile
- name: Build the coco at ${{ matrix.platform}} for ${{ matrix.target }} @ ${{ env.APP_VERSION }}
- name: Build the app
uses: tauri-apps/tauri-action@v0
env:
CI: false
@@ -166,8 +107,8 @@ jobs:
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
with:
tagName: ${{ github.ref_name }}
releaseName: Coco ${{ env.APP_VERSION }}
releaseBody: "${{ needs.create-release.outputs.RELEASE_BODY }}"
releaseName: Coco ${{ needs.create-release.outputs.APP_VERSION }}
releaseBody: ""
releaseDraft: true
prerelease: false
args: ${{ env.BUILD_ARGS }}
args: --target ${{ matrix.target }}

View File

@@ -1,69 +0,0 @@
name: Rust Code Check
on:
pull_request:
# Only run it when Rust code changes
paths:
- 'src-tauri/**'
jobs:
check:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Checkout dependency (pizza-engine) repository
uses: actions/checkout@v4
with:
repository: 'infinilabs/pizza'
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
submodules: recursive
ref: main
path: pizza
- name: Install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-latest')
run: |
sudo apt-get update
sudo apt-get install -y libwebkit2gtk-4.1-dev libappindicator3-dev librsvg2-dev patchelf xdg-utils libtracker-sparql-3.0-dev
# On Windows, we need to generate bindings for 'searchapi.h' using bindgen.
# And bindgen relies on 'libclang'
# https://rust-lang.github.io/rust-bindgen/requirements.html#windows
#
# We don't need to install it because it is already included in GitHub
# Action runner image:
# https://github.com/actions/runner-images/blob/main/images/windows/Windows2025-Readme.md#language-and-runtime
- name: Add pizza engine as a dependency
working-directory: src-tauri
shell: bash
run: cargo add --path ../pizza/lib/engine --features query_string_parser,persistence
- name: Format check
working-directory: src-tauri
shell: bash
run: |
rustup component add rustfmt
cargo fmt --all --check
- name: Check compilation (Without Pizza engine enabled)
working-directory: ./src-tauri
run: cargo check
- name: Check compilation (With Pizza engine enabled)
working-directory: ./src-tauri
run: cargo check --features use_pizza_engine
- name: Run tests (Without Pizza engine)
working-directory: ./src-tauri
run: cargo test
- name: Run tests (With Pizza engine)
working-directory: ./src-tauri
run: cargo test --features use_pizza_engine

11
.vscode/settings.json vendored
View File

@@ -8,14 +8,11 @@
"clsx",
"codegen",
"dataurl",
"deeplink",
"deepthink",
"dtolnay",
"dyld",
"elif",
"errmsg",
"fullscreen",
"fulltext",
"headlessui",
"Icdbb",
"icns",
@@ -32,8 +29,6 @@
"localstorage",
"lucide",
"maximizable",
"mdast",
"meval",
"Minimizable",
"msvc",
"nord",
@@ -43,11 +38,9 @@
"overscan",
"partialize",
"patchelf",
"Quicklink",
"Raycast",
"rehype",
"reqwest",
"rerank",
"rgba",
"rustup",
"screenshotable",
@@ -62,7 +55,6 @@
"traptitech",
"unlisten",
"unlistener",
"unlisteners",
"unminimize",
"uuidv",
"VITE",
@@ -83,6 +75,5 @@
"i18n-ally.keystyle": "nested",
"editor.tabSize": 2,
"editor.insertSpaces": true,
"editor.detectIndentation": false,
"i18n-ally.displayLanguage": "zh"
"editor.detectIndentation": false
}

View File

@@ -78,8 +78,4 @@ clean-rebuild:
$(MAKE) dev-build
add-dep-pizza-engine:
cd src-tauri && cargo add --git ssh://git@github.com/infinilabs/pizza.git pizza-engine --features query_string_parser,persistence
dev-build-with-pizza: add-dep-pizza-engine
@echo "Starting desktop development with Pizza Engine pulled in..."
RUST_BACKTRACE=1 pnpm tauri dev --features use_pizza_engine
cd src-tauri && cargo add --git ssh://git@github.com/infinilabs/pizza.git pizza-engine --features query_string_parser,persistence

View File

@@ -64,9 +64,9 @@ At Coco AI, we aim to streamline workplace collaboration by centralizing access
### Prerequisites
- [Node.js >= 18.12](https://nodejs.org/en/download/)
- [Rust (latest stable)](https://www.rust-lang.org/tools/install)
- [pnpm (package manager)](https://pnpm.io/installation)
- Node.js >= 18.12
- Rust (latest stable)
- pnpm (package manager)
### Development Setup
@@ -91,8 +91,6 @@ pnpm tauri build
- [Coco App Documentation](https://docs.infinilabs.com/coco-app/main/)
- [Coco Server Documentation](https://docs.infinilabs.com/coco-server/main/)
- [DeepWiki Coco App](https://deepwiki.com/infinilabs/coco-app)
- [DeepWiki Coco Server](https://deepwiki.com/infinilabs/coco-server)
- [Tauri Documentation](https://tauri.app/)
## Contributors

View File

@@ -1,22 +0,0 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "new-york",
"rsc": false,
"tsx": true,
"tailwind": {
"config": "tailwind.config.js",
"css": "src/main.css",
"baseColor": "neutral",
"cssVariables": true,
"prefix": ""
},
"iconLibrary": "lucide",
"aliases": {
"components": "@/components",
"utils": "@/lib/utils",
"ui": "@/components/ui",
"lib": "@/lib",
"hooks": "@/hooks"
},
"registries": {}
}

View File

@@ -9,7 +9,7 @@ Coco AI is a fully open-source, cross-platform unified search and productivity t
{{% load-img "/img/coco-preview.gif" "" %}}
For more details on Coco Server, visit: [https://docs.infinilabs.com/coco-server/](https://docs.infinilabs.com/coco-server/).
For more details on Coco Server, visit: [https://docs.infinilabs.com/coco-app/](https://docs.infinilabs.com/coco-app/).
## Community

View File

@@ -1,35 +1,21 @@
---
weight: 10
title: "macOS"
title: "Mac OS"
asciinema: true
---
# macOS
# Mac OS
## Download Coco AI
Go to [coco.rs](https://coco.rs/) and download the package of your architecture:
Goto [https://coco.rs/](https://coco.rs/)
{{% load-img "/img/macos/mac-download-app.png" "" %}}
It should be placed in your "Downloads" folder:
{{% load-img "/img/macos/mac-zip-file.png" "" %}}
{{% load-img "/img/download-mac-app.png" "" %}}
## Unzip DMG file
Unzip the file:
{{% load-img "/img/macos/mac-unzip-zip-file.png" "" %}}
You will get a `dmg` file:
{{% load-img "/img/macos/mac-dmg.png" "" %}}
{{% load-img "/img/unzip-dmg-file.png" "" %}}
## Drag to Application Folder
Double click the `dmg` file, a window will pop up. Then drag the "Coco-AI" app to
your "Applications" folder:
{{% load-img "/img/macos/drag-to-app-folder.png" "" %}}
{{% load-img "/img/drag-to-application-folder.png" "" %}}

View File

@@ -13,16 +13,8 @@ asciinema: true
[x11_protocol]: https://en.wikipedia.org/wiki/X_Window_System
[if_x11]: https://unix.stackexchange.com/q/202891/498440
## Install dependencies
```sh
$ sudo apt-get update
$ sudo apt-get install -y libwebkit2gtk-4.1-dev libappindicator3-dev librsvg2-dev patchelf xdg-utils libtracker-sparql-3.0-dev
```
## Go to the download page
Download page: [link](https://coco.rs/#install)
## Goto [https://coco.rs/](https://coco.rs/)
## Download the package

View File

@@ -5,7 +5,7 @@ title: "Release Notes"
# Release Notes
Information about release notes of Coco App is provided here.
Information about release notes of Coco Server is provided here.
## Latest (In development)
@@ -13,224 +13,6 @@ Information about release notes of Coco App is provided here.
### 🚀 Features
### 🐛 Bug fix
- fix: search_extension should not panic when ext is not found #983
- fix: persist configuration settings properly #987
### ✈️ Improvements
## 0.9.0 (2025-11-19)
### ❌ Breaking changes
### 🚀 Features
- feat: support switching groups via keyboard shortcuts #911
- feat: support opening logs from about page #915
- feat: support moving cursor with home and end keys #918
- feat: support pageup/pagedown to navigate search results #920
- feat: standardize multi-level menu label structure #925
- feat(View Extension): page field now accepts HTTP(s) links #925
- feat: return sub-exts when extension type exts themselves are matched #928
- feat: open quick ai with modifier key + enter #939
- feat: allow navigate back when cursor is at the beginning #940
- feat(extension compatibility): minimum_coco_version #946
- feat: add compact mode for window #947
- feat: advanced settings search debounce & local query source weight #950
- feat: add window opacity configuration option #963
- feat: add auto collapse delay for compact mode #981
### 🐛 Bug fix
- fix: automatic update of service list #913
- fix: duplicate chat content #916
- fix: resolve pinned window shortcut not working #917
- fix: WM ext does not work when operating focused win from another display #919
- fix(Window Management): Next/Previous Desktop do not work #926
- fix: fix page rapidly flickering issue #935
- fix(view extension): broken search bar UI when opening extensions via hotkey #938
- fix: allow deletion after selecting all text #943
- fix: prevent shaking when switching between chat and search pages #955
- fix: prevent duplicate login success messages #977
- fix: fix quick ai not continuing conversation #979
### ✈️ Improvements
- refactor: improve sorting logic of search results #910
- style: add dark drop shadow to images #912
- chore: add cross-domain configuration for web component #921
- refactor: retry if AXUIElementSetAttributeValue() does not work #924
- refactor(calculator): skip evaluation if expr is in form "num => num" #929
- chore: use a custom log directory #930
- chore: bump tauri_nspanel to v2.1 #933
- refactor: show_coco/hide_coco now use NSPanel's function on macOS #933
- refactor: procedure that convert_pages() into a func #934
- refactor(post-search): collect at least 2 documents from each query source #948
- refactor: custom_version_comparator() now compares semantic versions #941
- chore: center the main window vertically #959
- refactor(view extension): load HTML/resources via local HTTP server #973
## 0.8.0 (2025-09-28)
### ❌ Breaking changes
- chore: update request accesstoken api #866
### 🚀 Features
- feat: enhance ui for skipped version #834
- feat: support installing local extensions #749
- feat: support sending files in chat messages #764
- feat: sub extension can set 'platforms' now #847
- feat: add extension uninstall option in settings #855
- feat: impl extension settings 'hide_before_open' #862
- feat: index both en/zh_CN app names and show app name in chosen language #875
- feat: support context menu in debug mode #882
- feat: file search for Linux/GNOME #884
- feat: file search for Linux/KDE #886
- feat: extension Window Management for macOS #892
- feat: new extension type View #894
- feat: support opening file in its containing folder #900
### 🐛 Bug fix
- fix: fix issue with update check failure #833
- fix: web component login state #857
- fix: shortcut key not opening extension store #877
- fix: set up hotkey on main thread or Windows will complain #879
- fix: resolve deeplink login issue #881
- fix: use kill_on_drop() to avoid zombie proc in error case #887
- fix: settings window rendering/loading issue 889
- fix: ensure search paths are indexed #896
- fix: bump applications-rs to fix empty app name issue #898
### ✈️ Improvements
- refactor: calling service related interfaces #831
- refactor: split query_coco_fusion() #836
- chore: web component loading font icon #838
- chore: delete unused code files and dependencies #841
- chore: ignore tauri::AppHandle's generic argument R #845
- refactor: check Extension/plugin.json from all sources #846
- refactor: pinning window won't set CanJoinAllSpaces on macOS #854
- build: web component build error #858
- refactor: coordinate third-party extension operations using lock #867
- refactor: index iOS apps and macOS apps that store icon in Assets.car #872
- refactor: accept both '-' and '\_' as locale str separator #876
- refactor: relax the file search conditions on macOS #883
- refactor: ensure Coco won't take focus #891
- chore: skip login check for web widget #895
- chore: convertFileSrc() "link[href]" and "img[src]" #901
## 0.7.1 (2025-07-27)
### ❌ Breaking changes
### 🚀 Features
### 🐛 Bug fix
- fix: correct enter key behavior #828
### ✈️ Improvements
- chore: web component add notification component #825
- refactor: collection behavior defaults to `MoveToActiveSpace`, and only use `CanJoinAllSpaces` when window is pinned #829
## 0.7.0 (2025-07-25)
### ❌ Breaking changes
### 🚀 Features
- feat: file search using spotlight #705
- feat: voice input support in both search and chat modes #732
- feat: text to speech now powered by LLM #750
- feat: file search for Windows #762
### 🐛 Bug fix
- fix(file search): apply filters before from/size parameters #741
- fix(file search): searching by name&content does not search file name #743
- fix: prevent window from hiding when moved on Windows #748
- fix: unregister ext hotkey when it gets deleted #770
- fix: indexing apps does not respect search scope config #773
- fix: restore missing category titles on subpages #772
- fix: correct incorrect assistant display when quick ai access #779
- fix: resolved minor issues with voice playback #780
- fix: fixed incorrect taskbar icon display on linux #783
- fix: fix data inconsistency issue on secondary pages #784
- fix: incorrect status when installing extension #789
- fix: increase read_timeout for HTTP streaming stability #798
- fix: enter key problem #794
- fix: fix selection issue after renaming #800
- fix: fix shortcut issue in windows context menu #804
- fix: panic caused by "state() called before manage()" #806
- fix: fix multiline input issue #808
- fix: fix ctrl+k not working #815
- fix: fix update window config sync #818
- fix: fix enter key on subpages #819
- fix: panic on Ubuntu (GNOME) when opening apps #821
### ✈️ Improvements
- refactor: prioritize stat(2) when checking if a file is dir #737
- refactor: change File Search ext type to extension #738
- refactor: create chat & send chat api #739
- chore: icon support for more file types #740
- chore: replace meval-rs with our fork to clear dep warning #745
- refactor: adjusted assistant, datasource, mcp_server interface parameters #746
- refactor: adjust extension code hierarchy #747
- chore: bump dep applications-rs #751
- chore: rename QuickLink/quick_link to Quicklink/quicklink #752
- chore: assistant params & styles #753
- chore: make optional fields optional #758
- chore: search-chat components add formatUrl & think data & icons url #765
- chore: Coco app http request headers #744
- refactor: do status code check before deserializing response #767
- style: splash adapts to the width of mobile phones #768
- chore: search-chat add language and formatUrl parameters #775
- chore: not request the interface if not logged in #795
- refactor: clean up unsupported characters from query string in Win Search #802
- chore: display backtrace in panic log #805
## 0.6.0 (2025-06-29)
### ❌ Breaking changes
### 🚀 Features
- feat: support `Tab` and `Enter` for delete dialog buttons #700
- feat: add check for updates #701
- feat: impl extension store #699
- feat: support back navigation via delete key #717
### 🐛 Bug fix
- fix: quick ai state synchronous #693
- fix: toggle extension should register/unregister hotkey #691
- fix: take coco server back on refresh #696
- fix: some input fields couldnt accept spaces #709
- fix: context menu search not working #713
- fix: open extension store display #724
### ✈️ Improvements
- refactor: use author/ext_id as extension unique identifier #643
- refactor: refactoring search api #679
- chore: continue to chat page display #690
- chore: improve server list selection with enter key #692
- chore: add message for latest version check #703
- chore: log command execution results #718
- chore: adjust styles and add button reindex #719
## 0.5.0 (2025-06-13)
### ❌ Breaking changes
### 🚀 Features
- feat: check or enter to close the list of assistants #469
- feat: add dimness settings for pinned window #470
- feat: supports Shift + Enter input box line feeds #472
@@ -242,59 +24,17 @@ Information about release notes of Coco App is provided here.
- feat: the search input box supports multi-line input #501
- feat: websocket support self-signed TLS #504
- feat: add option to allow self-signed certificates #509
- feat: add AI summary component #518
- feat: dynamic log level via env var COCO_LOG #535
- feat: add quick AI access to search mode #556
- feat: rerank search results #561
- feat: ai overview support is enabled with shortcut #597
- feat: add key monitoring during reset #615
- feat: calculator extension add description #623
- feat: support right-click actions after text selection #624
- feat: add ai overview minimum number of search results configuration #625
- feat: add internationalized translations of AI-related extensions #632
- feat: context menu support for secondary pages #680
### 🐛 Bug fix
- fix: solve the problem of modifying the assistant in the chat #476
- fix: several issues around search #502
- fix: fixed the newly created session has no title when it is deleted #511
- fix: loading chat history for potential empty attachments
- fix: datasource & MCP list synchronization update #521
- fix: app icon & category icon #529
- fix: show only enabled datasource & MCP list
- fix: server image loading failure #534
- fix: panic when fetching app metadata on Windows #538
- fix: service switching error #539
- fix: switch server assistant and session unchanged #540
- fix: history list height #550
- fix: secondary page cannot be searched #551
- fix: the scroll button is not displayed by default #552
- fix: suggestion list position #553
- fix: independent chat window has no data #554
- fix: resolved navigation error on continue chat action #558
- fix: make extension search source respect parameter datasource #576
- fix: fixed issue with incorrect login status #600
- fix: new chat assistant id not found #603
- fix: resolve regex error on older macOS versions #605
- fix: fix chat log update and sorting issues #612
- fix: resolved an issue where number keys were not working on the web #616
- fix: do not panic when the datasource specified does not exist #618
- fix: fixed modifier keys not working with continue chat #619
- fix: invalid DSL error if input contains multiple lines #620
- fix: fix ai overview hidden height before message #622
- fix: tab key hides window in chat mode #641
- fix: arrow keys still navigated search when menu opened with Cmd+K #642
- fix: input lost when reopening dialog after search #644
- fix: web page unmount event #645
- fix: fix the problem of local path not opening #650
- fix: number keys not following settings #661
- fix: fix problem with up and down key indexing #676
- fix: arrow inserting escape sequences #683
### ✈️ Improvements
- chore: adjust list error message #475
- fix: solve the problem of modifying the assistant in the chat #476
- chore: refine wording on search failure
- choresearch and MCP show hidden logic #494
- chore: greetings show hidden logic #496
@@ -305,32 +45,6 @@ Information about release notes of Coco App is provided here.
- refactor: optimized the modification operation of the numeric input box #508
- style: modify the style of the search input box #513
- style: chat input icons show #515
- refactor: refactoring icon component #514
- refactor: optimizing list styles in markdown content #520
- feat: add a component for text reading aloud #522
- style: history component styles #528
- style: search error styles #533
- chore: skip register server that not logged in #536
- refactor: service info related components #537
- chore: chat content can be copied #539
- refactor: refactoring search error #541
- chore: add assistant count #542
- chore: add global login judgment #544
- chore: mark server offline on user logout #546
- chore: logout update server profile #549
- chore: assistant keyboard events and mouse events #559
- chore: web component start page config #560
- chore: assistant chat placeholder & refactor input box components #566
- refactor: input box related components #568
- chore: mark unavailable server to offline on refresh info #569
- chore: only show available servers in chat #570
- refactor: search result related components #571
- chore: initialize current assistant from history #606
- chore: add onContextMenu event #629
- chore: more logs for the setup process #634
- chore: copy supports http protocol #639
- refactor: use author/ext_id as extension unique identifier #643
- chore: add special character filtering #668
## 0.4.0 (2025-04-27)
@@ -360,8 +74,6 @@ Information about release notes of Coco App is provided here.
- feat: data sources support displaying customized icons #432
- feat: add shortcut key conflict hint and reset function #442
- feat: updated to include error message #465
- feat: support third party extensions #572
- feat: support ai overview #572
### Bug fix

BIN
docs/static/img/download-mac-app.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 586 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 299 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 650 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 441 KiB

BIN
docs/static/img/unzip-dmg-file.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

View File

@@ -1,7 +1,7 @@
{
"name": "coco",
"private": true,
"version": "0.9.0",
"version": "0.4.0",
"type": "module",
"scripts": {
"dev": "vite",
@@ -18,8 +18,8 @@
"release-beta": "release-it --preRelease=beta --preReleaseBase=1"
},
"dependencies": {
"@ant-design/icons": "^6.0.0",
"@headlessui/react": "^2.2.2",
"@radix-ui/react-slot": "^1.2.3",
"@tauri-apps/api": "^2.5.0",
"@tauri-apps/plugin-autostart": "~2.2.0",
"@tauri-apps/plugin-deep-link": "^2.2.1",
@@ -27,16 +27,15 @@
"@tauri-apps/plugin-global-shortcut": "~2.0.0",
"@tauri-apps/plugin-http": "~2.0.2",
"@tauri-apps/plugin-log": "~2.4.0",
"@tauri-apps/plugin-opener": "^2.5.0",
"@tauri-apps/plugin-os": "^2.2.1",
"@tauri-apps/plugin-process": "^2.2.1",
"@tauri-apps/plugin-shell": "^2.2.1",
"@tauri-apps/plugin-updater": "github:infinilabs/tauri-plugin-updater#v2",
"@tauri-apps/plugin-websocket": "~2.3.0",
"@tauri-apps/plugin-window": "2.0.0-alpha.1",
"@wavesurfer/react": "^1.0.11",
"ahooks": "^3.8.4",
"axios": "^1.12.0",
"class-variance-authority": "^0.7.1",
"axios": "^1.9.0",
"clsx": "^2.1.1",
"dayjs": "^1.11.13",
"dotenv": "^16.5.0",
@@ -45,7 +44,6 @@
"i18next-browser-languagedetector": "^8.1.0",
"lodash-es": "^4.17.21",
"lucide-react": "^0.461.0",
"mdast-util-gfm-autolink-literal": "2.0.0",
"mermaid": "^11.6.0",
"nanoid": "^5.1.5",
"react": "^18.3.1",
@@ -60,13 +58,10 @@
"remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.1",
"remark-math": "^6.0.0",
"tailwind-merge": "^3.3.1",
"tailwindcss-animate": "^1.0.7",
"tauri-plugin-fs-pro-api": "^2.4.0",
"tauri-plugin-macos-permissions-api": "^2.3.0",
"tauri-plugin-screenshots-api": "^2.2.0",
"tauri-plugin-windows-version-api": "^2.0.0",
"type-fest": "^4.41.0",
"use-debounce": "^10.0.4",
"uuid": "^11.1.0",
"wavesurfer.js": "^7.9.5",
@@ -94,6 +89,5 @@
"tsx": "^4.19.4",
"typescript": "^5.8.3",
"vite": "^5.4.19"
},
"packageManager": "pnpm@10.11.0+sha512.6540583f41cc5f628eb3d9773ecee802f4f9ef9923cc45b69890fb47991d4b092964694ec3a4f738a420c918a333062c8b925d312f42e4f0c263eb603551f977"
}
}

1207
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
(() => {})();

3777
src-tauri/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
[package]
name = "coco"
version = "0.9.0"
version = "0.4.0"
description = "Search, connect, collaborate all in one place."
authors = ["INFINI Labs"]
edition = "2024"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
@@ -15,7 +15,6 @@ crate-type = ["staticlib", "cdylib", "rlib"]
[build-dependencies]
tauri-build = { version = "2", features = ["default"] }
cfg-if = "1.0.1"
[features]
default = ["desktop"]
@@ -45,13 +44,14 @@ use_pizza_engine = []
[dependencies]
pizza-common = { git = "https://github.com/infinilabs/pizza-common", branch = "main" }
tauri = { version = "2", features = ["protocol-asset", "macos-private-api", "tray-icon", "image-ico", "image-png"] }
tauri = { version = "2", features = ["protocol-asset", "macos-private-api", "tray-icon", "image-ico", "image-png", "unstable"] }
tauri-plugin-shell = "2"
serde = { version = "1", features = ["derive"] }
# Need `arbitrary_precision` feature to support storing u128
# see: https://docs.rs/serde_json/latest/serde_json/struct.Number.html#method.from_u128
serde_json = { version = "1", features = ["arbitrary_precision", "preserve_order"] }
serde_json = { version = "1", features = ["arbitrary_precision"] }
tauri-plugin-http = "2"
tauri-plugin-websocket = "2"
tauri-plugin-deep-link = "2.0.0"
tauri-plugin-store = "2.2.0"
tauri-plugin-os = "2"
@@ -62,7 +62,7 @@ tauri-plugin-drag = "2"
tauri-plugin-macos-permissions = "2"
tauri-plugin-fs-pro = "2"
tauri-plugin-screenshots = "2"
applications = { git = "https://github.com/infinilabs/applications-rs", rev = "b5fac4034a40d42e72f727f1aa1cc1f19fe86653" }
applications = { git = "https://github.com/infinilabs/applications-rs", rev = "7bb507e6b12f73c96f3a52f0578d0246a689f381" }
tokio-native-tls = "0.3" # For wss connections
tokio = { version = "1", features = ["full"] }
tokio-tungstenite = { version = "0.20", features = ["native-tls"] }
@@ -81,67 +81,24 @@ plist = "1.7"
base64 = "0.13"
walkdir = "2"
log = "0.4"
strsim = "0.10"
futures-util = "0.3.31"
url = "2.5.2"
http = "1.1.0"
tungstenite = "0.24.0"
tokio-util = "0.7.14"
tauri-plugin-windows-version = "2"
meval = { git = "https://github.com/infinilabs/meval-rs" }
meval = "0.2"
chinese-number = "0.7"
num2words = "1"
tauri-plugin-log = "2"
chrono = "0.4.41"
serde_plain = "1.0.2"
derive_more = { version = "2.0.1", features = ["display"] }
anyhow = "1.0.98"
function_name = "0.3.0"
regex = "1.11.1"
borrowme = "0.0.15"
tauri-plugin-opener = "2"
async-recursion = "1.1.1"
zip = "4.0.0"
url = "2.5.2"
camino = "1.1.10"
tokio-stream = { version = "0.1.17", features = ["io-util"] }
sysinfo = "0.35.2"
indexmap = { version = "2.10.0", features = ["serde"] }
strum = { version = "0.27.2", features = ["derive"] }
sys-locale = "0.3.2"
tauri-plugin-prevent-default = "1"
oneshot = "0.1.11"
bitflags = "2.9.3"
cfg-if = "1.0.1"
dunce = "1.0.5"
urlencoding = "2.1.3"
scraper = "0.17"
toml = "0.8"
path-clean = "1.0.1"
actix-files = "0.6.8"
actix-web = "4.11.0"
[dev-dependencies]
tempfile = "3.23.0"
[target."cfg(target_os = \"macos\")".dependencies]
tauri-nspanel = { git = "https://github.com/ahkohd/tauri-nspanel", branch = "v2.1" }
objc2-app-kit = { version = "0.3.1", features = ["NSWindow"] }
objc2 = "0.6.2"
objc2-core-foundation = {version = "0.3.1", features = ["CFString", "CFCGTypes", "CFArray"] }
objc2-application-services = { version = "0.3.1", features = ["HIServices"] }
objc2-core-graphics = { version = "=0.3.1", features = ["CGEvent"] }
[target."cfg(target_os = \"linux\")".dependencies]
gio = "0.21.2"
glib = "0.21.2"
tracker-rs = "0.7"
which = "8.0.0"
configparser = "3.1.0"
tauri-nspanel = { git = "https://github.com/ahkohd/tauri-nspanel", branch = "v2" }
[target."cfg(any(target_os = \"macos\", windows, target_os = \"linux\"))".dependencies]
tauri-plugin-single-instance = { version = "2.0.0", features = ["deep-link"] }
serde = { version = "1.0.219", features = ["derive"], optional = true }
[profile.dev]
incremental = true # Compile your binary in smaller steps.
@@ -157,13 +114,6 @@ strip = true # Ensures debug symbols are removed.
tauri-plugin-autostart = "^2.2"
tauri-plugin-global-shortcut = "2"
tauri-plugin-updater = { git = "https://github.com/infinilabs/plugins-workspace", branch = "v2" }
# This should be compatible with the semver used by `tauri-plugin-updater`
semver = { version = "1", features = ["serde"] }
[target."cfg(target_os = \"windows\")".dependencies]
enigo="0.3"
windows = { version = "0.61", features = ["Win32_Foundation", "Win32_System_Com", "Win32_System_Ole", "Win32_System_Search", "Win32_UI_Shell_PropertiesSystem", "Win32_Data"] }
windows-sys = { version = "0.61", features = ["Win32", "Win32_System", "Win32_System_Com"] }
[target."cfg(target_os = \"windows\")".build-dependencies]
bindgen = "0.72.1"

View File

@@ -1,42 +1,3 @@
fn main() {
tauri_build::build();
// If env var `GITHUB_ACTIONS` exists, we are running in CI, set up the `ci`
// attribute
if std::env::var("GITHUB_ACTIONS").is_ok() {
println!("cargo:rustc-cfg=ci");
}
// Notify `rustc` of this `cfg` attribute to suppress unknown attribute warnings.
//
// unexpected condition name: `ci`
println!("cargo::rustc-check-cfg=cfg(ci)");
// Bindgen searchapi.h on Windows as the windows create does not provide
// bindings for it
cfg_if::cfg_if! {
if #[cfg(target_os = "windows")] {
use std::env;
use std::path::PathBuf;
let wrapper_header = r#"#include <windows.h>
#include <searchapi.h>"#;
let searchapi_bindings = bindgen::Builder::default()
.header_contents("wrapper.h", wrapper_header)
.generate()
.expect("failed to generate bindings for <searchapi.h>");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
searchapi_bindings
.write_to_file(out_path.join("searchapi_bindings.rs"))
.expect("couldn't write bindings to <OUT_DIR/searchapi_bindings.rs>")
// Looks like there is no need to link the library that contains the
// implementation of functions declared in 'searchapi.h' manually as
// the FFI bindings work (without doing that).
//
// This is wield, I do not expect the linker will link it automatically.
}
}
tauri_build::build()
}

View File

@@ -2,7 +2,7 @@
"$schema": "../gen/schemas/desktop-schema.json",
"identifier": "default",
"description": "Capability for the main window",
"windows": ["main", "chat", "settings", "check"],
"windows": ["main", "chat", "settings"],
"permissions": [
"core:default",
"core:event:allow-emit",
@@ -37,6 +37,9 @@
"http:allow-fetch-cancel",
"http:allow-fetch-read-body",
"http:allow-fetch-send",
"websocket:default",
"websocket:allow-connect",
"websocket:allow-send",
"autostart:allow-enable",
"autostart:allow-disable",
"autostart:allow-is-enabled",
@@ -68,8 +71,6 @@
"process:default",
"updater:default",
"windows-version:default",
"log:default",
"opener:default",
"core:window:allow-unminimize"
"log:default"
]
}

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "nightly-2025-06-26"
channel = "nightly-2024-10-29"

View File

@@ -1,34 +1,30 @@
use crate::common;
use crate::common::assistant::ChatRequestMessage;
use crate::common::http::convert_query_params_to_strings;
use crate::common::register::SearchSourceRegistry;
use crate::common::http::GetResponse;
use crate::server::http_client::HttpClient;
use crate::{common, server::servers::COCO_SERVERS};
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use futures_util::TryStreamExt;
use http::Method;
use serde_json::Value;
use std::collections::HashMap;
use tauri::{AppHandle, Emitter, Manager};
use tokio::io::AsyncBufReadExt;
use tauri::{AppHandle, Runtime};
#[tauri::command]
pub async fn chat_history(
_app_handle: AppHandle,
pub async fn chat_history<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
from: u32,
size: u32,
query: Option<String>,
) -> Result<String, String> {
let mut query_params = Vec::new();
// Add from/size as number values
query_params.push(format!("from={}", from));
query_params.push(format!("size={}", size));
let mut query_params: HashMap<String, Value> = HashMap::new();
if from > 0 {
query_params.insert("from".to_string(), from.into());
}
if size > 0 {
query_params.insert("size".to_string(), size.into());
}
if let Some(query) = query {
if !query.is_empty() {
query_params.push(format!("query={}", query.to_string()));
query_params.insert("query".to_string(), query.into());
}
}
@@ -43,18 +39,20 @@ pub async fn chat_history(
}
#[tauri::command]
pub async fn session_chat_history(
_app_handle: AppHandle,
pub async fn session_chat_history<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
session_id: String,
from: u32,
size: u32,
) -> Result<String, String> {
let mut query_params = Vec::new();
// Add from/size as number values
query_params.push(format!("from={}", from));
query_params.push(format!("size={}", size));
let mut query_params: HashMap<String, Value> = HashMap::new();
if from > 0 {
query_params.insert("from".to_string(), from.into());
}
if size > 0 {
query_params.insert("size".to_string(), size.into());
}
let path = format!("/chat/{}/_history", session_id);
@@ -66,14 +64,15 @@ pub async fn session_chat_history(
}
#[tauri::command]
pub async fn open_session_chat(
_app_handle: AppHandle,
pub async fn open_session_chat<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
session_id: String,
) -> Result<String, String> {
let query_params = HashMap::new();
let path = format!("/chat/{}/_open", session_id);
let response = HttpClient::post(&server_id, path.as_str(), None, None)
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error open session: {}", e))?;
@@ -81,30 +80,30 @@ pub async fn open_session_chat(
}
#[tauri::command]
pub async fn close_session_chat(
_app_handle: AppHandle,
pub async fn close_session_chat<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
session_id: String,
) -> Result<String, String> {
let query_params = HashMap::new();
let path = format!("/chat/{}/_close", session_id);
let response = HttpClient::post(&server_id, path.as_str(), None, None)
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error close session: {}", e))?;
common::http::get_response_body_text(response).await
}
#[tauri::command]
pub async fn cancel_session_chat(
_app_handle: AppHandle,
pub async fn cancel_session_chat<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
session_id: String,
query_params: Option<HashMap<String, Value>>,
) -> Result<String, String> {
let query_params = HashMap::new();
let path = format!("/chat/{}/_cancel", session_id);
let query_params = convert_query_params_to_strings(query_params);
let response = HttpClient::post(&server_id, path.as_str(), query_params, None)
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
@@ -112,161 +111,75 @@ pub async fn cancel_session_chat(
}
#[tauri::command]
pub async fn chat_create(
app_handle: AppHandle,
pub async fn new_chat<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
message: Option<String>,
attachments: Option<Vec<String>>,
websocket_id: String,
message: String,
query_params: Option<HashMap<String, Value>>,
client_id: String,
) -> Result<(), String> {
println!("chat_create message: {:?}", message);
println!("chat_create attachments: {:?}", attachments);
let message_empty = message.as_ref().map_or(true, |m| m.is_empty());
let attachments_empty = attachments.as_ref().map_or(true, |a| a.is_empty());
if message_empty && attachments_empty {
return Err("Message and attachments are empty".to_string());
}
let body = {
let request_message: ChatRequestMessage = ChatRequestMessage {
message,
attachments,
) -> Result<GetResponse, String> {
let body = if !message.is_empty() {
let message = ChatRequestMessage {
message: Some(message),
};
println!("chat_create body: {:?}", request_message);
Some(
serde_json::to_string(&request_message)
serde_json::to_string(&message)
.map_err(|e| format!("Failed to serialize message: {}", e))?
.into(),
)
} else {
None
};
let response = HttpClient::advanced_post(
&server_id,
"/chat/_create",
None,
convert_query_params_to_strings(query_params),
body,
)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
let mut headers = HashMap::new();
headers.insert("WEBSOCKET-SESSION-ID".to_string(), websocket_id.into());
if response.status() == 429 {
log::warn!("Rate limit exceeded for chat create");
return Err("Rate limited".to_string());
let response =
HttpClient::advanced_post(&server_id, "/chat/_new", Some(headers), query_params, body)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
let body_text = common::http::get_response_body_text(response).await?;
let chat_response: GetResponse =
serde_json::from_str(&body_text).map_err(|e| format!("Failed to parse response JSON: {}", e))?;
if chat_response.result != "created" {
return Err(format!("Unexpected result: {}", chat_response.result));
}
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
let stream = response.bytes_stream();
let reader = tokio_util::io::StreamReader::new(
stream.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
);
let mut lines = tokio::io::BufReader::new(reader).lines();
log::info!("client_id_create: {}", &client_id);
while let Ok(Some(line)) = lines.next_line().await {
log::info!("Received chat stream line: {}", &line);
if let Err(err) = app_handle.emit(&client_id, line) {
log::error!("Emit failed: {:?}", err);
let _ = app_handle.emit("chat-create-error", format!("Emit failed: {:?}", err));
}
}
Ok(())
Ok(chat_response)
}
#[tauri::command]
pub async fn chat_chat(
app_handle: AppHandle,
pub async fn send_message<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
websocket_id: String,
session_id: String,
message: Option<String>,
attachments: Option<Vec<String>>,
message: String,
query_params: Option<HashMap<String, Value>>, //search,deep_thinking
client_id: String,
) -> Result<(), String> {
println!("chat_chat message: {:?}", message);
println!("chat_chat attachments: {:?}", attachments);
let message_empty = message.as_ref().map_or(true, |m| m.is_empty());
let attachments_empty = attachments.as_ref().map_or(true, |a| a.is_empty());
if message_empty && attachments_empty {
return Err("Message and attachments are empty".to_string());
}
let body = {
let request_message = ChatRequestMessage {
message,
attachments,
};
println!("chat_chat body: {:?}", request_message);
Some(
serde_json::to_string(&request_message)
.map_err(|e| format!("Failed to serialize message: {}", e))?
.into(),
)
) -> Result<String, String> {
let path = format!("/chat/{}/_send", session_id);
let msg = ChatRequestMessage {
message: Some(message),
};
let path = format!("/chat/{}/_chat", session_id);
let mut headers = HashMap::new();
headers.insert("WEBSOCKET-SESSION-ID".to_string(), websocket_id.into());
let body = reqwest::Body::from(serde_json::to_string(&msg).unwrap());
let response = HttpClient::advanced_post(
&server_id,
path.as_str(),
None,
convert_query_params_to_strings(query_params),
body,
Some(headers),
query_params,
Some(body),
)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
if response.status() == 429 {
log::warn!("Rate limit exceeded for chat create");
return Err("Rate limited".to_string());
}
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
let stream = response.bytes_stream();
let reader = tokio_util::io::StreamReader::new(
stream.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
);
let mut lines = tokio::io::BufReader::new(reader).lines();
let mut first_log = true;
log::info!("client_id: {}", &client_id);
while let Ok(Some(line)) = lines.next_line().await {
log::info!("Received chat stream line: {}", &line);
if first_log {
log::info!("first stream line: {}", &line);
first_log = false;
}
if let Err(err) = app_handle.emit(&client_id, line) {
log::error!("Emit failed: {:?}", err);
print!("Error sending message: {:?}", err);
let _ = app_handle.emit("chat-create-error", format!("Emit failed: {:?}", err));
}
}
Ok(())
common::http::get_response_body_text(response).await
}
#[tauri::command]
@@ -306,194 +219,40 @@ pub async fn update_session_chat(
None,
Some(reqwest::Body::from(serde_json::to_string(&body).unwrap())),
)
.await
.map_err(|e| format!("Error updating session: {}", e))?;
.await
.map_err(|e| format!("Error updating session: {}", e))?;
Ok(response.status().is_success())
}
#[tauri::command]
pub async fn assistant_search(
_app_handle: AppHandle,
pub async fn assistant_search<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
query_params: Option<Vec<String>>,
from: u32,
size: u32,
query: Option<HashMap<String, Value>>,
) -> Result<Value, String> {
let response = HttpClient::post(&server_id, "/assistant/_search", query_params, None)
.await
.map_err(|e| format!("Error searching assistants: {}", e))?;
let mut body = serde_json::json!({
"from": from,
"size": size,
});
response
.json::<Value>()
.await
.map_err(|err| err.to_string())
}
if let Some(q) = query {
body["query"] = serde_json::to_value(q).map_err(|e| e.to_string())?;
}
#[tauri::command]
pub async fn assistant_get(
_app_handle: AppHandle,
server_id: String,
assistant_id: String,
) -> Result<Value, String> {
let response = HttpClient::get(
let response = HttpClient::post(
&server_id,
&format!("/assistant/{}", assistant_id),
None, // headers
)
.await
.map_err(|e| format!("Error getting assistant: {}", e))?;
response
.json::<Value>()
.await
.map_err(|err| err.to_string())
}
/// Gets the information of the assistant specified by `assistant_id` by querying **all**
/// Coco servers.
///
/// Returns as soon as the assistant is found on any Coco server.
#[tauri::command]
pub async fn assistant_get_multi(
app_handle: AppHandle,
assistant_id: String,
) -> Result<Value, String> {
let search_sources = app_handle.state::<SearchSourceRegistry>();
let sources_future = search_sources.get_sources();
let sources_list = sources_future.await;
let mut futures = FuturesUnordered::new();
for query_source in &sources_list {
let query_source_type = query_source.get_type();
if query_source_type.r#type != COCO_SERVERS {
// Assistants only exists on Coco servers.
continue;
}
let coco_server_id = query_source_type.id.clone();
let path = format!("/assistant/{}", assistant_id);
let fut = async move {
let res_response = HttpClient::get(
&coco_server_id,
&path,
None, // headers
)
.await;
match res_response {
Ok(response) => response
.json::<serde_json::Value>()
.await
.map_err(|e| e.to_string()),
Err(e) => Err(e),
}
};
futures.push(fut);
}
while let Some(res_response_json) = futures.next().await {
let response_json = match res_response_json {
Ok(json) => json,
Err(e) => return Err(e),
};
// Example response JSON
//
// When assistant is not found:
// ```json
// {
// "_id": "ID",
// "result": "not_found"
// }
// ```
//
// When assistant is found:
// ```json
// {
// "_id": "ID",
// "_source": {...}
// "found": true
// }
// ```
if let Some(found) = response_json.get("found") {
if found == true {
return Ok(response_json);
}
}
}
Err(format!(
"could not find Assistant [{}] on all the Coco servers",
assistant_id
))
}
use regex::Regex;
/// Remove all `"icon": "..."` fields from a JSON string
pub fn remove_icon_fields(json: &str) -> String {
// Regex to match `"icon": "..."` fields, including base64 or escaped strings
let re = Regex::new(r#""icon"\s*:\s*"[^"]*"(,?)"#).unwrap();
// Replace with empty string, or just remove trailing comma if needed
re.replace_all(json, |caps: &regex::Captures| {
if &caps[1] == "," {
"".to_string() // keep comma removal logic safe
} else {
"".to_string()
}
})
.to_string()
}
#[tauri::command]
pub async fn ask_ai(
app_handle: AppHandle,
message: String,
server_id: String,
assistant_id: String,
client_id: String,
) -> Result<(), String> {
let cleaned = remove_icon_fields(message.as_str());
let body = serde_json::json!({ "message": cleaned });
let path = format!("/assistant/{}/_ask", assistant_id);
println!("Sending request to {}", &path);
let response = HttpClient::send_request(
server_id.as_str(),
Method::POST,
path.as_str(),
None,
"/assistant/_search",
None,
Some(reqwest::Body::from(body.to_string())),
)
.await?;
.await
.map_err(|e| format!("Error searching assistants: {}", e))?;
if response.status() == 429 {
log::warn!("Rate limit exceeded for assistant: {}", &assistant_id);
return Ok(());
}
if !response.status().is_success() {
return Err(format!("Request Failed: {}", response.status()));
}
let stream = response.bytes_stream();
let reader = tokio_util::io::StreamReader::new(
stream.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
);
let mut lines = tokio::io::BufReader::new(reader).lines();
while let Ok(Some(line)) = lines.next_line().await {
dbg!("Received line: {}", &line);
let _ = app_handle.emit(&client_id, line).map_err(|err| {
println!("Failed to emit: {:?}", err);
});
}
Ok(())
response
.json::<Value>()
.await
.map_err(|err| err.to_string())
}

View File

@@ -1,48 +1,43 @@
use std::{fs::create_dir, io::Read};
use tauri::{AppHandle, Manager};
use tauri::{Manager, Runtime};
use tauri_plugin_autostart::ManagerExt;
/// If the state reported from the OS and the state stored by us differ, our state is
/// prioritized and seen as the correct one. Update the OS state to make them consistent.
pub fn ensure_autostart_state_consistent(tauri_app_handle: &AppHandle) -> Result<(), String> {
let autostart_manager = tauri_app_handle.autolaunch();
// Start or stop according to configuration
pub fn enable_autostart(app: &mut tauri::App) {
use tauri_plugin_autostart::MacosLauncher;
use tauri_plugin_autostart::ManagerExt;
let os_state = autostart_manager.is_enabled().map_err(|e| e.to_string())?;
let coco_stored_state = current_autostart(tauri_app_handle).map_err(|e| e.to_string())?;
app.handle()
.plugin(tauri_plugin_autostart::init(
MacosLauncher::AppleScript,
None,
))
.unwrap();
if os_state != coco_stored_state {
log::warn!(
"autostart inconsistent states, OS state [{}], Coco state [{}], config file could be deleted or corrupted",
os_state,
coco_stored_state
);
log::info!("trying to correct the inconsistent states");
let autostart_manager = app.autolaunch();
let result = if coco_stored_state {
autostart_manager.enable()
} else {
autostart_manager.disable()
};
// close autostart
// autostart_manager.disable().unwrap();
// return;
match result {
Ok(_) => {
log::info!("inconsistent autostart states fixed");
}
Err(e) => {
log::error!(
"failed to fix inconsistent autostart state due to error [{}]",
e
);
return Err(e.to_string());
}
}
match (
autostart_manager.is_enabled(),
current_autostart(app.app_handle()),
) {
(Ok(false), Ok(true)) => match autostart_manager.enable() {
Ok(_) => println!("Autostart enabled successfully."),
Err(err) => eprintln!("Failed to enable autostart: {}", err),
},
(Ok(true), Ok(false)) => match autostart_manager.disable() {
Ok(_) => println!("Autostart disable successfully."),
Err(err) => eprintln!("Failed to disable autostart: {}", err),
},
_ => (),
}
Ok(())
}
fn current_autostart(app: &tauri::AppHandle) -> Result<bool, String> {
fn current_autostart<R: Runtime>(app: &tauri::AppHandle<R>) -> Result<bool, String> {
use std::fs::File;
let path = app.path().app_config_dir().unwrap();
@@ -65,7 +60,10 @@ fn current_autostart(app: &tauri::AppHandle) -> Result<bool, String> {
}
#[tauri::command]
pub async fn change_autostart(app: tauri::AppHandle, open: bool) -> Result<(), String> {
pub async fn change_autostart<R: Runtime>(
app: tauri::AppHandle<R>,
open: bool,
) -> Result<(), String> {
use std::fs::File;
use std::io::Write;

View File

@@ -3,22 +3,19 @@ use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequestMessage {
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub attachments: Option<Vec<String>>,
}
#[allow(dead_code)]
pub struct NewChatResponse {
pub _id: String,
pub _source: Session,
pub _source: Source,
pub result: String,
pub payload: Option<Value>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Session {
pub struct Source {
pub id: String,
pub created: String,
pub updated: String,
@@ -26,11 +23,4 @@ pub struct Session {
pub title: Option<String>,
pub summary: Option<String>,
pub manually_renamed_title: bool,
pub visible: Option<bool>,
pub context: Option<SessionContext>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionContext {
pub attachments: Option<Vec<String>>,
}

View File

@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug,Clone, Serialize, Deserialize)]
pub struct Connector {
pub id: String,
pub created: Option<String>,
@@ -13,7 +13,7 @@ pub struct Connector {
pub url: Option<String>,
pub assets: Option<ConnectorAssets>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug,Clone, Serialize, Deserialize)]
pub struct ConnectorAssets {
pub icons: Option<std::collections::HashMap<String, String>>,
}
}

View File

@@ -18,4 +18,4 @@ pub struct DataSource {
pub struct ConnectorConfig {
pub id: Option<String>,
pub config: Option<serde_json::Value>, // Using serde_json::Value to handle any type of config
}
}

View File

@@ -1,12 +1,5 @@
#[cfg(target_os = "macos")]
use crate::extension::built_in::window_management::actions::Action;
use crate::extension::view_extension::serve_files_in;
use crate::extension::{ExtensionPermission, ExtensionSettings, ViewExtensionUISettings};
use log::debug;
use serde::{Deserialize, Serialize};
use serde_json::Value as Json;
use std::collections::HashMap;
use tauri::{AppHandle, Emitter};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RichLabel {
@@ -36,266 +29,6 @@ pub struct EditorInfo {
pub timestamp: Option<String>,
}
/// Defines the action that would be performed when a [document](Document) gets opened.
///
/// "Document" is a uniform type that the backend uses to send the search results
/// back to the frontend. Since Coco can search many sources, "Document" can
/// represent different things, application, web page, local file, extensions, and
/// so on. Each has its own specific open action.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum OnOpened {
/// Launch the application
Application { app_path: String },
/// Open the URL.
Document { url: String },
/// Perform this WM action.
#[cfg(target_os = "macos")]
WindowManagementAction { action: Action },
/// The document is an extension.
Extension(ExtensionOnOpened),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ExtensionOnOpened {
/// Different types of extensions have different open behaviors.
pub(crate) ty: ExtensionOnOpenedType,
/// Extensions settings. Some could affect open action.
///
/// Optional because not all extensions have their settings.
pub(crate) settings: Option<ExtensionSettings>,
/// Permission needed by this extension.
///
/// We do permission check when opening this permission. Currently, we only
/// do this to View extensions.
pub(crate) permission: Option<ExtensionPermission>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum ExtensionOnOpenedType {
/// Spawn a child process to run the `CommandAction`.
Command {
action: crate::extension::CommandAction,
},
/// Open the `link`.
//
// NOTE that this variant has the same definition as `struct Quicklink`, but we
// cannot use it directly, its `link` field should be deserialized/serialized
// from/to a string, but we need a JSON object here.
//
// See also the comments in `struct Quicklink`.
Quicklink {
link: crate::extension::QuicklinkLink,
open_with: Option<String>,
},
View {
/// Extension name
name: String,
// An absolute path to the extension icon or a font code.
icon: String,
/// Path to the HTML file that coco will load and render.
///
/// It should be an absolute path or Tauri cannot open it.
page: String,
ui: Option<ViewExtensionUISettings>,
},
}
impl OnOpened {
pub(crate) fn url(&self) -> String {
match self {
Self::Application { app_path } => app_path.clone(),
Self::Document { url } => url.clone(),
#[cfg(target_os = "macos")]
Self::WindowManagementAction { action: _ } => {
// We don't have URL for this
String::from("N/A")
}
Self::Extension(ext_on_opened) => {
match &ext_on_opened.ty {
ExtensionOnOpenedType::Command { action } => {
const WHITESPACE: &str = " ";
let mut ret = action.exec.clone();
ret.push_str(WHITESPACE);
if let Some(ref args) = action.args {
ret.push_str(args.join(WHITESPACE).as_str());
}
ret
}
// Currently, our URL is static and does not support dynamic parameters.
// The URL of a quicklink is nearly useless without such dynamic user
// inputs, so until we have dynamic URL support, we just use "N/A".
ExtensionOnOpenedType::Quicklink { .. } => String::from("N/A"),
ExtensionOnOpenedType::View {
name: _,
icon: _,
page: _,
ui: _,
} => {
// We currently don't have URL for this kind of extension.
String::from("N/A")
}
}
}
}
}
}
#[tauri::command]
pub(crate) async fn open(
tauri_app_handle: AppHandle,
on_opened: OnOpened,
extra_args: Option<HashMap<String, Json>>,
) -> Result<(), String> {
use crate::util::open as homemade_tauri_shell_open;
use std::process::Command;
match on_opened {
OnOpened::Application { app_path } => {
log::debug!("open application [{}]", app_path);
homemade_tauri_shell_open(tauri_app_handle.clone(), app_path).await?
}
OnOpened::Document { url } => {
log::debug!("open document [{}]", url);
homemade_tauri_shell_open(tauri_app_handle.clone(), url).await?
}
#[cfg(target_os = "macos")]
OnOpened::WindowManagementAction { action } => {
log::debug!("perform Window Management action [{:?}]", action);
crate::extension::built_in::window_management::perform_action_on_main_thread(
&tauri_app_handle,
action,
)?;
}
OnOpened::Extension(ext_on_opened) => {
// Apply the settings that would affect open behavior
if let Some(settings) = ext_on_opened.settings {
if let Some(should_hide) = settings.hide_before_open {
if should_hide {
crate::hide_coco(tauri_app_handle.clone()).await;
}
}
}
let permission = ext_on_opened.permission;
match ext_on_opened.ty {
ExtensionOnOpenedType::Command { action } => {
log::debug!("open (execute) command [{:?}]", action);
let mut cmd = Command::new(action.exec);
if let Some(args) = action.args {
cmd.args(args);
}
let output = cmd.output().map_err(|e| e.to_string())?;
// Sometimes, we wanna see the result in logs even though it doesn't fail.
log::debug!(
"executing open(Command) result, exit code: [{}], stdout: [{}], stderr: [{}]",
output.status,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
if !output.status.success() {
log::warn!(
"executing open(Command) failed, exit code: [{}], stdout: [{}], stderr: [{}]",
output.status,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
return Err(format!(
"Command failed, stderr [{}]",
String::from_utf8_lossy(&output.stderr)
));
}
}
ExtensionOnOpenedType::Quicklink {
link,
open_with: opt_open_with,
} => {
let url = link.concatenate_url(&extra_args);
log::debug!("open quicklink [{}] with [{:?}]", url, opt_open_with);
cfg_if::cfg_if! {
// The `open_with` functionality is only supported on macOS, provided
// by the `open -a` command.
if #[cfg(target_os = "macos")] {
let mut cmd = Command::new("open");
if let Some(ref open_with) = opt_open_with {
cmd.arg("-a");
cmd.arg(open_with.as_str());
}
cmd.arg(&url);
let output = cmd.output().map_err(|e| format!("failed to spawn [open] due to error [{}]", e))?;
if !output.status.success() {
return Err(format!(
"failed to open with app {:?}: {}",
opt_open_with,
String::from_utf8_lossy(&output.stderr)
));
}
} else {
homemade_tauri_shell_open(tauri_app_handle.clone(), url).await?
}
}
}
ExtensionOnOpenedType::View {
name,
icon,
page,
ui,
} => {
let page_path = Utf8Path::new(&page);
let directory = page_path.parent().unwrap_or_else(|| {
panic!("View extension page path should have a parent, i.e., it should be under a directory, but [{}] does not", page);
});
let mut url = serve_files_in(directory.as_ref()).await;
/*
* Emit an event to let the frontend code open this extension.
*
* Payload `view_extension_opened` contains the information needed
* to do that.
*
* See "src/pages/main/index.tsx" for more info.
*/
use camino::Utf8Path;
use serde_json::Value as Json;
use serde_json::to_value;
let html_filename = page_path
.file_name()
.unwrap_or_else(|| {
panic!("View extension page path should have a file name, but [{}] does not have one", page);
}).to_string();
url.push('/');
url.push_str(&html_filename);
let html_file_url = url;
debug!("View extension listening on: {}", html_file_url);
let view_extension_opened: [Json; 5] = [
Json::String(name),
Json::String(icon),
Json::String(html_file_url),
to_value(permission).unwrap(),
to_value(ui).unwrap(),
];
tauri_app_handle
.emit("open_view_extension", view_extension_opened)
.unwrap();
}
}
}
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Document {
pub id: String,
@@ -315,8 +48,6 @@ pub struct Document {
pub thumbnail: Option<String>,
pub cover: Option<String>,
pub tags: Option<Vec<String>>,
/// What will happen if we open this document.
pub on_opened: Option<OnOpened>,
pub url: Option<String>,
pub size: Option<i64>,
pub metadata: Option<HashMap<String, serde_json::Value>>,

View File

@@ -1,67 +1,34 @@
use reqwest::StatusCode;
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use thiserror::Error;
fn serialize_optional_status_code<S>(
status_code: &Option<StatusCode>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match status_code {
Some(code) => serializer.serialize_str(&format!("{:?}", code)),
None => serializer.serialize_none(),
}
}
#[allow(unused)]
#[derive(Debug, Deserialize)]
pub struct ErrorCause {
#[serde(default)]
pub r#type: Option<String>,
#[serde(default)]
pub reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(unused)]
pub struct ErrorDetail {
#[serde(default)]
pub root_cause: Option<Vec<ErrorCause>>,
#[serde(default)]
pub r#type: Option<String>,
#[serde(default)]
pub reason: Option<String>,
#[serde(default)]
pub caused_by: Option<ErrorCause>,
pub reason: String,
pub status: u16,
}
#[derive(Debug, Deserialize)]
pub struct ErrorResponse {
#[serde(default)]
pub error: Option<ErrorDetail>,
#[serde(default)]
#[allow(unused)]
pub status: Option<u16>,
pub error: ErrorDetail,
}
#[derive(Debug, Error, Serialize)]
pub enum SearchError {
#[error("HttpError: status code [{status_code:?}], msg [{msg}]")]
HttpError {
#[serde(serialize_with = "serialize_optional_status_code")]
status_code: Option<StatusCode>,
msg: String,
},
#[error("HTTP request failed: {0}")]
HttpError(String),
#[error("ParseError: {0}")]
#[error("Invalid response format: {0}")]
ParseError(String),
#[error("Timeout occurred")]
Timeout,
#[error("InternalError: {0}")]
#[error("Unknown error: {0}")]
#[allow(dead_code)]
Unknown(String),
#[error("InternalError error: {0}")]
#[allow(dead_code)]
InternalError(String),
}
@@ -72,10 +39,7 @@ impl From<reqwest::Error> for SearchError {
} else if err.is_decode() {
SearchError::ParseError(err.to_string())
} else {
SearchError::HttpError {
status_code: err.status(),
msg: err.to_string(),
}
SearchError::HttpError(err.to_string())
}
}
}
}

View File

@@ -2,8 +2,6 @@ use crate::common;
use reqwest::Response;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use tauri_plugin_store::JsonValue;
#[derive(Debug, Serialize, Deserialize)]
pub struct GetResponse {
@@ -42,34 +40,13 @@ pub async fn get_response_body_text(response: Response) -> Result<String, String
Ok(parsed_error) => {
dbg!(&parsed_error);
Err(format!(
"Server error ({}): {:?}",
status, parsed_error.error
"Server error ({}): {}",
parsed_error.error.status, parsed_error.error.reason
))
}
Err(_) => {
log::warn!("Failed to parse error response: {}", &body);
Err(fallback_error)
}
Err(_) => Err(fallback_error),
}
} else {
Ok(body)
}
}
pub fn convert_query_params_to_strings(
query_params: Option<HashMap<String, JsonValue>>,
) -> Option<Vec<String>> {
query_params.map(|map| {
map.into_iter()
.filter_map(|(k, v)| match v {
JsonValue::String(s) => Some(format!("{}={}", k, s)),
JsonValue::Number(n) => Some(format!("{}={}", k, n)),
JsonValue::Bool(b) => Some(format!("{}={}", k, b)),
_ => {
eprintln!("Skipping unsupported query value for key '{}': {:?}", k, v);
None
}
})
.collect()
})
}
}

View File

@@ -1,17 +1,16 @@
pub mod assistant;
pub mod auth;
pub mod connector;
pub mod datasource;
pub mod document;
pub mod error;
pub mod health;
pub mod http;
pub mod profile;
pub mod register;
pub mod search;
pub mod server;
pub mod auth;
pub mod datasource;
pub mod connector;
pub mod search;
pub mod document;
pub mod traits;
pub mod register;
pub mod assistant;
pub mod http;
pub mod error;
pub static MAIN_WINDOW_LABEL: &str = "main";
pub static SETTINGS_WINDOW_LABEL: &str = "settings";
pub static CHECK_WINDOW_LABEL: &str = "check";

View File

@@ -13,4 +13,4 @@ pub struct UserProfile {
pub email: String,
pub avatar: Option<String>,
pub preferences: Option<Preferences>,
}
}

View File

@@ -7,8 +7,8 @@ use std::error::Error;
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchResponse<T> {
pub took: Option<u64>,
pub timed_out: Option<bool>,
pub took: u64,
pub timed_out: bool,
pub _shards: Option<Shards>,
pub hits: Hits<T>,
}
@@ -25,7 +25,7 @@ pub struct Shards {
pub struct Hits<T> {
pub total: Total,
pub max_score: Option<f32>,
pub hits: Option<Vec<SearchHit<T>>>,
pub hits: Vec<SearchHit<T>>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -36,9 +36,9 @@ pub struct Total {
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchHit<T> {
pub _index: Option<String>,
pub _type: Option<String>,
pub _id: Option<String>,
pub _index: String,
pub _type: String,
pub _id: String,
pub _score: Option<f64>,
pub _source: T, // This will hold the type we pass in (e.g., DataSource)
}
@@ -58,18 +58,13 @@ where
Ok(search_response)
}
use serde::de::DeserializeOwned;
pub async fn parse_search_hits<T>(response: Response) -> Result<Vec<SearchHit<T>>, Box<dyn Error>>
where
T: DeserializeOwned + std::fmt::Debug,
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
let response = parse_search_response(response).await?;
match response.hits.hits {
Some(hits) => Ok(hits),
None => Ok(Vec::new()),
}
Ok(response.hits.hits)
}
pub async fn parse_search_results<T>(response: Response) -> Result<Vec<T>, Box<dyn Error>>
@@ -83,6 +78,20 @@ where
.collect())
}
#[allow(dead_code)]
pub async fn parse_search_results_with_score<T>(
response: Response,
) -> Result<Vec<(T, Option<f64>)>, Box<dyn Error>>
where
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
Ok(parse_search_hits(response)
.await?
.into_iter()
.map(|hit| (hit._source, hit._score))
.collect())
}
#[derive(Debug, Clone, Serialize)]
pub struct SearchQuery {
pub from: u64,
@@ -100,7 +109,7 @@ impl SearchQuery {
}
}
#[derive(Debug, Clone, Serialize, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Serialize)]
pub struct QuerySource {
pub r#type: String, //coco-server/local/ etc.
pub id: String, //coco server's id

View File

@@ -1,8 +1,6 @@
use crate::common::health::Health;
use crate::common::profile::UserProfile;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -50,17 +48,9 @@ pub struct Server {
pub updated: String,
#[serde(default = "default_enabled_type")]
pub enabled: bool,
/// Public Coco servers can be used without signing in.
#[serde(default = "default_bool_type")]
pub public: bool,
/// A coco server is available if:
///
/// 1. It is still online, we check this via the `GET /base_url/provider/_info`
/// interface.
/// 2. A user is logged in to this Coco server, i.e., a token is stored in the
/// `SERVER_TOKEN_LIST_CACHE`.
/// For public Coco servers, requirement 2 is not needed.
#[serde(default = "default_available_type")]
pub available: bool,
@@ -70,7 +60,6 @@ pub struct Server {
pub auth_provider: AuthProvider,
#[serde(default = "default_priority_type")]
pub priority: u32,
pub stats: Option<HashMap<String, Value>>,
}
impl PartialEq for Server {
@@ -92,10 +81,7 @@ pub struct ServerAccessToken {
#[serde(default = "default_empty_string")] // Custom default function for empty string
pub id: String,
pub access_token: String,
/// Unix timestamp in seconds
///
/// Currently, this is UNUSED.
pub expired_at: u32,
pub expired_at: u32, //unix timestamp in seconds
}
impl ServerAccessToken {

View File

@@ -1,16 +1,13 @@
use crate::common::error::SearchError;
// use std::{future::Future, pin::Pin};
use crate::common::search::SearchQuery;
use crate::common::search::{QueryResponse, QuerySource};
use async_trait::async_trait;
use tauri::AppHandle;
#[async_trait]
pub trait SearchSource: Send + Sync {
fn get_type(&self) -> QuerySource;
async fn search(
&self,
tauri_app_handle: AppHandle,
query: SearchQuery,
) -> Result<QueryResponse, SearchError>;
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError>;
}

View File

@@ -1,5 +0,0 @@
# Complete Coco extension API list grouped by its category.
fs = [
"read_dir"
]

View File

@@ -1,22 +0,0 @@
//! File system APIs
use tokio::fs::read_dir as tokio_read_dir;
#[tauri::command]
pub(crate) async fn read_dir(path: String) -> Result<Vec<String>, String> {
let mut iter = tokio_read_dir(path).await.map_err(|e| e.to_string())?;
let mut file_names = Vec::new();
loop {
let opt_entry = iter.next_entry().await.map_err(|e| e.to_string())?;
let Some(entry) = opt_entry else {
break;
};
let file_name = entry.file_name().to_string_lossy().into_owned();
file_names.push(file_name);
}
Ok(file_names)
}

View File

@@ -1,21 +0,0 @@
//! The Rust implementation of the Coco extension APIs.
//!
//! Extension developers do not use these Rust APIs directly, they use our
//! [Typescript library][ts_lib], which eventually calls these APIs.
//!
//! [ts_lib]: https://github.com/infinilabs/coco-api
pub(crate) mod fs;
use std::collections::HashMap;
/// Return all the available APIs grouped by their category.
#[tauri::command]
pub(crate) fn apis() -> HashMap<String, Vec<String>> {
static APIS_TOML: &str = include_str!("./apis.toml");
let apis: HashMap<String, Vec<String>> =
toml::from_str(APIS_TOML).expect("Failed to parse apis.toml file");
apis
}

View File

@@ -1,13 +0,0 @@
pub(super) const EXTENSION_ID: &str = "AIOverview";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "AIOverview",
"name": "AI Overview",
"description": "...",
"icon": "font_a-AIOverview",
"type": "ai_extension",
"enabled": true
}
"#;

File diff suppressed because it is too large Load Diff

View File

@@ -1,219 +0,0 @@
use super::super::LOCAL_QUERY_SOURCE_TYPE;
use crate::common::{
document::{DataSourceReference, Document},
error::SearchError,
search::{QueryResponse, QuerySource, SearchQuery},
traits::SearchSource,
};
use async_trait::async_trait;
use chinese_number::{ChineseCase, ChineseCountMethod, ChineseVariant, NumberToChinese};
use num2words::Num2Words;
use serde_json::Value;
use std::collections::HashMap;
use tauri::AppHandle;
pub(crate) const DATA_SOURCE_ID: &str = "Calculator";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "Calculator",
"name": "Calculator",
"platforms": ["macos", "linux", "windows"],
"description": "...",
"icon": "font_Calculator",
"type": "calculator",
"enabled": true
}
"#;
pub struct CalculatorSource {
base_score: f64,
}
impl CalculatorSource {
pub fn new(base_score: f64) -> Self {
CalculatorSource { base_score }
}
}
fn parse_query(query: &str) -> Value {
let mut query_json = serde_json::Map::new();
let operators = ["+", "-", "*", "/", "%"];
let found_operators: Vec<_> = query
.chars()
.filter(|c| operators.contains(&c.to_string().as_str()))
.collect();
if found_operators.len() == 1 {
let operation = match found_operators[0] {
'+' => "sum",
'-' => "subtract",
'*' => "multiply",
'/' => "divide",
'%' => "remainder",
_ => "expression",
};
query_json.insert("type".to_string(), Value::String(operation.to_string()));
} else {
query_json.insert("type".to_string(), Value::String("expression".to_string()));
}
query_json.insert("value".to_string(), Value::String(query.to_string()));
Value::Object(query_json)
}
fn parse_result(num: f64) -> Value {
let mut result_json = serde_json::Map::new();
let to_zh = num
.to_chinese(
ChineseVariant::Simple,
ChineseCase::Upper,
ChineseCountMethod::TenThousand,
)
.unwrap_or(num.to_string());
let to_en = Num2Words::new(num)
.to_words()
.map(|s| {
let mut chars = s.chars();
let mut result = String::new();
let mut capitalize = true;
while let Some(c) = chars.next() {
if c == ' ' || c == '-' {
result.push(c);
capitalize = true;
} else if capitalize {
result.extend(c.to_uppercase());
capitalize = false;
} else {
result.push(c);
}
}
result
})
.unwrap_or(num.to_string());
result_json.insert("value".to_string(), Value::String(num.to_string()));
result_json.insert("toZh".to_string(), Value::String(to_zh));
result_json.insert("toEn".to_string(), Value::String(to_en));
Value::Object(result_json)
}
#[async_trait]
impl SearchSource for CalculatorSource {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(DATA_SOURCE_ID.into())
.to_string_lossy()
.into(),
id: DATA_SOURCE_ID.into(),
}
}
async fn search(
&self,
_tauri_app_handle: AppHandle,
query: SearchQuery,
) -> Result<QueryResponse, SearchError> {
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
// Trim the leading and tailing whitespace so that our later if condition
// will only be evaluated against non-whitespace characters.
let query_string = query_string.trim();
if query_string.is_empty() {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
let query_string_clone = query_string.to_string();
let query_source = self.get_type();
let base_score = self.base_score;
let closure = move || -> QueryResponse {
let Ok(tokens) = meval::tokenizer::tokenize(&query_string_clone) else {
// Invalid expression, return nothing.
return QueryResponse {
source: query_source,
hits: Vec::new(),
total_hits: 0,
};
};
// If it is only a number, no need to evaluate it as the result is
// this number.
// Actually, there is no need to return the result back to the users
// in such case because letting them know "x = x" is meaningless.
if tokens.len() == 1 && matches!(tokens[0], meval::tokenizer::Token::Number(_)) {
return QueryResponse {
source: query_source,
hits: Vec::new(),
total_hits: 0,
};
}
let res_num = meval::eval_str(&query_string_clone);
match res_num {
Ok(num) => {
let mut payload: HashMap<String, Value> = HashMap::new();
let payload_query = parse_query(&query_string_clone);
let payload_result = parse_result(num);
payload.insert("query".to_string(), payload_query);
payload.insert("result".to_string(), payload_result);
let doc = Document {
id: DATA_SOURCE_ID.to_string(),
category: Some(DATA_SOURCE_ID.to_string()),
payload: Some(payload),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(DATA_SOURCE_ID.into()),
id: Some(DATA_SOURCE_ID.into()),
icon: Some(String::from("font_Calculator")),
}),
..Default::default()
};
QueryResponse {
source: query_source,
hits: vec![(doc, base_score)],
total_hits: 1,
}
}
Err(_) => QueryResponse {
source: query_source,
hits: Vec::new(),
total_hits: 0,
},
}
};
let spawn_result = tokio::task::spawn_blocking(closure).await;
match spawn_result {
Ok(response) => Ok(response),
Err(e) => std::panic::resume_unwind(e.into_panic()),
}
}
}

View File

@@ -1,216 +0,0 @@
//! File Search configuration entries definition and getter/setter functions.
use crate::extension::built_in::file_search::implementation::apply_config;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::sync::LazyLock;
use tauri::AppHandle;
use tauri_plugin_store::StoreExt;
// Tauri store keys for file system configuration
const TAURI_STORE_FILE_SYSTEM_CONFIG: &str = "file_system_config";
const TAURI_STORE_KEY_SEARCH_BY: &str = "search_by";
const TAURI_STORE_KEY_SEARCH_PATHS: &str = "search_paths";
const TAURI_STORE_KEY_EXCLUDE_PATHS: &str = "exclude_paths";
const TAURI_STORE_KEY_FILE_TYPES: &str = "file_types";
static HOME_DIR: LazyLock<String> = LazyLock::new(|| {
let os_string = dirs::home_dir()
.expect("$HOME should be set")
.into_os_string();
os_string
.into_string()
.expect("User home directory should be encoded with UTF-8")
});
#[derive(Debug, Clone, Serialize, Deserialize, Copy, PartialEq)]
pub enum SearchBy {
Name,
NameAndContents,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileSearchConfig {
pub search_paths: Vec<String>,
pub exclude_paths: Vec<String>,
pub file_types: Vec<String>,
pub search_by: SearchBy,
}
impl Default for FileSearchConfig {
fn default() -> Self {
Self {
search_paths: vec![
format!("{}/Documents", HOME_DIR.as_str()),
format!("{}/Desktop", HOME_DIR.as_str()),
format!("{}/Downloads", HOME_DIR.as_str()),
],
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
}
}
}
impl FileSearchConfig {
pub(crate) fn get(tauri_app_handle: &AppHandle) -> Self {
let store = tauri_app_handle
.store(TAURI_STORE_FILE_SYSTEM_CONFIG)
.unwrap_or_else(|e| {
panic!(
"store [{}] not found/loaded, error [{}]",
TAURI_STORE_FILE_SYSTEM_CONFIG, e
)
});
// Default value, will be used when specific config entries are not set
let default_config = FileSearchConfig::default();
let search_paths = {
if let Some(search_paths) = store.get(TAURI_STORE_KEY_SEARCH_PATHS) {
match search_paths {
Value::Array(arr) => {
let mut vec = Vec::with_capacity(arr.len());
for v in arr {
match v {
Value::String(s) => vec.push(s),
other => panic!(
"Expected all elements of 'search_paths' to be strings, but found: {:?}",
other
),
}
}
vec
}
other => panic!(
"Expected 'search_paths' to be an array of strings in the file system config store, but got: {:?}",
other
),
}
} else {
store.set(
TAURI_STORE_KEY_SEARCH_PATHS,
default_config.search_paths.as_slice(),
);
default_config.search_paths
}
};
let exclude_paths = {
if let Some(exclude_paths) = store.get(TAURI_STORE_KEY_EXCLUDE_PATHS) {
match exclude_paths {
Value::Array(arr) => {
let mut vec = Vec::with_capacity(arr.len());
for v in arr {
match v {
Value::String(s) => vec.push(s),
other => panic!(
"Expected all elements of 'exclude_paths' to be strings, but found: {:?}",
other
),
}
}
vec
}
other => panic!(
"Expected 'exclude_paths' to be an array of strings in the file system config store, but got: {:?}",
other
),
}
} else {
store.set(
TAURI_STORE_KEY_EXCLUDE_PATHS,
default_config.exclude_paths.as_slice(),
);
default_config.exclude_paths
}
};
let file_types = {
if let Some(file_types) = store.get(TAURI_STORE_KEY_FILE_TYPES) {
match file_types {
Value::Array(arr) => {
let mut vec = Vec::with_capacity(arr.len());
for v in arr {
match v {
Value::String(s) => vec.push(s),
other => panic!(
"Expected all elements of 'file_types' to be strings, but found: {:?}",
other
),
}
}
vec
}
other => panic!(
"Expected 'file_types' to be an array of strings in the file system config store, but got: {:?}",
other
),
}
} else {
store.set(
TAURI_STORE_KEY_FILE_TYPES,
default_config.file_types.as_slice(),
);
default_config.file_types
}
};
let search_by = {
if let Some(search_by) = store.get(TAURI_STORE_KEY_SEARCH_BY) {
serde_json::from_value(search_by.clone()).unwrap_or_else(|e| {
panic!(
"Failed to deserialize 'search_by' from file system config store. Invalid JSON: {:?}, error: {}",
search_by, e
)
})
} else {
store.set(
TAURI_STORE_KEY_SEARCH_BY,
serde_json::to_value(default_config.search_by).unwrap(),
);
default_config.search_by
}
};
Self {
search_by,
search_paths,
exclude_paths,
file_types,
}
}
}
// Tauri commands for managing file system configuration
#[tauri::command]
pub async fn get_file_system_config(tauri_app_handle: AppHandle) -> FileSearchConfig {
FileSearchConfig::get(&tauri_app_handle)
}
#[tauri::command]
pub async fn set_file_system_config(
tauri_app_handle: AppHandle,
config: FileSearchConfig,
) -> Result<(), String> {
let store = tauri_app_handle
.store(TAURI_STORE_FILE_SYSTEM_CONFIG)
.map_err(|e| e.to_string())?;
store.set(TAURI_STORE_KEY_SEARCH_PATHS, config.search_paths.as_slice());
store.set(
TAURI_STORE_KEY_EXCLUDE_PATHS,
config.exclude_paths.as_slice(),
);
store.set(TAURI_STORE_KEY_FILE_TYPES, config.file_types.as_slice());
store.set(
TAURI_STORE_KEY_SEARCH_BY,
serde_json::to_value(config.search_by).unwrap(),
);
// Apply the config when we know that this set operation won't fail
apply_config(&config)?;
Ok(())
}

View File

@@ -1,388 +0,0 @@
//! File system powered by GNOME's Tracker engine.
use super::super::super::EXTENSION_ID;
use super::super::super::config::FileSearchConfig;
use super::super::should_be_filtered_out;
use crate::common::document::DataSourceReference;
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::util::file::sync_get_file_icon;
use crate::{
common::document::{Document, OnOpened},
extension::built_in::file_search::config::SearchBy,
};
use camino::Utf8Path;
use gio::Cancellable;
use gio::Settings;
use gio::prelude::SettingsExtManual;
use glib::GString;
use glib::collections::strv::StrV;
use tracker::{SparqlConnection, SparqlCursor, prelude::SparqlCursorExtManual};
/// The service that we will connect to.
const SERVICE_NAME: &str = "org.freedesktop.Tracker3.Miner.Files";
/// Tracker won't return scores when we are not using full-text seach. In that
/// case, we use this score.
const SCORE: f64 = 1.0;
/// Helper function to return different SPARQL queries depending on the different configurations.
fn query_sparql(query_string: &str, config: &FileSearchConfig) -> String {
match config.search_by {
SearchBy::Name => {
// Cannot use the inverted index as that searches for all the attributes,
// but we only want to search the filename.
format!(
"SELECT nie:url(?file_item) WHERE {{ ?file_item nfo:fileName ?fileName . FILTER(regex(?fileName, '{query_string}', 'i')) }}"
)
}
SearchBy::NameAndContents => {
// Full-text search against all attributes
// OR
// filename search
format!(
"SELECT nie:url(?file_item) fts:rank(?file_item) WHERE {{ {{ ?file_item fts:match '{query_string}' }} UNION {{ ?file_item nfo:fileName ?fileName . FILTER(regex(?fileName, '{query_string}', 'i')) }} }} ORDER BY DESC fts:rank(?file_item)"
)
}
}
}
/// Helper function to replace unsupported characters with whitespace.
///
/// Tracker will error out if it encounters these characters.
///
/// The complete list of unsupported characters is unknown and we don't know how
/// to escape them, so let's replace them.
fn query_string_cleanup(old: &str) -> String {
const UNSUPPORTED_CHAR: [char; 3] = ['\'', '\n', '\\'];
// Using len in bytes is ok
let mut chars = Vec::with_capacity(old.len());
for char in old.chars() {
if UNSUPPORTED_CHAR.contains(&char) {
chars.push(' ');
} else {
chars.push(char);
}
}
chars.into_iter().collect()
}
struct Query {
conn: SparqlConnection,
cursor: SparqlCursor,
}
impl Query {
fn new(query_string: &str, config: &FileSearchConfig) -> Result<Self, String> {
let query_string = query_string_cleanup(query_string);
let sparql = query_sparql(&query_string, config);
let conn =
SparqlConnection::bus_new(SERVICE_NAME, None, None).map_err(|e| e.to_string())?;
let cursor = conn
.query(&sparql, Cancellable::NONE)
.map_err(|e| e.to_string())?;
Ok(Self { conn, cursor })
}
}
impl Drop for Query {
fn drop(&mut self) {
self.cursor.close();
self.conn.close();
}
}
impl Iterator for Query {
/// It yields a tuple `(file path, score)`
type Item = Result<(String, f64), String>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let has_next = match self
.cursor
.next(Cancellable::NONE)
.map_err(|e| e.to_string())
{
Ok(has_next) => has_next,
Err(err_str) => return Some(Err(err_str)),
};
if !has_next {
return None;
}
// The first column is the URL
let file_url_column = self.cursor.string(0);
// It could be None (or NULL ptr if you use C), I have no clue why.
let opt_str = file_url_column.as_ref().map(|gstr| gstr.as_str());
match opt_str {
Some(url) => {
// The returned URL has a prefix that we need to trim
const PREFIX: &str = "file://";
const PREFIX_LEN: usize = PREFIX.len();
let file_path = url[PREFIX_LEN..].to_string();
assert!(!file_path.is_empty());
assert_ne!(file_path, "/", "file search should not hit the root path");
let score = {
// The second column is the score, this column may not
// exist. We use SCORE if the real value is absent.
let score_column = self.cursor.string(1);
let opt_score_str = score_column.as_ref().map(|g_str| g_str.as_str());
let opt_score = opt_score_str.map(|str| {
str.parse::<f64>()
.expect("score should be valid for type f64")
});
opt_score.unwrap_or(SCORE)
};
return Some(Ok((file_path, score)));
}
None => {
// another try
continue;
}
}
}
}
}
pub(crate) async fn hits(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
// Special cases that will make querying faster.
if query_string.is_empty() || size == 0 || config.search_paths.is_empty() {
return Ok(Vec::new());
}
let mut result_hits = Vec::with_capacity(size);
let need_to_skip = {
if matches!(config.search_by, SearchBy::Name) {
// We don't use full-text search in this case, the returned documents
// won't be scored, the query hits won't be sorted, so processing the
// from parameter is meaningless.
false
} else {
from > 0
}
};
let mut num_skipped = 0;
let should_skip = from;
let query = Query::new(query_string, config)?;
for res_entry in query {
let (file_path, score) = res_entry?;
// This should be called before processing the `from` parameter.
if should_be_filtered_out(config, &file_path, true, true, true) {
continue;
}
// Process the `from` parameter.
if need_to_skip && num_skipped < should_skip {
// Skip this
num_skipped += 1;
continue;
}
let icon = sync_get_file_icon(&file_path);
let file_path_of_type_path = camino::Utf8Path::new(&file_path);
let r#where = file_path_of_type_path
.parent()
.unwrap_or_else(|| {
panic!(
"expect path [{}] to have a parent, but it does not",
file_path
);
})
.to_string();
let file_name = file_path_of_type_path.file_name().unwrap_or_else(|| {
panic!(
"expect path [{}] to have a file name, but it does not",
file_path
);
});
let on_opened = OnOpened::Document {
url: file_path.to_string(),
};
let doc = Document {
id: file_path.to_string(),
title: Some(file_name.to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(EXTENSION_ID.into()),
id: Some(EXTENSION_ID.into()),
icon: Some(String::from("font_Filesearch")),
}),
category: Some(r#where),
on_opened: Some(on_opened),
url: Some(file_path),
icon: Some(icon.to_string()),
..Default::default()
};
result_hits.push((doc, score));
// Collected enough documents, return
if result_hits.len() >= size {
break;
}
}
Ok(result_hits)
}
fn ensure_path_in_recursive_indexing_scope(list: &mut StrV, path: &str) {
for item in list.iter() {
let item_path = Utf8Path::new(item.as_str());
let path = Utf8Path::new(path);
// It is already covered or listed
if path.starts_with(item_path) {
return;
}
}
list.push(
GString::from_utf8_checked(path.as_bytes().to_vec())
.expect("search_path_str contains an interior NUL"),
);
}
fn ensure_path_and_descendants_not_in_single_indexing_scope(list: &mut StrV, path: &str) {
// Indexes to the items that should be removed
let mut item_to_remove = Vec::new();
for (idx, item) in list.iter().enumerate() {
let item_path = Utf8Path::new(item.as_str());
let path = Utf8Path::new(path);
if item_path.starts_with(path) {
item_to_remove.push(idx);
}
}
// Reverse the indexes so that the remove operation won't invalidate them.
for idx in item_to_remove.into_iter().rev() {
list.remove(idx);
}
}
pub(crate) fn apply_config(config: &FileSearchConfig) -> Result<(), String> {
// Tracker provides the following configuration entries to allow users to
// tweak the indexing scope:
//
// 1. ignored-directories: A list of names, directories with such names will be ignored.
// ['po', 'CVS', 'core-dumps', 'lost+found']
// 2. ignored-directories-with-content: Avoid any directory containing a file blocklisted here
// ['.trackerignore', '.git', '.hg', '.nomedia']
// 3. ignored-files: List of file patterns to avoid
// ['*~', '*.o', '*.la', '*.lo', '*.loT', '*.in', '*.m4', '*.rej', ...]
// 4. index-recursive-directories: List of directories to index recursively
// ['&DESKTOP', '&DOCUMENTS', '&MUSIC', '&PICTURES', '&VIDEOS']
// 5. index-single-directories: List of directories to index without inspecting subfolders,
// ['$HOME', '&DOWNLOAD']
//
// The first 3 entries specify patterns, in order to use them, we have to walk
// through the whole directory tree listed in search paths, which is impractical.
// So we only use the last 2 entries.
//
//
// Just want to mention that setting search path to "/home" could break Tracker:
//
// ```text
// Unknown target graph for uri:'file:///home' and mime:'inode/directory'
// ```
//
// See the related bug reports:
//
// https://gitlab.gnome.org/GNOME/localsearch/-/issues/313
// https://bugs.launchpad.net/bugs/2077181
//
//
// There is nothing we can do.
const TRACKER_SETTINGS_SCHEMA: &str = "org.freedesktop.Tracker3.Miner.Files";
const KEY_INDEX_RECURSIVE_DIRECTORIES: &str = "index-recursive-directories";
const KEY_INDEX_SINGLE_DIRECTORIES: &str = "index-single-directories";
let search_paths = &config.search_paths;
let settings = Settings::new(TRACKER_SETTINGS_SCHEMA);
let mut recursive_list: StrV = settings.strv(KEY_INDEX_RECURSIVE_DIRECTORIES);
let mut single_list: StrV = settings.strv(KEY_INDEX_SINGLE_DIRECTORIES);
for search_path in search_paths {
// We want our search path to be included in the recursive directories or
// any directory within the list covers it.
ensure_path_in_recursive_indexing_scope(&mut recursive_list, search_path);
// We want our search path and its any descendants are not listed in
// the index directories list.
ensure_path_and_descendants_not_in_single_indexing_scope(&mut single_list, search_path);
}
settings
.set_strv(KEY_INDEX_RECURSIVE_DIRECTORIES, recursive_list)
.expect("key is not read-only");
settings
.set_strv(KEY_INDEX_SINGLE_DIRECTORIES, single_list)
.expect("key is not be read-only");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_string_cleanup_basic() {
assert_eq!(query_string_cleanup("test"), "test");
assert_eq!(query_string_cleanup("hello world"), "hello world");
assert_eq!(query_string_cleanup("file.txt"), "file.txt");
}
#[test]
fn test_query_string_cleanup_unsupported_chars() {
assert_eq!(query_string_cleanup("test'file"), "test file");
assert_eq!(query_string_cleanup("test\nfile"), "test file");
assert_eq!(query_string_cleanup("test\\file"), "test file");
}
#[test]
fn test_query_string_cleanup_multiple_unsupported() {
assert_eq!(query_string_cleanup("test'file\nname"), "test file name");
assert_eq!(query_string_cleanup("test\'file"), "test file");
assert_eq!(query_string_cleanup("\n'test"), " test");
}
#[test]
fn test_query_string_cleanup_edge_cases() {
assert_eq!(query_string_cleanup(""), "");
assert_eq!(query_string_cleanup("'"), " ");
assert_eq!(query_string_cleanup("\n"), " ");
assert_eq!(query_string_cleanup("\\"), " ");
assert_eq!(query_string_cleanup(" '\n\\ "), " ");
}
#[test]
fn test_query_string_cleanup_mixed_content() {
assert_eq!(
query_string_cleanup("document's content\nwith\\backslash"),
"document s content with backslash"
);
assert_eq!(
query_string_cleanup("path/to'file\nextension\\test"),
"path/to file extension test"
);
}
}

View File

@@ -1,308 +0,0 @@
//! File search for KDE, powered by its Baloo engine.
use super::super::super::EXTENSION_ID;
use super::super::super::config::FileSearchConfig;
use super::super::super::config::SearchBy;
use super::super::should_be_filtered_out;
use crate::common::document::{DataSourceReference, Document};
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::extension::OnOpened;
use crate::util::file::sync_get_file_icon;
use camino::Utf8Path;
use configparser::ini::Ini;
use configparser::ini::WriteOptions;
use futures::stream::Stream;
use futures::stream::StreamExt;
use std::os::fd::OwnedFd;
use std::path::PathBuf;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio_stream::wrappers::LinesStream;
/// Baloo does not support scoring, use this score for all the documents.
const SCORE: f64 = 1.0;
/// KDE6 updates the binary name to "baloosearch6", but I believe there still have
/// distros using the original name. So we need to check both.
fn cli_tool_lookup() -> Option<PathBuf> {
use which::which;
let res_path = which("baloosearch").or_else(|_| which("baloosearch6"));
res_path.ok()
}
pub(crate) async fn hits(
query_string: &str,
_from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
// Special cases that will make querying faster.
if query_string.is_empty() || size == 0 || config.search_paths.is_empty() {
return Ok(Vec::new());
}
// If the tool is not found, return an empty result as well.
let Some(tool_path) = cli_tool_lookup() else {
return Ok(Vec::new());
};
let (mut iter, _baloosearch_child_process) =
execute_baloosearch_query(tool_path, query_string, size, config)?;
// Convert results to documents
let mut hits: Vec<(Document, f64)> = Vec::new();
while let Some(res_file_path) = iter.next().await {
let file_path = res_file_path.map_err(|io_err| io_err.to_string())?;
let icon = sync_get_file_icon(&file_path);
let file_path_of_type_path = camino::Utf8Path::new(&file_path);
let r#where = file_path_of_type_path
.parent()
.unwrap_or_else(|| {
panic!(
"expect path [{}] to have a parent, but it does not",
file_path
);
})
.to_string();
let file_name = file_path_of_type_path.file_name().unwrap_or_else(|| {
panic!(
"expect path [{}] to have a file name, but it does not",
file_path
);
});
let on_opened = OnOpened::Document {
url: file_path.clone(),
};
let doc = Document {
id: file_path.clone(),
title: Some(file_name.to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(EXTENSION_ID.into()),
id: Some(EXTENSION_ID.into()),
icon: Some(String::from("font_Filesearch")),
}),
category: Some(r#where),
on_opened: Some(on_opened),
url: Some(file_path),
icon: Some(icon.to_string()),
..Default::default()
};
hits.push((doc, SCORE));
}
Ok(hits)
}
/// Return an array containing the `baloosearch` command and its arguments.
fn build_baloosearch_query(
tool_path: PathBuf,
query_string: &str,
config: &FileSearchConfig,
) -> Vec<String> {
let tool_path = tool_path
.into_os_string()
.into_string()
.expect("binary path should be UTF-8 encoded");
let mut args = vec![tool_path];
match config.search_by {
SearchBy::Name => {
args.push(format!("filename:{query_string}"));
}
SearchBy::NameAndContents => {
args.push(query_string.to_string());
}
}
for search_path in config.search_paths.iter() {
args.extend_from_slice(&["-d".into(), search_path.clone()]);
}
args
}
/// Spawn the `baloosearch` child process and return an async iterator over its output,
/// allowing us to collect the results asynchronously.
///
/// # Return value:
///
/// * impl Stream: an async iterator that will yield the matched files
/// * Child: The handle to the baloosearch process. The child process will be
/// killed when this handle gets dropped so we need to keep it alive util we
/// exhaust the stream.
fn execute_baloosearch_query(
tool_path: PathBuf,
query_string: &str,
size: usize,
config: &FileSearchConfig,
) -> Result<(impl Stream<Item = std::io::Result<String>>, Child), String> {
let args = build_baloosearch_query(tool_path, query_string, config);
let (rx, tx) = std::io::pipe().unwrap();
let rx_owned = OwnedFd::from(rx);
let async_rx = tokio::net::unix::pipe::Receiver::from_owned_fd(rx_owned).unwrap();
let buffered_rx = BufReader::new(async_rx);
let lines = LinesStream::new(buffered_rx.lines());
let child = Command::new(&args[0])
.args(&args[1..])
.stdout(tx)
.stderr(std::process::Stdio::null())
// The child process will be killed when the Child instance gets dropped.
.kill_on_drop(true)
.spawn()
.map_err(|e| format!("Failed to spawn baloosearch: {e}"))?;
let config_clone = config.clone();
let iter = lines
.filter(move |res_path| {
std::future::ready({
match res_path {
Ok(path) => !should_be_filtered_out(&config_clone, path, false, true, true),
Err(_) => {
// Don't filter out Err() values
true
}
}
})
})
.take(size);
Ok((iter, child))
}
pub(crate) fn apply_config(config: &FileSearchConfig) -> Result<(), String> {
// Users can tweak Baloo via its configuration file, below are the fields that
// we need to modify:
//
// * Indexing-Enabled: turn indexing on or off
// * only basic indexing: If true, Baloo only indexes file names
// * folders: directories to index
// * exclude folders: directories to skip
//
// ```ini
// [Basic Settings]
// Indexing-Enabled=true
//
// [General]
// only basic indexing=true
// folders[$e]=$HOME/
// exclude folders[$e]=$HOME/FolderA/,$HOME/FolderB/
// ```
const SECTION_GENERAL: &str = "General";
const KEY_INCLUDE_FOLDERS: &str = "folders[$e]";
const KEY_EXCLUDE_FOLDERS: &str = "exclude folders[$e]";
const FOLDERS_SEPARATOR: &str = ",";
let rc_file_path = {
let mut home = dirs::home_dir()
.expect("cannot find the home directory, Coco should never run in such a environment");
home.push(".config/baloofilerc");
home
};
// Parse and load the rc file, it is in format INI
//
// Use `new_cs()`, the case-sensitive version of constructor as the config
// file contains uppercase letters, so it is case-sensitive.
let mut baloo_config = Ini::new_cs();
if rc_file_path.try_exists().map_err(|e| e.to_string())? {
let _ = baloo_config.load(rc_file_path.as_path())?;
}
// Ensure indexing is enabled
let _ = baloo_config.setstr("Basic Settings", "Indexing-Enabled", Some("true"));
// Let baloo index file content if we need that
if config.search_by == SearchBy::NameAndContents {
let _ = baloo_config.setstr(SECTION_GENERAL, "only basic indexing", Some("false"));
}
let mut include_folders = {
match baloo_config.get(SECTION_GENERAL, KEY_INCLUDE_FOLDERS) {
Some(str) => str
.split(FOLDERS_SEPARATOR)
.map(|str| str.to_string())
.collect::<Vec<String>>(),
None => Vec::new(),
}
};
let mut exclude_folders = {
match baloo_config.get(SECTION_GENERAL, KEY_EXCLUDE_FOLDERS) {
Some(str) => str
.split(FOLDERS_SEPARATOR)
.map(|str| str.to_string())
.collect::<Vec<String>>(),
None => Vec::new(),
}
};
fn ensure_path_included_include_folders(
include_folders: &mut Vec<String>,
search_path: &Utf8Path,
) {
for include_folder in include_folders.iter() {
let include_folder = Utf8Path::new(include_folder.as_str());
if search_path.starts_with(include_folder) {
return;
}
}
include_folders.push(search_path.as_str().to_string());
}
fn ensure_path_and_descendants_not_excluded(
exclude_folders: &mut Vec<String>,
search_path: &Utf8Path,
) {
let mut items_to_remove = Vec::new();
for (idx, exclude_folder) in exclude_folders.iter().enumerate() {
let exclude_folder = Utf8Path::new(exclude_folder);
if exclude_folder.starts_with(search_path) {
items_to_remove.push(idx);
}
}
for idx in items_to_remove.into_iter().rev() {
exclude_folders.remove(idx);
}
}
for search_path in config.search_paths.iter() {
let search_path = Utf8Path::new(search_path.as_str());
ensure_path_included_include_folders(&mut include_folders, search_path);
ensure_path_and_descendants_not_excluded(&mut exclude_folders, search_path);
}
let include_folders_str: String = include_folders.as_slice().join(FOLDERS_SEPARATOR);
let exclude_folders_str: String = exclude_folders.as_slice().join(FOLDERS_SEPARATOR);
let _ = baloo_config.set(
SECTION_GENERAL,
KEY_INCLUDE_FOLDERS,
Some(include_folders_str),
);
let _ = baloo_config.set(
SECTION_GENERAL,
KEY_EXCLUDE_FOLDERS,
Some(exclude_folders_str),
);
baloo_config
.pretty_write(rc_file_path.as_path(), &WriteOptions::new())
.map_err(|e| e.to_string())?;
Ok(())
}

View File

@@ -1,50 +0,0 @@
mod gnome;
mod kde;
use super::super::config::FileSearchConfig;
use crate::common::document::Document;
use crate::util::LinuxDesktopEnvironment;
use crate::util::get_linux_desktop_environment;
use std::ops::Deref;
use std::sync::LazyLock;
static DESKTOP_ENVIRONMENT: LazyLock<Option<LinuxDesktopEnvironment>> =
LazyLock::new(|| get_linux_desktop_environment());
/// Dispatch to implementations powered by different backends.
pub(crate) async fn hits(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
let de = DESKTOP_ENVIRONMENT.deref();
match de {
Some(LinuxDesktopEnvironment::Gnome) => gnome::hits(query_string, from, size, config).await,
Some(LinuxDesktopEnvironment::Kde) => kde::hits(query_string, from, size, config).await,
Some(LinuxDesktopEnvironment::Unsupported {
xdg_current_desktop: _,
}) => {
return Err("file search is not supported on this desktop environment".into());
}
None => {
return Err("could not determine Linux desktop environment".into());
}
}
}
pub(crate) fn apply_config(config: &FileSearchConfig) -> Result<(), String> {
let de = DESKTOP_ENVIRONMENT.deref();
match de {
Some(LinuxDesktopEnvironment::Gnome) => gnome::apply_config(config),
Some(LinuxDesktopEnvironment::Kde) => kde::apply_config(config),
Some(LinuxDesktopEnvironment::Unsupported {
xdg_current_desktop: _,
}) => {
return Err("file search is not supported on this desktop environment".into());
}
None => {
return Err("could not determine Linux desktop environment".into());
}
}
}

View File

@@ -1,190 +0,0 @@
use super::super::EXTENSION_ID;
use super::super::config::FileSearchConfig;
use super::super::config::SearchBy;
use super::should_be_filtered_out;
use crate::common::document::{DataSourceReference, Document};
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::extension::OnOpened;
use crate::util::file::sync_get_file_icon;
use futures::stream::Stream;
use futures::stream::StreamExt;
use std::os::fd::OwnedFd;
use std::path::Path;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio_stream::wrappers::LinesStream;
/// `mdfind` won't return scores, we use this score for all the documents.
const SCORE: f64 = 1.0;
pub(crate) async fn hits(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
let (mut iter, _mdfind_child_process) =
execute_mdfind_query(&query_string, from, size, &config)?;
// Convert results to documents
let mut hits: Vec<(Document, f64)> = Vec::new();
while let Some(res_file_path) = iter.next().await {
let file_path = res_file_path.map_err(|io_err| io_err.to_string())?;
let icon = sync_get_file_icon(&file_path);
let file_path_of_type_path = camino::Utf8Path::new(&file_path);
let r#where = file_path_of_type_path
.parent()
.unwrap_or_else(|| {
panic!(
"expect path [{}] to have a parent, but it does not",
file_path
);
})
.to_string();
let file_name = file_path_of_type_path.file_name().unwrap_or_else(|| {
panic!(
"expect path [{}] to have a file name, but it does not",
file_path
);
});
let on_opened = OnOpened::Document {
url: file_path.clone(),
};
let doc = Document {
id: file_path.clone(),
title: Some(file_name.to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(EXTENSION_ID.into()),
id: Some(EXTENSION_ID.into()),
icon: Some(String::from("font_Filesearch")),
}),
category: Some(r#where),
on_opened: Some(on_opened),
url: Some(file_path),
icon: Some(icon.to_string()),
..Default::default()
};
hits.push((doc, SCORE));
}
Ok(hits)
}
/// Return an array containing the `mdfind` command and its arguments.
fn build_mdfind_query(query_string: &str, config: &FileSearchConfig) -> Vec<String> {
let mut args = vec!["mdfind".to_string()];
match config.search_by {
SearchBy::Name => {
// The tailing char 'c' makes the search case-insensitive.
//
// According to [1], we should use this syntax "kMDItemFSName ==[c] '*{}*'",
// but it does not work on my machine (macOS 26 beta 7), and you
// can find similar complaints as well [2].
//
// [1]: https://developer.apple.com/library/archive/documentation/Carbon/Conceptual/SpotlightQuery/Concepts/QueryFormat.html
// [2]: https://apple.stackexchange.com/q/263671/394687
args.push(format!("kMDItemFSName == '*{}*'c", query_string));
}
SearchBy::NameAndContents => {
// Do not specify any File System Metadata Attribute Keys to search
// all of them, it is case-insensitive by default.
//
// Previously, we use:
//
// "kMDItemFSName == '*{}*' || kMDItemTextContent == '{}'"
//
// But the kMDItemTextContent attribute does not work as expected.
// For example, if a PDF document contains both "Waterloo" and
// "waterloo", it is only matched by "Waterloo".
args.push(query_string.to_string());
}
}
// Add search paths using -onlyin
for path in &config.search_paths {
if Path::new(path).exists() {
args.extend_from_slice(&["-onlyin".to_string(), path.to_string()]);
}
}
args
}
/// Spawn the `mdfind` child process and return an async iterator over its output,
/// allowing us to collect the results asynchronously.
///
/// # Return value:
///
/// * impl Stream: an async iterator that will yield the matched files
/// * Child: The handle to the mdfind process. The child process will be killed
/// when this handle gets dropped, we need to keep it alive until we exhaust
/// all the query results.
fn execute_mdfind_query(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<(impl Stream<Item = std::io::Result<String>>, Child), String> {
let args = build_mdfind_query(query_string, &config);
let (rx, tx) = std::io::pipe().unwrap();
let rx_owned = OwnedFd::from(rx);
let async_rx = tokio::net::unix::pipe::Receiver::from_owned_fd(rx_owned).unwrap();
let buffered_rx = BufReader::new(async_rx);
let lines = LinesStream::new(buffered_rx.lines());
let child = Command::new(&args[0])
.args(&args[1..])
.stdout(tx)
.stderr(std::process::Stdio::null())
.kill_on_drop(true)
.spawn()
.map_err(|e| format!("Failed to spawn mdfind: {}", e))?;
let config_clone = config.clone();
let iter = lines
.filter(move |res_path| {
std::future::ready({
match res_path {
Ok(path) => !should_be_filtered_out(&config_clone, path, false, true, true),
Err(_) => {
// Don't filter out Err() values
true
}
}
})
})
.skip(from)
.take(size);
Ok((iter, child))
}
pub(crate) fn apply_config(_: &FileSearchConfig) -> Result<(), String> {
// By default, macOS indexes all the files within a volume if indexing is
// enabled. So, to ensure our search paths are indexed by Spotlight,
// theoretically, we can do the following things:
//
// 1. Ensure indexing is enabled on the volumes where our search paths reside.
// However, we cannot do this as doing so requires `sudo`.
//
// 2. Ensure the search paths are not excluded from indexing scope. Users can
// stop Spotlight from indexing a directory by:
// 1. adding it to the "Privacy" list in 'System Settings'. Coco cannot
// modify this list, since the only way to change it is manually
// through System Settings.
// 2. Renaming directory name, adding a `.noindex` file extension to it.
// I don't want to use this trick, users won't feel comfortable and it
// could break at any time.
// 3. Creating a `.metadata_never_index` file within the directory (no longer works
// since macOS Mojave)
//
// There is nothing we can do.
Ok(())
}

View File

@@ -1,396 +0,0 @@
use cfg_if::cfg_if;
// * hits: the implementation of search
//
// * apply_config: Routines that should be performed to keep "other things"
// synchronous with the passed configuration.
// Currently, "other things" only include system indexer's setting entries.
cfg_if! {
if #[cfg(target_os = "linux")] {
mod linux;
pub(crate) use linux::hits;
pub(crate) use linux::apply_config;
} else if #[cfg(target_os = "macos")] {
mod macos;
pub(crate) use macos::hits;
pub(crate) use macos::apply_config;
} else if #[cfg(target_os = "windows")] {
mod windows;
pub(crate) use windows::hits;
pub(crate) use windows::apply_config;
}
}
cfg_if! {
if #[cfg(not(target_os = "windows"))] {
use super::config::FileSearchConfig;
use camino::Utf8Path;
}
}
/// If `file_path` should be removed from the search results given the filter
/// conditions specified in `config`.
#[cfg(not(target_os = "windows"))] // Not used on Windows
pub(crate) fn should_be_filtered_out(
config: &FileSearchConfig,
file_path: &str,
check_search_paths: bool,
check_exclude_paths: bool,
check_file_type: bool,
) -> bool {
let file_path = Utf8Path::new(file_path);
if check_search_paths {
// search path
let in_search_paths = config.search_paths.iter().any(|search_path| {
let search_path = Utf8Path::new(search_path);
file_path.starts_with(search_path)
});
if !in_search_paths {
return true;
}
}
if check_exclude_paths {
// exclude path
let is_excluded = config
.exclude_paths
.iter()
.any(|exclude_path| file_path.starts_with(exclude_path));
if is_excluded {
return true;
}
}
if check_file_type {
// file type
let matches_file_type = if config.file_types.is_empty() {
true
} else {
let path_obj = camino::Utf8Path::new(&file_path);
if let Some(extension) = path_obj.extension() {
config
.file_types
.iter()
.any(|file_type| file_type == extension)
} else {
// `config.file_types` is not empty, the hit files should have extensions.
false
}
};
if !matches_file_type {
return true;
}
}
false
}
// should_be_filtered_out() is not defined for Windows
#[cfg(all(test, not(target_os = "windows")))]
mod tests {
use super::super::config::SearchBy;
use super::*;
#[test]
fn test_should_be_filtered_out_with_no_check() {
let config = FileSearchConfig {
search_paths: vec!["/home/user/Documents".to_string()],
exclude_paths: vec![],
file_types: vec!["fffffff".into()],
search_by: SearchBy::Name,
};
assert!(!should_be_filtered_out(
&config, "abbc", false, false, false
));
}
#[test]
fn test_should_be_filtered_out_search_paths() {
let config = FileSearchConfig {
search_paths: vec![
"/home/user/Documents".to_string(),
"/home/user/Downloads".to_string(),
],
exclude_paths: vec![],
file_types: vec![],
search_by: SearchBy::Name,
};
// Files in search paths should not be filtered
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/file.txt",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Downloads/image.jpg",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/folder/file.txt",
true,
true,
true
));
// Files not in search paths should be filtered
assert!(should_be_filtered_out(
&config,
"/home/user/Pictures/photo.jpg",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/tmp/tempfile",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/usr/bin/ls",
true,
true,
true
));
}
#[test]
fn test_should_be_filtered_out_exclude_paths() {
let config = FileSearchConfig {
search_paths: vec!["/home/user".to_string()],
exclude_paths: vec![
"/home/user/Trash".to_string(),
"/home/user/.cache".to_string(),
],
file_types: vec![],
search_by: SearchBy::Name,
};
// Files in search paths but not excluded should not be filtered
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/file.txt",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Downloads/image.jpg",
true,
true,
true
));
// Files in excluded paths should be filtered
assert!(should_be_filtered_out(
&config,
"/home/user/Trash/deleted_file",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/home/user/.cache/temp",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/home/user/Trash/folder/file.txt",
true,
true,
true
));
}
#[test]
fn test_should_be_filtered_out_file_types() {
let config = FileSearchConfig {
search_paths: vec!["/home/user/Documents".to_string()],
exclude_paths: vec![],
file_types: vec!["txt".to_string(), "md".to_string()],
search_by: SearchBy::Name,
};
// Files with allowed extensions should not be filtered
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/notes.txt",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/readme.md",
true,
true,
true
));
// Files with disallowed extensions should be filtered
assert!(should_be_filtered_out(
&config,
"/home/user/Documents/image.jpg",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/home/user/Documents/document.pdf",
true,
true,
true
));
// Files without extensions should be filtered when file_types is not empty
assert!(should_be_filtered_out(
&config,
"/home/user/Documents/file",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/home/user/Documents/folder",
true,
true,
true
));
}
#[test]
fn test_should_be_filtered_out_empty_file_types() {
let config = FileSearchConfig {
search_paths: vec!["/home/user/Documents".to_string()],
exclude_paths: vec![],
file_types: vec![],
search_by: SearchBy::Name,
};
// When file_types is empty, all file types should be allowed
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/file.txt",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/image.jpg",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/document",
true,
true,
true
));
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/folder/",
true,
true,
true
));
}
#[test]
fn test_should_be_filtered_out_combined_filters() {
let config = FileSearchConfig {
search_paths: vec!["/home/user".to_string()],
exclude_paths: vec!["/home/user/Trash".to_string()],
file_types: vec!["txt".to_string()],
search_by: SearchBy::Name,
};
// Should pass all filters: in search path, not excluded, and correct file type
assert!(!should_be_filtered_out(
&config,
"/home/user/Documents/notes.txt",
true,
true,
true
));
// Fails file type filter
assert!(should_be_filtered_out(
&config,
"/home/user/Documents/image.jpg",
true,
true,
true
));
// Fails exclude path filter
assert!(should_be_filtered_out(
&config,
"/home/user/Trash/deleted.txt",
true,
true,
true
));
// Fails search path filter
assert!(should_be_filtered_out(
&config,
"/tmp/temp.txt",
true,
true,
true
));
}
#[test]
fn test_should_be_filtered_out_edge_cases() {
let config = FileSearchConfig {
search_paths: vec!["/home/user".to_string()],
exclude_paths: vec![],
file_types: vec!["txt".to_string()],
search_by: SearchBy::Name,
};
// Empty path
assert!(should_be_filtered_out(&config, "", true, true, true));
// Root path
assert!(should_be_filtered_out(&config, "/", true, true, true));
// Path that starts with search path but continues differently
assert!(!should_be_filtered_out(
&config,
"/home/user/document.txt",
true,
true,
true
));
assert!(should_be_filtered_out(
&config,
"/home/user_other/file.txt",
true,
true,
true
));
}
}

View File

@@ -1,234 +0,0 @@
//! Wraps Windows `ISearchCrawlScopeManager`
mod searchapi_h_bindings;
use searchapi_h_bindings::CLSID_CSEARCH_MANAGER;
use searchapi_h_bindings::IID_ISEARCH_MANAGER;
use searchapi_h_bindings::{
HRESULT, ISearchCatalogManager, ISearchCatalogManagerVtbl, ISearchCrawlScopeManager,
ISearchCrawlScopeManagerVtbl, ISearchManager,
};
use std::ffi::OsStr;
use std::ffi::OsString;
use std::os::windows::ffi::OsStrExt;
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null_mut;
use windows::core::w;
use windows_sys::Win32::Foundation::S_OK;
use windows_sys::Win32::System::Com::{
CLSCTX_LOCAL_SERVER, COINIT_APARTMENTTHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize,
};
#[derive(Debug, thiserror::Error)]
#[error("{msg}, function [{function}], HRESULT [{hresult}]")]
pub(crate) struct WindowSearchApiError {
function: &'static str,
hresult: HRESULT,
msg: String,
}
/// See doc of [`Rule`].
#[derive(Debug, PartialEq)]
pub(crate) enum RuleMode {
Inclusion,
Exclusion,
}
/// A rule adds or removes one or more paths to/from the Windows Search index.
#[derive(Debug)]
pub(crate) struct Rule {
/// A path or path pattern (wildcard supported, only for exclusion rule) that
/// specifies the paths that this rule applies to.
///
/// The rules used by Windows Search actually specify URLs rather than paths,
/// but we only care about paths, i.e., URLs with schema `file://`
pub(crate) paths: PathBuf,
/// Add or remove paths to/from the index.
pub(crate) mode: RuleMode,
}
/// A wrapper around Window's `ISearchCrawlScopeManager` type
pub(crate) struct CrawlScopeManager {
i_search_crawl_scope_manager: *mut ISearchCrawlScopeManager,
}
impl CrawlScopeManager {
fn vtable(&self) -> *mut ISearchCrawlScopeManagerVtbl {
unsafe { (*self.i_search_crawl_scope_manager).lpVtbl }
}
pub(crate) fn new() -> Result<Self, WindowSearchApiError> {
unsafe {
// 1. Initialize the COM library, use Apartment-threading as Self is not Send/Sync
let hr = CoInitializeEx(null_mut(), COINIT_APARTMENTTHREADED as u32);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "CoInitializeEx()",
hresult: hr,
msg: "failed to initialize the COM library".into(),
});
}
// 2. Create an instance of the CSearchManager.
let mut search_manager: *mut ISearchManager = null_mut();
let hr = CoCreateInstance(
&CLSID_CSEARCH_MANAGER, // CLSID of the object
null_mut(), // No outer unknown
CLSCTX_LOCAL_SERVER, // Server context
&IID_ISEARCH_MANAGER, // IID of the interface we want
&mut search_manager as *mut _ as *mut _, // Pointer to receive the interface
);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "CoCreateInstance()",
hresult: hr,
msg: "failed to initialize ISearchManager".into(),
});
}
assert!(!search_manager.is_null());
let search_manger_vtable = (*search_manager).lpVtbl;
let search_manager_fn_get_catalog = (*search_manger_vtable).GetCatalog.unwrap();
let mut search_catalog_manager: *mut ISearchCatalogManager = null_mut();
let string_literal_system_index = w!("SystemIndex");
let hr: HRESULT = search_manager_fn_get_catalog(
search_manager,
string_literal_system_index.0,
&mut search_catalog_manager as *mut *mut ISearchCatalogManager,
);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "ISearchManager::GetCatalog()",
hresult: hr,
msg: "failed to initialize ISearchCatalogManager".into(),
});
}
assert!(!search_catalog_manager.is_null());
let search_catalog_manager_vtable: *mut ISearchCatalogManagerVtbl =
(*search_catalog_manager).lpVtbl;
let fn_get_crawl_scope_manager = (*search_catalog_manager_vtable)
.GetCrawlScopeManager
.unwrap();
let mut search_crawl_scope_manager: *mut ISearchCrawlScopeManager = null_mut();
let hr =
fn_get_crawl_scope_manager(search_catalog_manager, &mut search_crawl_scope_manager);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "ISearchCatalogManager::GetCrawlScopeManager()",
hresult: hr,
msg: "failed to initialize ISearchCrawlScopeManager".into(),
});
}
assert!(!search_crawl_scope_manager.is_null());
Ok(Self {
i_search_crawl_scope_manager: search_crawl_scope_manager,
})
}
}
/// Does nothing unless you [`commit()`] the changes.
pub(crate) fn add_rule(&mut self, rule: Rule) -> Result<(), WindowSearchApiError> {
unsafe {
let vtable = self.vtable();
let fn_add_rule = (*vtable).AddUserScopeRule.unwrap();
let url: Vec<u16> = encode_path(&rule.paths);
let inclusion = (rule.mode == RuleMode::Inclusion) as i32;
let override_child_rules = true as i32;
let follow_flag = 0x1_u32; /* FF_INDEXCOMPLEXURLS */
let hr = fn_add_rule(
self.i_search_crawl_scope_manager,
url.as_ptr(),
inclusion,
override_child_rules,
follow_flag,
);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "ISearchCrawlScopeManager::AddUserScopeRule()",
hresult: hr,
msg: "failed to add scope rule".into(),
});
}
Ok(())
}
}
pub(crate) fn is_path_included<P: AsRef<Path> + ?Sized>(
&self,
path: &P,
) -> Result<bool, WindowSearchApiError> {
unsafe {
let vtable = self.vtable();
let fn_included_in_crawl_scope = (*vtable).IncludedInCrawlScope.unwrap();
let path: Vec<u16> = encode_path(path);
let mut included: i32 = 0 /* false */;
let hr = fn_included_in_crawl_scope(
self.i_search_crawl_scope_manager,
path.as_ptr(),
&mut included,
);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "ISearchCrawlScopeManager::IncludedInCrawlScope()",
hresult: hr,
msg: "failed to call IncludedInCrawlScope()".into(),
});
}
Ok(included == 1)
}
}
pub(crate) fn commit(&self) -> Result<(), WindowSearchApiError> {
unsafe {
let vtable = self.vtable();
let fn_commit = (*vtable).SaveAll.unwrap();
let hr = fn_commit(self.i_search_crawl_scope_manager);
if hr != S_OK {
return Err(WindowSearchApiError {
function: "ISearchCrawlScopeManager::SaveAll()",
hresult: hr,
msg: "failed to commit the changes".into(),
});
}
Ok(())
}
}
}
impl Drop for CrawlScopeManager {
fn drop(&mut self) {
unsafe {
CoUninitialize();
}
}
}
fn encode_path<P: AsRef<Path> + ?Sized>(path: &P) -> Vec<u16> {
let mut buffer = OsString::new();
// schema
buffer.push("file:///");
buffer.push(path.as_ref().as_os_str());
osstr_to_wstr(&buffer)
}
fn osstr_to_wstr<S: AsRef<OsStr> + ?Sized>(str: &S) -> Vec<u16> {
let os_str: &OsStr = str.as_ref();
let mut chars = os_str.encode_wide().collect::<Vec<u16>>();
chars.push(0 /* NUL */);
chars
}

View File

@@ -1,30 +0,0 @@
//! Rust binding of the types and functions declared in 'searchapi.h'
#![allow(unused)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(non_upper_case_globals)]
#![allow(unsafe_op_in_unsafe_fn)]
#![allow(unnecessary_transmutes)]
include!(concat!(env!("OUT_DIR"), "/searchapi_bindings.rs"));
// The bindings.rs contains a GUID type as well, we use the one provided by
// the windows_sys crate here.
use windows_sys::core::GUID as WIN_SYS_GUID;
// https://github.com/search?q=CLSID_CSearchManager+language%3AC&type=code&l=C
pub(crate) static CLSID_CSEARCH_MANAGER: WIN_SYS_GUID = WIN_SYS_GUID {
data1: 0x7d096c5f,
data2: 0xac08,
data3: 0x4f1f,
data4: [0xbe, 0xb7, 0x5c, 0x22, 0xc5, 0x17, 0xce, 0x39],
};
// https://github.com/search?q=IID_ISearchManager+language%3AC&type=code
pub(crate) static IID_ISEARCH_MANAGER: WIN_SYS_GUID = WIN_SYS_GUID {
data1: 0xAB310581,
data2: 0xac80,
data3: 0x11d1,
data4: [0x8d, 0xf3, 0x00, 0xc0, 0x4f, 0xb6, 0xef, 0x69],
};

View File

@@ -1,834 +0,0 @@
//! # Credits
//!
//! https://github.com/IRONAGE-Park/rag-sample/blob/3f0ad8c8012026cd3a7e453d08f041609426cb91/src/native/windows.rs
//! is the starting point of this implementation.
mod crawl_scope_manager;
use super::super::EXTENSION_ID;
use super::super::config::FileSearchConfig;
use super::super::config::SearchBy;
use crate::common::document::{DataSourceReference, Document};
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::extension::OnOpened;
use crate::util::file::sync_get_file_icon;
use std::borrow::Borrow;
use std::path::PathBuf;
use windows::{
Win32::System::{
Com::{CLSCTX_INPROC_SERVER, CoCreateInstance},
Ole::{OleInitialize, OleUninitialize},
Search::{
DB_NULL_HCHAPTER, DBACCESSOR_ROWDATA, DBBINDING, DBMEMOWNER_CLIENTOWNED,
DBPARAMIO_NOTPARAM, DBPART_VALUE, DBTYPE_WSTR, HACCESSOR, IAccessor, ICommand,
ICommandText, IDBCreateCommand, IDBCreateSession, IDBInitialize, IDataInitialize,
IRowset, MSDAINITIALIZE,
},
},
core::{GUID, IUnknown, Interface, PWSTR, w},
};
/// Owned version of `PWSTR` that holds the heap memory.
///
/// Use `as_pwstr()` to convert it to a raw pointer.
struct PwStrOwned(Vec<u16>);
impl PwStrOwned {
/// # SAFETY
///
/// The returned `PWSTR` is basically a raw pointer, it is only valid within the
/// lifetime of `PwStrOwned`.
unsafe fn as_pwstr(&mut self) -> PWSTR {
let raw_ptr = self.0.as_mut_ptr();
PWSTR::from_raw(raw_ptr)
}
}
/// Construct `PwStrOwned` from any `str`.
impl<S: AsRef<str> + ?Sized> From<&S> for PwStrOwned {
fn from(value: &S) -> Self {
let mut utf16_bytes = value.as_ref().encode_utf16().collect::<Vec<u16>>();
utf16_bytes.push(0); // the tailing NULL
PwStrOwned(utf16_bytes)
}
}
/// Helper function to replace unsupported characters with whitespace.
///
/// Windows search will error out if it encounters these characters.
///
/// The complete list of unsupported characters is unknown and we don't know how
/// to escape them, so let's replace them.
fn query_string_cleanup(old: &str) -> String {
const UNSUPPORTED_CHAR: [char; 2] = ['\'', '\n'];
// Using len in bytes is ok
let mut chars = Vec::with_capacity(old.len());
for char in old.chars() {
if UNSUPPORTED_CHAR.contains(&char) {
chars.push(' ');
} else {
chars.push(char);
}
}
chars.into_iter().collect()
}
/// Helper function to construct the Windows Search SQL.
///
/// Paging is not natively supported by windows Search SQL, it only supports `size`
/// via the `TOP` keyword ("SELECT TOP {n} {columns}"). The SQL returned by this
/// function will have `{n}` set to `from + size`, then we will manually implement
/// paging.
fn query_sql(query_string: &str, from: usize, size: usize, config: &FileSearchConfig) -> String {
let top_n = from
.checked_add(size)
.expect("[from + size] cannot fit into an [usize]");
// System.ItemUrl is a column that contains the file path
// example: "file:C:/Users/desktop.ini"
//
// System.Search.Rank is the relevance score
let mut sql = format!(
"SELECT TOP {} System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE",
top_n
);
let query_string = query_string_cleanup(query_string);
let search_by_predicate = match config.search_by {
SearchBy::Name => {
// `contains(System.FileName, '{query_string}')` would be faster
// because it uses inverted index, but that's not what we want
// due to the limitation of tokenization. For example, suppose "Coco AI.rs"
// will be tokenized to `["Coco", "AI", "rs"]`, then if users search
// via `Co`, this file won't be returned because term `Co` does not
// exist in the index.
//
// So we use wildcard instead even though it is slower.
format!("(System.FileName LIKE '%{query_string}%')")
}
SearchBy::NameAndContents => {
// Windows File Search does not support searching by file content.
//
// `CONTAINS('query_string')` would search all columns for `query_string`,
// this is the closest solution we have.
format!("((System.FileName LIKE '%{query_string}%') OR CONTAINS('{query_string}'))")
}
};
let search_paths_predicate: Option<String> = {
if config.search_paths.is_empty() {
None
} else {
let mut output = String::from("(");
for (idx, search_path) in config.search_paths.iter().enumerate() {
if idx != 0 {
output.push_str(" OR ");
}
output.push_str("SCOPE = 'file:");
output.push_str(&search_path);
output.push('\'');
}
output.push(')');
Some(output)
}
};
let exclude_paths_predicate: Option<String> = {
if config.exclude_paths.is_empty() {
None
} else {
let mut output = String::from("(");
for (idx, exclude_path) in config.exclude_paths.iter().enumerate() {
if idx != 0 {
output.push_str(" AND ");
}
output.push_str("(NOT SCOPE = 'file:");
output.push_str(&exclude_path);
output.push('\'');
output.push(')');
}
output.push(')');
Some(output)
}
};
let file_types_predicate: Option<String> = {
if config.file_types.is_empty() {
None
} else {
let mut output = String::from("(");
for (idx, file_type) in config.file_types.iter().enumerate() {
if idx != 0 {
output.push_str(" OR ");
}
// NOTE that this column contains a starting dot
output.push_str("System.FileExtension = '.");
output.push_str(&file_type);
output.push('\'');
}
output.push(')');
Some(output)
}
};
sql.push(' ');
sql.push_str(search_by_predicate.as_str());
if let Some(search_paths_predicate) = search_paths_predicate {
sql.push_str(" AND ");
sql.push_str(search_paths_predicate.as_str());
}
if let Some(exclude_paths_predicate) = exclude_paths_predicate {
sql.push_str(" AND ");
sql.push_str(exclude_paths_predicate.as_str());
}
if let Some(file_types_predicate) = file_types_predicate {
sql.push_str(" AND ");
sql.push_str(file_types_predicate.as_str());
}
sql
}
/// Default GUID for Search.CollatorDSO.1
const DBGUID_DEFAULT: GUID = GUID {
data1: 0xc8b521fb,
data2: 0x5cf3,
data3: 0x11ce,
data4: [0xad, 0xe5, 0x00, 0xaa, 0x00, 0x44, 0x77, 0x3d],
};
unsafe fn create_accessor_handle(accessor: &IAccessor, index: usize) -> Result<HACCESSOR, String> {
let bindings = DBBINDING {
iOrdinal: index,
obValue: 0,
obStatus: 0,
obLength: 0,
dwPart: DBPART_VALUE.0 as u32,
dwMemOwner: DBMEMOWNER_CLIENTOWNED.0 as u32,
eParamIO: DBPARAMIO_NOTPARAM.0 as u32,
cbMaxLen: 512,
dwFlags: 0,
wType: DBTYPE_WSTR.0 as u16,
bPrecision: 0,
bScale: 0,
..Default::default()
};
let mut status = 0;
let mut accessor_handle = HACCESSOR::default();
unsafe {
accessor
.CreateAccessor(
DBACCESSOR_ROWDATA.0 as u32,
1,
&bindings,
0,
&mut accessor_handle,
Some(&mut status),
)
.map_err(|e| e.to_string())?;
}
Ok(accessor_handle)
}
fn create_db_initialize() -> Result<IDBInitialize, String> {
unsafe {
let data_init: IDataInitialize =
CoCreateInstance(&MSDAINITIALIZE, None, CLSCTX_INPROC_SERVER)
.map_err(|e| e.to_string())?;
let mut unknown: Option<IUnknown> = None;
data_init
.GetDataSource(
None,
CLSCTX_INPROC_SERVER.0,
w!("provider=Search.CollatorDSO.1;EXTENDED PROPERTIES=\"Application=Windows\""),
&IDBInitialize::IID,
&mut unknown as *mut _ as *mut _,
)
.map_err(|e| e.to_string())?;
Ok(unknown.unwrap().cast().map_err(|e| e.to_string())?)
}
}
fn create_command(db_init: IDBInitialize) -> Result<ICommandText, String> {
unsafe {
let db_create_session: IDBCreateSession = db_init.cast().map_err(|e| e.to_string())?;
let session: IUnknown = db_create_session
.CreateSession(None, &IUnknown::IID)
.map_err(|e| e.to_string())?;
let db_create_command: IDBCreateCommand = session.cast().map_err(|e| e.to_string())?;
Ok(db_create_command
.CreateCommand(None, &ICommand::IID)
.map_err(|e| e.to_string())?
.cast()
.map_err(|e| e.to_string())?)
}
}
fn execute_windows_search_sql(sql_query: &str) -> Result<Vec<(String, String)>, String> {
unsafe {
let mut pwstr_owned_sql = PwStrOwned::from(sql_query);
// SAFETY: pwstr_owned_sql will live for the whole lifetime of this function.
let sql_query = pwstr_owned_sql.as_pwstr();
let db_init = create_db_initialize()?;
db_init.Initialize().map_err(|e| e.to_string())?;
let command = create_command(db_init)?;
// Set the command text
command
.SetCommandText(&DBGUID_DEFAULT, sql_query)
.map_err(|e| e.to_string())?;
// Execute the command
let mut rowset: Option<IRowset> = None;
command
.Execute(
None,
&IRowset::IID,
None,
None,
Some(&mut rowset as *mut _ as *mut _),
)
.map_err(|e| e.to_string())?;
let rowset = rowset.ok_or_else(|| {
format!(
"No rowset returned for query: {}",
// SAFETY: the raw pointer is not dangling
sql_query
.to_string()
.expect("the conversion should work as `sql_query` was created from a String",)
)
})?;
let accessor: IAccessor = rowset
.cast()
.map_err(|e| format!("Failed to cast to IAccessor: {}", e.to_string()))?;
let mut output = Vec::new();
let mut count = 0;
loop {
let mut rows_fetched = 0;
let mut row_handles = [std::ptr::null_mut(); 1];
let result = rowset.GetNextRows(
DB_NULL_HCHAPTER as usize,
0,
&mut rows_fetched,
&mut row_handles,
);
if result.is_err() {
break;
}
if rows_fetched == 0 {
break;
}
let mut data = Vec::new();
for i in 0..2 {
let mut item_name = [0u16; 512];
let accessor_handle = create_accessor_handle(&accessor, i + 1)?;
rowset
.GetData(
*row_handles[0],
accessor_handle,
item_name.as_mut_ptr() as *mut _,
)
.map_err(|e| {
format!(
"Failed to get data at count {}, index {}: {}",
count,
i,
e.to_string()
)
})?;
let name = String::from_utf16_lossy(&item_name);
// Remove null characters
data.push(name.trim_end_matches('\u{0000}').to_string());
accessor
.ReleaseAccessor(accessor_handle, None)
.map_err(|e| {
format!(
"Failed to release accessor at count {}, index {}: {}",
count,
i,
e.to_string()
)
})?;
}
output.push((data[0].clone(), data[1].clone()));
count += 1;
rowset
.ReleaseRows(
1,
row_handles[0],
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
.map_err(|e| {
format!(
"Failed to release rows at count {}: {}",
count,
e.to_string()
)
})?;
}
Ok(output)
}
}
pub(crate) async fn hits(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
let sql = query_sql(query_string, from, size, config);
unsafe { OleInitialize(None).map_err(|e| e.to_string())? };
let result = execute_windows_search_sql(&sql)?;
unsafe { OleUninitialize() };
// .take(size) is not needed as `result` will contain `from+size` files at most
let result_with_paging = result.into_iter().skip(from);
// result_with_paging won't contain more than `size` entries
let mut hits = Vec::with_capacity(size);
const ITEM_URL_PREFIX: &str = "file:";
const ITEM_URL_PREFIX_LEN: usize = ITEM_URL_PREFIX.len();
for (item_url, score_str) in result_with_paging {
// path returned from Windows Search contains a prefix, we need to trim it.
//
// "file:C:/Users/desktop.ini" => "C:/Users/desktop.ini"
let file_path = &item_url[ITEM_URL_PREFIX_LEN..];
let icon = sync_get_file_icon(file_path);
let file_path_of_type_path = camino::Utf8Path::new(&file_path);
let r#where = file_path_of_type_path
.parent()
.unwrap_or_else(|| {
panic!(
"expect path [{}] to have a parent, but it does not",
file_path
);
})
.to_string();
let file_name = file_path_of_type_path.file_name().unwrap_or_else(|| {
panic!(
"expect path [{}] to have a file name, but it does not",
file_path
);
});
let on_opened = OnOpened::Document {
url: file_path.to_string(),
};
let doc = Document {
id: file_path.to_string(),
title: Some(file_name.to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(EXTENSION_ID.into()),
id: Some(EXTENSION_ID.into()),
icon: Some(String::from("font_Filesearch")),
}),
category: Some(r#where),
on_opened: Some(on_opened),
url: Some(file_path.into()),
icon: Some(icon.to_string()),
..Default::default()
};
let score: f64 = score_str.parse().expect(
"System.Search.Rank should be in range [0, 1000], which should be valid for [f64]",
);
hits.push((doc, score));
}
Ok(hits)
}
pub(crate) fn apply_config(config: &FileSearchConfig) -> Result<(), String> {
// To ensure Windows Search indexer index the paths we specified in the
// config, we will:
//
// 1. Add an inclusion rule for every search path to ensure indexer index
// them
// 2. For the exclude paths, we exclude them from the crawl scope if they
// were not included in the scope before we update the scope. Otherwise,
// we cannot exclude them as doing that could potentially break other
// apps (by removing the indexes they rely on).
//
// Windows APIs are pretty smart. They won't blindly add an inclusion rule if
// the path you are trying to include is already included. The same applies
// to exclusion rules as well. Since Windows APIs handle these checks for us,
// we don't need to worry about them.
use crawl_scope_manager::CrawlScopeManager;
use crawl_scope_manager::Rule;
use crawl_scope_manager::RuleMode;
use std::borrow::Cow;
/// Windows APIs need the path to contain a tailing '\'
fn add_tailing_backslash(path: &str) -> Cow<'_, str> {
if path.ends_with(r#"\"#) {
Cow::Borrowed(path)
} else {
let mut owned = path.to_string();
owned.push_str(r#"\"#);
Cow::Owned(owned)
}
}
let mut manager = CrawlScopeManager::new().map_err(|e| e.to_string())?;
let search_paths = &config.search_paths;
let exclude_paths = &config.exclude_paths;
// indexes to `exclude_paths` of the paths we need to exclude
let mut paths_to_exclude: Vec<usize> = Vec::new();
for (idx, exclude_path) in exclude_paths.into_iter().enumerate() {
let exclude_path = add_tailing_backslash(&exclude_path);
let exclude_path: &str = exclude_path.borrow();
if !manager
.is_path_included(exclude_path)
.map_err(|e| e.to_string())?
{
paths_to_exclude.push(idx);
}
}
for search_path in search_paths {
let inclusion_rule = Rule {
paths: PathBuf::from(add_tailing_backslash(&search_path).into_owned()),
mode: RuleMode::Inclusion,
};
manager
.add_rule(inclusion_rule)
.map_err(|e| e.to_string())?;
}
for idx in paths_to_exclude {
let exclusion_rule = Rule {
paths: PathBuf::from(add_tailing_backslash(&exclude_paths[idx]).into_owned()),
mode: RuleMode::Exclusion,
};
manager
.add_rule(exclusion_rule)
.map_err(|e| e.to_string())?;
}
manager.commit().map_err(|e| e.to_string())?;
Ok(())
}
// Skip these tests in our CI, they fail with the following error
// "SQL is invalid: "0x80041820""
//
// I have no idea about the underlying root cause
#[cfg(all(test, not(ci)))]
mod test_windows_search {
use super::*;
/// Helper function for ensuring `sql` is valid SQL by actually executing it.
fn ensure_it_is_valid_sql(sql: &str) {
unsafe { OleInitialize(None).unwrap() };
execute_windows_search_sql(&sql).expect("SQL is invalid");
unsafe { OleUninitialize() };
}
#[test]
fn test_query_sql_empty_config_search_by_name() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("coco", 0, 10, &config);
assert_eq!(
sql,
"SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%coco%')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_empty_config_search_by_name_and_content() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::NameAndContents,
};
let sql = query_sql("coco", 0, 10, &config);
assert_eq!(
sql,
"SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE ((System.FileName LIKE '%coco%') OR CONTAINS('coco'))"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_search_paths() {
let config = FileSearchConfig {
search_paths: vec!["C:/Users/".into()],
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("coco", 0, 10, &config);
assert_eq!(
sql,
"SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%coco%') AND (SCOPE = 'file:C:/Users/')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_multiple_search_paths() {
let config = FileSearchConfig {
search_paths: vec![
"C:/Users/".into(),
"D:/Projects/".into(),
"E:/Documents/".into(),
],
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("test", 0, 5, &config);
assert_eq!(
sql,
"SELECT TOP 5 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%test%') AND (SCOPE = 'file:C:/Users/' OR SCOPE = 'file:D:/Projects/' OR SCOPE = 'file:E:/Documents/')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_exclude_paths() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: vec!["C:/Windows/".into()],
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("file", 0, 20, &config);
assert_eq!(
sql,
"SELECT TOP 20 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%file%') AND ((NOT SCOPE = 'file:C:/Windows/'))"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_multiple_exclude_paths() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: vec!["C:/Windows/".into(), "C:/System/".into(), "C:/Temp/".into()],
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("data", 5, 15, &config);
assert_eq!(
sql,
"SELECT TOP 20 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%data%') AND ((NOT SCOPE = 'file:C:/Windows/') AND (NOT SCOPE = 'file:C:/System/') AND (NOT SCOPE = 'file:C:/Temp/'))"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_file_types() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: vec!["txt".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("readme", 0, 10, &config);
assert_eq!(
sql,
"SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%readme%') AND (System.FileExtension = '.txt')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_multiple_file_types() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: vec!["rs".into(), "toml".into(), "md".into(), "json".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("config", 0, 50, &config);
assert_eq!(
sql,
"SELECT TOP 50 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%config%') AND (System.FileExtension = '.rs' OR System.FileExtension = '.toml' OR System.FileExtension = '.md' OR System.FileExtension = '.json')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_all_fields_combined() {
let config = FileSearchConfig {
search_paths: vec!["C:/Projects/".into(), "D:/Code/".into()],
exclude_paths: vec!["C:/Projects/temp/".into()],
file_types: vec!["rs".into(), "ts".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("main", 10, 25, &config);
assert_eq!(
sql,
"SELECT TOP 35 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%main%') AND (SCOPE = 'file:C:/Projects/' OR SCOPE = 'file:D:/Code/') AND ((NOT SCOPE = 'file:C:/Projects/temp/')) AND (System.FileExtension = '.rs' OR System.FileExtension = '.ts')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_special_characters() {
let config = FileSearchConfig {
search_paths: vec!["C:/Users/John Doe/".into()],
exclude_paths: Vec::new(),
file_types: vec!["c++".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("hello-world", 0, 10, &config);
assert_eq!(
sql,
"SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%hello-world%') AND (SCOPE = 'file:C:/Users/John Doe/') AND (System.FileExtension = '.c++')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_edge_case_large_offset() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("test", 100, 50, &config);
assert_eq!(
sql,
"SELECT TOP 150 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%test%')"
);
ensure_it_is_valid_sql(&sql);
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_query_string_cleanup_no_unsupported_chars() {
let input = "hello world";
let result = query_string_cleanup(input);
assert_eq!(result, input);
}
#[test]
fn test_query_string_cleanup_single_quote() {
let input = "don't worry";
let result = query_string_cleanup(input);
assert_eq!(result, "don t worry");
}
#[test]
fn test_query_string_cleanup_newline() {
let input = "line1\nline2";
let result = query_string_cleanup(input);
assert_eq!(result, "line1 line2");
}
#[test]
fn test_query_string_cleanup_both_unsupported_chars() {
let input = "don't\nworry";
let result = query_string_cleanup(input);
assert_eq!(result, "don t worry");
}
#[test]
fn test_query_string_cleanup_multiple_single_quotes() {
let input = "it's a 'test' string";
let result = query_string_cleanup(input);
assert_eq!(result, "it s a test string");
}
#[test]
fn test_query_string_cleanup_multiple_newlines() {
let input = "line1\n\nline2\nline3";
let result = query_string_cleanup(input);
assert_eq!(result, "line1 line2 line3");
}
#[test]
fn test_query_string_cleanup_empty_string() {
let input = "";
let result = query_string_cleanup(input);
assert_eq!(result, input);
}
#[test]
fn test_query_string_cleanup_only_unsupported_chars() {
let input = "'\n'";
let result = query_string_cleanup(input);
assert_eq!(result, " ");
}
#[test]
fn test_query_string_cleanup_unicode_characters() {
let input = "héllo wörld's\nfile";
let result = query_string_cleanup(input);
assert_eq!(result, "héllo wörld s file");
}
#[test]
fn test_query_string_cleanup_special_chars_preserved() {
let input = "test@file#name$with%symbols";
let result = query_string_cleanup(input);
assert_eq!(result, input);
}
}

View File

@@ -1,97 +0,0 @@
pub(crate) mod config;
pub(crate) mod implementation;
use super::super::LOCAL_QUERY_SOURCE_TYPE;
use crate::common::{
error::SearchError,
search::{QueryResponse, QuerySource, SearchQuery},
traits::SearchSource,
};
use async_trait::async_trait;
use config::FileSearchConfig;
use hostname;
use tauri::AppHandle;
pub(crate) const EXTENSION_ID: &str = "File Search";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "File Search",
"name": "File Search",
"platforms": ["macos", "windows", "linux"],
"description": "Search files on your system",
"icon": "font_Filesearch",
"type": "extension"
}
"#;
pub struct FileSearchExtensionSearchSource;
#[async_trait]
impl SearchSource for FileSearchExtensionSearchSource {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(EXTENSION_ID.into())
.to_string_lossy()
.into(),
id: EXTENSION_ID.into(),
}
}
async fn search(
&self,
tauri_app_handle: AppHandle,
query: SearchQuery,
) -> Result<QueryResponse, SearchError> {
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
let from = usize::try_from(query.from).expect("from too big");
let size = usize::try_from(query.size).expect("size too big");
let query_string = query_string.trim();
if query_string.is_empty() {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
// Get configuration from tauri store
let config = FileSearchConfig::get(&tauri_app_handle);
// If search paths are empty, then the hit should be empty.
//
// Without this, empty search paths will result in a mdfind that has no `-onlyin`
// option, which will in turn query the whole disk volume.
if config.search_paths.is_empty() {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
// Execute search in a blocking task
let query_source = self.get_type();
let hits = implementation::hits(&query_string, from, size, &config)
.await
.map_err(SearchError::InternalError)?;
let total_hits = hits.len();
Ok(QueryResponse {
source: query_source,
hits,
total_hits,
})
}
}

View File

@@ -1,724 +0,0 @@
//! Built-in extensions and related stuff.
pub mod ai_overview;
pub mod application;
pub mod calculator;
pub mod file_search;
pub mod pizza_engine_runtime;
pub mod quick_ai_access;
#[cfg(target_os = "macos")]
pub mod window_management;
use super::Extension;
use crate::SearchSourceRegistry;
use crate::extension::built_in::application::{set_apps_hotkey, unset_apps_hotkey};
use crate::extension::{
ExtensionBundleIdBorrowed, PLUGIN_JSON_FILE_NAME, alter_extension_json_file,
};
use anyhow::Context;
use file_search::config::FileSearchConfig;
use file_search::implementation::apply_config as file_search_apply_config;
use std::path::{Path, PathBuf};
use tauri::{AppHandle, Manager};
pub(crate) fn get_built_in_extension_directory(tauri_app_handle: &AppHandle) -> PathBuf {
let mut resource_dir = tauri_app_handle.path().app_data_dir().expect(
"User home directory not found, which should be impossible on desktop environments",
);
resource_dir.push("built_in_extensions");
resource_dir
}
/// Helper function to load the built-in extension specified by `extension_id`, used
/// in `list_built_in_extensions()`.
///
/// For built-in extensions, users are only allowed to edit these fields:
///
/// 1. alias (if this extension supports alias)
/// 2. hotkey (if this extension supports hotkey)
/// 3. enabled
///
/// If
///
/// 1. The above fields have invalid value
/// 2. Other fields are modified
///
/// we ignore and reset them to the default value.
async fn load_built_in_extension(
built_in_extensions_dir: &Path,
extension_id: &str,
default_plugin_json_file: &str,
) -> Result<Extension, String> {
let mut extension_dir = built_in_extensions_dir.join(extension_id);
let mut default_plugin_json = serde_json::from_str::<Extension>(&default_plugin_json_file).unwrap_or_else( |e| {
panic!("the default extension {} file of built-in extension [{}] cannot be parsed as a valid [struct Extension], error [{}]", PLUGIN_JSON_FILE_NAME, extension_id, e);
});
if !extension_dir.try_exists().map_err(|e| e.to_string())? {
tokio::fs::create_dir_all(extension_dir.as_path())
.await
.map_err(|e| e.to_string())?;
}
let plugin_json_file_path = {
extension_dir.push(PLUGIN_JSON_FILE_NAME);
extension_dir
};
// If the JSON file does not exist, create a file with the default template and return.
if !plugin_json_file_path
.try_exists()
.map_err(|e| e.to_string())?
{
tokio::fs::write(plugin_json_file_path, default_plugin_json_file)
.await
.map_err(|e| e.to_string())?;
return Ok(default_plugin_json);
}
let plugin_json_file_content = tokio::fs::read_to_string(plugin_json_file_path.as_path())
.await
.map_err(|e| e.to_string())?;
let res_plugin_json = serde_json::from_str::<Extension>(&plugin_json_file_content);
let Ok(plugin_json) = res_plugin_json else {
log::warn!(
"user invalidated built-in extension [{}] file, overwriting it with the default template",
extension_id
);
// If the JSON file cannot be parsed as `struct Extension`, overwrite it with the default template and return.
tokio::fs::write(plugin_json_file_path, default_plugin_json_file)
.await
.map_err(|e| e.to_string())?;
return Ok(default_plugin_json);
};
// Users are only allowed to edit the below fields
// 1. alias (if this extension supports alias)
// 2. hotkey (if this extension supports hotkey)
// 3. enabled
// so we ignore all other fields.
let alias = if default_plugin_json.supports_alias_hotkey() {
plugin_json.alias.clone()
} else {
None
};
let hotkey = if default_plugin_json.supports_alias_hotkey() {
plugin_json.hotkey.clone()
} else {
None
};
let enabled = plugin_json.enabled;
default_plugin_json.alias = alias;
default_plugin_json.hotkey = hotkey;
default_plugin_json.enabled = enabled;
let final_plugin_json_file_content = serde_json::to_string_pretty(&default_plugin_json)
.expect("failed to serialize `struct Extension`");
tokio::fs::write(plugin_json_file_path, final_plugin_json_file_content)
.await
.map_err(|e| e.to_string())?;
Ok(default_plugin_json)
}
/// Return the built-in extension list.
///
/// Will create extension files when they are not found.
///
/// Users may put extension files in the built-in extension directory, but
/// we do not care and will ignore them.
///
/// We only read alias/hotkey/enabled from the JSON file, we have ensured that if
/// alias/hotkey is not supported, then it will be `None`. Besides that, no further
/// validation is needed because nothing could go wrong.
pub(crate) async fn list_built_in_extensions(
tauri_app_handle: &AppHandle,
) -> Result<Vec<Extension>, String> {
let dir = get_built_in_extension_directory(tauri_app_handle);
let mut built_in_extensions = Vec::new();
built_in_extensions.push(
load_built_in_extension(
&dir,
application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME,
application::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
&dir,
calculator::DATA_SOURCE_ID,
calculator::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
&dir,
ai_overview::EXTENSION_ID,
ai_overview::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
&dir,
quick_ai_access::EXTENSION_ID,
quick_ai_access::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
&dir,
file_search::EXTENSION_ID,
file_search::PLUGIN_JSON_FILE,
)
.await?,
);
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
built_in_extensions.push(
load_built_in_extension(
&dir,
window_management::EXTENSION_ID,
window_management::PLUGIN_JSON_FILE,
)
.await?,
);
}
}
Ok(built_in_extensions)
}
pub(super) async fn init_built_in_extension(
tauri_app_handle: &AppHandle,
extension: &Extension,
search_source_registry: &SearchSourceRegistry,
) -> Result<(), String> {
log::trace!("initializing built-in extensions [{}]", extension.id);
if extension.id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
search_source_registry
.register_source(application::ApplicationSearchSource)
.await;
set_apps_hotkey(&tauri_app_handle)?;
log::debug!("built-in extension [{}] initialized", extension.id);
}
if extension.id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
search_source_registry
.register_source(calculator_search)
.await;
log::debug!("built-in extension [{}] initialized", extension.id);
}
if extension.id == file_search::EXTENSION_ID {
let file_system_search = file_search::FileSearchExtensionSearchSource;
search_source_registry
.register_source(file_system_search)
.await;
let file_search_config = FileSearchConfig::get(tauri_app_handle);
file_search_apply_config(&file_search_config)?;
log::debug!("built-in extension [{}] initialized", extension.id);
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
if extension.id == window_management::EXTENSION_ID {
let file_system_search = window_management::search_source::WindowManagementSearchSource;
search_source_registry
.register_source(file_system_search)
.await;
window_management::set_up_commands_hotkeys(tauri_app_handle, extension)?;
log::debug!("built-in extension [{}] initialized", extension.id);
}
}
}
Ok(())
}
pub(crate) fn is_extension_built_in(bundle_id: &ExtensionBundleIdBorrowed<'_>) -> bool {
bundle_id.developer.is_none()
}
pub(crate) async fn enable_built_in_extension(
tauri_app_handle: &AppHandle,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<(), String> {
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
let update_extension = |extension: &mut Extension| -> Result<(), String> {
extension.enabled = true;
Ok(())
};
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_none()
{
search_source_registry_tauri_state
.register_source(application::ApplicationSearchSource)
.await;
set_apps_hotkey(tauri_app_handle)?;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
// Check if this is an application
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_some()
{
let app_path = bundle_id.sub_extension_id.expect("just checked it is Some");
application::enable_app_search(tauri_app_handle, app_path)?;
return Ok(());
}
if bundle_id.extension_id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
search_source_registry_tauri_state
.register_source(calculator_search)
.await;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == quick_ai_access::EXTENSION_ID {
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == ai_overview::EXTENSION_ID {
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == file_search::EXTENSION_ID {
let file_system_search = file_search::FileSearchExtensionSearchSource;
search_source_registry_tauri_state
.register_source(file_system_search)
.await;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
let file_search_config = FileSearchConfig::get(tauri_app_handle);
file_search_apply_config(&file_search_config)?;
return Ok(());
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
if bundle_id.extension_id == window_management::EXTENSION_ID
&& bundle_id.sub_extension_id.is_none()
{
let built_in_extension_dir = get_built_in_extension_directory(tauri_app_handle);
let file_system_search = window_management::search_source::WindowManagementSearchSource;
search_source_registry_tauri_state
.register_source(file_system_search)
.await;
let extension =
load_extension_from_json_file(&built_in_extension_dir, bundle_id.extension_id)?;
window_management::set_up_commands_hotkeys(tauri_app_handle, &extension)?;
alter_extension_json_file(&built_in_extension_dir, bundle_id, update_extension)?;
return Ok(());
}
if bundle_id.extension_id == window_management::EXTENSION_ID {
if let Some(command_id) = bundle_id.sub_extension_id {
let built_in_extension_dir = get_built_in_extension_directory(tauri_app_handle);
alter_extension_json_file(&built_in_extension_dir, bundle_id, update_extension)?;
let extension =
load_extension_from_json_file(&built_in_extension_dir, bundle_id.extension_id)?;
window_management::set_up_command_hotkey(tauri_app_handle, &extension, command_id)?;
}
}
}
}
Ok(())
}
pub(crate) async fn disable_built_in_extension(
tauri_app_handle: &AppHandle,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<(), String> {
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
let update_extension = |extension: &mut Extension| -> Result<(), String> {
extension.enabled = false;
Ok(())
};
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_none()
{
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
unset_apps_hotkey(tauri_app_handle)?;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
// Check if this is an application
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_some()
{
let app_path = bundle_id.sub_extension_id.expect("just checked it is Some");
application::disable_app_search(tauri_app_handle, app_path)?;
return Ok(());
}
if bundle_id.extension_id == calculator::DATA_SOURCE_ID {
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == quick_ai_access::EXTENSION_ID {
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == ai_overview::EXTENSION_ID {
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == file_search::EXTENSION_ID {
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_extension,
)?;
return Ok(());
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
if bundle_id.extension_id == window_management::EXTENSION_ID
&& bundle_id.sub_extension_id.is_none()
{
let built_in_extension_dir = get_built_in_extension_directory(tauri_app_handle);
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
alter_extension_json_file(&built_in_extension_dir, bundle_id, update_extension)?;
let extension =
load_extension_from_json_file(&built_in_extension_dir, bundle_id.extension_id)?;
window_management::unset_commands_hotkeys(tauri_app_handle, &extension)?;
}
if bundle_id.extension_id == window_management::EXTENSION_ID {
if let Some(command_id) = bundle_id.sub_extension_id {
let built_in_extension_dir = get_built_in_extension_directory(tauri_app_handle);
alter_extension_json_file(&built_in_extension_dir, bundle_id, update_extension)?;
let extension =
load_extension_from_json_file(&built_in_extension_dir, bundle_id.extension_id)?;
window_management::unset_command_hotkey(tauri_app_handle, &extension, command_id)?;
}
}
}
}
Ok(())
}
pub(crate) fn set_built_in_extension_alias(
tauri_app_handle: &AppHandle,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
alias: &str,
) -> Result<(), String> {
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
application::set_app_alias(tauri_app_handle, app_path, alias);
}
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
if bundle_id.extension_id == window_management::EXTENSION_ID
&& bundle_id.sub_extension_id.is_some()
{
let update_function = |ext: &mut Extension| {
ext.alias = Some(alias.to_string());
Ok(())
};
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_function,
)?;
}
}
}
Ok(())
}
pub(crate) fn register_built_in_extension_hotkey(
tauri_app_handle: &AppHandle,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
hotkey: &str,
) -> Result<(), String> {
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
application::register_app_hotkey(&tauri_app_handle, app_path, hotkey)?;
}
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
let update_function = |ext: &mut Extension| {
ext.hotkey = Some(hotkey.into());
Ok(())
};
if bundle_id.extension_id == window_management::EXTENSION_ID {
if let Some(command_id) = bundle_id.sub_extension_id {
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_function,
)?;
window_management::register_command_hotkey(tauri_app_handle, command_id, hotkey)?;
}
}
}
}
Ok(())
}
pub(crate) fn unregister_built_in_extension_hotkey(
tauri_app_handle: &AppHandle,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<(), String> {
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
application::unregister_app_hotkey(&tauri_app_handle, app_path)?;
}
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
let update_function = |ext: &mut Extension| {
ext.hotkey = None;
Ok(())
};
if bundle_id.extension_id == window_management::EXTENSION_ID {
if let Some(command_id) = bundle_id.sub_extension_id {
let extension = load_extension_from_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id.extension_id,
)
.unwrap();
window_management::unregister_command_hotkey(tauri_app_handle, &extension, command_id)?;
alter_extension_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id,
update_function,
)
.unwrap();
}
}
}
}
Ok(())
}
fn split_extension_id(extension_id: &str) -> (&str, Option<&str>) {
match extension_id.find('.') {
Some(idx) => (&extension_id[..idx], Some(&extension_id[idx + 1..])),
None => (extension_id, None),
}
}
fn load_extension_from_json_file(
extension_directory: &Path,
extension_id: &str,
) -> Result<Extension, String> {
let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id);
let json_file_path = {
let mut extension_directory_path = extension_directory.join(parent_extension_id);
extension_directory_path.push(PLUGIN_JSON_FILE_NAME);
extension_directory_path
};
let mut extension = serde_json::from_reader::<_, Extension>(
std::fs::File::open(&json_file_path)
.with_context(|| {
format!(
"the [{}] file for extension [{}] is missing or broken",
PLUGIN_JSON_FILE_NAME, parent_extension_id
)
})
.map_err(|e| e.to_string())?,
)
.map_err(|e| e.to_string())?;
super::canonicalize_relative_icon_path(extension_directory, &mut extension)?;
Ok(extension)
}
#[allow(unused_macros)] // #[function_name::named] only used on macOS
#[function_name::named]
pub(crate) async fn is_built_in_extension_enabled(
tauri_app_handle: &AppHandle,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<bool, String> {
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_none()
{
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
// Check if this is an application
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
return Ok(application::is_app_search_enabled(app_path));
}
}
if bundle_id.extension_id == calculator::DATA_SOURCE_ID {
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
if bundle_id.extension_id == quick_ai_access::EXTENSION_ID {
let extension = load_extension_from_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id.extension_id,
)?;
return Ok(extension.enabled);
}
if bundle_id.extension_id == ai_overview::EXTENSION_ID {
let extension = load_extension_from_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id.extension_id,
)?;
return Ok(extension.enabled);
}
if bundle_id.extension_id == file_search::EXTENSION_ID && bundle_id.sub_extension_id.is_none() {
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
// Window Management
if bundle_id.extension_id == window_management::EXTENSION_ID
&& bundle_id.sub_extension_id.is_none()
{
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
// Window Management commands
if bundle_id.extension_id == window_management::EXTENSION_ID
&& let Some(command_id) = bundle_id.sub_extension_id
{
let extension = load_extension_from_json_file(
&get_built_in_extension_directory(tauri_app_handle),
bundle_id.extension_id,
)?;
let commands = extension
.commands
.expect("window management extension has commands");
let extension = commands.iter().find( |cmd| cmd.id == command_id).unwrap_or_else(|| {
panic!("function [{}()] invoked with a Window Management command that does not exist, extension ID [{}] ", function_name!(), command_id)
});
return Ok(extension.enabled);
}
}
}
unreachable!("extension [{:?}] is not a built-in extension", bundle_id)
}

View File

@@ -1,76 +0,0 @@
//! We use Pizza Engine to index applications and local files. The engine will be
//! run in the thread/runtime defined in this file.
//!
//! # Why such a thread/runtime is needed
//!
//! Generally, Tokio async runtime requires all the async tasks running on it to be
//! `Send` and `Sync`, but the async tasks created by Pizza Engine are not,
//! which forces us to create a dedicated thread/runtime to execute them.
use std::any::Any;
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::OnceLock;
pub(crate) trait SearchSourceState {
#[cfg_attr(not(feature = "use_pizza_engine"), allow(unused))]
fn as_mut_any(&mut self) -> &mut dyn Any;
}
#[async_trait::async_trait(?Send)]
pub(crate) trait Task: Send + Sync {
fn search_source_id(&self) -> &'static str;
async fn exec(&mut self, state: &mut Option<Box<dyn SearchSourceState>>);
}
pub(crate) static RUNTIME_TX: OnceLock<tokio::sync::mpsc::UnboundedSender<Box<dyn Task>>> =
OnceLock::new();
/// This function blocks until the runtime thread is ready for accepting tasks.
pub(crate) async fn start_pizza_engine_runtime() {
const THREAD_NAME: &str = "Pizza engine runtime thread";
log::trace!("starting Pizza engine runtime");
let (engine_start_signal_tx, engine_start_signal_rx) = tokio::sync::oneshot::channel();
std::thread::Builder::new()
.name(THREAD_NAME.into())
.spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
let main = async {
let mut states: HashMap<String, Option<Box<dyn SearchSourceState>>> =
HashMap::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
RUNTIME_TX.set(tx).unwrap();
engine_start_signal_tx
.send(())
.expect("engine_start_signal_rx dropped");
while let Some(mut task) = rx.recv().await {
let opt_search_source_state = match states.entry(task.search_source_id().into())
{
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(None),
};
task.exec(opt_search_source_state).await;
}
};
rt.block_on(main);
})
.unwrap_or_else(|e| {
panic!(
"failed to start thread [{}] due to error [{}]",
THREAD_NAME, e
);
});
engine_start_signal_rx
.await
.expect("engine_start_signal_tx dropped, the runtime thread could be dead");
log::trace!("Pizza engine runtime started");
}

View File

@@ -1,12 +0,0 @@
pub(super) const EXTENSION_ID: &str = "QuickAIAccess";
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "QuickAIAccess",
"name": "Quick AI Access",
"description": "...",
"icon": "font_a-QuickAIAccess",
"type": "ai_extension",
"enabled": true
}
"#;

View File

@@ -1,134 +0,0 @@
#[derive(Debug, Clone, PartialEq, Copy, Hash, serde::Serialize, serde::Deserialize)]
pub enum Action {
/// Move the window to fill left half of the screen.
TopHalf,
/// Move the window to fill bottom half of the screen.
BottomHalf,
/// Move the window to fill left half of the screen.
LeftHalf,
/// Move the window to fill right half of the screen.
RightHalf,
/// Move the window to fill center half of the screen.
CenterHalf,
/// Resize window to the top left quarter of the screen.
TopLeftQuarter,
/// Resize window to the top right quarter of the screen.
TopRightQuarter,
/// Resize window to the bottom left quarter of the screen.
BottomLeftQuarter,
/// Resize window to the bottom right quarter of the screen.
BottomRightQuarter,
/// Resize window to the top left sixth of the screen.
TopLeftSixth,
/// Resize window to the top center sixth of the screen.
TopCenterSixth,
/// Resize window to the top right sixth of the screen.
TopRightSixth,
/// Resize window to the bottom left sixth of the screen.
BottomLeftSixth,
/// Resize window to the bottom center sixth of the screen.
BottomCenterSixth,
/// Resize window to the bottom right sixth of the screen.
BottomRightSixth,
/// Resize window to the top third of the screen.
TopThird,
/// Resize window to the middle third of the screen.
MiddleThird,
/// Resize window to the bottom third of the screen.
BottomThird,
/// Center window in the screen.
Center,
/// Resize window to the first fourth of the screen.
FirstFourth,
/// Resize window to the second fourth of the screen.
SecondFourth,
/// Resize window to the third fourth of the screen.
ThirdFourth,
/// Resize window to the last fourth of the screen.
LastFourth,
/// Resize window to the first third of the screen.
FirstThird,
/// Resize window to the center third of the screen.
CenterThird,
/// Resize window to the last third of the screen.
LastThird,
/// Resize window to the first two thirds of the screen.
FirstTwoThirds,
/// Resize window to the center two thirds of the screen.
CenterTwoThirds,
/// Resize window to the last two thirds of the screen.
LastTwoThirds,
/// Resize window to the first three fourths of the screen.
FirstThreeFourths,
/// Resize window to the center three fourths of the screen.
CenterThreeFourths,
/// Resize window to the last three fourths of the screen.
LastThreeFourths,
/// Resize window to the top three fourths of the screen.
TopThreeFourths,
/// Resize window to the bottom three fourths of the screen.
BottomThreeFourths,
/// Resize window to the top two thirds of the screen.
TopTwoThirds,
/// Resize window to the bottom two thirds of the screen.
BottomTwoThirds,
/// Resize window to the top center two thirds of the screen.
TopCenterTwoThirds,
/// Resize window to the top first fourth of the screen.
TopFirstFourth,
/// Resize window to the top second fourth of the screen.
TopSecondFourth,
/// Resize window to the top third fourth of the screen.
TopThirdFourth,
/// Resize window to the top last fourth of the screen.
TopLastFourth,
/// Increase the window until it reaches the screen size.
MakeLarger,
/// Decrease the window until it reaches its minimal size.
MakeSmaller,
/// Maximize window to almost fit the screen.
AlmostMaximize,
/// Maximize window to fit the screen.
Maximize,
/// Maximize width of window to fit the screen.
MaximizeWidth,
/// Maximize height of window to fit the screen.
MaximizeHeight,
/// Move window to the top edge of the screen.
MoveUp,
/// Move window to the bottom of the screen.
MoveDown,
/// Move window to the left edge of the screen.
MoveLeft,
/// Move window to the right edge of the screen.
MoveRight,
/// Move window to the next desktop.
NextDesktop,
/// Move window to the previous desktop.
PreviousDesktop,
/// Move window to the next display.
NextDisplay,
/// Move window to the previous display.
PreviousDisplay,
/// Restore window to its last position.
Restore,
/// Toggle fullscreen mode.
ToggleFullscreen,
}

View File

@@ -1,796 +0,0 @@
//! This module calls macOS APIs to implement various helper functions needed by
//! to perform the defined actions.
mod private;
use std::ffi::c_uint;
use std::ffi::c_ushort;
use std::ffi::c_void;
use std::ops::Deref;
use std::ptr::NonNull;
use std::time::Duration;
use objc2::MainThreadMarker;
use objc2_app_kit::NSEvent;
use objc2_app_kit::NSScreen;
use objc2_app_kit::NSWorkspace;
use objc2_application_services::AXError;
use objc2_application_services::AXUIElement;
use objc2_application_services::AXValue;
use objc2_application_services::AXValueType;
use objc2_core_foundation::CFBoolean;
use objc2_core_foundation::CFRetained;
use objc2_core_foundation::CFString;
use objc2_core_foundation::CFType;
use objc2_core_foundation::CGPoint;
use objc2_core_foundation::CGRect;
use objc2_core_foundation::CGSize;
use objc2_core_foundation::Type;
use objc2_core_foundation::{CFArray, CFDictionary, CFNumber};
use objc2_core_graphics::CGError;
use objc2_core_graphics::CGEvent;
use objc2_core_graphics::CGEventFlags;
use objc2_core_graphics::CGEventTapLocation;
use objc2_core_graphics::CGEventType;
use objc2_core_graphics::CGMouseButton;
use objc2_core_graphics::CGRectGetMidX;
use objc2_core_graphics::CGRectGetMinY;
use objc2_core_graphics::CGRectIntersectsRect;
use objc2_core_graphics::CGWindowID;
use super::error::Error;
use private::CGSCopyManagedDisplaySpaces;
use private::CGSGetActiveSpace;
use private::CGSMainConnectionID;
use private::CGSSpaceID;
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
fn intersects(r1: CGRect, r2: CGRect) -> bool {
unsafe { CGRectIntersectsRect(r1, r2) }
}
/// Core graphics APIs use flipped coordinate system, while AppKit uses the
/// unflippled version, they differ in the y-axis. We need to do the conversion
/// (to `CGPoint.y`) manually.
fn flip_frame_y(main_screen_height: f64, frame_height: f64, frame_unflipped_y: f64) -> f64 {
main_screen_height - (frame_unflipped_y + frame_height)
}
/// Helper function to extract an UI element's origin.
fn get_ui_element_origin(ui_element: &CFRetained<AXUIElement>) -> Result<CGPoint, Error> {
let mut position_value: *const CFType = std::ptr::null();
let ptr_to_position_value = NonNull::new(&mut position_value).unwrap();
let position_attr = CFString::from_static_str("AXPosition");
let error = unsafe { ui_element.copy_attribute_value(&position_attr, ptr_to_position_value) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
assert!(!position_value.is_null());
let position: CFRetained<AXValue> =
unsafe { CFRetained::from_raw(NonNull::new(position_value.cast_mut().cast()).unwrap()) };
let mut position_cg_point = CGPoint::ZERO;
let ptr_to_position_cg_point =
NonNull::new((&mut position_cg_point as *mut CGPoint).cast()).unwrap();
let result = unsafe { position.value(AXValueType::CGPoint, ptr_to_position_cg_point) };
assert!(result, "type mismatched");
Ok(position_cg_point)
}
/// Send a set origin request to the `ui_element`, return once request is sent.
fn set_ui_element_origin_oneshot(
ui_element: &CFRetained<AXUIElement>,
mut origin: CGPoint,
) -> Result<(), Error> {
let ptr_to_origin = NonNull::new((&mut origin as *mut CGPoint).cast::<c_void>()).unwrap();
let pos_value = unsafe { AXValue::new(AXValueType::CGPoint, ptr_to_origin) }.unwrap();
let pos_attr = CFString::from_static_str("AXPosition");
let error = unsafe { ui_element.set_attribute_value(&pos_attr, pos_value.deref()) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
Ok(())
}
/// Helper function to extract an UI element's size.
fn get_ui_element_size(ui_element: &CFRetained<AXUIElement>) -> Result<CGSize, Error> {
let mut size_value: *const CFType = std::ptr::null();
let ptr_to_size_value = NonNull::new(&mut size_value).unwrap();
let size_attr = CFString::from_static_str("AXSize");
let error = unsafe { ui_element.copy_attribute_value(&size_attr, ptr_to_size_value) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
assert!(!size_value.is_null());
let size: CFRetained<AXValue> =
unsafe { CFRetained::from_raw(NonNull::new(size_value.cast_mut().cast()).unwrap()) };
let mut size_cg_size = CGSize::ZERO;
let ptr_to_size_cg_size = NonNull::new((&mut size_cg_size as *mut CGSize).cast()).unwrap();
let result = unsafe { size.value(AXValueType::CGSize, ptr_to_size_cg_size) };
assert!(result, "type mismatched");
Ok(size_cg_size)
}
/// Send a set size request to the `ui_element`, return once request is sent.
fn set_ui_element_size_oneshot(
ui_element: &CFRetained<AXUIElement>,
mut size: CGSize,
) -> Result<(), Error> {
let ptr_to_size = NonNull::new((&mut size as *mut CGSize).cast::<c_void>()).unwrap();
let size_value = unsafe { AXValue::new(AXValueType::CGSize, ptr_to_size) }.unwrap();
let size_attr = CFString::from_static_str("AXSize");
let error = unsafe { ui_element.set_attribute_value(&size_attr, size_value.deref()) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
Ok(())
}
/// Get the frontmost/focused window (as an UI element).
fn get_frontmost_window() -> Result<CFRetained<AXUIElement>, Error> {
let workspace = unsafe { NSWorkspace::sharedWorkspace() };
let frontmost_app =
unsafe { workspace.frontmostApplication() }.ok_or(Error::CannotFindFocusWindow)?;
let pid = unsafe { frontmost_app.processIdentifier() };
let app_element = unsafe { AXUIElement::new_application(pid) };
let mut window_element: *const CFType = std::ptr::null();
let ptr_to_window_element = NonNull::new(&mut window_element).unwrap();
let focused_window_attr = CFString::from_static_str("AXFocusedWindow");
let error =
unsafe { app_element.copy_attribute_value(&focused_window_attr, ptr_to_window_element) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
assert!(!window_element.is_null());
let window_element: *mut AXUIElement = window_element.cast::<AXUIElement>().cast_mut();
let window = unsafe { CFRetained::from_raw(NonNull::new(window_element).unwrap()) };
Ok(window)
}
/// Get the CGWindowID of the frontmost/focused window.
#[allow(unused)] // In case we need it in the future
pub(crate) fn get_frontmost_window_id() -> Result<CGWindowID, Error> {
let element = get_frontmost_window()?;
let ptr: NonNull<AXUIElement> = CFRetained::as_ptr(&element);
let mut window_id_buffer: CGWindowID = 0;
let error =
unsafe { private::_AXUIElementGetWindow(ptr.as_ptr(), &mut window_id_buffer as *mut _) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
Ok(window_id_buffer)
}
/// Returns the workspace ID list grouped by display. For example, suppose you
/// have 2 displays and 10 workspaces (5 workspaces per display), then this
/// function might return something like:
///
/// ```text
/// [
/// [8, 11, 12, 13, 24],
/// [519, 77, 15, 249, 414]
/// ]
/// ```
///
/// Even though this function return macOS internal space IDs, they should correspond
/// to the logical workspace that users are familiar with. The display that contains
/// workspaces `[8, 11, 12, 13, 24]` should be your main display; workspace 8 represents
/// Desktop 1, and workspace 414 represents Desktop 10.
fn workspace_ids_grouped_by_display() -> Vec<Vec<CGSSpaceID>> {
unsafe {
let mut ret = Vec::new();
let conn = CGSMainConnectionID();
let display_spaces_raw = CGSCopyManagedDisplaySpaces(conn);
let display_spaces: CFRetained<CFArray> =
CFRetained::from_raw(NonNull::new(display_spaces_raw).unwrap());
let key_spaces: CFRetained<CFString> = CFString::from_static_str("Spaces");
let key_spaces_ptr: NonNull<CFString> = CFRetained::as_ptr(&key_spaces);
let key_id64: CFRetained<CFString> = CFString::from_static_str("id64");
let key_id64_ptr: NonNull<CFString> = CFRetained::as_ptr(&key_id64);
for i in 0..display_spaces.count() {
let mut workspaces_of_this_display = Vec::new();
let dict_ref = display_spaces.value_at_index(i);
let dict: &CFDictionary = &*(dict_ref as *const CFDictionary);
let mut ptr_to_value_buffer: *const c_void = std::ptr::null();
let key_exists = dict.value_if_present(
key_spaces_ptr.as_ptr().cast::<c_void>().cast_const(),
&mut ptr_to_value_buffer as *mut _,
);
assert!(key_exists);
assert!(!ptr_to_value_buffer.is_null());
let spaces_raw: *const CFArray = ptr_to_value_buffer.cast::<CFArray>();
let spaces = &*spaces_raw;
for idx in 0..spaces.count() {
let workspace_dictionary: &CFDictionary =
&*spaces.value_at_index(idx).cast::<CFDictionary>();
let mut ptr_to_value_buffer: *const c_void = std::ptr::null();
let key_exists = workspace_dictionary.value_if_present(
key_id64_ptr.as_ptr().cast::<c_void>().cast_const(),
&mut ptr_to_value_buffer as *mut _,
);
assert!(key_exists);
assert!(!ptr_to_value_buffer.is_null());
let ptr_workspace_id = ptr_to_value_buffer.cast::<CFNumber>();
let workspace_id = (&*ptr_workspace_id).as_i32().unwrap();
workspaces_of_this_display.push(workspace_id);
}
ret.push(workspaces_of_this_display);
}
ret
}
}
/// Get the next workspace's logical ID. By logical ID, we mean the ID that
/// users are familiar with, workspace 1/2/3 and so on, rather than the internal
/// `CGSSpaceID`.
///
/// NOTE that this function returns None when the current workspace is the last
/// workspace in the current display.
pub(crate) fn get_next_workspace_logical_id() -> Option<usize> {
let window_server_connection = unsafe { CGSMainConnectionID() };
let current_workspace_id = unsafe { CGSGetActiveSpace(window_server_connection) };
// Logical ID starts from 1
let mut logical_id = 1_usize;
for workspaces_in_a_display in workspace_ids_grouped_by_display() {
for (idx, workspace_raw_id) in workspaces_in_a_display.iter().enumerate() {
if *workspace_raw_id == current_workspace_id {
// We found it, now check if it is the last workspace in this display
if idx == workspaces_in_a_display.len() - 1 {
return None;
} else {
return Some(logical_id + 1);
}
} else {
logical_id += 1;
continue;
}
}
}
unreachable!(
"unless the private API CGSGetActiveSpace() is broken, it should return an ID that is in the workspace ID list"
)
}
/// Get the previous workspace's logical ID.
///
/// See [`get_next_workspace_logical_id`] for the doc.
pub(crate) fn get_previous_workspace_logical_id() -> Option<usize> {
let window_server_connection = unsafe { CGSMainConnectionID() };
let current_workspace_id = unsafe { CGSGetActiveSpace(window_server_connection) };
// Logical ID starts from 1
let mut logical_id = 1_usize;
for workspaces_in_a_display in workspace_ids_grouped_by_display() {
for (idx, workspace_raw_id) in workspaces_in_a_display.iter().enumerate() {
if *workspace_raw_id == current_workspace_id {
// We found it, now check if it is the first workspace in this display
if idx == 0 {
return None;
} else {
// this sub operation is safe, logical_id is at least 2
return Some(logical_id - 1);
}
} else {
logical_id += 1;
continue;
}
}
}
unreachable!(
"unless the private API CGSGetActiveSpace() is broken, it should return an ID that is in the workspace ID list"
)
}
/// Move the frontmost window to the specified workspace.
///
/// Credits to the Silica library
///
/// * https://github.com/ianyh/Silica/blob/b91a18dbb822e99ce6b487d1cb4841e863139b2a/Silica/Sources/SIWindow.m#L215-L260
/// * https://github.com/ianyh/Silica/blob/b91a18dbb822e99ce6b487d1cb4841e863139b2a/Silica/Sources/SISystemWideElement.m#L29-L65
pub(crate) fn move_frontmost_window_to_workspace(space: usize) -> Result<(), Error> {
assert!(space >= 1);
if space > 16 {
return Err(Error::TooManyWorkspace);
}
let window_frame = get_frontmost_window_frame()?;
let close_button_frame = get_frontmost_window_close_button_frame()?;
let prev_mouse_position = unsafe {
let event = CGEvent::new(None);
CGEvent::location(event.as_deref())
};
let mouse_cursor_point = CGPoint::new(
unsafe { CGRectGetMidX(close_button_frame) },
window_frame.origin.y
+ (window_frame.origin.y - unsafe { CGRectGetMinY(close_button_frame) }).abs() / 2.0,
);
let mouse_move_event = unsafe {
CGEvent::new_mouse_event(
None,
CGEventType::MouseMoved,
mouse_cursor_point,
CGMouseButton::Left,
)
};
let mouse_drag_event = unsafe {
CGEvent::new_mouse_event(
None,
CGEventType::LeftMouseDragged,
mouse_cursor_point,
CGMouseButton::Left,
)
};
let mouse_down_event = unsafe {
CGEvent::new_mouse_event(
None,
CGEventType::LeftMouseDown,
mouse_cursor_point,
CGMouseButton::Left,
)
};
let mouse_up_event = unsafe {
CGEvent::new_mouse_event(
None,
CGEventType::LeftMouseUp,
mouse_cursor_point,
CGMouseButton::Left,
)
};
unsafe {
CGEvent::set_flags(mouse_move_event.as_deref(), CGEventFlags(0));
CGEvent::set_flags(mouse_down_event.as_deref(), CGEventFlags(0));
CGEvent::set_flags(mouse_up_event.as_deref(), CGEventFlags(0));
// Move the mouse into place at the window's toolbar
CGEvent::post(CGEventTapLocation::HIDEventTap, mouse_move_event.as_deref());
// Mouse down to set up the drag
CGEvent::post(CGEventTapLocation::HIDEventTap, mouse_down_event.as_deref());
// Drag event to grab hold of the window
CGEvent::post(CGEventTapLocation::HIDEventTap, mouse_drag_event.as_deref());
}
// Make a slight delay to make sure the window is grabbed
std::thread::sleep(Duration::from_millis(50));
// cast is safe as space is in range [1, 16]
let hot_key: c_ushort = 118 + space as c_ushort - 1;
let mut flags: c_uint = 0;
let mut key_code: c_ushort = 0;
let error = unsafe {
private::CGSGetSymbolicHotKeyValue(hot_key, std::ptr::null_mut(), &mut key_code, &mut flags)
};
if error != CGError::Success {
return Err(Error::CGError(error));
}
unsafe {
// If the hotkey is disabled, enable it.
if !private::CGSIsSymbolicHotKeyEnabled(hot_key) {
if private::CGSSetSymbolicHotKeyEnabled(hot_key, true) != CGError::Success {
return Err(Error::CGError(error));
}
}
}
let opt_keyboard_event = unsafe { CGEvent::new_keyboard_event(None, key_code, true) };
unsafe {
// cast is safe (uint -> u64)
CGEvent::set_flags(opt_keyboard_event.as_deref(), CGEventFlags(flags as u64));
}
let keyboard_event = opt_keyboard_event.unwrap();
let event = unsafe { NSEvent::eventWithCGEvent(&keyboard_event) }.unwrap();
let keyboard_event_up = unsafe { CGEvent::new_keyboard_event(None, event.keyCode(), false) };
unsafe {
CGEvent::set_flags(keyboard_event_up.as_deref(), CGEventFlags(0));
// Send the shortcut command to get Mission Control to switch spaces from under the window.
CGEvent::post(CGEventTapLocation::HIDEventTap, event.CGEvent().as_deref());
CGEvent::post(
CGEventTapLocation::HIDEventTap,
keyboard_event_up.as_deref(),
);
}
// Make a slight delay to finish the space transition animation
std::thread::sleep(Duration::from_millis(50));
/*
* Cleanup
*/
unsafe {
// Let go of the window.
CGEvent::post(CGEventTapLocation::HIDEventTap, mouse_up_event.as_deref());
// Reset mouse position
let mouse_reset_event = {
CGEvent::new_mouse_event(
None,
CGEventType::MouseMoved,
prev_mouse_position,
CGMouseButton::Left,
)
};
CGEvent::set_flags(mouse_reset_event.as_deref(), CGEventFlags(0));
CGEvent::post(
CGEventTapLocation::HIDEventTap,
mouse_reset_event.as_deref(),
);
}
Ok(())
}
pub(crate) fn get_frontmost_window_origin() -> Result<CGPoint, Error> {
let frontmost_window = get_frontmost_window()?;
get_ui_element_origin(&frontmost_window)
}
pub(crate) fn get_frontmost_window_size() -> Result<CGSize, Error> {
let frontmost_window = get_frontmost_window()?;
get_ui_element_size(&frontmost_window)
}
pub(crate) fn get_frontmost_window_frame() -> Result<CGRect, Error> {
let origin = get_frontmost_window_origin()?;
let size = get_frontmost_window_size()?;
Ok(CGRect { origin, size })
}
/// Get the frontmost window's close button, then extract its frame.
fn get_frontmost_window_close_button_frame() -> Result<CGRect, Error> {
let window = get_frontmost_window()?;
let mut ptr_to_close_button: *const CFType = std::ptr::null();
let ptr_to_buffer = NonNull::new(&mut ptr_to_close_button).unwrap();
let close_button_attribute = CFString::from_static_str("AXCloseButton");
let error = unsafe { window.copy_attribute_value(&close_button_attribute, ptr_to_buffer) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
assert!(!ptr_to_close_button.is_null());
let close_button_element = ptr_to_close_button.cast::<AXUIElement>().cast_mut();
let close_button = unsafe { CFRetained::from_raw(NonNull::new(close_button_element).unwrap()) };
let origin = get_ui_element_origin(&close_button)?;
let size = get_ui_element_size(&close_button)?;
Ok(CGRect { origin, size })
}
/// This function returns the "visible frame" [^1] of all the screens.
///
/// FIXME: This function relies on the [`visibleFrame()`][vf_doc] API, which
/// has 2 bugs we need to work around:
///
/// 1. It assumes the Dock is on the main display, which in reality depends on
/// how users arrange their displays and the "Dock position on screen" setting
/// entry.
/// 2. For non-main displays, it assumes that they don't have a menu bar, but macOS
/// puts a menu bar on every display.
///
/// Update: This could be wrong, but looks like Apple fixed these 2 bugs in macOS
/// 26. At least the buggy behaviors disappear in my test.
///
///
/// [^1]: Visible frame: a rectangle defines the portion of the screen in which it
/// is currently safe to draw your apps content.
///
/// [vf_doc]: https://developer.apple.com/documentation/AppKit/NSScreen/visibleFrame
pub(crate) fn list_visible_frame_of_all_screens() -> Result<Vec<CGRect>, Error> {
let main_thread_marker = MainThreadMarker::new().ok_or(Error::NotInMainThread)?;
let screens = NSScreen::screens(main_thread_marker).to_vec();
if screens.is_empty() {
return Ok(Vec::new());
}
let main_screen = screens.first().expect("screens is not empty");
let frames = screens
.iter()
.map(|ns_screen| {
// NSScreen is an AppKit API, which uses unflipped coordinate
// system, flip it
let mut unflipped_frame = ns_screen.visibleFrame();
let flipped_frame_origin_y = flip_frame_y(
main_screen.frame().size.height,
unflipped_frame.size.height,
unflipped_frame.origin.y,
);
unflipped_frame.origin.y = flipped_frame_origin_y;
unflipped_frame
})
.collect();
Ok(frames)
}
/// Get the Visible frame of the "active screen"[^1].
///
///
/// [^1]: the screen which the frontmost window is on.
pub(crate) fn get_active_screen_visible_frame() -> Result<CGRect, Error> {
let main_thread_marker = MainThreadMarker::new().ok_or(Error::NotInMainThread)?;
let frontmost_window_frame = get_frontmost_window_frame()?;
let screens = NSScreen::screens(main_thread_marker)
.into_iter()
.collect::<Vec<_>>();
if screens.is_empty() {
return Err(Error::NoDisplay);
}
let main_screen_height = screens[0].frame().size.height;
// AppKit uses Unflipped Coordinate System, but Accessibility APIs use
// Flipped Coordinate System, we need to flip the origin of these screens.
for screen in screens {
let mut screen_frame = screen.frame();
let unflipped_y = screen_frame.origin.y;
let flipped_y = flip_frame_y(main_screen_height, screen_frame.size.height, unflipped_y);
screen_frame.origin.y = flipped_y;
if intersects(screen_frame, frontmost_window_frame) {
let mut visible_frame = screen.visibleFrame();
let flipped_y = flip_frame_y(
main_screen_height,
visible_frame.size.height,
visible_frame.origin.y,
);
visible_frame.origin.y = flipped_y;
return Ok(visible_frame);
}
}
unreachable!()
}
/// Move the frontmost window's origin to the point specified by `x` and `y`.
pub fn move_frontmost_window(x: f64, y: f64) -> Result<(), Error> {
let frontmost_window = get_frontmost_window()?;
let mut point = CGPoint::new(x, y);
let ptr_to_point = NonNull::new((&mut point as *mut CGPoint).cast::<c_void>()).unwrap();
let pos_value = unsafe { AXValue::new(AXValueType::CGPoint, ptr_to_point) }.unwrap();
let pos_attr = CFString::from_static_str("AXPosition");
let error = unsafe { frontmost_window.set_attribute_value(&pos_attr, pos_value.deref()) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
Ok(())
}
/// Set the frontmost window's frame to the specified frame - adjust size and
/// location at the same time.
///
/// This function **retries** up to `RETRY` times until the set operations
/// successfully get performed.
///
/// # Retry
///
/// Retry is added because I encountered a case where `AXUIElementSetAttributeValue()`
/// does not work in the expected way. When I execute the `NextDisplay` command
/// to move the focused window from a big display (2560x1440) to a small display
/// (1440*900), the window size could be set to 1460 sometimes. No idea if this
/// is a bug of the Accessibility APIs or due to the improper API uses. So we
/// retry for `RETRY` times at most to try our beest make it behave correctly.
pub fn set_frontmost_window_frame(frame: CGRect) -> Result<(), Error> {
const RETRY: usize = 5;
/// Sleep for 50ms as I don't want to send too many requests to the focused
/// app and WindowServer because doing that could make them busy and then
/// they won't process my set requests.
///
/// The above is simply my observation, I don't know how the messaging really
/// works under the hood.
const SLEEP: Duration = Duration::from_millis(50);
let frontmost_window = get_frontmost_window()?;
/*
* Set window origin
*/
set_ui_element_origin_oneshot(&frontmost_window, frame.origin)?;
for _ in 0..RETRY {
std::thread::sleep(SLEEP);
let current = get_ui_element_origin(&frontmost_window)?;
if current == frame.origin {
break;
} else {
set_ui_element_origin_oneshot(&frontmost_window, frame.origin)?;
}
}
/*
* Set window size
*/
set_ui_element_size_oneshot(&frontmost_window, frame.size)?;
for _ in 0..RETRY {
std::thread::sleep(SLEEP);
let current = get_ui_element_size(&frontmost_window)?;
// For size, we do not check if `current` has the exact same value as
// `frame.size` as I have encountered a case where I ask macOS to set
// the height to 1550, but the height gets set to 1551.
if cgsize_roughly_equal(current, frame.size, 3.0) {
break;
} else {
set_ui_element_size_oneshot(&frontmost_window, frame.size)?;
}
}
Ok(())
}
pub fn toggle_fullscreen() -> Result<(), Error> {
let frontmost_window = get_frontmost_window()?;
let fullscreen_attr = CFString::from_static_str("AXFullScreen");
let mut current_value_ref: *const CFType = std::ptr::null();
let error = unsafe {
frontmost_window.copy_attribute_value(
&fullscreen_attr,
NonNull::new(&mut current_value_ref).unwrap(),
)
};
// TODO: If the attribute doesn't exist, error won't be Success as well.
// Before we handle that, we need to know the error case that will be
// returned in that case.
if error != AXError::Success {
return Err(Error::AXError(error));
}
assert!(!current_value_ref.is_null());
let current_value = unsafe {
let retained_boolean: CFRetained<CFBoolean> = CFRetained::from_raw(
NonNull::new(current_value_ref.cast::<CFBoolean>().cast_mut()).unwrap(),
);
retained_boolean.as_bool()
};
let new_value = !current_value;
let new_value_ref: CFRetained<CFBoolean> = CFBoolean::new(new_value).retain();
let error =
unsafe { frontmost_window.set_attribute_value(&fullscreen_attr, new_value_ref.deref()) };
if error != AXError::Success {
return Err(Error::AXError(error));
}
Ok(())
}
/// Check if `lhs` roughly equals to `rhs`. The Roughness can be controlled by
/// argument `tolerance`.
fn cgsize_roughly_equal(lhs: CGSize, rhs: CGSize, tolerance: f64) -> bool {
let width_diff = (lhs.width - rhs.width).abs();
let height_diff = (lhs.height - rhs.height).abs();
width_diff <= tolerance && height_diff <= tolerance
}
static LAST_FRAME: LazyLock<Mutex<HashMap<CGWindowID, CGRect>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub(crate) fn set_frontmost_window_last_frame(window_id: CGWindowID, frame: CGRect) {
let mut map = LAST_FRAME.lock().unwrap();
map.insert(window_id, frame);
}
pub(crate) fn get_frontmost_window_last_frame(window_id: CGWindowID) -> Option<CGRect> {
let map = LAST_FRAME.lock().unwrap();
map.get(&window_id).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_intersects_adjacent_rects_x() {
let r1 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
let r2 = CGRect::new(CGPoint::new(100.0, 0.0), CGSize::new(100.0, 100.0));
assert!(
!intersects(r1, r2),
"Adjacent rects on X should not intersect"
);
}
#[test]
fn test_intersects_adjacent_rects_y() {
let r1 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
let r2 = CGRect::new(CGPoint::new(0.0, 100.0), CGSize::new(100.0, 100.0));
assert!(
!intersects(r1, r2),
"Adjacent rects on Y should not intersect"
);
}
#[test]
fn test_intersects_overlapping_rects() {
let r1 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
let r2 = CGRect::new(CGPoint::new(50.0, 50.0), CGSize::new(100.0, 100.0));
assert!(intersects(r1, r2), "Overlapping rects should intersect");
}
#[test]
fn test_intersects_separate_rects() {
let r1 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
let r2 = CGRect::new(CGPoint::new(101.0, 101.0), CGSize::new(100.0, 100.0));
assert!(!intersects(r1, r2), "Separate rects should not intersect");
}
#[test]
fn test_intersects_contained_rect() {
let r1 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
let r2 = CGRect::new(CGPoint::new(10.0, 10.0), CGSize::new(50.0, 50.0));
assert!(intersects(r1, r2), "Contained rect should intersect");
}
#[test]
fn test_intersects_identical_rects() {
let r1 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
let r2 = CGRect::new(CGPoint::new(0.0, 0.0), CGSize::new(100.0, 100.0));
assert!(intersects(r1, r2), "Identical rects should intersect");
}
}

View File

@@ -1,70 +0,0 @@
//! Private macOS APIs.
use bitflags::bitflags;
use objc2_application_services::AXError;
use objc2_application_services::AXUIElement;
use objc2_core_foundation::CFArray;
use objc2_core_graphics::CGError;
use objc2_core_graphics::CGWindowID;
use std::ffi::c_int;
use std::ffi::c_uint;
use std::ffi::c_ushort;
pub(crate) type CGSConnectionID = u32;
pub(crate) type CGSSpaceID = c_int;
bitflags! {
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct CGSSpaceMask: c_int {
const INCLUDE_CURRENT = 1 << 0;
const INCLUDE_OTHERS = 1 << 1;
const INCLUDE_USER = 1 << 2;
const INCLUDE_OS = 1 << 3;
const VISIBLE = 1 << 16;
const CURRENT_SPACES = Self::INCLUDE_USER.bits() | Self::INCLUDE_CURRENT.bits();
const OTHER_SPACES = Self::INCLUDE_USER.bits() | Self::INCLUDE_OTHERS.bits();
const ALL_SPACES =
Self::INCLUDE_USER.bits() | Self::INCLUDE_OTHERS.bits() | Self::INCLUDE_CURRENT.bits();
const ALL_VISIBLE_SPACES = Self::ALL_SPACES.bits() | Self::VISIBLE.bits();
const CURRENT_OS_SPACES = Self::INCLUDE_OS.bits() | Self::INCLUDE_CURRENT.bits();
const OTHER_OS_SPACES = Self::INCLUDE_OS.bits() | Self::INCLUDE_OTHERS.bits();
const ALL_OS_SPACES =
Self::INCLUDE_OS.bits() | Self::INCLUDE_OTHERS.bits() | Self::INCLUDE_CURRENT.bits();
}
}
unsafe extern "C" {
/// Extract `window_id` from an AXUIElement.
pub(crate) fn _AXUIElementGetWindow(
elem: *mut AXUIElement,
window_id: *mut CGWindowID,
) -> AXError;
/// Connect to the WindowServer and get a connection descriptor.
pub(crate) fn CGSMainConnectionID() -> CGSConnectionID;
/// It returns a CFArray of dictionaries. Each dictionary contains information
/// about a display, including a list of all the spaces (CGSSpaceID) on that display.
pub(crate) fn CGSCopyManagedDisplaySpaces(cid: CGSConnectionID) -> *mut CFArray;
/// Gets the ID of the space currently visible to the user.
pub(crate) fn CGSGetActiveSpace(cid: CGSConnectionID) -> CGSSpaceID;
/// Returns the values the symbolic hot key represented by the given UID is configured with.
pub(crate) fn CGSGetSymbolicHotKeyValue(
hotKey: c_ushort,
outKeyEquivalent: *mut c_ushort,
outVirtualKeyCode: *mut c_ushort,
outModifiers: *mut c_uint,
) -> CGError;
/// Returns whether the symbolic hot key represented by the given UID is enabled.
pub(crate) fn CGSIsSymbolicHotKeyEnabled(hotKey: c_ushort) -> bool;
/// Sets whether the symbolic hot key represented by the given UID is enabled.
pub(crate) fn CGSSetSymbolicHotKeyEnabled(hotKey: c_ushort, isEnabled: bool) -> CGError;
}

View File

@@ -1,25 +0,0 @@
use objc2_application_services::AXError;
use objc2_core_graphics::CGError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
/// Cannot find the focused window.
#[error("Cannot find the focused window.")]
CannotFindFocusWindow,
/// Error code from the macOS Accessibility APIs.
#[error("Error code from the macOS Accessibility APIs: {0:?}")]
AXError(AXError),
/// Function should be in called from the main thread, but it is not.
#[error("Function should be in called from the main thread, but it is not.")]
NotInMainThread,
/// No monitor detected.
#[error("No monitor detected.")]
NoDisplay,
/// Can only handle 16 Workspaces at most.
#[error("libwmgr can only handle 16 Workspaces at most.")]
TooManyWorkspace,
/// Error code from the macOS Core Graphics APIs.
#[error("Error code from the macOS Core Graphics APIs: {0:?}")]
CGError(CGError),
}

View File

@@ -1,974 +0,0 @@
pub(crate) mod actions;
mod backend;
mod error;
pub(crate) mod on_opened;
pub(crate) mod search_source;
use crate::common::document::open;
use crate::extension::Extension;
use actions::Action;
use backend::get_active_screen_visible_frame;
use backend::get_frontmost_window_frame;
use backend::get_frontmost_window_id;
use backend::get_frontmost_window_last_frame;
use backend::get_next_workspace_logical_id;
use backend::get_previous_workspace_logical_id;
use backend::list_visible_frame_of_all_screens;
use backend::move_frontmost_window;
use backend::move_frontmost_window_to_workspace;
use backend::set_frontmost_window_frame;
use backend::set_frontmost_window_last_frame;
use backend::toggle_fullscreen;
use error::Error;
use objc2_core_foundation::{CGPoint, CGRect, CGSize};
use oneshot::channel as oneshot_channel;
use tauri::AppHandle;
use tauri::async_runtime;
use tauri_plugin_global_shortcut::GlobalShortcutExt;
use tauri_plugin_global_shortcut::ShortcutState;
pub(crate) const EXTENSION_ID: &str = "Window Management";
pub(crate) const EXTENSION_NAME_LOWERCASE: &str = "window management";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = include_str!("./plugin.json");
pub(crate) fn perform_action_on_main_thread(
tauri_app_handle: &AppHandle,
action: Action,
) -> Result<(), String> {
let (tx, rx) = oneshot_channel();
tauri_app_handle
.run_on_main_thread(move || {
let res = perform_action(action).map_err(|e| e.to_string());
tx.send(res)
.expect("oneshot channel receiver unexpectedly dropped");
})
.expect("tauri internal bug, channel receiver dropped");
rx.recv()
.expect("oneshot channel sender unexpectedly dropped before sending function return value")
}
/// Perform this action to the focused window.
fn perform_action(action: Action) -> Result<(), Error> {
let visible_frame = get_active_screen_visible_frame()?;
let frontmost_window_id = get_frontmost_window_id()?;
let frontmost_window_frame = get_frontmost_window_frame()?;
set_frontmost_window_last_frame(frontmost_window_id, frontmost_window_frame);
match action {
Action::TopHalf => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomHalf => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 2.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::LeftHalf => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::RightHalf => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 2.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::CenterHalf => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 4.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopLeftQuarter => {
let origin = visible_frame.origin;
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopRightQuarter => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 2.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomLeftQuarter => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 2.0,
};
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomRightQuarter => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 2.0,
y: visible_frame.origin.y + visible_frame.size.height / 2.0,
};
let size = CGSize {
width: visible_frame.size.width / 2.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopLeftSixth => {
let origin = visible_frame.origin;
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopCenterSixth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 3.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopRightSixth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width * 2.0 / 3.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomLeftSixth => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 2.0,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomCenterSixth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 3.0,
y: visible_frame.origin.y + visible_frame.size.height / 2.0,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomRightSixth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width * 2.0 / 3.0,
y: visible_frame.origin.y + visible_frame.size.height / 2.0,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height / 2.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopThird => {
let origin = visible_frame.origin;
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 3.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::MiddleThird => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 3.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 3.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomThird => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height * 2.0 / 3.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 3.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::Center => {
let window_size = frontmost_window_frame.size;
let origin = CGPoint {
x: visible_frame.origin.x + (visible_frame.size.width - window_size.width) / 2.0,
y: visible_frame.origin.y + (visible_frame.size.height - window_size.height) / 2.0,
};
move_frontmost_window(origin.x, origin.y)
}
Action::FirstFourth => {
let origin = visible_frame.origin;
let size = CGSize {
width: visible_frame.size.width / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::SecondFourth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 4.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::ThirdFourth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width * 2.0 / 4.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::LastFourth => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width * 3.0 / 4.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::FirstThird => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::CenterThird => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 3.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::LastThird => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width * 2.0 / 3.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width / 3.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::FirstTwoThirds => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 2.0 / 3.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::CenterTwoThirds => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 6.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 2.0 / 3.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::LastTwoThirds => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 3.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 2.0 / 3.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::FirstThreeFourths => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 3.0 / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::CenterThreeFourths => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 8.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 3.0 / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::LastThreeFourths => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 4.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 3.0 / 4.0,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopThreeFourths => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height * 3.0 / 4.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomThreeFourths => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 4.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height * 3.0 / 4.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopTwoThirds => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height * 2.0 / 3.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::BottomTwoThirds => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 3.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height * 2.0 / 3.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopCenterTwoThirds => {
let origin = CGPoint {
x: visible_frame.origin.x + visible_frame.size.width / 6.0,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width * 2.0 / 3.0,
height: visible_frame.size.height * 2.0 / 3.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopFirstFourth => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 4.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopSecondFourth => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height / 4.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 4.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopThirdFourth => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height * 2.0 / 4.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 4.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::TopLastFourth => {
let origin = CGPoint {
x: visible_frame.origin.x,
y: visible_frame.origin.y + visible_frame.size.height * 3.0 / 4.0,
};
let size = CGSize {
width: visible_frame.size.width,
height: visible_frame.size.height / 4.0,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::MakeLarger => {
let window_origin = frontmost_window_frame.origin;
let window_size = frontmost_window_frame.size;
let delta_width = 20_f64;
let delta_height = window_size.height / window_size.width * delta_width;
let delta_origin_x = delta_width / 2.0;
let delta_origin_y = delta_height / 2.0;
let new_width = {
let possible_value = window_size.width + delta_width;
if possible_value > visible_frame.size.width {
visible_frame.size.width
} else {
possible_value
}
};
let new_height = {
let possible_value = window_size.height + delta_height;
if possible_value > visible_frame.size.height {
visible_frame.size.height
} else {
possible_value
}
};
let new_origin_x = {
let possible_value = window_origin.x - delta_origin_x;
if possible_value < visible_frame.origin.x {
visible_frame.origin.x
} else {
possible_value
}
};
let new_origin_y = {
let possible_value = window_origin.y - delta_origin_y;
if possible_value < visible_frame.origin.y {
visible_frame.origin.y
} else {
possible_value
}
};
let origin = CGPoint {
x: new_origin_x,
y: new_origin_y,
};
let size = CGSize {
width: new_width,
height: new_height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::MakeSmaller => {
let window_origin = frontmost_window_frame.origin;
let window_size = frontmost_window_frame.size;
let delta_width = 20_f64;
let delta_height = window_size.height / window_size.width * delta_width;
let delta_origin_x = delta_width / 2.0;
let delta_origin_y = delta_height / 2.0;
let origin = CGPoint {
x: window_origin.x + delta_origin_x,
y: window_origin.y + delta_origin_y,
};
let size = CGSize {
width: window_size.width - delta_width,
height: window_size.height - delta_height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::AlmostMaximize => {
let new_size = CGSize {
width: visible_frame.size.width * 0.9,
height: visible_frame.size.height * 0.9,
};
let new_origin = CGPoint {
x: visible_frame.origin.x + (visible_frame.size.width * 0.1),
y: visible_frame.origin.y + (visible_frame.size.height * 0.1),
};
let new_frame = CGRect {
origin: new_origin,
size: new_size,
};
set_frontmost_window_frame(new_frame)
}
Action::Maximize => {
let new_frame = CGRect {
origin: visible_frame.origin,
size: visible_frame.size,
};
set_frontmost_window_frame(new_frame)
}
Action::MaximizeWidth => {
let window_origin = frontmost_window_frame.origin;
let window_size = frontmost_window_frame.size;
let origin = CGPoint {
x: visible_frame.origin.x,
y: window_origin.y,
};
let size = CGSize {
width: visible_frame.size.width,
height: window_size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::MaximizeHeight => {
let window_origin = frontmost_window_frame.origin;
let window_size = frontmost_window_frame.size;
let origin = CGPoint {
x: window_origin.x,
y: visible_frame.origin.y,
};
let size = CGSize {
width: window_size.width,
height: visible_frame.size.height,
};
let new_frame = CGRect { origin, size };
set_frontmost_window_frame(new_frame)
}
Action::MoveUp => {
let window_origin = frontmost_window_frame.origin;
let new_y = (window_origin.y - 10.0).max(visible_frame.origin.y);
move_frontmost_window(window_origin.x, new_y)
}
Action::MoveDown => {
let window_origin = frontmost_window_frame.origin;
let window_size = frontmost_window_frame.size;
let new_y = (window_origin.y + 10.0)
.min(visible_frame.origin.y + visible_frame.size.height - window_size.height);
move_frontmost_window(window_origin.x, new_y)
}
Action::MoveLeft => {
let window_origin = frontmost_window_frame.origin;
let new_x = (window_origin.x - 10.0).max(visible_frame.origin.x);
move_frontmost_window(new_x, window_origin.y)
}
Action::MoveRight => {
let window_origin = frontmost_window_frame.origin;
let window_size = frontmost_window_frame.size;
let new_x = (window_origin.x + 10.0)
.min(visible_frame.origin.x + visible_frame.size.width - window_size.width);
move_frontmost_window(new_x, window_origin.y)
}
Action::NextDesktop => {
let Some(next_workspace_logical_id) = get_next_workspace_logical_id() else {
// nothing to do
return Ok(());
};
move_frontmost_window_to_workspace(next_workspace_logical_id)
}
Action::PreviousDesktop => {
let Some(previous_workspace_logical_id) = get_previous_workspace_logical_id() else {
// nothing to do
return Ok(());
};
// Now let's switch the workspace
move_frontmost_window_to_workspace(previous_workspace_logical_id)
}
Action::NextDisplay => {
const TOO_MANY_MONITORS: &str = "I don't think you can have so many monitors";
let frames = list_visible_frame_of_all_screens()?;
let n_frames = frames.len();
if n_frames == 0 {
return Err(Error::NoDisplay);
}
if n_frames == 1 {
return Ok(());
}
let index = frames
.iter()
.position(|fr| fr == &visible_frame)
.expect("active screen should be in the list");
let new_index: usize = {
let index_i32: i32 = index.try_into().expect(TOO_MANY_MONITORS);
let index_i32_plus_one = index_i32.checked_add(1).expect(TOO_MANY_MONITORS);
let final_value = index_i32_plus_one % n_frames as i32;
final_value
.try_into()
.expect("final value should be positive")
};
let new_frame = frames[new_index];
set_frontmost_window_frame(new_frame)
}
Action::PreviousDisplay => {
const TOO_MANY_MONITORS: &str = "I don't think you can have so many monitors";
let frames = list_visible_frame_of_all_screens()?;
let n_frames = frames.len();
if n_frames == 0 {
return Err(Error::NoDisplay);
}
if n_frames == 1 {
return Ok(());
}
let index = frames
.iter()
.position(|fr| fr == &visible_frame)
.expect("active screen should be in the list");
let new_index: usize = {
let index_i32: i32 = index.try_into().expect(TOO_MANY_MONITORS);
let index_i32_minus_one = index_i32 - 1;
let n_frames_i32: i32 = n_frames.try_into().expect(TOO_MANY_MONITORS);
let final_value = (index_i32_minus_one + n_frames_i32) % n_frames_i32;
final_value
.try_into()
.expect("final value should be positive")
};
let new_frame = frames[new_index];
set_frontmost_window_frame(new_frame)
}
Action::Restore => {
let Some(previous_frame) = get_frontmost_window_last_frame(frontmost_window_id) else {
// Previous frame found, Nothing to do
return Ok(());
};
set_frontmost_window_frame(previous_frame)
}
Action::ToggleFullscreen => toggle_fullscreen(),
}
}
pub(crate) fn set_up_commands_hotkeys(
tauri_app_handle: &AppHandle,
wm_extension: &Extension,
) -> Result<(), String> {
for command in wm_extension
.commands
.as_ref()
.expect("Window Management extension has commands")
.iter()
.filter(|cmd| cmd.enabled)
{
if let Some(ref hotkey) = command.hotkey {
let on_opened = on_opened::on_opened(&command.id);
let extension_id_clone = command.id.clone();
tauri_app_handle
.global_shortcut()
.on_shortcut(hotkey.as_str(), move |tauri_app_handle, _hotkey, event| {
let on_opened_clone = on_opened.clone();
let extension_id_clone = extension_id_clone.clone();
let app_handle_clone = tauri_app_handle.clone();
if event.state() == ShortcutState::Pressed {
async_runtime::spawn(async move {
let result = open(app_handle_clone, on_opened_clone, None).await;
if let Err(msg) = result {
log::warn!(
"failed to open extension [{}], error [{}]",
extension_id_clone,
msg
);
}
});
}
})
.map_err(|e| e.to_string())?;
}
}
Ok(())
}
pub(crate) fn unset_commands_hotkeys(
tauri_app_handle: &AppHandle,
wm_extension: &Extension,
) -> Result<(), String> {
for command in wm_extension
.commands
.as_ref()
.expect("Window Management extension has commands")
.iter()
.filter(|cmd| cmd.enabled)
{
if let Some(ref hotkey) = command.hotkey {
tauri_app_handle
.global_shortcut()
.unregister(hotkey.as_str())
.map_err(|e| e.to_string())?;
}
}
Ok(())
}
pub(crate) fn set_up_command_hotkey(
tauri_app_handle: &AppHandle,
wm_extension: &Extension,
command_id: &str,
) -> Result<(), String> {
let commands = wm_extension
.commands
.as_ref()
.expect("Window Management has commands");
let opt_command = commands.iter().find(|ext| ext.id == command_id);
let Some(command) = opt_command else {
panic!("Window Management command does not exist {}", command_id);
};
if let Some(ref hotkey) = command.hotkey {
let on_opened = on_opened::on_opened(&command.id);
let extension_id_clone = command.id.clone();
tauri_app_handle
.global_shortcut()
.on_shortcut(hotkey.as_str(), move |tauri_app_handle, _hotkey, event| {
let on_opened_clone = on_opened.clone();
let extension_id_clone = extension_id_clone.clone();
let app_handle_clone = tauri_app_handle.clone();
if event.state() == ShortcutState::Pressed {
async_runtime::spawn(async move {
let result = open(app_handle_clone, on_opened_clone, None).await;
if let Err(msg) = result {
log::warn!(
"failed to open extension [{}], error [{}]",
extension_id_clone,
msg
);
}
});
}
})
.map_err(|e| e.to_string())?;
}
Ok(())
}
pub(crate) fn unset_command_hotkey(
tauri_app_handle: &AppHandle,
wm_extension: &Extension,
command_id: &str,
) -> Result<(), String> {
let commands = wm_extension
.commands
.as_ref()
.expect("Window Management has commands");
let opt_command = commands.iter().find(|ext| ext.id == command_id);
let Some(command) = opt_command else {
panic!("Window Management command does not exist {}", command_id);
};
if let Some(ref hotkey) = command.hotkey {
tauri_app_handle
.global_shortcut()
.unregister(hotkey.as_str())
.map_err(|e| e.to_string())?;
}
Ok(())
}
pub(crate) fn register_command_hotkey(
tauri_app_handle: &AppHandle,
command_id: &str,
hotkey: &str,
) -> Result<(), String> {
let on_opened = on_opened::on_opened(&command_id);
let extension_id_clone = command_id.to_string();
tauri_app_handle
.global_shortcut()
.on_shortcut(hotkey, move |tauri_app_handle, _hotkey, event| {
let on_opened_clone = on_opened.clone();
let extension_id_clone = extension_id_clone.clone();
let app_handle_clone = tauri_app_handle.clone();
if event.state() == ShortcutState::Pressed {
async_runtime::spawn(async move {
let result = open(app_handle_clone, on_opened_clone, None).await;
if let Err(msg) = result {
log::warn!(
"failed to open extension [{}], error [{}]",
extension_id_clone,
msg
);
}
});
}
})
.map_err(|e| e.to_string())?;
Ok(())
}
pub(crate) fn unregister_command_hotkey(
tauri_app_handle: &AppHandle,
wm_extension: &Extension,
command_id: &str,
) -> Result<(), String> {
let commands = wm_extension
.commands
.as_ref()
.expect("Window Management has commands");
let opt_command = commands.iter().find(|ext| ext.id == command_id);
let Some(command) = opt_command else {
panic!("Window Management command does not exist {}", command_id);
};
let Some(ref hotkey) = command.hotkey else {
return Ok(());
};
tauri_app_handle
.global_shortcut()
.unregister(hotkey.as_str())
.map_err(|e| e.to_string())?;
Ok(())
}

View File

@@ -1,10 +0,0 @@
use super::actions::Action;
use crate::common::document::OnOpened;
use serde_plain;
pub(crate) fn on_opened(command_id: &str) -> OnOpened {
let action: Action = serde_plain::from_str(command_id).unwrap_or_else(|_| {
panic!("Window Management commands IDs should be valid for `enum Action`, someone corrupts the JSON file");
});
OnOpened::WindowManagementAction { action }
}

View File

@@ -1,415 +0,0 @@
{
"id": "Window Management",
"name": "Window Management",
"platforms": [
"macos"
],
"description": "Resize, reorganize and move your focused window effortlessly",
"icon": "font_a-Windowmanagement",
"type": "extension",
"category": "Utilities",
"tags": [
"Productivity"
],
"commands": [
{
"id": "TopHalf",
"name": "Top Half",
"description": "Move the focused window to fill left half of the screen.",
"icon": "font_a-TopHalf",
"type": "command"
},
{
"id": "BottomHalf",
"name": "Bottom Half",
"description": "Move the focused window to fill bottom half of the screen.",
"icon": "font_a-BottomHalf",
"type": "command"
},
{
"id": "LeftHalf",
"name": "Left Half",
"description": "Move the focused window to fill left half of the screen.",
"icon": "font_a-LeftHalf",
"type": "command"
},
{
"id": "RightHalf",
"name": "Right Half",
"description": "Move the focused window to fill right half of the screen.",
"icon": "font_a-RightHalf",
"type": "command"
},
{
"id": "CenterHalf",
"name": "Center Half",
"description": "Move the focused window to fill center half of the screen.",
"icon": "font_a-CenterHalf",
"type": "command"
},
{
"id": "Maximize",
"name": "Maximize",
"description": "Maximize the focused window to fit the screen.",
"icon": "font_Maximize",
"type": "command"
},
{
"id": "TopLeftQuarter",
"name": "Top Left Quarter",
"description": "Resize the focused window to the top left quarter of the screen.",
"icon": "font_a-TopLeftQuarter",
"type": "command"
},
{
"id": "TopRightQuarter",
"name": "Top Right Quarter",
"description": "Resize the focused window to the top right quarter of the screen.",
"icon": "font_a-TopRightQuarter",
"type": "command"
},
{
"id": "BottomLeftQuarter",
"name": "Bottom Left Quarter",
"description": "Resize the focused window to the bottom left quarter of the screen.",
"icon": "font_a-BottomLeftQuarter",
"type": "command"
},
{
"id": "BottomRightQuarter",
"name": "Bottom Right Quarter",
"description": "Resize the focused window to the bottom right quarter of the screen.",
"icon": "font_a-BottomRightQuarter",
"type": "command"
},
{
"id": "TopLeftSixth",
"name": "Top Left Sixth",
"description": "Resize the focused window to the top left sixth of the screen.",
"icon": "font_a-TopLeftSixth",
"type": "command"
},
{
"id": "TopCenterSixth",
"name": "Top Center Sixth",
"description": "Resize the focused window to the top center sixth of the screen.",
"icon": "font_a-TopCenterSixth",
"type": "command"
},
{
"id": "TopRightSixth",
"name": "Top Right Sixth",
"description": "Resize the focused window to the top right sixth of the screen.",
"icon": "font_a-TopRightSixth",
"type": "command"
},
{
"id": "BottomLeftSixth",
"name": "Bottom Left Sixth",
"description": "Resize the focused window to the bottom left sixth of the screen.",
"icon": "font_a-BottomLeftSixth",
"type": "command"
},
{
"id": "BottomCenterSixth",
"name": "Bottom Center Sixth",
"description": "Resize the focused window to the bottom center sixth of the screen.",
"icon": "font_a-BottomCenterSixth",
"type": "command"
},
{
"id": "BottomRightSixth",
"name": "Bottom Right Sixth",
"description": "Resize the focused window to the bottom right sixth of the screen.",
"icon": "font_a-BottomRightSixth",
"type": "command"
},
{
"id": "TopThird",
"name": "Top Third",
"description": "Resize the focused window to the top third of the screen.",
"icon": "font_a-TopThirdFourth",
"type": "command"
},
{
"id": "MiddleThird",
"name": "Middle Third",
"description": "Resize the focused window to the middle third of the screen.",
"icon": "font_a-MiddleThird",
"type": "command"
},
{
"id": "BottomThird",
"name": "Bottom Third",
"description": "Resize the focused window to the bottom third of the screen.",
"icon": "font_a-BottomThird",
"type": "command"
},
{
"id": "Center",
"name": "Center",
"description": "Center the focused window in the screen.",
"icon": "font_Center",
"type": "command"
},
{
"id": "FirstFourth",
"name": "First Fourth",
"description": "Resize the focused window to the first fourth of the screen.",
"icon": "font_a-FirstFourth",
"type": "command"
},
{
"id": "SecondFourth",
"name": "Second Fourth",
"description": "Resize the focused window to the second fourth of the screen.",
"icon": "font_a-SecondFourth",
"type": "command"
},
{
"id": "ThirdFourth",
"name": "Third Fourth",
"description": "Resize the focused window to the third fourth of the screen.",
"icon": "font_a-ThirdFourth",
"type": "command"
},
{
"id": "LastFourth",
"name": "Last Fourth",
"description": "Resize the focused window to the last fourth of the screen.",
"icon": "font_a-LastFourth",
"type": "command"
},
{
"id": "FirstThird",
"name": "First Third",
"description": "Resize the focused window to the first third of the screen.",
"icon": "font_a-FirstThird",
"type": "command"
},
{
"id": "CenterThird",
"name": "Center Third",
"description": "Resize the focused window to the center third of the screen.",
"icon": "font_a-CenterThird",
"type": "command"
},
{
"id": "LastThird",
"name": "Last Third",
"description": "Resize the focused window to the last third of the screen.",
"icon": "font_a-LastThird",
"type": "command"
},
{
"id": "FirstTwoThirds",
"name": "First Two Thirds",
"description": "Resize the focused window to the first two thirds of the screen.",
"icon": "font_a-FirstTwoThirds",
"type": "command"
},
{
"id": "CenterTwoThirds",
"name": "Center Two Thirds",
"description": "Resize the focused window to the center two thirds of the screen.",
"icon": "font_a-CenterTwoThirds",
"type": "command"
},
{
"id": "LastTwoThirds",
"name": "Last Two Thirds",
"description": "Resize the focused window to the last two thirds of the screen.",
"icon": "font_a-LastTwoThirds",
"type": "command"
},
{
"id": "FirstThreeFourths",
"name": "First Three Fourths",
"description": "Resize the focused window to the first three fourths of the screen.",
"icon": "font_a-FirstThreeFourths",
"type": "command"
},
{
"id": "CenterThreeFourths",
"name": "Center Three Fourths",
"description": "Resize the focused window to the center three fourths of the screen.",
"icon": "font_a-CenterThreeFourths",
"type": "command"
},
{
"id": "LastThreeFourths",
"name": "Last Three Fourths",
"description": "Resize the focused window to the last three fourths of the screen.",
"icon": "font_a-LastThreeFourths",
"type": "command"
},
{
"id": "TopThreeFourths",
"name": "Top Three Fourths",
"description": "Resize the focused window to the top three fourths of the screen.",
"icon": "font_a-TopThreeFourths",
"type": "command"
},
{
"id": "BottomThreeFourths",
"name": "Bottom Three Fourths",
"description": "Resize the focused window to the bottom three fourths of the screen.",
"icon": "font_a-BottomThreeFourths",
"type": "command"
},
{
"id": "TopTwoThirds",
"name": "Top Two Thirds",
"description": "Resize the focused window to the top two thirds of the screen.",
"icon": "font_a-TopTwoThirds",
"type": "command"
},
{
"id": "BottomTwoThirds",
"name": "Bottom Two Thirds",
"description": "Resize the focused window to the bottom two thirds of the screen.",
"icon": "font_a-BottomTwoThirds",
"type": "command"
},
{
"id": "TopCenterTwoThirds",
"name": "Top Center Two Thirds",
"description": "Resize the focused window to the top center two thirds of the screen.",
"icon": "font_a-TopCenterTwoThirds",
"type": "command"
},
{
"id": "TopFirstFourth",
"name": "Top First Fourth",
"description": "Resize the focused window to the top first fourth of the screen.",
"icon": "font_a-TopFirstFourth",
"type": "command"
},
{
"id": "TopSecondFourth",
"name": "Top Second Fourth",
"description": "Resize the focused window to the top second fourth of the screen.",
"icon": "font_a-TopSecondFourth",
"type": "command"
},
{
"id": "TopThirdFourth",
"name": "Top Third Fourth",
"description": "Resize the focused window to the top third fourth of the screen.",
"icon": "font_a-TopThirdFourth",
"type": "command"
},
{
"id": "TopLastFourth",
"name": "Top Last Fourth",
"description": "Resize the focused window to the top last fourth of the screen.",
"icon": "font_a-TopLastFourth",
"type": "command"
},
{
"id": "MakeLarger",
"name": "Make Larger",
"description": "Increase the focused window until it reaches the screen size.",
"icon": "font_a-MakeLarger",
"type": "command"
},
{
"id": "MakeSmaller",
"name": "Make Smaller",
"description": "Decrease the focused window until it reaches its minimal size.",
"icon": "font_a-MakeSmaller",
"type": "command"
},
{
"id": "AlmostMaximize",
"name": "Almost Maximize",
"description": "Maximize the focused window to almost fit the screen.",
"icon": "font_a-AlmostMaximize",
"type": "command"
},
{
"id": "MaximizeWidth",
"name": "Maximize Width",
"description": "Maximize width of the focused window to fit the screen.",
"icon": "font_a-MaximizeWidth",
"type": "command"
},
{
"id": "MaximizeHeight",
"name": "Maximize Height",
"description": "Maximize height of the focused window to fit the screen.",
"icon": "font_a-MaximizeHeight",
"type": "command"
},
{
"id": "MoveUp",
"name": "Move Up",
"description": "Move the focused window to the top edge of the screen.",
"icon": "font_a-MoveUp",
"type": "command"
},
{
"id": "MoveDown",
"name": "Move Down",
"description": "Move the focused window to the bottom of the screen.",
"icon": "font_a-MoveDown",
"type": "command"
},
{
"id": "MoveLeft",
"name": "Move Left",
"description": "Move the focused window to the left edge of the screen.",
"icon": "font_a-MoveLeft",
"type": "command"
},
{
"id": "MoveRight",
"name": "Move Right",
"description": "Move the focused window to the right edge of the screen.",
"icon": "font_a-MoveRight",
"type": "command"
},
{
"id": "NextDesktop",
"name": "Next Desktop",
"description": "Move the focused window to the next desktop.",
"icon": "font_a-NextDesktop",
"type": "command"
},
{
"id": "PreviousDesktop",
"name": "Previous Desktop",
"description": "Move the focused window to the previous desktop.",
"icon": "font_a-PreviousDesktop",
"type": "command"
},
{
"id": "NextDisplay",
"name": "Next Display",
"description": "Move the focused window to the next display.",
"icon": "font_a-NextDisplay",
"type": "command"
},
{
"id": "PreviousDisplay",
"name": "Previous Display",
"description": "Move the focused window to the previous display.",
"icon": "font_a-PreviousDisplay",
"type": "command"
},
{
"id": "Restore",
"name": "Restore",
"description": "Restore the focused window to its last position.",
"icon": "font_Restore",
"type": "command"
},
{
"id": "ToggleFullscreen",
"name": "Toggle Fullscreen",
"description": "Toggle fullscreen mode.",
"icon": "font_a-ToggleFullscreen",
"type": "command"
}
]
}

View File

@@ -1,138 +0,0 @@
use super::EXTENSION_ID;
use super::EXTENSION_NAME_LOWERCASE;
use crate::common::document::{DataSourceReference, Document};
use crate::common::{
error::SearchError,
search::{QueryResponse, QuerySource, SearchQuery},
traits::SearchSource,
};
use crate::extension::built_in::{get_built_in_extension_directory, load_extension_from_json_file};
use crate::extension::{ExtensionType, LOCAL_QUERY_SOURCE_TYPE, calculate_text_similarity};
use async_trait::async_trait;
use hostname;
use tauri::AppHandle;
/// A search source to allow users to search WM actions.
pub(crate) struct WindowManagementSearchSource;
#[async_trait]
impl SearchSource for WindowManagementSearchSource {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(EXTENSION_ID.into())
.to_string_lossy()
.into(),
id: EXTENSION_ID.into(),
}
}
async fn search(
&self,
tauri_app_handle: AppHandle,
query: SearchQuery,
) -> Result<QueryResponse, SearchError> {
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
let from = usize::try_from(query.from).expect("from too big");
let size = usize::try_from(query.size).expect("size too big");
let query_string = query_string.trim();
if query_string.is_empty() {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
let query_string_lowercase = query_string.to_lowercase();
let extension = load_extension_from_json_file(
&get_built_in_extension_directory(&tauri_app_handle),
super::EXTENSION_ID,
)
.map_err(SearchError::InternalError)?;
let commands = extension.commands.expect("this extension has commands");
let mut hits: Vec<(Document, f64)> = Vec::new();
// We know they are all commands
let command_type_string = ExtensionType::Command.to_string();
for command in commands.iter().filter(|ext| ext.enabled) {
let score = {
let mut score = 0_f64;
if let Some(name_score) =
calculate_text_similarity(&query_string_lowercase, &command.name.to_lowercase())
{
score += name_score;
}
if let Some(ref alias) = command.alias {
if let Some(alias_score) =
calculate_text_similarity(&query_string_lowercase, &alias.to_lowercase())
{
score += alias_score;
}
}
// An "extension" type extension should return all its
// sub-extensions when the query string matches its name.
// To do this, we score the extension name and take that
// into account.
if let Some(main_extension_score) =
calculate_text_similarity(&query_string_lowercase, &EXTENSION_NAME_LOWERCASE)
{
score += main_extension_score;
}
score
};
if score > 0.0 {
let on_opened = super::on_opened::on_opened(&command.id);
let url = on_opened.url();
let document = Document {
id: command.id.clone(),
title: Some(command.name.clone()),
icon: Some(command.icon.clone()),
on_opened: Some(on_opened),
url: Some(url),
category: Some(command_type_string.clone()),
source: Some(DataSourceReference {
id: Some(command_type_string.clone()),
name: Some(command_type_string.clone()),
icon: None,
r#type: Some(command_type_string.clone()),
}),
..Default::default()
};
hits.push((document, score));
}
}
hits.sort_by(|(_, score_a), (_, score_b)| {
score_a
.partial_cmp(&score_b)
.expect("expect no NAN/INFINITY/...")
});
let total_hits = hits.len();
let from_size_applied = hits.into_iter().skip(from).take(size).collect();
Ok(QueryResponse {
source: self.get_type(),
hits: from_size_applied,
total_hits,
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,822 +0,0 @@
//! Coco has 4 sources of `plugin.json` to check and validate:
//!
//! 1. From coco-extensions repository
//!
//! Granted, Coco APP won't check these files directly, but the code here
//! will run in that repository's CI to prevent errors in the first place.
//!
//! 2. From the "<data directory>/third_party_extensions" directory
//! 3. Imported via "Import Local Extension"
//! 4. Downloaded from the "store/extension/<extension ID>/_download" API
//!
//! This file contains the checks that are general enough to be applied to all
//! these 4 sources
use crate::extension::Extension;
use crate::extension::ExtensionType;
use crate::extension::PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION;
use crate::util::platform::Platform;
use std::collections::HashSet;
pub(crate) fn general_check(extension: &Extension) -> Result<(), String> {
// Check main extension
check_main_extension_only(extension)?;
check_main_extension_or_sub_extension(extension, &format!("extension [{}]", extension.id))?;
// `None` if `extension` is compatible with all the platforms. Otherwise `Some(limited_platforms)`
let limited_supported_platforms = match extension.platforms.as_ref() {
Some(platforms) => {
if platforms.len() == Platform::num_of_supported_platforms() {
None
} else {
Some(platforms)
}
}
None => None,
};
// Check sub extensions
let commands = match extension.commands {
Some(ref v) => v.as_slice(),
None => &[],
};
let scripts = match extension.scripts {
Some(ref v) => v.as_slice(),
None => &[],
};
let quicklinks = match extension.quicklinks {
Some(ref v) => v.as_slice(),
None => &[],
};
let views = match extension.views {
Some(ref v) => v.as_slice(),
None => &[],
};
let sub_extensions = [commands, scripts, quicklinks, views].concat();
let mut sub_extension_ids = HashSet::new();
for sub_extension in sub_extensions.iter() {
check_sub_extension_only(&extension.id, sub_extension, limited_supported_platforms)?;
check_main_extension_or_sub_extension(
extension,
&format!("sub-extension [{}-{}]", extension.id, sub_extension.id),
)?;
if !sub_extension_ids.insert(sub_extension.id.as_str()) {
// extension ID already exists
return Err(format!(
"sub-extension with ID [{}] already exists",
sub_extension.id
));
}
}
Ok(())
}
/// This checks the main extension only, it won't check sub-extensions.
fn check_main_extension_only(extension: &Extension) -> Result<(), String> {
// Group and Extension cannot have alias
if extension.alias.is_some() {
if extension.r#type == ExtensionType::Group || extension.r#type == ExtensionType::Extension
{
return Err(format!(
"invalid extension [{}], extension of type [{:?}] cannot have alias",
extension.id, extension.r#type
));
}
}
// Group and Extension cannot have hotkey
if extension.hotkey.is_some() {
if extension.r#type == ExtensionType::Group || extension.r#type == ExtensionType::Extension
{
return Err(format!(
"invalid extension [{}], extension of type [{:?}] cannot have hotkey",
extension.id, extension.r#type
));
}
}
if extension.commands.is_some()
|| extension.scripts.is_some()
|| extension.quicklinks.is_some()
|| extension.views.is_some()
{
if extension.r#type != ExtensionType::Group && extension.r#type != ExtensionType::Extension
{
return Err(format!(
"invalid extension [{}], only extension of type [Group] and [Extension] can have sub-extensions",
extension.id,
));
}
}
if extension.settings.is_some() {
// Sub-extensions are all searchable, so this check is only for main extensions.
if !extension.searchable() {
return Err(format!(
"invalid extension {}, field [settings] is currently only allowed in searchable extension, this type of extension is not searchable [{}]",
extension.id, extension.r#type
));
}
}
Ok(())
}
fn check_sub_extension_only(
extension_id: &str,
sub_extension: &Extension,
limited_platforms: Option<&HashSet<Platform>>,
) -> Result<(), String> {
if sub_extension.r#type == ExtensionType::Group
|| sub_extension.r#type == ExtensionType::Extension
{
return Err(format!(
"invalid sub-extension [{}-{}]: sub-extensions should not be of type [Group] or [Extension]",
extension_id, sub_extension.id
));
}
if sub_extension.commands.is_some()
|| sub_extension.scripts.is_some()
|| sub_extension.quicklinks.is_some()
|| sub_extension.views.is_some()
{
return Err(format!(
"invalid sub-extension [{}-{}]: fields [commands/scripts/quicklinks/views] should not be set in sub-extensions",
extension_id, sub_extension.id
));
}
if sub_extension.developer.is_some() {
return Err(format!(
"invalid sub-extension [{}-{}]: field [developer] should not be set in sub-extensions",
extension_id, sub_extension.id
));
}
if let Some(platforms_supported_by_main_extension) = limited_platforms {
match sub_extension.platforms {
Some(ref platforms_supported_by_sub_extension) => {
let diff = platforms_supported_by_sub_extension
.difference(&platforms_supported_by_main_extension)
.into_iter()
.map(|p| p.to_string())
.collect::<Vec<String>>();
if !diff.is_empty() {
return Err(format!(
"invalid sub-extension [{}-{}]: it supports platforms {:?} that are not supported by the main extension",
extension_id, sub_extension.id, diff
));
}
}
None => {
// if `sub_extension.platform` is None, it means it has the same value
// as main extension's `platforms` field, so we don't need to check it.
}
}
}
if sub_extension.minimum_coco_version.is_some() {
return Err(format!(
"invalid sub-extension [{}-{}]: [{}] cannot be set for sub-extensions",
extension_id, sub_extension.id, PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION
));
}
Ok(())
}
fn check_main_extension_or_sub_extension(
extension: &Extension,
identifier: &str,
) -> Result<(), String> {
// If field `action` is Some, then it should be a Command
if extension.action.is_some() && extension.r#type != ExtensionType::Command {
return Err(format!(
"invalid {}, field [action] is set for a non-Command extension",
identifier
));
}
if extension.r#type == ExtensionType::Command && extension.action.is_none() {
return Err(format!(
"invalid {}, field [action] should be set for a Command extension",
identifier
));
}
// If field `quicklink` is Some, then it should be a Quicklink
if extension.quicklink.is_some() && extension.r#type != ExtensionType::Quicklink {
return Err(format!(
"invalid {}, field [quicklink] is set for a non-Quicklink extension",
identifier
));
}
if extension.r#type == ExtensionType::Quicklink && extension.quicklink.is_none() {
return Err(format!(
"invalid {}, field [quicklink] should be set for a Quicklink extension",
identifier
));
}
// If field `page` is Some, then it should be a View
if extension.page.is_some() && extension.r#type != ExtensionType::View {
return Err(format!(
"invalid {}, field [page] is set for a non-View extension",
identifier
));
}
if extension.r#type == ExtensionType::View && extension.page.is_none() {
return Err(format!(
"invalid {}, field [page] should be set for a View extension",
identifier
));
}
// If field `ui` is Some, then it should be a View
if extension.ui.is_some() && extension.r#type != ExtensionType::View {
return Err(format!(
"invalid {}, field [ui] is set for a non-View extension",
identifier
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extension::{
CommandAction, ExtensionSettings, Quicklink, QuicklinkLink, QuicklinkLinkComponent,
};
/// Helper function to create a basic valid extension
fn create_basic_extension(id: &str, extension_type: ExtensionType) -> Extension {
let page = if extension_type == ExtensionType::View {
Some("index.html".into())
} else {
None
};
Extension {
id: id.to_string(),
name: "Test Extension".to_string(),
developer: None,
platforms: None,
description: "Test description".to_string(),
icon: "test-icon.png".to_string(),
r#type: extension_type,
action: None,
quicklink: None,
commands: None,
scripts: None,
quicklinks: None,
views: None,
alias: None,
hotkey: None,
enabled: true,
page,
ui: None,
permission: None,
settings: None,
minimum_coco_version: None,
screenshots: None,
url: None,
version: None,
}
}
/// Helper function to create a command action
fn create_command_action() -> CommandAction {
CommandAction {
exec: "echo".to_string(),
args: Some(vec!["test".to_string()]),
}
}
/// Helper function to create a quicklink
fn create_quicklink() -> Quicklink {
Quicklink {
link: QuicklinkLink {
components: vec![QuicklinkLinkComponent::StaticStr(
"https://example.com".to_string(),
)],
},
open_with: None,
}
}
/* test_check_main_extension_only */
#[test]
fn test_group_cannot_have_alias() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
extension.alias = Some("group-alias".to_string());
let result = general_check(&extension);
assert!(result.is_err());
assert!(result.unwrap_err().contains("cannot have alias"));
}
#[test]
fn test_extension_cannot_have_alias() {
let mut extension = create_basic_extension("test-ext", ExtensionType::Extension);
extension.alias = Some("ext-alias".to_string());
let result = general_check(&extension);
assert!(result.is_err());
assert!(result.unwrap_err().contains("cannot have alias"));
}
#[test]
fn test_group_cannot_have_hotkey() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
extension.hotkey = Some("cmd+g".to_string());
let result = general_check(&extension);
assert!(result.is_err());
assert!(result.unwrap_err().contains("cannot have hotkey"));
}
#[test]
fn test_extension_cannot_have_hotkey() {
let mut extension = create_basic_extension("test-ext", ExtensionType::Extension);
extension.hotkey = Some("cmd+e".to_string());
let result = general_check(&extension);
assert!(result.is_err());
assert!(result.unwrap_err().contains("cannot have hotkey"));
}
#[test]
fn test_non_container_types_cannot_have_sub_extensions() {
let mut extension = create_basic_extension("test-cmd", ExtensionType::Command);
extension.action = Some(create_command_action());
extension.commands = Some(vec![create_basic_extension(
"sub-cmd",
ExtensionType::Command,
)]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("only extension of type [Group] and [Extension] can have sub-extensions")
);
}
#[test]
fn test_non_searchable_extension_set_field_settings() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
extension.settings = Some(ExtensionSettings {
hide_before_open: None,
});
let error_msg = general_check(&extension).unwrap_err();
assert!(
error_msg
.contains("field [settings] is currently only allowed in searchable extension")
);
let mut extension = create_basic_extension("test-extension", ExtensionType::Extension);
extension.settings = Some(ExtensionSettings {
hide_before_open: None,
});
let error_msg = general_check(&extension).unwrap_err();
assert!(
error_msg
.contains("field [settings] is currently only allowed in searchable extension")
);
}
/* test_check_main_extension_only */
/* test check_main_extension_or_sub_extension */
#[test]
fn test_command_must_have_action() {
let extension = create_basic_extension("test-cmd", ExtensionType::Command);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [action] should be set for a Command extension")
);
}
#[test]
fn test_non_command_cannot_have_action() {
let mut extension = create_basic_extension("test-script", ExtensionType::Script);
extension.action = Some(create_command_action());
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [action] is set for a non-Command extension")
);
}
#[test]
fn test_quicklink_must_have_quicklink_field() {
let extension = create_basic_extension("test-quicklink", ExtensionType::Quicklink);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [quicklink] should be set for a Quicklink extension")
);
}
#[test]
fn test_non_quicklink_cannot_have_quicklink_field() {
let mut extension = create_basic_extension("test-cmd", ExtensionType::Command);
extension.action = Some(create_command_action());
extension.quicklink = Some(create_quicklink());
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [quicklink] is set for a non-Quicklink extension")
);
}
#[test]
fn test_view_must_have_page_field() {
let mut extension = create_basic_extension("test-view", ExtensionType::View);
// create_basic_extension() will set its page field if type is View, clear it
extension.page = None;
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [page] should be set for a View extension")
);
}
#[test]
fn test_non_view_cannot_have_page_field() {
let mut extension = create_basic_extension("test-cmd", ExtensionType::Command);
extension.action = Some(create_command_action());
extension.page = Some("index.html".into());
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [page] is set for a non-View extension")
);
}
/* test check_main_extension_or_sub_extension */
/* Test check_sub_extension_only */
#[test]
fn test_sub_extension_cannot_be_group() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
let sub_group = create_basic_extension("sub-group", ExtensionType::Group);
extension.commands = Some(vec![sub_group]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("sub-extensions should not be of type [Group] or [Extension]")
);
}
#[test]
fn test_sub_extension_cannot_be_extension() {
let mut extension = create_basic_extension("test-ext", ExtensionType::Extension);
let sub_ext = create_basic_extension("sub-ext", ExtensionType::Extension);
extension.scripts = Some(vec![sub_ext]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("sub-extensions should not be of type [Group] or [Extension]")
);
}
#[test]
fn test_sub_extension_cannot_have_developer() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.developer = Some("test-dev".to_string());
extension.commands = Some(vec![sub_cmd]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("field [developer] should not be set in sub-extensions")
);
}
#[test]
fn test_sub_extension_cannot_have_sub_extensions() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.commands = Some(vec![create_basic_extension(
"nested-cmd",
ExtensionType::Command,
)]);
extension.commands = Some(vec![sub_cmd]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(result.unwrap_err().contains(
"fields [commands/scripts/quicklinks/views] should not be set in sub-extensions"
));
}
#[test]
fn test_sub_extension_cannot_set_minimum_coco_version() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.minimum_coco_version = Some(semver::Version::new(0, 8, 0));
extension.commands = Some(vec![sub_cmd]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(result.unwrap_err().contains(&format!(
"[{}] cannot be set for sub-extensions",
PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION
)));
}
/* Test check_sub_extension_only */
#[test]
fn test_duplicate_sub_extension_ids() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
let mut cmd1 = create_basic_extension("duplicate-id", ExtensionType::Command);
cmd1.action = Some(create_command_action());
let mut cmd2 = create_basic_extension("duplicate-id", ExtensionType::Command);
cmd2.action = Some(create_command_action());
extension.commands = Some(vec![cmd1, cmd2]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("sub-extension with ID [duplicate-id] already exists")
);
}
#[test]
fn test_duplicate_ids_across_different_sub_extension_types() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
let mut cmd = create_basic_extension("same-id", ExtensionType::Command);
cmd.action = Some(create_command_action());
let script = create_basic_extension("same-id", ExtensionType::Script);
extension.commands = Some(vec![cmd]);
extension.scripts = Some(vec![script]);
let result = general_check(&extension);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.contains("sub-extension with ID [same-id] already exists")
);
}
#[test]
fn test_valid_group_extension() {
let mut extension = create_basic_extension("test-group", ExtensionType::Group);
extension.commands = Some(vec![create_basic_extension("cmd1", ExtensionType::Command)]);
assert!(general_check(&extension).is_ok());
}
#[test]
fn test_valid_extension_type() {
let mut extension = create_basic_extension("test-ext", ExtensionType::Extension);
extension.scripts = Some(vec![create_basic_extension(
"script1",
ExtensionType::Script,
)]);
assert!(general_check(&extension).is_ok());
}
#[test]
fn test_valid_command_extension() {
let mut extension = create_basic_extension("test-cmd", ExtensionType::Command);
extension.action = Some(create_command_action());
assert!(general_check(&extension).is_ok());
}
#[test]
fn test_valid_quicklink_extension() {
let mut extension = create_basic_extension("test-quicklink", ExtensionType::Quicklink);
extension.quicklink = Some(create_quicklink());
assert!(general_check(&extension).is_ok());
}
#[test]
fn test_valid_complex_extension() {
let mut extension = create_basic_extension("spotify-controls", ExtensionType::Extension);
// Add valid commands
let mut play_pause = create_basic_extension("play-pause", ExtensionType::Command);
play_pause.action = Some(create_command_action());
let mut next_track = create_basic_extension("next-track", ExtensionType::Command);
next_track.action = Some(create_command_action());
let mut prev_track = create_basic_extension("prev-track", ExtensionType::Command);
prev_track.action = Some(create_command_action());
extension.commands = Some(vec![play_pause, next_track, prev_track]);
assert!(general_check(&extension).is_ok());
}
#[test]
fn test_valid_single_layer_command() {
let mut extension = create_basic_extension("empty-trash", ExtensionType::Command);
extension.action = Some(create_command_action());
assert!(general_check(&extension).is_ok());
}
#[test]
fn test_command_alias_and_hotkey_allowed() {
let mut extension = create_basic_extension("test-cmd", ExtensionType::Command);
extension.action = Some(create_command_action());
extension.alias = Some("cmd-alias".to_string());
extension.hotkey = Some("cmd+t".to_string());
assert!(general_check(&extension).is_ok());
}
/*
* Tests for check that sub extension cannot support extensions that are not
* supported by the main extension
*
* Start here
*/
#[test]
fn test_platform_validation_both_none() {
// Case 1: main extension's platforms = None, sub extension's platforms = None
// Should return Ok(())
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = None;
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = None;
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_ok());
}
#[test]
fn test_platform_validation_main_all_sub_none() {
// Case 2: main extension's platforms = Some(all platforms), sub extension's platforms = None
// Should return Ok(())
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = Some(Platform::all());
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = None;
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_ok());
}
#[test]
fn test_platform_validation_main_none_sub_some() {
// Case 3: main extension's platforms = None, sub extension's platforms = Some([Platform::Macos])
// Should return Ok(()) because None means supports all platforms
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = None;
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = Some(HashSet::from([Platform::Macos]));
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_ok());
}
#[test]
fn test_platform_validation_main_all_sub_subset() {
// Case 4: main extension's platforms = Some(all platforms), sub extension's platforms = Some([Platform::Macos])
// Should return Ok(()) because sub extension supports a subset of main extension's platforms
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = Some(Platform::all());
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = Some(HashSet::from([Platform::Macos]));
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_ok());
}
#[test]
fn test_platform_validation_main_limited_sub_unsupported() {
// Case 5: main extension's platforms = Some([Platform::Macos]), sub extension's platforms = Some([Platform::Linux])
// Should return Err because sub extension supports a platform not supported by main extension
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = Some(HashSet::from([Platform::Macos]));
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = Some(HashSet::from([Platform::Linux]));
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(error_msg.contains("it supports platforms"));
assert!(error_msg.contains("that are not supported by the main extension"));
assert!(error_msg.contains("Linux")); // Should mention the unsupported platform
}
#[test]
fn test_platform_validation_main_partial_sub_unsupported() {
// Case 6: main extension's platforms = Some([Platform::Macos, Platform::Windows]), sub extension's platforms = Some([Platform::Linux])
// Should return Err because sub extension supports a platform not supported by main extension
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = Some(HashSet::from([Platform::Macos, Platform::Windows]));
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = Some(HashSet::from([Platform::Linux]));
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(error_msg.contains("it supports platforms"));
assert!(error_msg.contains("that are not supported by the main extension"));
assert!(error_msg.contains("Linux")); // Should mention the unsupported platform
}
#[test]
fn test_platform_validation_main_limited_sub_none() {
// Case 7: main extension's platforms = Some([Platform::Macos]), sub extension's platforms = None
// Should return Ok(()) because when sub extension's platforms is None, it inherits main extension's platforms
let mut main_extension = create_basic_extension("main-ext", ExtensionType::Group);
main_extension.platforms = Some(HashSet::from([Platform::Macos]));
let mut sub_cmd = create_basic_extension("sub-cmd", ExtensionType::Command);
sub_cmd.action = Some(create_command_action());
sub_cmd.platforms = None;
main_extension.commands = Some(vec![sub_cmd]);
let result = general_check(&main_extension);
assert!(result.is_ok());
}
/*
* Tests for check that sub extension cannot support extensions that are not
* supported by the main extension
*
* End here
*/
}

View File

@@ -1,261 +0,0 @@
use super::check_compatibility_via_mcv;
use crate::extension::PLUGIN_JSON_FILE_NAME;
use crate::extension::third_party::check::general_check;
use crate::extension::third_party::install::{
filter_out_incompatible_sub_extensions, is_extension_installed,
};
use crate::extension::third_party::{
THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE, get_third_party_extension_directory,
};
use crate::extension::{
Extension, canonicalize_relative_icon_path, canonicalize_relative_page_path,
};
use crate::util::platform::Platform;
use serde_json::Value as Json;
use std::path::Path;
use std::path::PathBuf;
use tauri::AppHandle;
use tokio::fs;
/// All the extensions installed from local file will belong to a special developer
/// "__local__".
const DEVELOPER_ID_LOCAL: &str = "__local__";
/// Install the extension specified by `path`.
///
/// `path` should point to a directory with the following structure:
///
/// ```text
/// extension-directory/
/// ├── assets/
/// │ ├── icon.png
/// │ └── other-assets...
/// └── plugin.json
/// ```
#[tauri::command]
pub(crate) async fn install_local_extension(
tauri_app_handle: AppHandle,
path: PathBuf,
) -> Result<(), String> {
let extension_dir_name = path
.file_name()
.ok_or_else(|| "Invalid extension: no directory name".to_string())?
.to_str()
.ok_or_else(|| "Invalid extension: non-UTF8 extension id".to_string())?;
// we use extension directory name as the extension ID.
let extension_id = extension_dir_name;
if is_extension_installed(DEVELOPER_ID_LOCAL, extension_id).await {
// The frontend code uses this string to distinguish between 2 error cases:
//
// 1. This extension is already imported
// 2. This extension is incompatible with the current platform
// 3. The selected directory does not contain a valid extension
//
// do NOT edit this without updating the frontend code.
//
// ```ts
// if (errorMessage === "already imported") {
// addError(t("settings.extensions.hints.extensionAlreadyImported"));
// } else if (errorMessage === "incompatible") {
// addError(t("settings.extensions.hints.incompatibleExtension"));
// } else {
// addError(t("settings.extensions.hints.importFailed"));
// }
// ```
//
// This is definitely error-prone, but we have to do this until we have
// structured error type
return Err("already imported".into());
}
let plugin_json_path = path.join(PLUGIN_JSON_FILE_NAME);
let plugin_json_content = fs::read_to_string(&plugin_json_path)
.await
.map_err(|e| e.to_string())?;
// Parse as JSON first as it is not valid for `struct Extension`, we need to
// correct it (set fields `id` and `developer`) before converting it to `struct Extension`:
let mut extension_json: Json =
serde_json::from_str(&plugin_json_content).map_err(|e| e.to_string())?;
if !check_compatibility_via_mcv(&extension_json)? {
return Err("app_incompatible".into());
}
// Set the main extension ID to the directory name
let extension_obj = extension_json
.as_object_mut()
.expect("extension_json should be an object");
extension_obj.insert("id".to_string(), Json::String(extension_id.to_string()));
extension_obj.insert(
"developer".to_string(),
Json::String(DEVELOPER_ID_LOCAL.to_string()),
);
// Counter for sub-extension IDs
let mut counter = 1u32;
// Set IDs for commands
if let Some(commands) = extension_obj.get_mut("commands") {
if let Some(commands_array) = commands.as_array_mut() {
for command in commands_array {
if let Some(command_obj) = command.as_object_mut() {
command_obj.insert("id".to_string(), Json::String(counter.to_string()));
counter += 1;
}
}
}
}
// Set IDs for quicklinks
if let Some(quicklinks) = extension_obj.get_mut("quicklinks") {
if let Some(quicklinks_array) = quicklinks.as_array_mut() {
for quicklink in quicklinks_array {
if let Some(quicklink_obj) = quicklink.as_object_mut() {
quicklink_obj.insert("id".to_string(), Json::String(counter.to_string()));
counter += 1;
}
}
}
}
// Set IDs for scripts
if let Some(scripts) = extension_obj.get_mut("scripts") {
if let Some(scripts_array) = scripts.as_array_mut() {
for script in scripts_array {
if let Some(script_obj) = script.as_object_mut() {
script_obj.insert("id".to_string(), Json::String(counter.to_string()));
counter += 1;
}
}
}
}
// Now we can convert JSON to `struct Extension`
let mut extension: Extension =
serde_json::from_value(extension_json).map_err(|e| e.to_string())?;
let current_platform = Platform::current();
/* Check begins here */
general_check(&extension)?;
if let Some(ref platforms) = extension.platforms {
if !platforms.contains(&current_platform) {
// The frontend code uses this string to distinguish between 3 error cases:
//
// 1. This extension is already imported
// 2. This extension is incompatible with the current platform
// 3. The selected directory does not contain a valid extension
//
// do NOT edit this without updating the frontend code.
//
// ```ts
// if (errorMessage === "already imported") {
// addError(t("settings.extensions.hints.extensionAlreadyImported"));
// } else if (errorMessage === "incompatible") {
// addError(t("settings.extensions.hints.incompatibleExtension"));
// } else {
// addError(t("settings.extensions.hints.importFailed"));
// }
// ```
//
// This is definitely error-prone, but we have to do this until we have
// structured error type
return Err("platform_incompatible".into());
}
}
/* Check ends here */
// Extension is compatible with current platform, but it could contain sub
// extensions that are not, filter them out.
filter_out_incompatible_sub_extensions(&mut extension, current_platform);
// We are going to modify our third-party extension list, grab the write lock
// to ensure exclusive access.
let mut third_party_ext_list_write_lock = THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE
.get()
.expect("global third party search source not set")
.write_lock()
.await;
// Create destination directory
let dest_dir = get_third_party_extension_directory(&tauri_app_handle)
.join(DEVELOPER_ID_LOCAL)
.join(extension_dir_name);
fs::create_dir_all(&dest_dir)
.await
.map_err(|e| e.to_string())?;
// Copy all files except plugin.json
let mut entries = fs::read_dir(&path).await.map_err(|e| e.to_string())?;
while let Some(entry) = entries.next_entry().await.map_err(|e| e.to_string())? {
let file_name = entry.file_name();
let file_name_str = file_name
.to_str()
.ok_or_else(|| "Invalid filename: non-UTF8".to_string())?;
// plugin.json will be handled separately.
if file_name_str == PLUGIN_JSON_FILE_NAME {
continue;
}
let src_path = entry.path();
let dest_path = dest_dir.join(&file_name);
if src_path.is_dir() {
// Recursively copy directory
copy_dir_recursively(&src_path, &dest_path).await?;
} else {
// Copy file
fs::copy(&src_path, &dest_path)
.await
.map_err(|e| e.to_string())?;
}
}
// Write the corrected plugin.json file
let corrected_plugin_json =
serde_json::to_string_pretty(&extension).map_err(|e| e.to_string())?;
let dest_plugin_json_path = dest_dir.join(PLUGIN_JSON_FILE_NAME);
fs::write(&dest_plugin_json_path, corrected_plugin_json)
.await
.map_err(|e| e.to_string())?;
// Canonicalize relative icon and page paths
canonicalize_relative_icon_path(&dest_dir, &mut extension)?;
canonicalize_relative_page_path(&dest_dir, &mut extension)?;
// Add extension to the search source
third_party_ext_list_write_lock.push(extension);
Ok(())
}
/// Helper function to recursively copy directories.
#[async_recursion::async_recursion]
async fn copy_dir_recursively(src: &Path, dest: &Path) -> Result<(), String> {
tokio::fs::create_dir_all(dest)
.await
.map_err(|e| e.to_string())?;
let mut read_dir = tokio::fs::read_dir(src).await.map_err(|e| e.to_string())?;
while let Some(entry) = read_dir.next_entry().await.map_err(|e| e.to_string())? {
let src_path = entry.path();
let dest_path = dest.join(entry.file_name());
if src_path.is_dir() {
copy_dir_recursively(&src_path, &dest_path).await?;
} else {
tokio::fs::copy(&src_path, &dest_path)
.await
.map_err(|e| e.to_string())?;
}
}
Ok(())
}

View File

@@ -1,353 +0,0 @@
//! This module contains the code of extension installation.
//!
//!
//! # How
//!
//! Technically, installing an extension involves the following steps. The order
//! varies between 2 implementations.
//!
//! 1. Check if it is already installed, if so, return
//!
//! 2. Check if it is compatible by inspecting the "minimum_coco_version"
//! field. If it is incompatible, reject and error out.
//!
//! This should be done before convert `plugin.json` JSON to `struct Extension`
//! as the definition of `struct Extension` could change in the future, in this
//! case, we want to tell users that "it is an incompatible extension" rather
//! than "this extension is invalid".
//!
//! 3. Correct the `plugin.json` JSON if it does not conform to our `struct
//! Extension` definition. This can happen because the JSON written by
//! developers is in a simplified form for a better developer experience.
//!
//! 4. Validate the corrected `plugin.json`
//! 1. misc checks
//! 2. Platform compatibility check
//!
//! 5. Write the extension files to the corresponding location
//!
//! * developer directory
//! * extension directory
//! * assets directory
//! * various assets files, e.g., "icon.png"
//! * plugin.json file
//! * View pages if exist
//!
//! 6. Canonicalize `Extension.icon` and `Extension.page` fields if they are
//! relative paths
//!
//! * icon: relative to the `assets` directory
//! * page: relative to the extension root directory
//!
//! 7. Add the extension to the in-memory extension list.
pub(crate) mod local_extension;
pub(crate) mod store;
use crate::extension::Extension;
use crate::extension::PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION;
use crate::util::platform::Platform;
use crate::util::version::{COCO_VERSION, parse_coco_semver};
use serde_json::Value as Json;
use std::ops::Deref;
use super::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE;
pub(crate) async fn is_extension_installed(developer: &str, extension_id: &str) -> bool {
THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE
.get()
.expect("global third party search source not set")
.extension_exists(developer, extension_id)
.await
}
/// Filters out sub-extensions that are not compatible with the current platform.
///
/// We make `current_platform` an argument so that this function is testable.
pub(crate) fn filter_out_incompatible_sub_extensions(
extension: &mut Extension,
current_platform: Platform,
) {
// Only process extensions of type Group or Extension that can have sub-extensions
if !extension.r#type.contains_sub_items() {
return;
}
// For main extensions, None means all.
let main_extension_supported_platforms = extension.platforms.clone().unwrap_or(Platform::all());
// Filter commands
if let Some(ref mut commands) = extension.commands {
commands.retain(|sub_ext| {
if let Some(ref platforms) = sub_ext.platforms {
platforms.contains(&current_platform)
} else {
main_extension_supported_platforms.contains(&current_platform)
}
});
}
// Filter scripts
if let Some(ref mut scripts) = extension.scripts {
scripts.retain(|sub_ext| {
if let Some(ref platforms) = sub_ext.platforms {
platforms.contains(&current_platform)
} else {
main_extension_supported_platforms.contains(&current_platform)
}
});
}
// Filter quicklinks
if let Some(ref mut quicklinks) = extension.quicklinks {
quicklinks.retain(|sub_ext| {
if let Some(ref platforms) = sub_ext.platforms {
platforms.contains(&current_platform)
} else {
main_extension_supported_platforms.contains(&current_platform)
}
});
}
// Filter views
if let Some(ref mut views) = extension.views {
views.retain(|sub_ext| {
if let Some(ref platforms) = sub_ext.platforms {
platforms.contains(&current_platform)
} else {
main_extension_supported_platforms.contains(&current_platform)
}
});
}
}
/// Inspect the "minimum_coco_version" field and see if this extension is
/// compatible with the current Coco app.
fn check_compatibility_via_mcv(plugin_json: &Json) -> Result<bool, String> {
let Some(mcv_json) = plugin_json.get(PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION) else {
return Ok(true);
};
if mcv_json == &Json::Null {
return Ok(true);
}
let Some(mcv_str) = mcv_json.as_str() else {
return Err(format!(
"invalid extension: field [{}] should be a string",
PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION
));
};
let Some(mcv) = parse_coco_semver(mcv_str) else {
return Err(format!(
"invalid extension: [{}] is not a valid version string",
PLUGIN_JSON_FIELD_MINIMUM_COCO_VERSION
));
};
Ok(COCO_VERSION.deref() >= &mcv)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extension::ExtensionType;
use std::collections::HashSet;
/// Helper function to create a basic extension for testing
/// `filter_out_incompatible_sub_extensions`
fn create_test_extension(
extension_type: ExtensionType,
platforms: Option<HashSet<Platform>>,
) -> Extension {
Extension {
id: "ID".into(),
name: "name".into(),
developer: None,
platforms,
description: "Test extension".to_string(),
icon: "test-icon".to_string(),
r#type: extension_type,
action: None,
quicklink: None,
commands: None,
scripts: None,
quicklinks: None,
views: None,
alias: None,
hotkey: None,
enabled: true,
settings: None,
page: None,
ui: None,
minimum_coco_version: None,
permission: None,
screenshots: None,
url: None,
version: None,
}
}
#[test]
fn test_filter_out_incompatible_sub_extensions_filter_non_group_extension_unchanged() {
// Command
let mut extension = create_test_extension(ExtensionType::Command, None);
let clone = extension.clone();
filter_out_incompatible_sub_extensions(&mut extension, Platform::Linux);
assert_eq!(extension, clone);
// Quicklink
let mut extension = create_test_extension(ExtensionType::Quicklink, None);
let clone = extension.clone();
filter_out_incompatible_sub_extensions(&mut extension, Platform::Linux);
assert_eq!(extension, clone);
}
#[test]
fn test_filter_out_incompatible_sub_extensions() {
let mut main_extension = create_test_extension(ExtensionType::Group, None);
// init sub extensions, which are macOS-only
let commands = vec![create_test_extension(
ExtensionType::Command,
Some(HashSet::from([Platform::Macos])),
)];
let quicklinks = vec![create_test_extension(
ExtensionType::Quicklink,
Some(HashSet::from([Platform::Macos])),
)];
let scripts = vec![create_test_extension(
ExtensionType::Script,
Some(HashSet::from([Platform::Macos])),
)];
let views = vec![create_test_extension(
ExtensionType::View,
Some(HashSet::from([Platform::Macos])),
)];
// Set sub extensions
main_extension.commands = Some(commands);
main_extension.quicklinks = Some(quicklinks);
main_extension.scripts = Some(scripts);
main_extension.views = Some(views);
// Current platform is Linux, all the sub extensions should be filtered out.
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
// assertions
assert!(main_extension.commands.unwrap().is_empty());
assert!(main_extension.quicklinks.unwrap().is_empty());
assert!(main_extension.scripts.unwrap().is_empty());
assert!(main_extension.views.unwrap().is_empty());
}
/// Sub extensions are compatible with all the platforms, nothing to filter out.
#[test]
fn test_filter_out_incompatible_sub_extensions_all_compatible() {
{
let mut main_extension = create_test_extension(ExtensionType::Group, None);
// init sub extensions, which are compatible with all the platforms
let commands = vec![create_test_extension(
ExtensionType::Command,
Some(Platform::all()),
)];
let quicklinks = vec![create_test_extension(
ExtensionType::Quicklink,
Some(Platform::all()),
)];
let scripts = vec![create_test_extension(
ExtensionType::Script,
Some(Platform::all()),
)];
let views = vec![create_test_extension(
ExtensionType::View,
Some(Platform::all()),
)];
// Set sub extensions
main_extension.commands = Some(commands);
main_extension.quicklinks = Some(quicklinks);
main_extension.scripts = Some(scripts);
main_extension.views = Some(views);
// Current platform is Linux, all the sub extensions should be filtered out.
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
// assertions
assert_eq!(main_extension.commands.unwrap().len(), 1);
assert_eq!(main_extension.quicklinks.unwrap().len(), 1);
assert_eq!(main_extension.scripts.unwrap().len(), 1);
assert_eq!(main_extension.views.unwrap().len(), 1);
}
// main extension is compatible with all platforms, sub extension's platforms
// is None, which means all platforms are supported
{
let mut main_extension = create_test_extension(ExtensionType::Group, None);
// init sub extensions, which are compatible with all the platforms
let commands = vec![create_test_extension(ExtensionType::Command, None)];
let quicklinks = vec![create_test_extension(ExtensionType::Quicklink, None)];
let scripts = vec![create_test_extension(ExtensionType::Script, None)];
let views = vec![create_test_extension(ExtensionType::View, None)];
// Set sub extensions
main_extension.commands = Some(commands);
main_extension.quicklinks = Some(quicklinks);
main_extension.scripts = Some(scripts);
main_extension.views = Some(views);
// Current platform is Linux, all the sub extensions should be filtered out.
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
// assertions
assert_eq!(main_extension.commands.unwrap().len(), 1);
assert_eq!(main_extension.quicklinks.unwrap().len(), 1);
assert_eq!(main_extension.scripts.unwrap().len(), 1);
assert_eq!(main_extension.views.unwrap().len(), 1);
}
}
#[test]
fn test_main_extension_is_incompatible_sub_extension_platforms_none() {
{
let mut main_extension =
create_test_extension(ExtensionType::Group, Some(HashSet::from([Platform::Macos])));
let commands = vec![create_test_extension(ExtensionType::Command, None)];
main_extension.commands = Some(commands);
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
assert_eq!(main_extension.commands.unwrap().len(), 0);
}
{
let mut main_extension =
create_test_extension(ExtensionType::Group, Some(HashSet::from([Platform::Macos])));
let scripts = vec![create_test_extension(ExtensionType::Script, None)];
main_extension.scripts = Some(scripts);
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
assert_eq!(main_extension.scripts.unwrap().len(), 0);
}
{
let mut main_extension =
create_test_extension(ExtensionType::Group, Some(HashSet::from([Platform::Macos])));
let quicklinks = vec![create_test_extension(ExtensionType::Quicklink, None)];
main_extension.quicklinks = Some(quicklinks);
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
assert_eq!(main_extension.quicklinks.unwrap().len(), 0);
}
{
let mut main_extension =
create_test_extension(ExtensionType::Group, Some(HashSet::from([Platform::Macos])));
let views = vec![create_test_extension(ExtensionType::View, None)];
main_extension.views = Some(views);
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Linux);
assert_eq!(main_extension.views.unwrap().len(), 0);
}
}
#[test]
fn test_main_extension_compatible_sub_extension_platforms_none() {
let mut main_extension =
create_test_extension(ExtensionType::Group, Some(HashSet::from([Platform::Macos])));
let views = vec![create_test_extension(ExtensionType::View, None)];
main_extension.views = Some(views);
filter_out_incompatible_sub_extensions(&mut main_extension, Platform::Macos);
assert_eq!(main_extension.views.unwrap().len(), 1);
}
}

View File

@@ -1,418 +0,0 @@
//! Extension store related stuff.
use super::super::LOCAL_QUERY_SOURCE_TYPE;
use super::check_compatibility_via_mcv;
use super::is_extension_installed;
use crate::common::document::DataSourceReference;
use crate::common::document::Document;
use crate::common::error::SearchError;
use crate::common::search::QueryResponse;
use crate::common::search::QuerySource;
use crate::common::search::SearchQuery;
use crate::common::traits::SearchSource;
use crate::extension::Extension;
use crate::extension::PLUGIN_JSON_FILE_NAME;
use crate::extension::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE;
use crate::extension::canonicalize_relative_icon_path;
use crate::extension::canonicalize_relative_page_path;
use crate::extension::third_party::check::general_check;
use crate::extension::third_party::get_third_party_extension_directory;
use crate::extension::third_party::install::filter_out_incompatible_sub_extensions;
use crate::server::http_client::HttpClient;
use crate::util::platform::Platform;
use async_trait::async_trait;
use reqwest::StatusCode;
use serde_json::Map as JsonObject;
use serde_json::Value as Json;
use std::io::Read;
use tauri::AppHandle;
const DATA_SOURCE_ID: &str = "Extension Store";
pub(crate) struct ExtensionStore;
#[async_trait]
impl SearchSource for ExtensionStore {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(DATA_SOURCE_ID.into())
.to_string_lossy()
.into(),
id: DATA_SOURCE_ID.into(),
}
}
async fn search(
&self,
_tauri_app_handle: AppHandle,
query: SearchQuery,
) -> Result<QueryResponse, SearchError> {
const SCORE: f64 = 2000.0;
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
let lowercase_query_string = query_string.to_lowercase();
let expected_str = "extension store";
if expected_str.contains(&lowercase_query_string) {
let doc = Document {
id: DATA_SOURCE_ID.to_string(),
category: Some(DATA_SOURCE_ID.to_string()),
title: Some(DATA_SOURCE_ID.to_string()),
icon: Some("font_Store".to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(DATA_SOURCE_ID.into()),
id: Some(DATA_SOURCE_ID.into()),
icon: Some("font_Store".to_string()),
}),
..Default::default()
};
Ok(QueryResponse {
source: self.get_type(),
hits: vec![(doc, SCORE)],
total_hits: 1,
})
} else {
Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
})
}
}
}
#[tauri::command]
pub(crate) async fn search_extension(
query_params: Option<Vec<String>>,
) -> Result<Vec<Json>, String> {
let response = HttpClient::get(
"default_coco_server",
"store/extension/_search",
query_params,
)
.await
.map_err(|e| format!("Failed to send request: {:?}", e))?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(Vec::new());
}
// The response of a ES style search request
let mut response: JsonObject<String, Json> = response
.json()
.await
.map_err(|e| format!("Failed to parse response: {:?}", e))?;
let hits_json = response.remove("hits").unwrap_or_else(|| {
panic!(
"the JSON response should contain field [hits], response [{:?}]",
response
)
});
let mut hits = match hits_json {
Json::Object(obj) => obj,
_ => panic!(
"field [hits] should be a JSON object, but it is not, value: [{}]",
hits_json
),
};
let Some(hits_hits_json) = hits.remove("hits") else {
return Ok(Vec::new());
};
let hits_hits = match hits_hits_json {
Json::Array(arr) => arr,
_ => panic!(
"field [hits.hits] should be an array, but it is not, value: [{}]",
hits_hits_json
),
};
let mut extensions = Vec::with_capacity(hits_hits.len());
for hit in hits_hits {
let mut hit_obj = match hit {
Json::Object(obj) => obj,
_ => panic!(
"each hit in [hits.hits] should be a JSON object, but it is not, value: [{}]",
hit
),
};
let source = hit_obj
.remove("_source")
.expect("each hit should contain field [_source]");
let mut source_obj = match source {
Json::Object(obj) => obj,
_ => panic!(
"field [_source] should be a JSON object, but it is not, value: [{}]",
source
),
};
let developer_id = source_obj
.get("developer")
.and_then(|dev| dev.get("id"))
.and_then(|id| id.as_str())
.expect("developer.id should exist");
let extension_id = source_obj
.get("id")
.and_then(|id| id.as_str())
.expect("extension id should exist");
let installed = is_extension_installed(developer_id, extension_id).await;
source_obj.insert("installed".to_string(), Json::Bool(installed));
extensions.push(Json::Object(source_obj));
}
Ok(extensions)
}
#[tauri::command]
pub(crate) async fn extension_detail(
id: String,
) -> Result<Option<JsonObject<String, Json>>, String> {
let path = format!("store/extension/{}", id);
let response = HttpClient::get("default_coco_server", path.as_str(), None)
.await
.map_err(|e| format!("Failed to send request: {:?}", e))?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(None);
}
let response_dbg_str = format!("{:?}", response);
// The response of an ES style GET request
let mut response: JsonObject<String, Json> = response.json().await.unwrap_or_else(|_e| {
panic!(
"response body of [/store/extension/<ID>] is not a JSON object, response [{:?}]",
response_dbg_str
)
});
let source_json = response.remove("_source").unwrap_or_else(|| {
panic!("field [_source] not found in the JSON returned from [/store/extension/<ID>]")
});
let mut source_obj = match source_json {
Json::Object(obj) => obj,
_ => panic!(
"field [_source] should be a JSON object, but it is not, value: [{}]",
source_json
),
};
let developer_id = match &source_obj["developer"]["id"] {
Json::String(dev) => dev,
_ => {
panic!(
"field [_source.developer.id] should be a string, but it is not, value: [{}]",
source_obj["developer"]["id"]
)
}
};
let installed = is_extension_installed(developer_id, &id).await;
source_obj.insert("installed".to_string(), Json::Bool(installed));
Ok(Some(source_obj))
}
#[tauri::command]
pub(crate) async fn install_extension_from_store(
tauri_app_handle: AppHandle,
id: String,
) -> Result<(), String> {
let path = format!("store/extension/{}/_download", id);
let response = HttpClient::get("default_coco_server", &path, None)
.await
.map_err(|e| format!("Failed to download extension: {}", e))?;
if response.status() == StatusCode::NOT_FOUND {
return Err(format!("extension [{}] not found", id));
}
let bytes = response
.bytes()
.await
.map_err(|e| format!("Failed to read response bytes: {}", e))?;
let cursor = std::io::Cursor::new(bytes);
let mut archive =
zip::ZipArchive::new(cursor).map_err(|e| format!("Failed to read zip archive: {}", e))?;
// The plugin.json sent from the server does not conform to our `struct Extension` definition:
//
// 1. Its `developer` field is a JSON object, but we need a string
// 2. sub-extensions won't have their `id` fields set
//
// we need to correct it
let mut plugin_json = archive
.by_name(PLUGIN_JSON_FILE_NAME)
.map_err(|e| e.to_string())?;
let mut plugin_json_content = String::new();
std::io::Read::read_to_string(&mut plugin_json, &mut plugin_json_content)
.map_err(|e| e.to_string())?;
let mut extension: Json = serde_json::from_str(&plugin_json_content)
.map_err(|e| format!("Failed to parse plugin.json: {}", e))?;
if !check_compatibility_via_mcv(&extension)? {
return Err("app_incompatible".into());
}
let mut_ref_to_developer_object: &mut Json = extension
.as_object_mut()
.expect("plugin.json should be an object")
.get_mut("developer")
.expect("plugin.json should contain field [developer]");
let developer_id = mut_ref_to_developer_object
.get("id")
.expect("plugin.json should contain [developer.id]")
.as_str()
.expect("plugin.json field [developer.id] should be a string");
*mut_ref_to_developer_object = Json::String(developer_id.into());
// Set IDs for sub-extensions (commands, quicklinks, scripts)
let mut counter = 0;
// Helper function to set IDs for array fields
fn set_ids_for_field(extension: &mut Json, field_name: &str, counter: &mut i32) {
if let Some(field) = extension.as_object_mut().unwrap().get_mut(field_name) {
if let Some(array) = field.as_array_mut() {
for item in array {
if let Some(item_obj) = item.as_object_mut() {
if !item_obj.contains_key("id") {
item_obj.insert("id".to_string(), Json::String(counter.to_string()));
*counter += 1;
}
}
}
}
}
}
set_ids_for_field(&mut extension, "commands", &mut counter);
set_ids_for_field(&mut extension, "quicklinks", &mut counter);
set_ids_for_field(&mut extension, "scripts", &mut counter);
// Now the extension JSON is valid
let mut extension: Extension = serde_json::from_value(extension).unwrap_or_else(|e| {
panic!(
"cannot parse plugin.json as struct Extension, error [{:?}]",
e
);
});
let developer_id = extension.developer.clone().expect("developer has been set");
drop(plugin_json);
general_check(&extension)?;
let current_platform = Platform::current();
if let Some(ref platforms) = extension.platforms {
if !platforms.contains(&current_platform) {
return Err("platform_incompatible".into());
}
}
if is_extension_installed(&developer_id, &id).await {
return Err("Extension already installed.".into());
}
// Extension is compatible with current platform, but it could contain sub
// extensions that are not, filter them out.
filter_out_incompatible_sub_extensions(&mut extension, current_platform);
// We are going to modify our third-party extension list, grab the write lock
// to ensure exclusive access.
let mut third_party_ext_list_write_lock = THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE
.get()
.expect("global third party search source not set")
.write_lock()
.await;
// Write extension files to the extension directory
let extension_id = extension.id.clone();
let extension_directory = {
let mut path = get_third_party_extension_directory(&tauri_app_handle);
path.push(developer_id);
path.push(extension_id.as_str());
path
};
tokio::fs::create_dir_all(extension_directory.as_path())
.await
.map_err(|e| e.to_string())?;
// Extract all files except plugin.json
for i in 0..archive.len() {
let mut zip_file = archive.by_index(i).map_err(|e| e.to_string())?;
// `.name()` is safe to use in our cases, the cases listed in the below
// page won't happen to us.
//
// https://docs.rs/zip/4.2.0/zip/read/struct.ZipFile.html#method.name
//
// Example names:
//
// * `assets/icon.png`
// * `assets/screenshot.png`
// * `plugin.json`
//
// Yes, the `assets` directory is not a part of it.
let zip_file_name = zip_file.name();
// Skip the plugin.json file as we'll create it from the extension variable
if zip_file_name == PLUGIN_JSON_FILE_NAME {
continue;
}
let dest_file_path = extension_directory.join(zip_file_name);
// For cases like `assets/xxx.png`
if let Some(parent_dir) = dest_file_path.parent()
&& !parent_dir.exists()
{
tokio::fs::create_dir_all(parent_dir)
.await
.map_err(|e| e.to_string())?;
}
let mut dest_file = tokio::fs::File::create(&dest_file_path)
.await
.map_err(|e| e.to_string())?;
let mut src_bytes = Vec::with_capacity(
zip_file
.size()
.try_into()
.expect("we won't have a extension file that is bigger than 4GiB"),
);
zip_file
.read_to_end(&mut src_bytes)
.map_err(|e| e.to_string())?;
tokio::io::copy(&mut src_bytes.as_slice(), &mut dest_file)
.await
.map_err(|e| e.to_string())?;
}
// Create plugin.json from the extension variable
let plugin_json_path = extension_directory.join(PLUGIN_JSON_FILE_NAME);
let extension_json = serde_json::to_string_pretty(&extension).map_err(|e| e.to_string())?;
tokio::fs::write(&plugin_json_path, extension_json)
.await
.map_err(|e| e.to_string())?;
// Canonicalize relative icon and page paths
canonicalize_relative_icon_path(&extension_directory, &mut extension)?;
canonicalize_relative_page_path(&extension_directory, &mut extension)?;
third_party_ext_list_write_lock.push(extension);
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,38 +0,0 @@
//! View extension-related stuff
use actix_files::Files;
use actix_web::{App, HttpServer, dev::ServerHandle};
use std::path::Path;
use tokio::sync::Mutex;
static FILE_SERVER_HANDLE: Mutex<Option<ServerHandle>> = Mutex::const_new(None);
/// Start a static HTTP file server serving the directory specified by `path`.
/// Return the URL of the server.
pub(crate) async fn serve_files_in(path: &Path) -> String {
const ADDR: &str = "127.0.0.1";
let mut guard = FILE_SERVER_HANDLE.lock().await;
if let Some(prev_server_handle) = guard.take() {
prev_server_handle.stop(true).await;
}
let path = path.to_path_buf();
let http_server =
HttpServer::new(move || App::new().service(Files::new("/", &path).show_files_listing()))
// Set port to 0 and let OS assign a port to us
.bind((ADDR, 0))
.unwrap();
let assigned_port = http_server.addrs()[0].port();
let server = http_server.disable_signals().workers(1).run();
let new_handle = server.handle();
tokio::spawn(server);
*guard = Some(new_handle);
format!("http://{}:{}", ADDR, assigned_port)
}

View File

@@ -1,44 +1,40 @@
mod assistant;
mod autostart;
mod common;
mod extension;
mod local;
mod search;
mod server;
mod settings;
mod setup;
mod shortcut;
// We need this in main.rs, so it has to be pub
pub mod util;
mod util;
use crate::common::register::SearchSourceRegistry;
use crate::common::{CHECK_WINDOW_LABEL, MAIN_WINDOW_LABEL, SETTINGS_WINDOW_LABEL};
// use crate::common::traits::SearchSource;
use crate::common::{MAIN_WINDOW_LABEL, SETTINGS_WINDOW_LABEL};
use crate::server::servers::{load_or_insert_default_server, load_servers_token};
use crate::util::logging::set_up_tauri_logger;
use crate::util::prevent_default;
use autostart::change_autostart;
use autostart::{change_autostart, enable_autostart};
use lazy_static::lazy_static;
use std::sync::Mutex;
use std::sync::OnceLock;
use tauri::async_runtime::block_on;
use tauri::plugin::TauriPlugin;
#[cfg(target_os = "macos")]
use tauri::ActivationPolicy;
use tauri::{
AppHandle, Emitter, LogicalPosition, Manager, PhysicalPosition, WebviewWindow, WindowEvent,
AppHandle, Emitter, Manager, PhysicalPosition, Runtime, WebviewWindow, Window, WindowEvent,
};
use tauri_plugin_autostart::MacosLauncher;
/// Tauri store name
pub(crate) const COCO_TAURI_STORE: &str = "coco_tauri_store";
pub(crate) const WINDOW_CENTER_BASELINE_HEIGHT: i32 = 590;
lazy_static! {
static ref PREVIOUS_MONITOR_NAME: Mutex<Option<String>> = Mutex::new(None);
}
/// To allow us to access tauri's `AppHandle` when its context is inaccessible,
/// store it globally. It will be set in `init()`.
///
/// # WARNING
///
/// You may find this work, but the usage is discouraged and should be generally
/// avoided. If you do need it, always be careful that it may not be set() when
/// you access it.
pub(crate) static GLOBAL_TAURI_APP_HANDLE: OnceLock<AppHandle> = OnceLock::new();
#[tauri::command]
@@ -48,26 +44,6 @@ async fn change_window_height(handle: AppHandle, height: u32) {
let mut size = window.outer_size().unwrap();
size.height = height;
window.set_size(size).unwrap();
// Center the window horizontally and vertically based on the baseline height of 590
let monitor = window.primary_monitor().ok().flatten().or_else(|| {
window
.available_monitors()
.ok()
.and_then(|ms| ms.into_iter().next())
});
if let Some(monitor) = monitor {
let monitor_position = monitor.position();
let monitor_size = monitor.size();
let window_width = window.outer_size().unwrap().width as i32;
let x = monitor_position.x + (monitor_size.width as i32 - window_width) / 2;
let y =
monitor_position.y + (monitor_size.height as i32 - WINDOW_CENTER_BASELINE_HEIGHT) / 2;
let _ = window.set_position(PhysicalPosition::new(x, y));
}
}
#[derive(serde::Deserialize)]
@@ -88,20 +64,20 @@ pub fn run() {
let ctx = tauri::generate_context!();
let mut app_builder = tauri::Builder::default();
// Set up logger first
app_builder = app_builder.plugin(set_up_tauri_logger());
#[cfg(desktop)]
{
app_builder =
app_builder.plugin(tauri_plugin_single_instance::init(|_app, _argv, _cwd| {}));
app_builder = app_builder.plugin(tauri_plugin_single_instance::init(|_app, argv, _cwd| {
println!("a new app instance was opened with {argv:?} and the deep link event was already triggered");
// when defining deep link schemes at runtime, you must also check `argv` here
}));
}
app_builder = app_builder
.plugin(tauri_plugin_http::init())
.plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_autostart::init(
MacosLauncher::LaunchAgent,
MacosLauncher::AppleScript,
None,
))
.plugin(tauri_plugin_deep_link::init())
@@ -111,14 +87,9 @@ pub fn run() {
.plugin(tauri_plugin_macos_permissions::init())
.plugin(tauri_plugin_screenshots::init())
.plugin(tauri_plugin_process::init())
.plugin(
tauri_plugin_updater::Builder::new()
.default_version_comparator(crate::util::version::custom_version_comparator)
.build(),
)
.plugin(tauri_plugin_updater::Builder::new().build())
.plugin(tauri_plugin_windows_version::init())
.plugin(tauri_plugin_opener::init())
.plugin(prevent_default::init());
.plugin(set_up_tauri_logger());
// Conditional compilation for macOS
#[cfg(target_os = "macos")]
@@ -136,8 +107,7 @@ pub fn run() {
show_coco,
hide_coco,
show_settings,
show_check,
hide_check,
server::servers::get_server_token,
server::servers::add_coco_server,
server::servers::remove_coco_server,
server::servers::list_coco_servers,
@@ -152,8 +122,8 @@ pub fn run() {
server::connector::get_connectors_by_server,
search::query_coco_fusion,
assistant::chat_history,
assistant::chat_create,
assistant::chat_chat,
assistant::new_chat,
assistant::send_message,
assistant::session_chat_history,
assistant::open_session_chat,
assistant::close_session_chat,
@@ -161,82 +131,86 @@ pub fn run() {
assistant::delete_session_chat,
assistant::update_session_chat,
assistant::assistant_search,
assistant::assistant_get,
assistant::assistant_get_multi,
// server::get_coco_server_datasources,
// server::get_coco_server_connectors,
server::websocket::connect_to_server,
server::websocket::disconnect,
get_app_search_source,
server::attachment::upload_attachment,
server::attachment::get_attachment_by_ids,
server::attachment::get_attachment,
server::attachment::delete_attachment,
server::transcription::transcription,
util::open,
server::system_settings::get_system_settings,
extension::built_in::application::get_app_list,
extension::built_in::application::get_app_search_path,
extension::built_in::application::get_app_metadata,
extension::built_in::application::add_app_search_path,
extension::built_in::application::remove_app_search_path,
extension::built_in::application::reindex_applications,
extension::quicklink_link_arguments,
extension::list_extensions,
extension::enable_extension,
extension::disable_extension,
extension::set_extension_alias,
extension::register_extension_hotkey,
extension::unregister_extension_hotkey,
extension::is_extension_enabled,
extension::third_party::install::store::search_extension,
extension::third_party::install::store::extension_detail,
extension::third_party::install::store::install_extension_from_store,
extension::third_party::install::local_extension::install_local_extension,
extension::third_party::uninstall_extension,
extension::is_extension_compatible,
extension::api::apis,
extension::api::fs::read_dir,
simulate_mouse_click,
local::get_disabled_local_query_sources,
local::enable_local_query_source,
local::disable_local_query_source,
local::application::get_app_list,
local::application::get_app_search_path,
local::application::get_app_metadata,
local::application::set_app_alias,
local::application::register_app_hotkey,
local::application::unregister_app_hotkey,
local::application::disable_app_search,
local::application::enable_app_search,
local::application::add_app_search_path,
local::application::remove_app_search_path,
settings::set_allow_self_signature,
settings::get_allow_self_signature,
settings::set_local_query_source_weight,
settings::get_local_query_source_weight,
assistant::ask_ai,
crate::common::document::open,
extension::built_in::file_search::config::get_file_system_config,
extension::built_in::file_search::config::set_file_system_config,
server::synthesize::synthesize,
util::file::get_file_icon,
setup::backend_setup,
util::app_lang::update_app_lang,
util::path::path_absolute,
util::logging::app_log_dir
])
.setup(|app| {
let app_handle = app.handle().clone();
GLOBAL_TAURI_APP_HANDLE
.set(app_handle.clone())
.expect("variable already initialized");
let registry = SearchSourceRegistry::default();
app.manage(registry); // Store registry in Tauri's app state
app.manage(server::websocket::WebSocketManager::default());
block_on(async {
init(app.handle()).await;
});
shortcut::enable_shortcut(app);
enable_autostart(app);
#[cfg(target_os = "macos")]
app.set_activation_policy(ActivationPolicy::Accessory);
// app.listen("theme-changed", move |event| {
// if let Ok(payload) = serde_json::from_str::<ThemeChangedPayload>(event.payload()) {
// // switch_tray_icon(app.app_handle(), payload.is_dark_mode);
// println!("Theme changed: is_dark_mode = {}", payload.is_dark_mode);
// }
// });
#[cfg(desktop)]
{
log::trace!("hiding Dock icon on macOS");
app.set_activation_policy(tauri::ActivationPolicy::Accessory);
log::trace!("Dock icon should be hidden now");
#[cfg(any(windows, target_os = "linux"))]
{
app.deep_link().register("coco")?;
use tauri_plugin_deep_link::DeepLinkExt;
app.deep_link().register_all()?;
}
}
// app.deep_link().on_open_url(|event| {
// dbg!(event.urls());
// });
/* ----------- This code must be executed on the main thread and must not be relocated. ----------- */
let app_handle = app.app_handle();
let main_window = app_handle.get_webview_window(MAIN_WINDOW_LABEL).unwrap();
let settings_window = app_handle
.get_webview_window(SETTINGS_WINDOW_LABEL)
.unwrap();
let check_window = app_handle.get_webview_window(CHECK_WINDOW_LABEL).unwrap();
setup::default(
app_handle,
main_window.clone(),
settings_window.clone(),
check_window.clone(),
);
/* ----------- This code must be executed on the main thread and must not be relocated. ----------- */
let main_window = app.get_webview_window(MAIN_WINDOW_LABEL).unwrap();
let settings_window = app.get_webview_window(SETTINGS_WINDOW_LABEL).unwrap();
setup::default(app, main_window.clone(), settings_window.clone());
Ok(())
})
.on_window_event(|window, event| match event {
WindowEvent::CloseRequested { api, .. } => {
//dbg!("Close requested event received");
dbg!("Close requested event received");
window.hide().unwrap();
api.prevent_close();
}
@@ -251,10 +225,10 @@ pub fn run() {
has_visible_windows,
..
} => {
// dbg!(
// "Reopen event received: has_visible_windows = {}",
// has_visible_windows
// );
dbg!(
"Reopen event received: has_visible_windows = {}",
has_visible_windows
);
if has_visible_windows {
return;
}
@@ -265,17 +239,17 @@ pub fn run() {
});
}
pub async fn init(app_handle: &AppHandle) {
pub async fn init<R: Runtime>(app_handle: &AppHandle<R>) {
// Await the async functions to load the servers and tokens
if let Err(err) = load_or_insert_default_server(app_handle).await {
log::error!("Failed to load servers: {}", err);
eprintln!("Failed to load servers: {}", err);
}
if let Err(err) = load_servers_token(app_handle).await {
log::error!("Failed to load server tokens: {}", err);
eprintln!("Failed to load server tokens: {}", err);
}
let coco_servers = server::servers::get_all_servers().await;
let coco_servers = server::servers::get_all_servers();
// Get the registry from Tauri's state
// let registry: State<SearchSourceRegistry> = app_handle.state::<SearchSourceRegistry>();
@@ -285,125 +259,163 @@ pub async fn init(app_handle: &AppHandle) {
.await;
}
extension::built_in::pizza_engine_runtime::start_pizza_engine_runtime().await;
local::start_pizza_engine_runtime();
}
#[tauri::command]
async fn show_coco(app_handle: AppHandle) {
if let Some(window) = app_handle.get_webview_window(MAIN_WINDOW_LABEL) {
async fn show_coco<R: Runtime>(app_handle: AppHandle<R>) {
if let Some(window) = app_handle.get_window(MAIN_WINDOW_LABEL) {
move_window_to_active_monitor(&window);
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
use tauri_nspanel::ManagerExt;
let app_handle_clone = app_handle.clone();
app_handle.run_on_main_thread(move || {
let panel = app_handle_clone.get_webview_panel(MAIN_WINDOW_LABEL).unwrap();
panel.show_and_make_key();
}).unwrap();
} else {
let _ = window.show();
let _ = window.unminimize();
// The Window Management (WM) extension (macOS-only) controls the
// frontmost window. Setting focus on macOS makes Coco the frontmost
// window, which means the WM extension would control Coco instead of other
// windows, which is not what we want.
//
// On Linux/Windows, however, setting focus is a necessity to ensure that
// users open Coco's window, then they can start typing, without needing
// to click on the window.
let _ = window.set_focus();
}
};
let _ = window.show();
let _ = window.unminimize();
let _ = window.set_focus();
let _ = app_handle.emit("show-coco", ());
}
}
#[tauri::command]
async fn hide_coco(app_handle: AppHandle) {
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
use tauri_nspanel::ManagerExt;
let app_handle_clone = app_handle.clone();
app_handle.run_on_main_thread(move || {
let panel = app_handle_clone.get_webview_panel(MAIN_WINDOW_LABEL).expect("cannot find the main window/panel");
panel.hide();
}).unwrap();
async fn hide_coco<R: Runtime>(app: AppHandle<R>) {
if let Some(window) = app.get_window(MAIN_WINDOW_LABEL) {
if let Err(err) = window.hide() {
eprintln!("Failed to hide the window: {}", err);
} else {
let window = app_handle.get_webview_window(MAIN_WINDOW_LABEL).expect("cannot find the main window");
if let Err(err) = window.hide() {
log::error!("Failed to hide the window: {}", err);
} else {
log::debug!("Window successfully hidden.");
}
println!("Window successfully hidden.");
}
};
} else {
eprintln!("Main window not found.");
}
}
fn move_window_to_active_monitor(window: &WebviewWindow) {
let scale_factor = window.scale_factor().unwrap();
fn move_window_to_active_monitor<R: Runtime>(window: &Window<R>) {
dbg!("Moving window to active monitor");
// Try to get the available monitors, handle failure gracefully
let available_monitors = match window.available_monitors() {
Ok(monitors) => monitors,
Err(e) => {
eprintln!("Failed to get monitors: {}", e);
return;
}
};
let point = window.cursor_position().unwrap();
// Attempt to get the cursor position, handle failure gracefully
let cursor_position = match window.cursor_position() {
Ok(pos) => Some(pos),
Err(e) => {
eprintln!("Failed to get cursor position: {}", e);
None
}
};
let LogicalPosition { x, y } = point.to_logical(scale_factor);
match window.monitor_from_point(x, y) {
Ok(Some(monitor)) => {
if let Some(name) = monitor.name() {
let previous_monitor_name = PREVIOUS_MONITOR_NAME.lock().unwrap();
if let Some(ref prev_name) = *previous_monitor_name {
if name.to_string() == *prev_name {
log::debug!("Currently on the same monitor");
return;
}
}
}
// Find the monitor that contains the cursor or default to the primary monitor
let target_monitor = if let Some(cursor_position) = cursor_position {
// Convert cursor position to integers
let cursor_x = cursor_position.x.round() as i32;
let cursor_y = cursor_position.y.round() as i32;
// Find the monitor that contains the cursor
available_monitors.into_iter().find(|monitor| {
let monitor_position = monitor.position();
let monitor_size = monitor.size();
// Current window size for horizontal centering
let window_size = match window.inner_size() {
Ok(size) => size,
Err(e) => {
log::error!("Failed to get window size: {}", e);
return;
}
};
let window_width = window_size.width as i32;
cursor_x >= monitor_position.x
&& cursor_x <= monitor_position.x + monitor_size.width as i32
&& cursor_y >= monitor_position.y
&& cursor_y <= monitor_position.y + monitor_size.height as i32
})
} else {
None
};
// Horizontal center uses actual width, vertical center uses 590 baseline
let window_x = monitor_position.x + (monitor_size.width as i32 - window_width) / 2;
let window_y = monitor_position.y
+ (monitor_size.height as i32 - WINDOW_CENTER_BASELINE_HEIGHT) / 2;
// Use the target monitor or default to the primary monitor
let monitor = match target_monitor.or_else(|| window.primary_monitor().ok().flatten()) {
Some(monitor) => monitor,
None => {
eprintln!("No monitor found!");
return;
}
};
if let Err(e) = window.set_position(PhysicalPosition::new(window_x, window_y)) {
log::error!("Failed to move window: {}", e);
}
if let Some(name) = monitor.name() {
let previous_monitor_name = PREVIOUS_MONITOR_NAME.lock().unwrap();
if let Some(name) = monitor.name() {
log::debug!("Window moved to monitor: {}", name);
let mut previous_monitor = PREVIOUS_MONITOR_NAME.lock().unwrap();
*previous_monitor = Some(name.to_string());
if let Some(ref prev_name) = *previous_monitor_name {
if name.to_string() == *prev_name {
println!("Currently on the same monitor");
return;
}
}
Ok(None) => {
log::error!("No monitor found at the specified point");
}
}
let monitor_position = monitor.position();
let monitor_size = monitor.size();
// Get the current size of the window
let window_size = match window.inner_size() {
Ok(size) => size,
Err(e) => {
log::error!("Failed to get monitor from point: {}", e);
eprintln!("Failed to get window size: {}", e);
return;
}
};
let window_width = window_size.width as i32;
let window_height = window_size.height as i32;
// Calculate the new position to center the window on the monitor
let window_x = monitor_position.x + (monitor_size.width as i32 - window_width) / 2;
let window_y = monitor_position.y + (monitor_size.height as i32 - window_height) / 2;
// Move the window to the new position
if let Err(e) = window.set_position(PhysicalPosition::new(window_x, window_y)) {
eprintln!("Failed to move window: {}", e);
}
if let Some(name) = monitor.name() {
println!("Window moved to monitor: {}", name);
let mut previous_monitor = PREVIOUS_MONITOR_NAME.lock().unwrap();
*previous_monitor = Some(name.to_string());
}
}
#[allow(dead_code)]
fn open_settings(app: &tauri::AppHandle) {
use tauri::webview::WebviewBuilder;
println!("settings menu item was clicked");
let window = app.get_webview_window("settings");
if let Some(window) = window {
let _ = window.show();
let _ = window.unminimize();
let _ = window.set_focus();
} else {
let window = tauri::window::WindowBuilder::new(app, "settings")
.title("Settings Window")
.fullscreen(false)
.resizable(false)
.minimizable(false)
.maximizable(false)
.inner_size(800.0, 600.0)
.build()
.unwrap();
let webview_builder =
WebviewBuilder::new("settings", tauri::WebviewUrl::App("/ui/settings".into()));
let _webview = window
.add_child(
webview_builder,
tauri::LogicalPosition::new(0, 0),
window.inner_size().unwrap(),
)
.unwrap();
}
}
#[tauri::command]
async fn get_app_search_source(app_handle: AppHandle) -> Result<(), String> {
async fn get_app_search_source<R: Runtime>(app_handle: AppHandle<R>) -> Result<(), String> {
local::init_local_search_source(&app_handle).await?;
let _ = server::connector::refresh_all_connectors(&app_handle).await;
let _ = server::datasource::refresh_all_datasources(&app_handle).await;
@@ -412,34 +424,100 @@ async fn get_app_search_source(app_handle: AppHandle) -> Result<(), String> {
#[tauri::command]
async fn show_settings(app_handle: AppHandle) {
log::debug!("settings menu item was clicked");
let window = app_handle
.get_webview_window(SETTINGS_WINDOW_LABEL)
.expect("we have a settings window");
window.show().unwrap();
window.unminimize().unwrap();
window.set_focus().unwrap();
open_settings(&app_handle);
}
#[tauri::command]
async fn show_check(app_handle: AppHandle) {
log::debug!("check menu item was clicked");
let window = app_handle
.get_webview_window(CHECK_WINDOW_LABEL)
.expect("we have a check window");
async fn simulate_mouse_click<R: Runtime>(window: WebviewWindow<R>, is_chat_mode: bool) {
#[cfg(target_os = "windows")]
{
use enigo::{Button, Coordinate, Direction, Enigo, Mouse, Settings};
use std::{thread, time::Duration};
window.show().unwrap();
window.unminimize().unwrap();
window.set_focus().unwrap();
if let Ok(mut enigo) = Enigo::new(&Settings::default()) {
// Save the current mouse position
if let Ok((original_x, original_y)) = enigo.location() {
// Retrieve the window's outer position (top-left corner)
if let Ok(position) = window.outer_position() {
// Retrieve the window's inner size (client area)
if let Ok(size) = window.inner_size() {
// Calculate the center position of the title bar
let x = position.x + (size.width as i32 / 2);
let y = if is_chat_mode {
position.y + size.height as i32 - 50
} else {
position.y + 30
};
// Move the mouse cursor to the calculated position
if enigo.move_mouse(x, y, Coordinate::Abs).is_ok() {
// // Simulate a left mouse click
let _ = enigo.button(Button::Left, Direction::Click);
// let _ = enigo.button(Button::Left, Direction::Release);
thread::sleep(Duration::from_millis(100));
// Move the mouse cursor back to the original position
let _ = enigo.move_mouse(original_x, original_y, Coordinate::Abs);
}
}
}
}
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = window;
let _ = is_chat_mode;
}
}
#[tauri::command]
async fn hide_check(app_handle: AppHandle) {
log::debug!("check window was closed");
let window = &app_handle
.get_webview_window(CHECK_WINDOW_LABEL)
.expect("we have a check window");
/// Log format:
///
/// ```text
/// [time] [log level] [file module:line] message
/// ```
///
/// Example:
///
///
/// ```text
/// [05-11 17:00:00] [INF] [coco_lib:625] Coco-AI started
/// ```
fn set_up_tauri_logger() -> TauriPlugin<tauri::Wry> {
use log::Level;
window.hide().unwrap();
fn format_log_level(level: Level) -> &'static str {
match level {
Level::Trace => "TRC",
Level::Debug => "DBG",
Level::Info => "INF",
Level::Warn => "WAR",
Level::Error => "ERR",
}
}
fn format_target_and_line(record: &log::Record) -> String {
let mut str = record.target().to_string();
if let Some(line) = record.line() {
str.push(':');
str.push_str(&line.to_string());
}
str
}
tauri_plugin_log::Builder::new()
.format(|out, message, record| {
let now = chrono::Local::now().format("%m-%d %H:%M:%S");
let level = format_log_level(record.level());
let target_and_line = format_target_and_line(record);
out.finish(format_args!(
"[{}] [{}] [{}] {}",
now, level, target_and_line, message
));
})
.level(log::LevelFilter::Debug)
.build()
}

View File

@@ -12,10 +12,9 @@ pub use with_feature::*;
#[cfg(not(feature = "use_pizza_engine"))]
pub use without_feature::*;
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct AppEntry {
path: String,
name: String,
@@ -25,26 +24,15 @@ pub struct AppEntry {
is_disabled: bool,
}
#[derive(serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AppMetadata {
name: String,
r#where: String,
size: u64,
icon: String,
created: u128,
modified: u128,
last_opened: u128,
}
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "Applications",
"platforms": ["macos", "linux", "windows"],
"name": "Applications",
"description": "Application search",
"icon": "font_Application",
"type": "group",
"enabled": true
}
"#;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,18 @@
use super::super::Extension;
use super::AppMetadata;
use crate::common::error::SearchError;
use crate::common::search::{QueryResponse, QuerySource, SearchQuery};
use crate::common::traits::SearchSource;
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::local::LOCAL_QUERY_SOURCE_TYPE;
use async_trait::async_trait;
use tauri::AppHandle;
use tauri::{AppHandle, Runtime};
use super::AppEntry;
use super::AppMetadata;
pub(crate) const QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME: &str = "Applications";
pub struct ApplicationSearchSource;
impl ApplicationSearchSource {
pub async fn prepare_index_and_store(_app_handle: AppHandle) -> Result<(), String> {
pub async fn init<R: Runtime>(_app_handle: AppHandle<R>) -> Result<(), String> {
Ok(())
}
}
@@ -30,11 +30,7 @@ impl SearchSource for ApplicationSearchSource {
}
}
async fn search(
&self,
_tauri_app_handle: AppHandle,
_query: SearchQuery,
) -> Result<QueryResponse, SearchError> {
async fn search(&self, _query: SearchQuery) -> Result<QueryResponse, SearchError> {
Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
@@ -43,39 +39,49 @@ impl SearchSource for ApplicationSearchSource {
}
}
pub fn set_app_alias(_tauri_app_handle: &AppHandle, _app_path: &str, _alias: &str) {
#[tauri::command]
pub async fn set_app_alias(_app_path: String, _alias: String) -> Result<(), String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub fn register_app_hotkey(
_tauri_app_handle: &AppHandle,
_app_path: &str,
_hotkey: &str,
#[tauri::command]
pub async fn register_app_hotkey<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
_hotkey: String,
) -> Result<(), String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub fn unregister_app_hotkey(_tauri_app_handle: &AppHandle, _app_path: &str) -> Result<(), String> {
#[tauri::command]
pub async fn unregister_app_hotkey<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<(), String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub fn disable_app_search(_tauri_app_handle: &AppHandle, _app_path: &str) -> Result<(), String> {
#[tauri::command]
pub async fn disable_app_search<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<(), String> {
// no-op
Ok(())
}
pub fn enable_app_search(_tauri_app_handle: &AppHandle, _app_path: &str) -> Result<(), String> {
// no-op
Ok(())
}
pub fn is_app_search_enabled(_app_path: &str) -> bool {
false
}
#[tauri::command]
pub async fn add_app_search_path(
_tauri_app_handle: AppHandle,
pub async fn enable_app_search<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<(), String> {
// no-op
Ok(())
}
#[tauri::command]
pub async fn add_app_search_path<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_search_path: String,
) -> Result<(), String> {
// no-op
@@ -83,8 +89,8 @@ pub async fn add_app_search_path(
}
#[tauri::command]
pub async fn remove_app_search_path(
_tauri_app_handle: AppHandle,
pub async fn remove_app_search_path<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_search_path: String,
) -> Result<(), String> {
// no-op
@@ -92,37 +98,24 @@ pub async fn remove_app_search_path(
}
#[tauri::command]
pub async fn get_app_search_path(_tauri_app_handle: AppHandle) -> Vec<String> {
pub async fn get_app_search_path<R: Runtime>(_tauri_app_handle: AppHandle<R>) -> Vec<String> {
// Return an empty list
Vec::new()
}
#[tauri::command]
pub async fn get_app_list(_tauri_app_handle: AppHandle) -> Result<Vec<Extension>, String> {
pub async fn get_app_list<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
) -> Result<Vec<AppEntry>, String> {
// Return an empty list
Ok(Vec::new())
}
#[tauri::command]
pub async fn get_app_metadata(
_tauri_app_handle: AppHandle,
pub async fn get_app_metadata<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<AppMetadata, String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub(crate) fn set_apps_hotkey(_tauri_app_handle: &AppHandle) -> Result<(), String> {
// no-op
Ok(())
}
pub(crate) fn unset_apps_hotkey(_tauri_app_handle: &AppHandle) -> Result<(), String> {
// no-op
Ok(())
}
#[tauri::command]
pub async fn reindex_applications(_tauri_app_handle: AppHandle) -> Result<(), String> {
// no-op
Ok(())
}

View File

@@ -0,0 +1,163 @@
use super::LOCAL_QUERY_SOURCE_TYPE;
use crate::common::{
document::{DataSourceReference, Document},
error::SearchError,
search::{QueryResponse, QuerySource, SearchQuery},
traits::SearchSource,
};
use async_trait::async_trait;
use chinese_number::{ChineseCase, ChineseCountMethod, ChineseVariant, NumberToChinese};
use num2words::Num2Words;
use serde_json::Value;
use std::collections::HashMap;
pub(crate) const DATA_SOURCE_ID: &str = "Calculator";
pub struct CalculatorSource {
base_score: f64,
}
impl CalculatorSource {
pub fn new(base_score: f64) -> Self {
CalculatorSource { base_score }
}
}
fn parse_query(query: String) -> Value {
let mut query_json = serde_json::Map::new();
let operators = ["+", "-", "*", "/", "%"];
let found_operators: Vec<_> = query
.chars()
.filter(|c| operators.contains(&c.to_string().as_str()))
.collect();
if found_operators.len() == 1 {
let operation = match found_operators[0] {
'+' => "sum",
'-' => "subtract",
'*' => "multiply",
'/' => "divide",
'%' => "remainder",
_ => "expression",
};
query_json.insert("type".to_string(), Value::String(operation.to_string()));
} else {
query_json.insert("type".to_string(), Value::String("expression".to_string()));
}
query_json.insert("value".to_string(), Value::String(query));
Value::Object(query_json)
}
fn parse_result(num: f64) -> Value {
let mut result_json = serde_json::Map::new();
let to_zh = num
.to_chinese(
ChineseVariant::Simple,
ChineseCase::Upper,
ChineseCountMethod::TenThousand,
)
.unwrap_or(num.to_string());
let to_en = Num2Words::new(num)
.to_words()
.map(|s| {
let mut chars = s.chars();
let mut result = String::new();
let mut capitalize = true;
while let Some(c) = chars.next() {
if c == ' ' || c == '-' {
result.push(c);
capitalize = true;
} else if capitalize {
result.extend(c.to_uppercase());
capitalize = false;
} else {
result.push(c);
}
}
result
})
.unwrap_or(num.to_string());
result_json.insert("value".to_string(), Value::String(num.to_string()));
result_json.insert("toZh".to_string(), Value::String(to_zh));
result_json.insert("toEn".to_string(), Value::String(to_en));
Value::Object(result_json)
}
#[async_trait]
impl SearchSource for CalculatorSource {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(DATA_SOURCE_ID.into())
.to_string_lossy()
.into(),
id: DATA_SOURCE_ID.into(),
}
}
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError> {
let query_string = query
.query_strings
.get("query")
.unwrap_or(&"".to_string())
.to_string();
if query_string.is_empty() || query_string.len() == 1 {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
match meval::eval_str(&query_string) {
Ok(num) => {
let mut payload: HashMap<String, Value> = HashMap::new();
let payload_query = parse_query(query_string);
let payload_result = parse_result(num);
payload.insert("query".to_string(), payload_query);
payload.insert("result".to_string(), payload_result);
let doc = Document {
id: DATA_SOURCE_ID.to_string(),
category: Some(DATA_SOURCE_ID.to_string()),
payload: Some(payload),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(DATA_SOURCE_ID.into()),
id: Some(DATA_SOURCE_ID.into()),
icon: None,
}),
..Default::default()
};
return Ok(QueryResponse {
source: self.get_type(),
hits: vec![(doc, self.base_score)],
total_hits: 1,
});
}
Err(_) => {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
};
}
}

164
src-tauri/src/local/mod.rs Normal file
View File

@@ -0,0 +1,164 @@
pub mod application;
pub mod calculator;
pub mod file_system;
use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::OnceLock;
use crate::common::register::SearchSourceRegistry;
use serde_json::Value as Json;
use tauri::{AppHandle, Manager, Runtime};
use tauri_plugin_store::StoreExt;
pub const LOCAL_QUERY_SOURCE_TYPE: &str = "local";
pub const TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE: &str = "local_query_source_enabled_state";
trait SearchSourceState {
#[cfg_attr(not(feature = "use_pizza_engine"), allow(unused))]
fn as_mut_any(&mut self) -> &mut dyn Any;
}
#[async_trait::async_trait(?Send)]
trait Task: Send + Sync {
fn search_source_id(&self) -> &'static str;
async fn exec(&mut self, state: &mut Option<Box<dyn SearchSourceState>>);
}
static RUNTIME_TX: OnceLock<tokio::sync::mpsc::UnboundedSender<Box<dyn Task>>> = OnceLock::new();
pub(crate) fn start_pizza_engine_runtime() {
std::thread::spawn(|| {
let rt = tokio::runtime::Runtime::new().unwrap();
let main = async {
let mut states: HashMap<String, Option<Box<dyn SearchSourceState>>> = HashMap::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
RUNTIME_TX.set(tx).unwrap();
while let Some(mut task) = rx.recv().await {
let opt_search_source_state = match states.entry(task.search_source_id().into()) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(None),
};
task.exec(opt_search_source_state).await;
}
};
rt.block_on(main);
});
}
pub(crate) async fn init_local_search_source<R: Runtime>(
app_handle: &AppHandle<R>,
) -> Result<(), String> {
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.map_err(|e| e.to_string())?;
if enabled_status_store.is_empty() {
enabled_status_store.set(
application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME,
Json::Bool(true),
);
enabled_status_store.set(calculator::DATA_SOURCE_ID, Json::Bool(true));
}
let registry = app_handle.state::<SearchSourceRegistry>();
application::ApplicationSearchSource::init(app_handle.clone()).await?;
for (id, enabled) in enabled_status_store.entries() {
let enabled = match enabled {
Json::Bool(b) => b,
_ => unreachable!("enabled state should be stored as a boolean"),
};
if enabled {
if id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
registry
.register_source(application::ApplicationSearchSource)
.await;
}
if id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
registry.register_source(calculator_search).await;
}
}
}
Ok(())
}
#[tauri::command]
pub async fn get_disabled_local_query_sources<R: Runtime>(app_handle: AppHandle<R>) -> Vec<String> {
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.unwrap_or_else(|e| {
panic!(
"tauri store [{}] should exist and be loaded, but that's not true due to error [{}]",
TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e
)
});
let mut disabled_local_query_sources = Vec::new();
for (id, enabled) in enabled_status_store.entries() {
let enabled = match enabled {
Json::Bool(b) => b,
_ => unreachable!("enabled state should be stored as a boolean"),
};
if !enabled {
disabled_local_query_sources.push(id);
}
}
disabled_local_query_sources
}
#[tauri::command]
pub async fn enable_local_query_source<R: Runtime>(
app_handle: AppHandle<R>,
query_source_id: String,
) {
let registry = app_handle.state::<SearchSourceRegistry>();
if query_source_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
let application_search = application::ApplicationSearchSource;
registry.register_source(application_search).await;
}
if query_source_id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
registry.register_source(calculator_search).await;
}
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.unwrap_or_else(|e| {
panic!(
"tauri store [{}] should exist and be loaded, but that's not true due to error [{}]",
TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e
)
});
enabled_status_store.set(query_source_id, Json::Bool(true));
}
#[tauri::command]
pub async fn disable_local_query_source<R: Runtime>(
app_handle: AppHandle<R>,
query_source_id: String,
) {
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(&query_source_id).await;
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.unwrap_or_else(|e| {
panic!(
"tauri store [{}] should exist and be loaded, but that's not true due to error [{}]",
TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e
)
});
enabled_status_store.set(query_source_id, Json::Bool(false));
}

View File

@@ -1,79 +1,5 @@
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
use coco_lib::util::logging::app_log_dir;
use std::fs::OpenOptions;
use std::io::Write;
/// Set up panic hook to log panic information to a file
fn setup_panic_hook() {
std::panic::set_hook(Box::new(|panic_info| {
let timestamp = chrono::Local::now();
// "%Y-%m-%d %H:%M:%S"
//
// I would like to use the above format, but Windows does not allow that
// and complains with OS error 123.
let datetime_str = timestamp.format("%Y-%m-%d-%H-%M-%S").to_string();
let log_dir = app_log_dir();
// Ensure the log directory exists
if let Err(e) = std::fs::create_dir_all(&log_dir) {
eprintln!("Panic hook error: failed to create log directory: {}", e);
return;
}
let panic_file = log_dir.join(format!("{}_rust_panic.log", datetime_str));
// Prepare panic information
let panic_message = if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic_info.payload().downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic message".to_string()
};
let location = if let Some(location) = panic_info.location() {
format!(
"{}:{}:{}",
location.file(),
location.line(),
location.column()
)
} else {
"Unknown location".to_string()
};
// Use `force_capture()` instead of `capture()` as we want backtrace
// regardless of whether the corresponding env vars are set or not.
let backtrace = std::backtrace::Backtrace::force_capture();
let panic_log = format!(
"Time: [{}]\nLocation: [{}]\nMessage: [{}]\nBacktrace: \n{}",
datetime_str, location, panic_message, backtrace
);
// Write to panic file
match OpenOptions::new()
.create(true)
.append(true)
.open(&panic_file)
{
Ok(mut file) => {
if let Err(e) = writeln!(file, "{}", panic_log) {
eprintln!("Panic hook error: Failed to write panic to file: {}", e);
}
}
Err(e) => {
eprintln!("Panic hook error: Failed to open panic log file: {}", e);
}
}
}));
}
fn main() {
// Panic hook setup should be the first thing to do, everything could panic!
setup_panic_hook();
coco_lib::run();
}

View File

@@ -3,382 +3,173 @@ use crate::common::register::SearchSourceRegistry;
use crate::common::search::{
FailedRequest, MultiSourceQueryResponse, QueryHits, QuerySource, SearchQuery,
};
use crate::common::traits::SearchSource;
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::server::servers::logout_coco_server;
use crate::server::servers::mark_server_as_offline;
use crate::settings::get_local_query_source_weight;
use function_name::named;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use reqwest::StatusCode;
use futures::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tauri::{AppHandle, Manager};
use tokio::time::{Duration, timeout};
#[named]
use std::collections::HashSet;
use tauri::{AppHandle, Manager, Runtime};
use tokio::time::{timeout, Duration};
#[tauri::command]
pub async fn query_coco_fusion(
tauri_app_handle: AppHandle,
pub async fn query_coco_fusion<R: Runtime>(
app_handle: AppHandle<R>,
from: u64,
size: u64,
query_strings: HashMap<String, String>,
query_timeout: u64,
) -> Result<MultiSourceQueryResponse, SearchError> {
let opt_query_source_id = query_strings.get("querysource");
let search_sources = tauri_app_handle.state::<SearchSourceRegistry>();
let query_source_list = search_sources.get_sources().await;
let timeout_duration = Duration::from_millis(query_timeout);
let search_query = SearchQuery::new(from, size, query_strings.clone());
let query_source_to_search = query_strings.get("querysource");
log::debug!(
"{}() invoked with parameters: from: [{}], size: [{}], query_strings: [{:?}], timeout: [{:?}]",
function_name!(),
from,
size,
query_strings,
timeout_duration
);
// Dispatch to different `query_coco_fusion_xxx()` functions.
if let Some(query_source_id) = opt_query_source_id {
query_coco_fusion_single_query_source(
tauri_app_handle,
query_source_list,
query_source_id.clone(),
timeout_duration,
search_query,
)
.await
} else {
query_coco_fusion_multi_query_sources(
tauri_app_handle,
query_source_list,
timeout_duration,
search_query,
)
.await
}
}
/// Query only 1 query source.
///
/// The logic here is much simpler than `query_coco_fusion_multi_query_sources()`
/// as we don't need to re-rank due to fact that this does not involve multiple
/// query sources.
async fn query_coco_fusion_single_query_source(
tauri_app_handle: AppHandle,
mut query_source_list: Vec<Arc<dyn SearchSource>>,
id_of_query_source_to_query: String,
timeout_duration: Duration,
search_query: SearchQuery,
) -> Result<MultiSourceQueryResponse, SearchError> {
// If this query source ID is specified, we only query this query source.
log::debug!(
"parameter [querysource={}] specified, will only query this query source",
id_of_query_source_to_query
);
let opt_query_source_trait_object_index = query_source_list
.iter()
.position(|query_source| query_source.get_type().id == id_of_query_source_to_query);
let Some(query_source_trait_object_index) = opt_query_source_trait_object_index else {
// It is possible (an edge case) that the frontend invokes `query_coco_fusion()`
// with a querysource that does not exist in the source list:
//
// 1. Search applications
// 2. Navigate to the application sub page
// 3. Disable the application extension in settings, which removes this
// query source from the list
// 4. hide the search window
// 5. Re-open the search window, you will still be in the sub page, type to search
// something
//
// The application query source is not in the source list because the extension
// was disabled and thus removed from the query sources, but the last
// search is indeed invoked with parameter `querysource=application`.
return Ok(MultiSourceQueryResponse {
failed: Vec::new(),
hits: Vec::new(),
total_hits: 0,
});
};
let query_source_trait_object = query_source_list.remove(query_source_trait_object_index);
let query_source = query_source_trait_object.get_type();
let search_fut = query_source_trait_object.search(tauri_app_handle.clone(), search_query);
let timeout_result = timeout(timeout_duration, search_fut).await;
let mut failed_requests: Vec<FailedRequest> = Vec::new();
let mut hits = Vec::new();
let mut total_hits = 0;
match timeout_result {
// Ignore the `_timeout` variable as it won't provide any useful debugging information.
Err(_timeout) => {
log::warn!(
"searching query source [{}] timed out, skip this request",
query_source.id
);
}
Ok(query_result) => match query_result {
Ok(response) => {
total_hits = response.total_hits;
for (document, score) in response.hits {
log::debug!(
"document from query source [{}]: ID [{}], title [{:?}], score [{}]",
response.source.id,
document.id,
document.title,
score
);
let query_hit = QueryHits {
source: Some(response.source.clone()),
score,
document,
};
hits.push(query_hit);
}
}
Err(search_error) => {
query_coco_fusion_handle_failed_request(
tauri_app_handle.clone(),
&mut failed_requests,
query_source,
search_error,
)
.await;
}
},
}
Ok(MultiSourceQueryResponse {
failed: failed_requests,
hits,
total_hits,
})
}
async fn query_coco_fusion_multi_query_sources(
tauri_app_handle: AppHandle,
query_source_trait_object_list: Vec<Arc<dyn SearchSource>>,
timeout_duration: Duration,
search_query: SearchQuery,
) -> Result<MultiSourceQueryResponse, SearchError> {
log::debug!(
"will query query sources {:?}",
query_source_trait_object_list
.iter()
.map(|search_source| search_source.get_type().id.clone())
.collect::<Vec<String>>()
);
let query_keyword = search_query
.query_strings
.get("query")
.unwrap_or(&"".to_string())
.clone();
let size = search_query.size;
let search_sources = app_handle.state::<SearchSourceRegistry>();
let sources_future = search_sources.get_sources();
let mut futures = FuturesUnordered::new();
let mut sources = HashMap::new();
for query_source_trait_object in query_source_trait_object_list {
let query_source = query_source_trait_object.get_type().clone();
let tauri_app_handle_clone = tauri_app_handle.clone();
let search_query_clone = search_query.clone();
let sources_list = sources_future.await;
futures.push(async move {
(
// Store `query_source` as part of future for debugging purposes.
query_source,
timeout(timeout_duration, async {
query_source_trait_object
.search(tauri_app_handle_clone, search_query_clone)
.await
})
.await,
)
});
// Time limit for each query
let timeout_duration = Duration::from_millis(query_timeout);
// Push all queries into futures
for query_source in sources_list {
let query_source_type = query_source.get_type().clone();
if let Some(query_source_to_search) = query_source_to_search {
// We should not search this data source
if &query_source_type.id != query_source_to_search {
continue;
}
}
sources.insert(query_source_type.id.clone(), query_source_type);
let query = SearchQuery::new(from, size, query_strings.clone());
let query_source_clone = query_source.clone(); // Clone Arc to avoid ownership issues
futures.push(tokio::spawn(async move {
// Timeout each query execution
timeout(timeout_duration, async {
query_source_clone.search(query).await
})
.await
}));
}
let mut total_hits = 0;
let mut failed_requests = Vec::new();
let mut all_hits_grouped_by_query_source: HashMap<QuerySource, Vec<QueryHits>> = HashMap::new();
let mut all_hits: Vec<(String, QueryHits, f64)> = Vec::new();
let mut hits_per_source: HashMap<String, Vec<(QueryHits, f64)>> = HashMap::new();
while let Some((query_source, timeout_result)) = futures.next().await {
match timeout_result {
// Ignore the `_timeout` variable as it won't provide any useful debugging information.
Err(_timeout) => {
log::warn!(
"searching query source [{}] timed out, skip this request",
query_source.id
);
while let Some(result) = futures.next().await {
match result {
Ok(Ok(Ok(response))) => {
total_hits += response.total_hits;
let source_id = response.source.id.clone();
for (doc, score) in response.hits {
let query_hit = QueryHits {
source: Some(response.source.clone()),
score,
document: doc,
};
all_hits.push((source_id.clone(), query_hit.clone(), score));
hits_per_source
.entry(source_id.clone())
.or_insert_with(Vec::new)
.push((query_hit, score));
}
}
Ok(query_result) => match query_result {
Ok(response) => {
total_hits += response.total_hits;
for (document, score) in response.hits {
log::debug!(
"document from query source [{}]: ID [{}], title [{:?}], score [{}]",
response.source.id,
document.id,
document.title,
score
);
let query_hit = QueryHits {
source: Some(response.source.clone()),
score,
document,
};
all_hits_grouped_by_query_source
.entry(query_source.clone())
.or_insert_with(Vec::new)
.push(query_hit);
}
}
Err(search_error) => {
query_coco_fusion_handle_failed_request(
tauri_app_handle.clone(),
&mut failed_requests,
query_source,
search_error,
)
.await;
}
},
}
}
let n_sources = all_hits_grouped_by_query_source.len();
if n_sources == 0 {
return Ok(MultiSourceQueryResponse {
failed: Vec::new(),
hits: Vec::new(),
total_hits: 0,
});
}
/*
* Apply settings: local query source weight
*/
let local_query_source_weight: f64 = get_local_query_source_weight(tauri_app_handle);
// Scores remain unchanged if it is 1.0
if local_query_source_weight != 1.0 {
for (query_source, hits) in all_hits_grouped_by_query_source.iter_mut() {
if query_source.r#type == LOCAL_QUERY_SOURCE_TYPE {
hits.iter_mut()
.for_each(|hit| hit.score = hit.score * local_query_source_weight);
Ok(Ok(Err(err))) => {
failed_requests.push(FailedRequest {
source: QuerySource {
r#type: "N/A".into(),
name: "N/A".into(),
id: "N/A".into(),
},
status: 0,
error: Some(err.to_string()),
reason: None,
});
}
Ok(Err(err)) => {
failed_requests.push(FailedRequest {
source: QuerySource {
r#type: "N/A".into(),
name: "N/A".into(),
id: "N/A".into(),
},
status: 0,
error: Some(err.to_string()),
reason: None,
});
}
// Timeout reached, skip this request
_ => {
failed_requests.push(FailedRequest {
source: QuerySource {
r#type: "N/A".into(),
name: "N/A".into(),
id: "N/A".into(),
},
status: 0,
error: Some(format!("{:?}", &result)),
reason: None,
});
}
}
}
/*
* Sort hits within each source by score (descending) in case data sources
* do not sort them
*/
for hits in all_hits_grouped_by_query_source.values_mut() {
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Greater)
});
// Sort hits within each source by score (descending)
for hits in hits_per_source.values_mut() {
hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
/*
* Collect hits evenly across sources, to ensure:
*
* 1. All sources have hits returned
* 2. Query sources with many hits won't dominate
*/
let mut final_hits_grouped_by_source_id: HashMap<String, Vec<QueryHits>> = HashMap::new();
let mut pruned: HashMap<&str, &[QueryHits]> = HashMap::new();
// Include at least 2 hits from each query source
let max_hits_per_source = (size as usize / n_sources).max(2);
for (query_source, hits) in all_hits_grouped_by_query_source.iter() {
let hits_taken = if hits.len() > max_hits_per_source {
pruned.insert(&query_source.id, &hits[max_hits_per_source..]);
hits[0..max_hits_per_source].to_vec()
} else {
hits.clone()
};
final_hits_grouped_by_source_id.insert(query_source.id.clone(), hits_taken);
}
let final_hits_len = final_hits_grouped_by_source_id
.iter()
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
let pruned_len = pruned
.iter()
.fold(0, |acc: usize, (_source_id, hits)| acc + hits.len());
/*
* If we still need more hits, take the highest-scoring from `pruned`
*
* `pruned` contains sorted arrays, we scan it in a way similar to
* how n-way-merge-sort extracts the element with the greatest value.
*/
if final_hits_len < size as usize {
let n_need = size as usize - final_hits_len;
let n_have = pruned_len;
let n_take = n_have.min(n_need);
for _ in 0..n_take {
let mut highest_score_hit: Option<(&str, &QueryHits)> = None;
for (source_id, sorted_hits) in pruned.iter_mut() {
if sorted_hits.is_empty() {
continue;
}
let hit = &sorted_hits[0];
let have_higher_score_hit = match highest_score_hit {
Some((_, current_highest_score_hit)) => {
hit.score > current_highest_score_hit.score
}
None => true,
};
if have_higher_score_hit {
highest_score_hit = Some((*source_id, hit));
// Advance sorted_hits by 1 element, if have
if sorted_hits.len() == 1 {
*sorted_hits = &[];
} else {
*sorted_hits = &sorted_hits[1..];
}
}
}
let (source_id, hit) = highest_score_hit.expect("`pruned` should contain at least `n_take` elements so `highest_score_hit` should be set");
final_hits_grouped_by_source_id
.get_mut(source_id)
.expect("all the source_ids stored in `pruned` come from `final_hits_grouped_by_source_id`, so it should exist")
.push(hit.clone());
}
}
/*
* Re-rank the final hits
*/
if n_sources > 1 {
boosted_levenshtein_rerank(&query_keyword, &mut final_hits_grouped_by_source_id);
}
let total_sources = hits_per_source.len();
let max_hits_per_source = if total_sources > 0 {
size as usize / total_sources
} else {
size as usize
};
let mut final_hits = Vec::new();
for (_source_id, hits) in final_hits_grouped_by_source_id {
final_hits.extend(hits);
let mut seen_docs = HashSet::new(); // To track documents we've already added
// Distribute hits fairly across sources
for (_source_id, hits) in &mut hits_per_source {
let take_count = hits.len().min(max_hits_per_source);
for (doc, _) in hits.drain(0..take_count) {
if !seen_docs.contains(&doc.document.id) {
seen_docs.insert(doc.document.id.clone());
final_hits.push(doc);
}
}
}
// If we still need more hits, take the highest-scoring remaining ones
if final_hits.len() < size as usize {
let remaining_needed = size as usize - final_hits.len();
// Sort all hits by score descending, removing duplicates by document ID
all_hits.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let extra_hits = all_hits
.into_iter()
.filter(|(source_id, _, _)| hits_per_source.contains_key(source_id)) // Only take from known sources
.filter_map(|(_, doc, _)| {
if !seen_docs.contains(&doc.document.id) {
seen_docs.insert(doc.document.id.clone());
Some(doc)
} else {
None
}
})
.take(remaining_needed)
.collect::<Vec<_>>();
final_hits.extend(extra_hits);
}
// **Sort final hits by score descending**
@@ -388,154 +179,9 @@ async fn query_coco_fusion_multi_query_sources(
.unwrap_or(std::cmp::Ordering::Equal)
});
// Truncate `final_hits` in case it contains more than `size` hits
final_hits.truncate(size as usize);
if final_hits.len() < 5 {
//TODO: Add a recommendation system to suggest more sources
log::info!(
"Less than 5 hits found, consider using recommendation to find more suggestions."
);
//local: recent history, local extensions
//remote: ai agents, quick links, other tasks, managed by server
}
Ok(MultiSourceQueryResponse {
failed: failed_requests,
hits: final_hits,
total_hits,
})
}
use std::collections::HashSet;
use strsim::levenshtein;
fn boosted_levenshtein_rerank(
query: &str,
all_hits_grouped_by_source_id: &mut HashMap<String, Vec<QueryHits>>,
) {
let query_lower = query.to_lowercase();
for (source_id, hits) in all_hits_grouped_by_source_id.iter_mut() {
// Skip special sources like calculator
if source_id == crate::extension::built_in::calculator::DATA_SOURCE_ID {
continue;
}
for hit in hits.iter_mut() {
let document_title = hit.document.title.as_deref().unwrap_or("");
let document_title_lowercase = document_title.to_lowercase();
let new_score = {
let mut score = 0.0;
// --- Exact or substring boost ---
if document_title.contains(query) {
score += 0.4;
} else if document_title_lowercase.contains(&query_lower) {
score += 0.2;
}
// --- Levenshtein distance (character similarity) ---
let dist = levenshtein(&query_lower, &document_title_lowercase);
let max_len = query_lower.len().max(document_title.len());
let levenshtein_score = if max_len > 0 {
(1.0 - (dist as f64 / max_len as f64)) as f32
} else {
0.0
};
// --- Jaccard similarity (token overlap) ---
let jaccard_score = jaccard_similarity(&query_lower, &document_title_lowercase);
// --- Combine scores (weights adjustable) ---
// Levenshtein emphasizes surface similarity
// Jaccard emphasizes term overlap (semantic hint)
let hybrid_score = 0.7 * levenshtein_score + 0.3 * jaccard_score;
// --- Apply hybrid score ---
score += hybrid_score;
// --- Limit score range ---
score.min(1.0) as f64
};
hit.score = new_score;
}
}
}
/// Compute token-based Jaccard similarity
fn jaccard_similarity(a: &str, b: &str) -> f32 {
let a_tokens: HashSet<_> = tokenize(a).into_iter().collect();
let b_tokens: HashSet<_> = tokenize(b).into_iter().collect();
if a_tokens.is_empty() || b_tokens.is_empty() {
return 0.0;
}
let intersection = a_tokens.intersection(&b_tokens).count() as f32;
let union = a_tokens.union(&b_tokens).count() as f32;
intersection / union
}
/// Basic tokenizer (case-insensitive, alphanumeric words only)
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect()
}
/// Helper function to handle a failed request.
///
/// Extracted as a function because `query_coco_fusion_single_query_source()` and
/// `query_coco_fusion_multi_query_sources()` share the same error handling logic.
async fn query_coco_fusion_handle_failed_request(
tauri_app_handle: AppHandle,
failed_requests: &mut Vec<FailedRequest>,
query_source: QuerySource,
search_error: SearchError,
) {
log::error!(
"searching query source [{}] failed, error [{}]",
query_source.id,
search_error
);
let mut status_code_num: u16 = 0;
if let SearchError::HttpError {
status_code: opt_status_code,
msg: _,
} = search_error
{
if let Some(status_code) = opt_status_code {
status_code_num = status_code.as_u16();
if status_code != StatusCode::OK {
if status_code == StatusCode::UNAUTHORIZED {
// This Coco server is unavailable. In addition to marking it as
// unavailable, we need to log out because the status code is 401.
logout_coco_server(tauri_app_handle.clone(), query_source.id.to_string()).await.unwrap_or_else(|e| {
panic!(
"the search request to Coco server [id {}, name {}] failed with status code {}, the login token is invalid, we are trying to log out, but failed with error [{}]",
query_source.id, query_source.name, StatusCode::UNAUTHORIZED, e
);
})
} else {
// This Coco server is unavailable
mark_server_as_offline(tauri_app_handle.clone(), &query_source.id).await;
}
}
}
}
failed_requests.push(FailedRequest {
source: query_source,
status: status_code_num,
error: Some(search_error.to_string()),
reason: None,
});
}

View File

@@ -15,6 +15,42 @@ pub struct UploadAttachmentResponse {
pub attachments: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttachmentSource {
pub id: String,
pub created: String,
pub updated: String,
pub session: String,
pub name: String,
pub icon: String,
pub url: String,
pub size: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttachmentHit {
pub _index: String,
pub _type: Option<String>,
pub _id: String,
pub _score: Option<f64>,
pub _source: AttachmentSource,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttachmentHits {
pub total: Value,
pub max_score: Option<f64>,
pub hits: Option<Vec<AttachmentHit>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetAttachmentResponse {
pub took: u32,
pub timed_out: bool,
pub _shards: Option<Value>,
pub hits: AttachmentHits,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeleteAttachmentResponse {
pub _id: String,
@@ -24,6 +60,7 @@ pub struct DeleteAttachmentResponse {
#[command]
pub async fn upload_attachment(
server_id: String,
session_id: String,
file_paths: Vec<PathBuf>,
) -> Result<UploadAttachmentResponse, String> {
let mut form = Form::new();
@@ -45,12 +82,10 @@ pub async fn upload_attachment(
form = form.part("files", part);
}
let server = get_server_by_id(&server_id)
.await
.ok_or("Server not found")?;
let url = HttpClient::join_url(&server.endpoint, &format!("attachment/_upload"));
let server = get_server_by_id(&server_id).ok_or("Server not found")?;
let url = HttpClient::join_url(&server.endpoint, &format!("chat/{}/_upload", session_id));
let token = get_server_token(&server_id).await;
let token = get_server_token(&server_id).await?;
let mut headers = HashMap::new();
if let Some(token) = token {
headers.insert("X-API-TOKEN".to_string(), token.access_token);
@@ -72,25 +107,20 @@ pub async fn upload_attachment(
}
#[command]
pub async fn get_attachment_by_ids(
pub async fn get_attachment(
server_id: String,
attachments: Vec<String>,
) -> Result<Value, String> {
println!("get_attachment_by_ids server_id: {}", server_id);
println!("get_attachment_by_ids attachments: {:?}", attachments);
session_id: String,
) -> Result<GetAttachmentResponse, String> {
let mut query_params = HashMap::new();
query_params.insert("session".to_string(), serde_json::Value::String(session_id));
let request_body = serde_json::json!({
"attachments": attachments
});
let body = reqwest::Body::from(serde_json::to_string(&request_body).unwrap());
let response = HttpClient::post(&server_id, "/attachment/_search", None, Some(body))
let response = HttpClient::get(&server_id, "/attachment/_search", Some(query_params))
.await
.map_err(|e| format!("Request error: {}", e))?;
let body = get_response_body_text(response).await?;
serde_json::from_str::<Value>(&body)
serde_json::from_str::<GetAttachmentResponse>(&body)
.map_err(|e| format!("Failed to parse attachment response: {}", e))
}

View File

@@ -4,31 +4,31 @@ use crate::server::servers::{
get_server_by_id, persist_servers, persist_servers_token, save_access_token, save_server,
try_register_server_to_search_source,
};
use tauri::AppHandle;
use tauri::{AppHandle, Runtime};
#[allow(dead_code)]
fn request_access_token_url(request_id: &str) -> String {
// Remove the endpoint part and keep just the path for the request
format!("/auth/access_token?request_id={}", request_id)
format!("/auth/request_access_token?request_id={}", request_id)
}
#[tauri::command]
pub async fn handle_sso_callback(
app_handle: AppHandle,
pub async fn handle_sso_callback<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
request_id: String,
code: String,
) -> Result<(), String> {
// Retrieve the server details using the server ID
let server = get_server_by_id(&server_id).await;
let server = get_server_by_id(&server_id);
let expire_in = 3600; // TODO, need to update to actual expire_in value
if let Some(mut server) = server {
// Save the access token for the server
let access_token = ServerAccessToken::new(server_id.clone(), code.clone(), expire_in);
// dbg!(&server_id, &request_id, &code, &token);
save_access_token(server_id.clone(), access_token).await;
persist_servers_token(&app_handle).await?;
save_access_token(server_id.clone(), access_token);
persist_servers_token(&app_handle)?;
// Register the server to the search source
try_register_server_to_search_source(app_handle.clone(), &server).await;
@@ -41,7 +41,7 @@ pub async fn handle_sso_callback(
Ok(p) => {
server.profile = Some(p);
server.available = true;
save_server(&server).await;
save_server(&server);
persist_servers(&app_handle).await?;
Ok(())
}

View File

@@ -1,12 +1,11 @@
use crate::common::connector::Connector;
use crate::common::search::parse_search_results;
use crate::server::http_client::{HttpClient, status_code_check};
use crate::server::http_client::HttpClient;
use crate::server::servers::get_all_servers;
use http::StatusCode;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tauri::AppHandle;
use tauri::{AppHandle, Runtime};
lazy_static! {
static ref CONNECTOR_CACHE: Arc<RwLock<HashMap<String, HashMap<String, Connector>>>> =
@@ -29,8 +28,8 @@ pub fn get_connector_by_id(server_id: &str, connector_id: &str) -> Option<Connec
Some(connector.clone())
}
pub async fn refresh_all_connectors(app_handle: &AppHandle) -> Result<(), String> {
let servers = get_all_servers().await;
pub async fn refresh_all_connectors<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
let servers = get_all_servers();
// Collect all the tasks for fetching and refreshing connectors
let mut server_map = HashMap::new();
@@ -108,7 +107,6 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, Stri
// dbg!("Error fetching connector for id {}: {}", &id, &e);
format!("Error fetching connector: {}", e)
})?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results directly from the response body
let datasource: Vec<Connector> = parse_search_results(resp)
@@ -122,8 +120,8 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, Stri
}
#[tauri::command]
pub async fn get_connectors_by_server(
_app_handle: AppHandle,
pub async fn get_connectors_by_server<R: Runtime>(
_app_handle: AppHandle<R>,
id: String,
) -> Result<Vec<Connector>, String> {
let connectors = fetch_connectors_by_server(&id).await?;

View File

@@ -1,13 +1,19 @@
use crate::common::datasource::DataSource;
use crate::common::search::parse_search_results;
use crate::server::connector::get_connector_by_id;
use crate::server::http_client::{HttpClient, status_code_check};
use crate::server::http_client::HttpClient;
use crate::server::servers::get_all_servers;
use http::StatusCode;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tauri::AppHandle;
use tauri::{AppHandle, Runtime};
#[derive(serde::Deserialize, Debug)]
pub struct GetDatasourcesByServerOptions {
pub from: Option<u32>,
pub size: Option<u32>,
pub query: Option<String>,
}
lazy_static! {
static ref DATASOURCE_CACHE: Arc<RwLock<HashMap<String, HashMap<String, DataSource>>>> =
@@ -31,10 +37,10 @@ pub fn get_datasources_from_cache(server_id: &str) -> Option<HashMap<String, Dat
Some(server_cache.clone())
}
pub async fn refresh_all_datasources(_app_handle: &AppHandle) -> Result<(), String> {
pub async fn refresh_all_datasources<R: Runtime>(_app_handle: &AppHandle<R>) -> Result<(), String> {
// dbg!("Attempting to refresh all datasources");
let servers = get_all_servers().await;
let servers = get_all_servers();
let mut server_map = HashMap::new();
@@ -90,17 +96,50 @@ pub async fn refresh_all_datasources(_app_handle: &AppHandle) -> Result<(), Stri
#[tauri::command]
pub async fn datasource_search(
id: &str,
query_params: Option<Vec<String>>, //["query=abc", "filter=er", "filter=efg", "from=0", "size=5"],
options: Option<GetDatasourcesByServerOptions>,
) -> Result<Vec<DataSource>, String> {
let from = options.as_ref().and_then(|opt| opt.from).unwrap_or(0);
let size = options.as_ref().and_then(|opt| opt.size).unwrap_or(10000);
let query = options
.and_then(|opt| opt.query)
.unwrap_or(String::default());
let mut body = serde_json::json!({
"from": from,
"size": size,
});
if !query.is_empty() {
body["query"] = serde_json::json!({
"bool": {
"must": [{
"query_string": {
"fields": ["combined_fulltext"],
"query": query,
"fuzziness": "AUTO",
"fuzzy_prefix_length": 2,
"fuzzy_max_expansions": 10,
"fuzzy_transpositions": true,
"allow_leading_wildcard": false
}
}]
}
});
}
// Perform the async HTTP request outside the cache lock
let resp = HttpClient::post(id, "/datasource/_search", query_params, None)
let resp = HttpClient::post(
id,
"/datasource/_search",
None,
Some(reqwest::Body::from(body.to_string())),
)
.await
.map_err(|e| format!("Error fetching datasource: {}", e))?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results from the response
let datasources: Vec<DataSource> = parse_search_results(resp).await.map_err(|e| {
//dbg!("Error parsing search results: {}", &e);
dbg!("Error parsing search results: {}", &e);
e.to_string()
})?;
@@ -113,17 +152,50 @@ pub async fn datasource_search(
#[tauri::command]
pub async fn mcp_server_search(
id: &str,
query_params: Option<Vec<String>>,
options: Option<GetDatasourcesByServerOptions>,
) -> Result<Vec<DataSource>, String> {
let from = options.as_ref().and_then(|opt| opt.from).unwrap_or(0);
let size = options.as_ref().and_then(|opt| opt.size).unwrap_or(10000);
let query = options
.and_then(|opt| opt.query)
.unwrap_or(String::default());
let mut body = serde_json::json!({
"from": from,
"size": size,
});
if !query.is_empty() {
body["query"] = serde_json::json!({
"bool": {
"must": [{
"query_string": {
"fields": ["combined_fulltext"],
"query": query,
"fuzziness": "AUTO",
"fuzzy_prefix_length": 2,
"fuzzy_max_expansions": 10,
"fuzzy_transpositions": true,
"allow_leading_wildcard": false
}
}]
}
});
}
// Perform the async HTTP request outside the cache lock
let resp = HttpClient::post(id, "/mcp_server/_search", query_params, None)
let resp = HttpClient::post(
id,
"/mcp_server/_search",
None,
Some(reqwest::Body::from(body.to_string())),
)
.await
.map_err(|e| format!("Error fetching datasource: {}", e))?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results from the response
let mcp_server: Vec<DataSource> = parse_search_results(resp).await.map_err(|e| {
//dbg!("Error parsing search results: {}", &e);
dbg!("Error parsing search results: {}", &e);
e.to_string()
})?;

View File

@@ -1,19 +1,17 @@
use crate::server::servers::{get_server_by_id, get_server_token};
use crate::util::app_lang::get_app_lang;
use crate::util::platform::Platform;
use http::{HeaderName, HeaderValue, StatusCode};
use http::{HeaderName, HeaderValue};
use once_cell::sync::Lazy;
use reqwest::{Client, Method, RequestBuilder};
use std::collections::HashMap;
use std::sync::LazyLock;
use std::time::Duration;
use tauri_plugin_store::JsonValue;
use tokio::sync::Mutex;
pub(crate) fn new_reqwest_http_client(accept_invalid_certs: bool) -> Client {
Client::builder()
.read_timeout(Duration::from_secs(60)) // Set a timeout of 60 second
.connect_timeout(Duration::from_secs(30)) // Set a timeout of 30 second
.timeout(Duration::from_secs(5 * 60)) // Set a timeout of 5 minute
.read_timeout(Duration::from_secs(3)) // Set a timeout of 3 second
.connect_timeout(Duration::from_secs(3)) // Set a timeout of 3 second
.timeout(Duration::from_secs(10)) // Set a timeout of 10 seconds
.danger_accept_invalid_certs(accept_invalid_certs) // allow self-signed certificates
.build()
.expect("Failed to build client")
@@ -29,26 +27,6 @@ pub static HTTP_CLIENT: Lazy<Mutex<Client>> = Lazy::new(|| {
Mutex::new(new_reqwest_http_client(allow_self_signature))
});
/// These header values won't change during a process's lifetime.
static STATIC_HEADERS: LazyLock<HashMap<String, String>> = LazyLock::new(|| {
HashMap::from([
(
"X-OS-NAME".into(),
Platform::current()
.to_os_name_http_header_str()
.into_owned(),
),
(
"X-OS-VER".into(),
sysinfo::System::os_version()
.expect("sysinfo::System::os_version() should be Some on major systems"),
),
("X-OS-ARCH".into(), sysinfo::System::cpu_arch()),
("X-APP-NAME".into(), "coco-app".into()),
("X-APP-VER".into(), env!("CARGO_PKG_VERSION").into()),
])
});
pub struct HttpClient;
impl HttpClient {
@@ -62,7 +40,7 @@ impl HttpClient {
pub async fn send_raw_request(
method: Method,
url: &str,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>,
headers: Option<HashMap<String, String>>,
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
@@ -78,7 +56,7 @@ impl HttpClient {
Self::get_request_builder(method, url, headers, query_params, body).await;
let response = request_builder.send().await.map_err(|e| {
//dbg!("Failed to send request: {}", &e);
dbg!("Failed to send request: {}", &e);
format!("Failed to send request: {}", e)
})?;
@@ -96,7 +74,7 @@ impl HttpClient {
method: Method,
url: &str,
headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>, // Add query parameters
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> RequestBuilder {
let client = HTTP_CLIENT.lock().await; // Acquire the lock on HTTP_CLIENT
@@ -104,32 +82,8 @@ impl HttpClient {
// Build the request
let mut request_builder = client.request(method.clone(), url);
// Populate the headers defined by us
let mut req_headers = reqwest::header::HeaderMap::new();
for (key, value) in STATIC_HEADERS.iter() {
let key = HeaderName::from_bytes(key.as_bytes())
.expect("headers defined by us should be valid");
let value = HeaderValue::from_str(value.trim()).unwrap_or_else(|e| {
panic!(
"header value [{}] is invalid, error [{}], this should be unreachable",
value, e
);
});
req_headers.insert(key, value);
}
let app_lang = get_app_lang().await.to_string();
req_headers.insert(
"X-APP-LANG",
HeaderValue::from_str(&app_lang).unwrap_or_else(|e| {
panic!(
"header value [{}] is invalid, error [{}], this should be unreachable",
app_lang, e
);
}),
);
// Headers from the function parameter
if let Some(h) = headers {
let mut req_headers = reqwest::header::HeaderMap::new();
for (key, value) in h.into_iter() {
match (
HeaderName::from_bytes(key.as_bytes()),
@@ -152,9 +106,24 @@ impl HttpClient {
request_builder = request_builder.headers(req_headers);
}
if let Some(params) = query_params {
let query: Vec<(&str, &str)> =
params.iter().filter_map(|s| s.split_once('=')).collect();
if let Some(query) = query_params {
// Convert only supported value types into strings
let query: HashMap<String, String> = query
.into_iter()
.filter_map(|(k, v)| {
match v {
JsonValue::String(s) => Some((k, s)),
JsonValue::Number(n) => Some((k, n.to_string())),
JsonValue::Bool(b) => Some((k, b.to_string())),
_ => {
dbg!(
"Unsupported query parameter type. Only strings, numbers, and booleans are supported.",k,v,
);
None
} // skip arrays, objects, nulls
}
})
.collect();
request_builder = request_builder.query(&query);
}
@@ -171,18 +140,18 @@ impl HttpClient {
method: Method,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>,
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
// Fetch the server using the server_id
let server = get_server_by_id(server_id).await;
let server = get_server_by_id(server_id);
if let Some(s) = server {
// Construct the URL
let url = HttpClient::join_url(&s.endpoint, path);
// Retrieve the token for the server (token is optional)
let token = get_server_token(server_id)
.await
.await?
.map(|t| t.access_token.clone());
let mut headers = if let Some(custom_headers) = custom_headers {
@@ -196,16 +165,16 @@ impl HttpClient {
headers.insert("X-API-TOKEN".to_string(), t);
}
// log::debug!(
// "Sending request to server: {}, url: {}, headers: {:?}",
// &server_id,
// &url,
// &headers
// );
log::debug!(
"Sending request to server: {}, url: {}, headers: {:?}",
&server_id,
&url,
&headers
);
Self::send_raw_request(method, &url, query_params, Some(headers), body).await
} else {
Err(format!("Server [{}] not found", server_id))
Err("Server not found".to_string())
}
}
@@ -213,7 +182,7 @@ impl HttpClient {
pub async fn get(
server_id: &str,
path: &str,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::GET, path, None, query_params, None).await
}
@@ -222,7 +191,7 @@ impl HttpClient {
pub async fn post(
server_id: &str,
path: &str,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::POST, path, None, query_params, body).await
@@ -232,7 +201,7 @@ impl HttpClient {
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(
@@ -252,7 +221,7 @@ impl HttpClient {
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(
@@ -272,7 +241,7 @@ impl HttpClient {
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
) -> Result<reqwest::Response, String> {
HttpClient::send_request(
server_id,
@@ -285,30 +254,3 @@ impl HttpClient {
.await
}
}
/// Helper function to check status code.
///
/// If the status code is not in the `allowed_status_codes` list, return an error.
pub(crate) fn status_code_check(
response: &reqwest::Response,
allowed_status_codes: &[StatusCode],
) -> Result<(), String> {
let status_code = response.status();
if !allowed_status_codes.contains(&status_code) {
let msg = format!(
"Response of request [{}] status code failed: status code [{}], which is not in the 'allow' list {:?}",
response.url(),
status_code,
allowed_status_codes
.iter()
.map(|status| status.to_string())
.collect::<Vec<String>>()
);
log::warn!("{}", msg);
Err(msg)
} else {
Ok(())
}
}

View File

@@ -8,6 +8,6 @@ pub mod http_client;
pub mod profile;
pub mod search;
pub mod servers;
pub mod synthesize;
pub mod system_settings;
pub mod transcription;
pub mod websocket;

View File

@@ -1,11 +1,11 @@
use crate::common::http::get_response_body_text;
use crate::common::profile::UserProfile;
use crate::server::http_client::HttpClient;
use tauri::AppHandle;
use tauri::{AppHandle, Runtime};
#[tauri::command]
pub async fn get_user_profiles(
_app_handle: AppHandle,
pub async fn get_user_profiles<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
) -> Result<UserProfile, String> {
// Use the generic GET method from HttpClient

View File

@@ -1,4 +1,4 @@
use crate::common::document::{Document, OnOpened};
use crate::common::document::Document;
use crate::common::error::SearchError;
use crate::common::http::get_response_body_text;
use crate::common::search::{QueryHits, QueryResponse, QuerySource, SearchQuery, SearchResponse};
@@ -6,10 +6,11 @@ use crate::common::server::Server;
use crate::common::traits::SearchSource;
use crate::server::http_client::HttpClient;
use async_trait::async_trait;
// use futures::stream::StreamExt;
use ordered_float::OrderedFloat;
use reqwest::StatusCode;
use std::collections::HashMap;
use tauri::AppHandle;
use tauri_plugin_store::JsonValue;
// use std::hash::Hash;
#[allow(dead_code)]
pub(crate) struct DocumentsSizedCollector {
@@ -44,7 +45,7 @@ impl DocumentsSizedCollector {
}
}
fn documents(self) -> impl ExactSizeIterator<Item = Document> {
fn documents(self) -> impl ExactSizeIterator<Item=Document> {
self.docs.into_iter().map(|(_, doc, _)| doc)
}
@@ -90,74 +91,41 @@ impl SearchSource for CocoSearchSource {
}
}
async fn search(
&self,
_tauri_app_handle: AppHandle,
query: SearchQuery,
) -> Result<QueryResponse, SearchError> {
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError> {
let url = "/query/_search";
let mut total_hits = 0;
let mut hits: Vec<(Document, f64)> = Vec::new();
let mut query_params = Vec::new();
// Add from/size as number values
query_params.push(format!("from={}", query.from));
query_params.push(format!("size={}", query.size));
// Add query strings
let mut query_args: HashMap<String, JsonValue> = HashMap::new();
query_args.insert("from".into(), JsonValue::Number(query.from.into()));
query_args.insert("size".into(), JsonValue::Number(query.size.into()));
for (key, value) in query.query_strings {
query_params.push(format!("{}={}", key, value));
query_args.insert(key, JsonValue::String(value));
}
let response = HttpClient::get(&self.server.id, &url, Some(query_params))
let response = HttpClient::get(
&self.server.id,
&url,
Some(query_args),
)
.await
.map_err(|e| SearchError::HttpError {
status_code: None,
msg: format!("{}", e),
})?;
let status_code = response.status();
if ![StatusCode::OK, StatusCode::CREATED].contains(&status_code) {
return Err(SearchError::HttpError {
status_code: Some(status_code),
msg: format!("Request failed with status code [{}]", status_code),
});
}
.map_err(|e| SearchError::HttpError(format!("Error to send search request: {}", e)))?;
// Use the helper function to parse the response body
let response_body = get_response_body_text(response)
.await
.map_err(|e| SearchError::ParseError(e))?;
.map_err(|e| SearchError::ParseError(format!("Failed to read response body: {}", e)))?;
// Check if the response body is empty
if !response_body.is_empty() {
// log::info!("Search response body: {}", &response_body);
// Parse the search response from the body text
let parsed: SearchResponse<Document> = serde_json::from_str(&response_body)
.map_err(|e| SearchError::ParseError(format!("Failed to parse search response: {}", e)))?;
// Parse the search response from the body text
let parsed: SearchResponse<Document> = serde_json::from_str(&response_body)
.map_err(|e| SearchError::ParseError(format!("{}", e)))?;
// Process the parsed response
total_hits = parsed.hits.total.value as usize;
if let Some(items) = parsed.hits.hits {
for hit in items {
let mut document = hit._source;
// Default _score to 0.0 if None
let score = hit._score.unwrap_or(0.0);
let on_opened = document
.url
.as_ref()
.map(|url| OnOpened::Document { url: url.clone() });
// Set the `on_opened` field as it won't be returned from Coco server
document.on_opened = on_opened;
hits.push((document, score));
}
}
}
// Process the parsed response
let total_hits = parsed.hits.total.value as usize;
let hits: Vec<(Document, f64)> = parsed
.hits
.hits
.into_iter()
.map(|hit| (hit._source, hit._score.unwrap_or(0.0))) // Default _score to 0.0 if None
.collect();
// Return the final result
Ok(QueryResponse {

View File

@@ -1,4 +1,3 @@
use crate::COCO_TAURI_STORE;
use crate::common::http::get_response_body_text;
use crate::common::register::SearchSourceRegistry;
use crate::common::server::{AuthProvider, Provider, Server, ServerAccessToken, Sso, Version};
@@ -6,71 +5,68 @@ use crate::server::connector::fetch_connectors_by_server;
use crate::server::datasource::datasource_search;
use crate::server::http_client::HttpClient;
use crate::server::search::CocoSearchSource;
use function_name;
use http::StatusCode;
use crate::COCO_TAURI_STORE;
use lazy_static::lazy_static;
use reqwest::Method;
use serde_json::Value as JsonValue;
use serde_json::from_value;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::sync::LazyLock;
use std::sync::Arc;
use std::sync::RwLock;
use tauri::Runtime;
use tauri::{AppHandle, Manager};
use tauri_plugin_store::StoreExt;
use tokio::sync::RwLock;
// Assuming you're using serde_json
/// Coco sever list
static SERVER_LIST_CACHE: LazyLock<RwLock<HashMap<String, Server>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
lazy_static! {
static ref SERVER_CACHE: Arc<RwLock<HashMap<String, Server>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref SERVER_TOKEN: Arc<RwLock<HashMap<String, ServerAccessToken>>> =
Arc::new(RwLock::new(HashMap::new()));
}
/// If a server has a token stored here that has not expired, it is considered logged in.
///
/// Since the `expire_at` field of `struct ServerAccessToken` is currently unused,
/// all servers stored here are treated as logged in.
static SERVER_TOKEN_LIST_CACHE: LazyLock<RwLock<HashMap<String, ServerAccessToken>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
#[allow(dead_code)]
fn check_server_exists(id: &str) -> bool {
let cache = SERVER_CACHE.read().unwrap(); // Acquire read lock
cache.contains_key(id)
}
/// `SERVER_LIST_CACHE` will be stored in KV store COCO_TAURI_STORE, under this key.
pub const COCO_SERVERS: &str = "coco_servers";
/// `SERVER_TOKEN_LIST_CACHE` will be stored in KV store COCO_TAURI_STORE, under this key.
const COCO_SERVER_TOKENS: &str = "coco_server_tokens";
pub async fn get_server_by_id(id: &str) -> Option<Server> {
let cache = SERVER_LIST_CACHE.read().await;
pub fn get_server_by_id(id: &str) -> Option<Server> {
let cache = SERVER_CACHE.read().unwrap(); // Acquire read lock
cache.get(id).cloned()
}
pub async fn get_server_token(id: &str) -> Option<ServerAccessToken> {
let cache = SERVER_TOKEN_LIST_CACHE.read().await;
#[tauri::command]
pub async fn get_server_token(id: &str) -> Result<Option<ServerAccessToken>, String> {
let cache = SERVER_TOKEN.read().map_err(|err| err.to_string())?;
cache.get(id).cloned()
Ok(cache.get(id).cloned())
}
pub async fn save_access_token(server_id: String, token: ServerAccessToken) -> bool {
let mut cache = SERVER_TOKEN_LIST_CACHE.write().await;
pub fn save_access_token(server_id: String, token: ServerAccessToken) -> bool {
let mut cache = SERVER_TOKEN.write().unwrap();
cache.insert(server_id, token).is_none()
}
async fn check_endpoint_exists(endpoint: &str) -> bool {
let cache = SERVER_LIST_CACHE.read().await;
fn check_endpoint_exists(endpoint: &str) -> bool {
let cache = SERVER_CACHE.read().unwrap();
cache.values().any(|server| server.endpoint == endpoint)
}
/// Return true if `server` does not exists in the server list, i.e., it is a newly-added
/// server.
pub async fn save_server(server: &Server) -> bool {
let mut cache = SERVER_LIST_CACHE.write().await;
cache.insert(server.id.clone(), server.clone()).is_none()
pub fn save_server(server: &Server) -> bool {
let mut cache = SERVER_CACHE.write().unwrap();
cache.insert(server.id.clone(), server.clone()).is_none() // If the server id did not exist, `insert` will return `None`
}
/// Return the removed `Server` if it exists in the server list.
async fn remove_server_by_id(id: &str) -> Option<Server> {
log::debug!("remove server by id: {}", &id);
let mut cache = SERVER_LIST_CACHE.write().await;
cache.remove(id)
fn remove_server_by_id(id: String) -> bool {
dbg!("remove server by id:", &id);
let mut cache = SERVER_CACHE.write().unwrap();
let deleted = cache.remove(id.as_str());
deleted.is_some()
}
pub async fn persist_servers(app_handle: &AppHandle) -> Result<(), String> {
let cache = SERVER_LIST_CACHE.read().await;
pub async fn persist_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
let cache = SERVER_CACHE.read().unwrap(); // Acquire a read lock, not a write lock, since you're not modifying the cache
// Convert HashMap to Vec for serialization (iterating over values of HashMap)
let servers: Vec<Server> = cache.values().cloned().collect();
@@ -90,16 +86,14 @@ pub async fn persist_servers(app_handle: &AppHandle) -> Result<(), String> {
Ok(())
}
/// Return true if the server token of the server specified by `id` exists in
/// the token list and gets deleted.
pub async fn remove_server_token(id: &str) -> bool {
log::debug!("remove server token by id: {}", &id);
let mut cache = SERVER_TOKEN_LIST_CACHE.write().await;
pub fn remove_server_token(id: &str) -> bool {
dbg!("remove server token by id:", &id);
let mut cache = SERVER_TOKEN.write().unwrap();
cache.remove(id).is_some()
}
pub async fn persist_servers_token(app_handle: &AppHandle) -> Result<(), String> {
let cache = SERVER_TOKEN_LIST_CACHE.read().await;
pub fn persist_servers_token<R: Runtime>(app_handle: &AppHandle<R>) -> Result<(), String> {
let cache = SERVER_TOKEN.read().unwrap(); // Acquire a read lock, not a write lock, since you're not modifying the cache
// Convert HashMap to Vec for serialization (iterating over values of HashMap)
let servers: Vec<ServerAccessToken> = cache.values().cloned().collect();
@@ -110,7 +104,7 @@ pub async fn persist_servers_token(app_handle: &AppHandle) -> Result<(), String>
.map(|server| serde_json::to_value(server).expect("Failed to serialize access_tokens")) // Automatically serialize all fields
.collect();
log::debug!("persist servers token: {:?}", &json_servers);
dbg!(format!("persist servers token: {:?}", &json_servers));
// Save the serialized servers to Tauri's store
app_handle
@@ -149,16 +143,17 @@ fn get_default_server() -> Server {
profile: None,
auth_provider: AuthProvider {
sso: Sso {
url: "https://coco.infini.cloud/sso/login/cloud?provider=coco-cloud&product=coco".to_string(),
url: "https://coco.infini.cloud/sso/login/".to_string(),
},
},
priority: 0,
stats: None,
}
}
pub async fn load_servers_token(app_handle: &AppHandle) -> Result<Vec<ServerAccessToken>, String> {
log::debug!("Attempting to load servers token");
pub async fn load_servers_token<R: Runtime>(
app_handle: &AppHandle<R>,
) -> Result<Vec<ServerAccessToken>, String> {
dbg!("Attempting to load servers token");
let store = app_handle
.store(COCO_TAURI_STORE)
@@ -177,46 +172,33 @@ pub async fn load_servers_token(app_handle: &AppHandle) -> Result<Vec<ServerAcce
servers.ok_or_else(|| "Failed to read servers from store: No servers found".to_string())?;
// Convert each item in the JsonValue array to a Server
match servers {
JsonValue::Array(servers_array) => {
let mut deserialized_tokens: Vec<ServerAccessToken> =
Vec::with_capacity(servers_array.len());
for server_json in servers_array {
match from_value(server_json.clone()) {
Ok(token) => {
deserialized_tokens.push(token);
}
Err(e) => {
panic!(
"failed to deserialize JSON [{}] to [struct ServerAccessToken], error [{}], store [{}] key [{}] is possibly corrupted!",
server_json, e, COCO_TAURI_STORE, COCO_SERVER_TOKENS
);
}
}
}
if let JsonValue::Array(servers_array) = servers {
// Deserialize each JsonValue into Server, filtering out any errors
let deserialized_tokens: Vec<ServerAccessToken> = servers_array
.into_iter()
.filter_map(|server_json| from_value(server_json).ok()) // Only keep valid Server instances
.collect();
if deserialized_tokens.is_empty() {
return Err("Failed to deserialize any servers from the store.".to_string());
}
for server in deserialized_tokens.iter() {
save_access_token(server.id.clone(), server.clone()).await;
}
log::debug!("loaded {:?} servers's token", &deserialized_tokens.len());
Ok(deserialized_tokens)
if deserialized_tokens.is_empty() {
return Err("Failed to deserialize any servers from the store.".to_string());
}
_ => {
unreachable!(
"coco server tokens should be stored in an array under store [{}] key [{}], but it is not",
COCO_TAURI_STORE, COCO_SERVER_TOKENS
);
for server in deserialized_tokens.iter() {
save_access_token(server.id.clone(), server.clone());
}
dbg!(format!(
"loaded {:?} servers's token",
&deserialized_tokens.len()
));
Ok(deserialized_tokens)
} else {
Err("Failed to read servers from store: Invalid format".to_string())
}
}
pub async fn load_servers(app_handle: &AppHandle) -> Result<Vec<Server>, String> {
pub async fn load_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<Vec<Server>, String> {
let store = app_handle
.store(COCO_TAURI_STORE)
.expect("create or load a store should not fail");
@@ -234,89 +216,91 @@ pub async fn load_servers(app_handle: &AppHandle) -> Result<Vec<Server>, String>
servers.ok_or_else(|| "Failed to read servers from store: No servers found".to_string())?;
// Convert each item in the JsonValue array to a Server
match servers {
JsonValue::Array(servers_array) => {
let mut deserialized_servers = Vec::with_capacity(servers_array.len());
for server_json in servers_array {
match from_value(server_json.clone()) {
Ok(server) => {
deserialized_servers.push(server);
}
Err(e) => {
panic!(
"failed to deserialize JSON [{}] to [struct Server], error [{}], store [{}] key [{}] is possibly corrupted!",
server_json, e, COCO_TAURI_STORE, COCO_SERVERS
);
}
}
}
if let JsonValue::Array(servers_array) = servers {
// Deserialize each JsonValue into Server, filtering out any errors
let deserialized_servers: Vec<Server> = servers_array
.into_iter()
.filter_map(|server_json| from_value(server_json).ok()) // Only keep valid Server instances
.collect();
if deserialized_servers.is_empty() {
return Err("Failed to deserialize any servers from the store.".to_string());
}
for server in deserialized_servers.iter() {
save_server(&server).await;
}
log::debug!("load servers: {:?}", &deserialized_servers);
Ok(deserialized_servers)
if deserialized_servers.is_empty() {
return Err("Failed to deserialize any servers from the store.".to_string());
}
_ => {
unreachable!(
"coco servers should be stored in an array under store [{}] key [{}], but it is not",
COCO_TAURI_STORE, COCO_SERVERS
);
for server in deserialized_servers.iter() {
save_server(&server);
}
// dbg!(format!("load servers: {:?}", &deserialized_servers));
Ok(deserialized_servers)
} else {
Err("Failed to read servers from store: Invalid format".to_string())
}
}
/// Function to load servers or insert a default one if none exist
pub async fn load_or_insert_default_server(app_handle: &AppHandle) -> Result<Vec<Server>, String> {
log::debug!("Attempting to load or insert default server");
pub async fn load_or_insert_default_server<R: Runtime>(
app_handle: &AppHandle<R>,
) -> Result<Vec<Server>, String> {
dbg!("Attempting to load or insert default server");
let exists_servers = load_servers(&app_handle).await;
if exists_servers.is_ok() && !exists_servers.as_ref()?.is_empty() {
log::debug!("loaded {} servers", &exists_servers.clone()?.len());
dbg!(format!("loaded {} servers", &exists_servers.clone()?.len()));
return exists_servers;
}
let default = get_default_server();
save_server(&default).await;
save_server(&default);
log::debug!("loaded default servers");
dbg!("loaded default servers");
Ok(vec![default])
}
#[tauri::command]
pub async fn list_coco_servers(app_handle: AppHandle) -> Result<Vec<Server>, String> {
pub async fn list_coco_servers<R: Runtime>(
_app_handle: AppHandle<R>,
) -> Result<Vec<Server>, String> {
//hard fresh all server's info, in order to get the actual health
refresh_all_coco_server_info(app_handle.clone()).await;
let servers: Vec<Server> = get_all_servers().await;
refresh_all_coco_server_info(_app_handle.clone()).await;
let servers: Vec<Server> = get_all_servers();
Ok(servers)
}
pub async fn get_all_servers() -> Vec<Server> {
let cache = SERVER_LIST_CACHE.read().await;
#[allow(dead_code)]
pub fn get_servers_as_hashmap() -> HashMap<String, Server> {
let cache = SERVER_CACHE.read().unwrap();
cache.clone()
}
pub fn get_all_servers() -> Vec<Server> {
let cache = SERVER_CACHE.read().unwrap();
cache.values().cloned().collect()
}
pub async fn refresh_all_coco_server_info(app_handle: AppHandle) {
let servers = get_all_servers().await;
/// We store added Coco servers in the Tauri store using this key.
pub const COCO_SERVERS: &str = "coco_servers";
const COCO_SERVER_TOKENS: &str = "coco_server_tokens";
pub async fn refresh_all_coco_server_info<R: Runtime>(app_handle: AppHandle<R>) {
let servers = get_all_servers();
for server in servers {
let _ = refresh_coco_server_info(app_handle.clone(), server.id.clone()).await;
}
}
#[tauri::command]
pub async fn refresh_coco_server_info(app_handle: AppHandle, id: String) -> Result<Server, String> {
pub async fn refresh_coco_server_info<R: Runtime>(
app_handle: AppHandle<R>,
id: String,
) -> Result<Server, String> {
// Retrieve the server from the cache
let cached_server = {
let cache = SERVER_LIST_CACHE.read().await;
let cache = SERVER_CACHE.read().unwrap();
cache.get(&id).cloned()
};
@@ -331,16 +315,12 @@ pub async fn refresh_coco_server_info(app_handle: AppHandle, id: String) -> Resu
let profile = server.profile;
// Send request to fetch updated server info
let response = match HttpClient::get(&id, "/provider/_info", None).await {
Ok(response) => response,
Err(e) => {
mark_server_as_offline(app_handle, &id).await;
return Err(e);
}
};
let response = HttpClient::get(&id, "/provider/_info", None)
.await
.map_err(|e| format!("Failed to contact the server: {}", e))?;
if !response.status().is_success() {
mark_server_as_offline(app_handle, &id).await;
mark_server_as_offline(&id).await;
return Err(format!("Request failed with status: {}", response.status()));
}
@@ -355,22 +335,12 @@ pub async fn refresh_coco_server_info(app_handle: AppHandle, id: String) -> Resu
updated_server.id = id.clone();
updated_server.builtin = is_builtin;
updated_server.enabled = is_enabled;
updated_server.available = {
if server.public {
// Public Coco servers are available as long as they are online.
true
} else {
// For non-public Coco servers, we still need to check if it is
// logged in, i.e., has a token stored in `SERVER_TOKEN_LIST_CACHE`.
get_server_token(&id).await.is_some()
}
};
updated_server.available = true;
updated_server.profile = profile;
trim_endpoint_last_forward_slash(&mut updated_server);
// Save and persist
save_server(&updated_server).await;
try_register_server_to_search_source(app_handle.clone(), &updated_server).await;
save_server(&updated_server);
persist_servers(&app_handle)
.await
.map_err(|e| format!("Failed to persist servers: {}", e))?;
@@ -383,18 +353,21 @@ pub async fn refresh_coco_server_info(app_handle: AppHandle, id: String) -> Resu
}
#[tauri::command]
pub async fn add_coco_server(app_handle: AppHandle, endpoint: String) -> Result<Server, String> {
pub async fn add_coco_server<R: Runtime>(
app_handle: AppHandle<R>,
endpoint: String,
) -> Result<Server, String> {
load_or_insert_default_server(&app_handle)
.await
.map_err(|e| format!("Failed to load default servers: {}", e))?;
let endpoint = endpoint.trim_end_matches('/');
if check_endpoint_exists(endpoint).await {
log::debug!(
"trying to register a Coco server [{}] that has already been registered",
endpoint
);
if check_endpoint_exists(endpoint) {
dbg!(format!(
"This Coco server has already been registered: {:?}",
&endpoint
));
return Err("This Coco server has already been registered.".into());
}
@@ -403,16 +376,7 @@ pub async fn add_coco_server(app_handle: AppHandle, endpoint: String) -> Result<
.await
.map_err(|e| format!("Failed to send request to the server: {}", e))?;
log::debug!("Get provider info response: {:?}", &response);
if response.status() != StatusCode::OK {
log::debug!(
"trying to register a Coco server [{}] that is possibly down",
endpoint
);
return Err("This Coco server is possibly down".into());
}
dbg!(format!("Get provider info response: {:?}", &response));
let body = get_response_body_text(response).await?;
@@ -421,255 +385,158 @@ pub async fn add_coco_server(app_handle: AppHandle, endpoint: String) -> Result<
trim_endpoint_last_forward_slash(&mut server);
// The JSON returned from `provider/_info` won't have this field, serde will set
// it to an empty string during deserialization, we need to set a valid value here.
if server.id.is_empty() {
server.id = pizza_common::utils::uuid::Uuid::new().to_string();
}
// Use the default name, if it is not set.
if server.name.is_empty() {
server.name = "Coco Server".to_string();
}
// Update the `available` field
if server.public {
// Serde already sets this to true, but just to make the code clear, do it again.
server.available = true;
} else {
let opt_token = get_server_token(&server.id).await;
assert!(
opt_token.is_none(),
"this Coco server is newly-added, we should have no token stored for it!"
);
// This is a non-public Coco server, and it is not logged in, so it is unavailable.
server.available = false;
}
save_server(&server).await;
save_server(&server);
try_register_server_to_search_source(app_handle.clone(), &server).await;
persist_servers(&app_handle)
.await
.map_err(|e| format!("Failed to persist Coco servers: {}", e))?;
log::debug!("Successfully registered server: {:?}", &endpoint);
dbg!(format!("Successfully registered server: {:?}", &endpoint));
Ok(server)
}
#[tauri::command]
#[function_name::named]
pub async fn remove_coco_server(app_handle: AppHandle, id: String) -> Result<(), ()> {
pub async fn remove_coco_server<R: Runtime>(
app_handle: AppHandle<R>,
id: String,
) -> Result<(), ()> {
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(id.as_str()).await;
let opt_server = remove_server_by_id(id.as_str()).await;
let Some(server) = opt_server else {
panic!(
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
function_name!(),
id
);
};
remove_server_token(id.as_str());
remove_server_by_id(id);
persist_servers(&app_handle)
.await
.expect("failed to save servers");
persist_servers_token(&app_handle).expect("failed to save server tokens");
Ok(())
}
// Only non-public Coco servers require tokens
if !server.public {
// If is logged in, clear the token as well.
let deleted = remove_server_token(id.as_str()).await;
if deleted {
persist_servers_token(&app_handle)
.await
.expect("failed to save server tokens");
}
#[tauri::command]
pub async fn enable_server<R: Runtime>(app_handle: AppHandle<R>, id: String) -> Result<(), ()> {
println!("enable_server: {}", id);
let server = get_server_by_id(id.as_str());
if let Some(mut server) = server {
server.enabled = true;
save_server(&server);
// Register the server to the search source
try_register_server_to_search_source(app_handle.clone(), &server).await;
persist_servers(&app_handle)
.await
.expect("failed to save servers");
}
Ok(())
}
#[tauri::command]
#[function_name::named]
pub async fn enable_server(app_handle: AppHandle, id: String) -> Result<(), ()> {
let opt_server = get_server_by_id(id.as_str()).await;
let Some(mut server) = opt_server else {
panic!(
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
function_name!(),
id
);
};
server.enabled = true;
save_server(&server).await;
// Register the server to the search source
try_register_server_to_search_source(app_handle.clone(), &server).await;
persist_servers(&app_handle)
.await
.expect("failed to save servers");
Ok(())
}
#[tauri::command]
#[function_name::named]
pub async fn disable_server(app_handle: AppHandle, id: String) -> Result<(), ()> {
let opt_server = get_server_by_id(id.as_str()).await;
let Some(mut server) = opt_server else {
panic!(
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
function_name!(),
id
);
};
server.enabled = false;
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(id.as_str()).await;
save_server(&server).await;
persist_servers(&app_handle)
.await
.expect("failed to save servers");
Ok(())
}
/// For non-public Coco servers, we add it to the search source as long as it is
/// enabled.
///
/// For public Coco server, an extra token is required.
pub async fn try_register_server_to_search_source(app_handle: AppHandle, server: &Server) {
pub async fn try_register_server_to_search_source(
app_handle: AppHandle<impl Runtime>,
server: &Server,
) {
if server.enabled {
log::trace!(
"Server [name: {}, id: {}] is public: {} and available: {}",
&server.name,
&server.id,
&server.public,
&server.available
);
if !server.public {
let opt_token = get_server_token(&server.id).await;
if opt_token.is_none() {
log::debug!("Server {} is not public and no token was found", &server.id);
return;
}
}
let registry = app_handle.state::<SearchSourceRegistry>();
let source = CocoSearchSource::new(server.clone());
registry.register_source(source).await;
}
}
#[function_name::named]
#[allow(unused)]
async fn mark_server_as_online(app_handle: AppHandle, id: &str) {
let server = get_server_by_id(id).await;
if let Some(mut server) = server {
server.available = true;
server.health = None;
save_server(&server).await;
try_register_server_to_search_source(app_handle.clone(), &server).await;
} else {
log::warn!(
"[{}()] invoked with a server [{}] that does not exist!",
function_name!(),
id
);
}
}
#[function_name::named]
pub(crate) async fn mark_server_as_offline(app_handle: AppHandle, id: &str) {
let server = get_server_by_id(id).await;
pub async fn mark_server_as_offline(id: &str) {
// println!("server_is_offline: {}", id);
let server = get_server_by_id(id);
if let Some(mut server) = server {
server.available = false;
server.health = None;
save_server(&server).await;
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(id).await;
} else {
log::warn!(
"[{}()] invoked with a server [{}] that does not exist!",
function_name!(),
id
);
save_server(&server);
}
}
#[tauri::command]
#[function_name::named]
pub async fn logout_coco_server(app_handle: AppHandle, id: String) -> Result<(), String> {
log::debug!("Attempting to log out server by id: {}", &id);
pub async fn disable_server<R: Runtime>(app_handle: AppHandle<R>, id: String) -> Result<(), ()> {
println!("disable_server: {}", id);
// Check if the server exists
let Some(mut server) = get_server_by_id(id.as_str()).await else {
panic!(
"[{}()] invoked with a server [{}] that does not exist! Mismatched states between frontend and backend!",
function_name!(),
id
);
};
let server = get_server_by_id(id.as_str());
if let Some(mut server) = server {
server.enabled = false;
// Clear server profile
server.profile = None;
// Logging out from a non-public Coco server makes it unavailable
if !server.public {
server.available = false;
}
// Save the updated server data
save_server(&server).await;
// Persist the updated server data
if let Err(e) = persist_servers(&app_handle).await {
log::debug!("Failed to save server for id: {}. Error: {:?}", &id, &e);
return Err(format!("Failed to save server: {}", &e));
}
let has_token = get_server_token(id.as_str()).await.is_some();
if server.public {
if has_token {
panic!("Public Coco server won't have token")
}
} else {
assert!(
has_token,
"This is a non-public Coco server, and it is logged in, we should have a token"
);
// Remove the server token from cache
remove_server_token(id.as_str()).await;
// Persist the updated tokens
if let Err(e) = persist_servers_token(&app_handle).await {
log::debug!("Failed to save tokens for id: {}. Error: {:?}", &id, &e);
return Err(format!("Failed to save tokens: {}", &e));
}
}
// Remove it from the search source if it becomes unavailable
if !server.available {
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(id.as_str()).await;
}
log::debug!("Successfully logged out server with id: {}", &id);
save_server(&server);
persist_servers(&app_handle)
.await
.expect("failed to save servers");
}
Ok(())
}
/// Helper function to remove the trailing slash from the server's endpoint if present.
#[tauri::command]
pub async fn logout_coco_server<R: Runtime>(
app_handle: AppHandle<R>,
id: String,
) -> Result<(), String> {
dbg!("Attempting to log out server by id:", &id);
// Check if server token exists
if let Some(_token) = get_server_token(id.as_str()).await? {
dbg!("Found server token for id:", &id);
// Remove the server token from cache
remove_server_token(id.as_str());
// Persist the updated tokens
if let Err(e) = persist_servers_token(&app_handle) {
dbg!("Failed to save tokens for id: {}. Error: {:?}", &id, &e);
return Err(format!("Failed to save tokens: {}", &e));
}
} else {
// Log the case where server token is not found
dbg!("No server token found for id: {}", &id);
}
// Check if the server exists
if let Some(mut server) = get_server_by_id(id.as_str()) {
dbg!("Found server for id:", &id);
// Clear server profile
server.profile = None;
// Save the updated server data
save_server(&server);
// Persist the updated server data
if let Err(e) = persist_servers(&app_handle).await {
dbg!("Failed to save server for id: {}. Error: {:?}", &id, &e);
return Err(format!("Failed to save server: {}", &e));
}
} else {
// Log the case where server is not found
dbg!("No server found for id: {}", &id);
return Err(format!("No server found for id: {}", id));
}
dbg!("Successfully logged out server with id:", &id);
Ok(())
}
/// Removes the trailing slash from the server's endpoint if present.
fn trim_endpoint_last_forward_slash(server: &mut Server) {
let endpoint = &mut server.endpoint;
while endpoint.ends_with('/') {
endpoint.pop();
if server.endpoint.ends_with('/') {
server.endpoint.pop(); // Remove the last character
while server.endpoint.ends_with('/') {
server.endpoint.pop();
}
}
}
@@ -678,47 +545,41 @@ fn provider_info_url(endpoint: &str) -> String {
format!("{endpoint}/provider/_info")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trim_endpoint_last_forward_slash() {
let mut server = Server {
id: "test".to_string(),
builtin: false,
enabled: true,
#[test]
fn test_trim_endpoint_last_forward_slash() {
let mut server = Server {
id: "test".to_string(),
builtin: false,
enabled: true,
name: "".to_string(),
endpoint: "https://example.com///".to_string(),
provider: Provider {
name: "".to_string(),
endpoint: "https://example.com///".to_string(),
provider: Provider {
name: "".to_string(),
icon: "".to_string(),
website: "".to_string(),
eula: "".to_string(),
privacy_policy: "".to_string(),
banner: "".to_string(),
description: "".to_string(),
icon: "".to_string(),
website: "".to_string(),
eula: "".to_string(),
privacy_policy: "".to_string(),
banner: "".to_string(),
description: "".to_string(),
},
version: Version {
number: "".to_string(),
},
minimal_client_version: None,
updated: "".to_string(),
public: false,
available: false,
health: None,
profile: None,
auth_provider: AuthProvider {
sso: Sso {
url: "".to_string(),
},
version: Version {
number: "".to_string(),
},
minimal_client_version: None,
updated: "".to_string(),
public: false,
available: false,
health: None,
profile: None,
auth_provider: AuthProvider {
sso: Sso {
url: "".to_string(),
},
},
priority: 0,
stats: None,
};
},
priority: 0,
};
trim_endpoint_last_forward_slash(&mut server);
trim_endpoint_last_forward_slash(&mut server);
assert_eq!(server.endpoint, "https://example.com");
}
assert_eq!(server.endpoint, "https://example.com");
}

View File

@@ -1,57 +0,0 @@
use crate::server::http_client::HttpClient;
use futures_util::StreamExt;
use http::Method;
use serde_json::json;
use tauri::{AppHandle, Emitter, command};
#[command]
pub async fn synthesize(
app_handle: AppHandle,
client_id: String,
server_id: String,
voice: String,
content: String,
) -> Result<(), String> {
let body = json!({
"voice": voice,
"content": content,
})
.to_string();
let response = HttpClient::send_request(
server_id.as_str(),
Method::POST,
"/services/audio/synthesize",
None,
None,
Some(reqwest::Body::from(body.to_string())),
)
.await?;
log::info!("Synthesize response status: {}", response.status());
if response.status() == 429 {
return Ok(());
}
if !response.status().is_success() {
return Err(format!("Request Failed: {}", response.status()));
}
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
if let Err(err) = app_handle.emit(&client_id, bytes.to_vec()) {
log::error!("Emit error: {:?}", err);
}
}
Err(e) => {
log::error!("Stream error: {:?}", e);
break;
}
}
}
Ok(())
}

View File

@@ -1,96 +1,43 @@
use crate::common::http::get_response_body_text;
use crate::server::http_client::HttpClient;
use serde::{Deserialize, Serialize};
use serde_json::{Value, from_str};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use tauri::command;
#[derive(Debug, Serialize, Deserialize)]
pub struct TranscriptionResponse {
task_id: String,
results: Vec<Value>,
pub text: String,
}
#[command]
pub async fn transcription(
server_id: String,
audio_type: String,
audio_content: String,
) -> Result<TranscriptionResponse, String> {
// Send request to initiate transcription task
let init_response = HttpClient::post(
let mut query_params = HashMap::new();
query_params.insert("type".to_string(), JsonValue::String(audio_type));
query_params.insert("content".to_string(), JsonValue::String(audio_content));
// Send the HTTP POST request
let response = HttpClient::post(
&server_id,
"/services/audio/transcription",
Some(query_params),
None,
Some(audio_content.into()),
)
.await
.map_err(|e| format!("Failed to initiate transcription: {}", e))?;
// Extract response body as text
let init_response_text = get_response_body_text(init_response)
.await
.map_err(|e| format!("Failed to read initial response body: {}", e))?;
.map_err(|e| format!("Error sending transcription request: {}", e))?;
// Parse response JSON to extract task ID
let init_response_json: Value = from_str(&init_response_text).map_err(|e| {
format!(
"Failed to parse initial response JSON: {}. Raw response: {}",
e, init_response_text
)
})?;
let transcription_task_id = init_response_json["task_id"]
.as_str()
.ok_or_else(|| {
format!(
"Missing or invalid task_id in initial response: {}",
init_response_text
)
})?
.to_string();
// Set up polling with timeout
let polling_start = std::time::Instant::now();
const POLLING_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const POLLING_INTERVAL: std::time::Duration = std::time::Duration::from_millis(200);
let mut transcription_response: TranscriptionResponse;
loop {
// Poll for transcription results
let poll_response = HttpClient::get(
&server_id,
&format!("/services/audio/task/{}", transcription_task_id),
None,
)
// Use get_response_body_text to extract the response body as text
let response_body = get_response_body_text(response)
.await
.map_err(|e| format!("Failed to poll transcription task: {}", e))?;
.map_err(|e| format!("Failed to read response body: {}", e))?;
// Extract poll response body
let poll_response_text = get_response_body_text(poll_response)
.await
.map_err(|e| format!("Failed to read poll response body: {}", e))?;
// Parse poll response JSON
transcription_response = from_str(&poll_response_text).map_err(|e| {
format!(
"Failed to parse poll response JSON: {}. Raw response: {}",
e, poll_response_text
)
})?;
// Check if transcription results are available
if !transcription_response.results.is_empty() {
break;
}
// Check for timeout
if polling_start.elapsed() >= POLLING_TIMEOUT {
return Err("Transcription task timed out after 30 seconds".to_string());
}
// Wait before next poll
tokio::time::sleep(POLLING_INTERVAL).await;
}
// Deserialize the response body into TranscriptionResponse
let transcription_response: TranscriptionResponse = serde_json::from_str(&response_body)
.map_err(|e| format!("Failed to parse transcription response: {}", e))?;
Ok(transcription_response)
}

View File

@@ -0,0 +1,168 @@
use crate::server::servers::{get_server_by_id, get_server_token};
use futures::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tauri::{AppHandle, Emitter, Runtime};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::handshake::client::generate_key;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::{connect_async_tls_with_config, Connector};
#[derive(Default)]
pub struct WebSocketManager {
connections: Arc<Mutex<HashMap<String, Arc<WebSocketInstance>>>>,
}
struct WebSocketInstance {
ws_connection: Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>, // No need to lock the entire map
cancel_tx: mpsc::Sender<()>,
}
fn convert_to_websocket(endpoint: &str) -> Result<String, String> {
let url = url::Url::parse(endpoint).map_err(|e| format!("Invalid URL: {}", e))?;
let ws_protocol = if url.scheme() == "https" {
"wss://"
} else {
"ws://"
};
let host = url.host_str().ok_or("No host found in URL")?;
let port = url
.port_or_known_default()
.unwrap_or(if url.scheme() == "https" { 443 } else { 80 });
let ws_endpoint = if port == 80 || port == 443 {
format!("{}{}{}", ws_protocol, host, "/ws")
} else {
format!("{}{}:{}/ws", ws_protocol, host, port)
};
Ok(ws_endpoint)
}
#[tauri::command]
pub async fn connect_to_server<R: Runtime>(
tauri_app_handle: AppHandle<R>,
id: String,
client_id: String,
state: tauri::State<'_, WebSocketManager>,
app_handle: AppHandle,
) -> Result<(), String> {
let connections_clone = state.connections.clone();
// Disconnect old connection first
disconnect(client_id.clone(), state.clone()).await.ok();
let server = get_server_by_id(&id).ok_or(format!("Server with ID {} not found", id))?;
let endpoint = convert_to_websocket(&server.endpoint)?;
let token = get_server_token(&id).await?.map(|t| t.access_token.clone());
let mut request =
tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(&endpoint)
.map_err(|e| format!("Failed to create WebSocket request: {}", e))?;
request
.headers_mut()
.insert("Connection", "Upgrade".parse().unwrap());
request
.headers_mut()
.insert("Upgrade", "websocket".parse().unwrap());
request
.headers_mut()
.insert("Sec-WebSocket-Version", "13".parse().unwrap());
request
.headers_mut()
.insert("Sec-WebSocket-Key", generate_key().parse().unwrap());
if let Some(token) = token {
request
.headers_mut()
.insert("X-API-TOKEN", token.parse().unwrap());
}
let allow_self_signature =
crate::settings::get_allow_self_signature(tauri_app_handle.clone()).await;
let tls_connector = tokio_native_tls::native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(allow_self_signature)
.build()
.map_err(|e| format!("TLS build error: {:?}", e))?;
let connector = Connector::NativeTls(tls_connector.into());
let (ws_stream, _) = connect_async_tls_with_config(
request,
None, // WebSocketConfig
true, // disable_nagle
Some(connector), // Connector
)
.await
.map_err(|e| format!("WebSocket TLS error: {:?}", e))?;
let (cancel_tx, mut cancel_rx) = mpsc::channel(1);
let instance = Arc::new(WebSocketInstance {
ws_connection: Mutex::new(ws_stream),
cancel_tx,
});
// Insert connection into the map (lock is held briefly)
{
let mut connections = connections_clone.lock().await;
connections.insert(client_id.clone(), instance.clone());
}
// Spawn WebSocket handler in a separate task
let app_handle_clone = app_handle.clone();
let client_id_clone = client_id.clone();
tokio::spawn(async move {
let ws = &mut *instance.ws_connection.lock().await;
loop {
tokio::select! {
msg = ws.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
let _ = app_handle_clone.emit(&format!("ws-message-{}", client_id_clone), text);
},
Some(Err(_)) | None => {
let _ = app_handle_clone.emit(&format!("ws-error-{}", client_id_clone), id.clone());
break;
}
_ => {}
}
}
_ = cancel_rx.recv() => {
let _ = app_handle_clone.emit(&format!("ws-error-{}", client_id_clone), id.clone());
break;
}
}
}
// Remove connection after it closes
let mut connections = connections_clone.lock().await;
connections.remove(&client_id_clone);
});
Ok(())
}
#[tauri::command]
pub async fn disconnect(
client_id: String,
state: tauri::State<'_, WebSocketManager>,
) -> Result<(), String> {
let instance = {
let mut connections = state.connections.lock().await;
connections.remove(&client_id)
};
if let Some(instance) = instance {
let _ = instance.cancel_tx.send(()).await;
// Close WebSocket (lock only the connection, not the whole map)
let mut ws = instance.ws_connection.lock().await;
let _ = ws.close(None).await;
}
Ok(())
}

View File

@@ -1,13 +1,12 @@
use crate::COCO_TAURI_STORE;
use serde_json::Value as Json;
use tauri::AppHandle;
use tauri::{AppHandle, Runtime};
use tauri_plugin_store::StoreExt;
const SETTINGS_ALLOW_SELF_SIGNATURE: &str = "settings_allow_self_signature";
const LOCAL_QUERY_SOURCE_WEIGHT: &str = "local_query_source_weight";
#[tauri::command]
pub async fn set_allow_self_signature(tauri_app_handle: AppHandle, value: bool) {
pub async fn set_allow_self_signature<R: Runtime>(tauri_app_handle: AppHandle<R>, value: bool) {
use crate::server::http_client;
let store = tauri_app_handle
@@ -41,7 +40,7 @@ pub async fn set_allow_self_signature(tauri_app_handle: AppHandle, value: bool)
}
/// Synchronous version of `async get_allow_self_signature()`.
pub fn _get_allow_self_signature(tauri_app_handle: AppHandle) -> bool {
pub fn _get_allow_self_signature<R: Runtime>(tauri_app_handle: AppHandle<R>) -> bool {
let store = tauri_app_handle
.store(COCO_TAURI_STORE)
.unwrap_or_else(|e| {
@@ -68,48 +67,6 @@ pub fn _get_allow_self_signature(tauri_app_handle: AppHandle) -> bool {
}
#[tauri::command]
pub async fn get_allow_self_signature(tauri_app_handle: AppHandle) -> bool {
pub async fn get_allow_self_signature<R: Runtime>(tauri_app_handle: AppHandle<R>) -> bool {
_get_allow_self_signature(tauri_app_handle)
}
#[tauri::command]
pub async fn set_local_query_source_weight(tauri_app_handle: AppHandle, value: f64) {
let store = tauri_app_handle
.store(COCO_TAURI_STORE)
.unwrap_or_else(|e| {
panic!(
"store [{}] not found/loaded, error [{}]",
COCO_TAURI_STORE, e
)
});
store.set(LOCAL_QUERY_SOURCE_WEIGHT, value);
}
#[tauri::command]
pub fn get_local_query_source_weight(tauri_app_handle: AppHandle) -> f64 {
// default to 1.0
const DEFAULT: f64 = 1.0;
let store = tauri_app_handle
.store(COCO_TAURI_STORE)
.unwrap_or_else(|e| {
panic!(
"store [{}] not found/loaded, error [{}]",
COCO_TAURI_STORE, e
)
});
if !store.has(LOCAL_QUERY_SOURCE_WEIGHT) {
store.set(LOCAL_QUERY_SOURCE_WEIGHT, DEFAULT);
}
match store
.get(LOCAL_QUERY_SOURCE_WEIGHT)
.expect("should be Some")
{
Json::Number(n) => n
.as_f64()
.unwrap_or_else(|| panic!("setting [{}] should be a f64", LOCAL_QUERY_SOURCE_WEIGHT)),
_ => unreachable!("{} should be stored as a number", LOCAL_QUERY_SOURCE_WEIGHT),
}
}

View File

@@ -1,9 +1,3 @@
use tauri::{AppHandle, WebviewWindow};
use tauri::{App, WebviewWindow};
pub fn platform(
_tauri_app_handle: &AppHandle,
_main_window: WebviewWindow,
_settings_window: WebviewWindow,
_check_window: WebviewWindow,
) {
}
pub fn platform(_app: &mut App, _main_window: WebviewWindow, _settings_window: WebviewWindow) {}

Some files were not shown because too many files have changed in this diff Show More