2 Commits

Author SHA1 Message Date
ayang
9ee6b9a6c9 feat: add file upload failure handling and alert message 2025-05-16 14:32:11 +08:00
ayang
24b1758b11 refactor: enabling the InputExtra component 2025-05-15 15:50:03 +08:00
257 changed files with 6641 additions and 17631 deletions

View File

@@ -1,18 +0,0 @@
name: Enforce no dependency pizza-engine
on:
pull_request:
jobs:
main:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name:
working-directory: ./src-tauri
run: |
# if cargo remove pizza-engine succeeds, then it is in our dependency list, fail the CI pipeline.
if cargo remove pizza-engine; then exit 1; fi

View File

@@ -9,16 +9,10 @@ on:
jobs:
create-release:
runs-on: ubuntu-latest
outputs:
APP_VERSION: ${{ steps.get-version.outputs.APP_VERSION }}
RELEASE_BODY: ${{ steps.get-changelog.outputs.RELEASE_BODY }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set output
id: vars
run: echo "tag=${GITHUB_REF#refs/*/}" >> $GITHUB_OUTPUT
@@ -28,28 +22,11 @@ jobs:
with:
node-version: 20
- name: Get build version
shell: bash
id: get-version
run: |
PACKAGE_VERSION=$(jq -r '.version' package.json)
CARGO_VERSION=$(grep -m 1 '^version =' src-tauri/Cargo.toml | sed -E 's/.*"([^"]+)".*/\1/')
if [ "$PACKAGE_VERSION" != "$CARGO_VERSION" ]; then
echo "::error::Version mismatch!"
else
echo "Version match: $PACKAGE_VERSION"
fi
echo "APP_VERSION=$PACKAGE_VERSION" >> $GITHUB_OUTPUT
- name: Generate changelog
id: get-changelog
run: |
CHANGELOG_BODY=$(npx changelogithub --draft --name ${{ steps.vars.outputs.tag }})
echo "RELEASE_BODY<<EOF" >> $GITHUB_OUTPUT
echo "$CHANGELOG_BODY" >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
id: create_release
run: npx changelogithub --draft --name ${{ steps.vars.outputs.tag }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
build-app:
needs: create-release
@@ -75,23 +52,11 @@ jobs:
target: "x86_64-unknown-linux-gnu"
- platform: "ubuntu-22.04-arm"
target: "aarch64-unknown-linux-gnu"
env:
APP_VERSION: ${{ needs.create-release.outputs.APP_VERSION }}
runs-on: ${{ matrix.platform }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Checkout dependency repository
uses: actions/checkout@v4
with:
repository: 'infinilabs/pizza'
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
submodules: recursive
ref: main
path: pizza
- name: Setup node
uses: actions/setup-node@v4
with:
@@ -100,31 +65,17 @@ jobs:
with:
version: latest
- name: Install rust target
run: rustup target add ${{ matrix.target }}
- name: Install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-22.04')
run: |
sudo apt-get update
sudo apt-get install -y libwebkit2gtk-4.1-dev libappindicator3-dev librsvg2-dev patchelf xdg-utils
- name: Add Rust build target
working-directory: src-tauri
shell: bash
run: |
rustup target add ${{ matrix.target }} || true
- name: Add pizza engine as a dependency
working-directory: src-tauri
shell: bash
run: |
BUILD_ARGS="--target ${{ matrix.target }}"
if [[ "${{matrix.target }}" != "i686-pc-windows-msvc" ]]; then
echo "Adding pizza engine as a dependency for ${{matrix.platform }}-${{matrix.target }}"
( cargo add --path ../pizza/lib/engine --features query_string_parser,persistence )
BUILD_ARGS+=" --features use_pizza_engine"
else
echo "Skipping pizza engine dependency for ${{matrix.platform }}-${{matrix.target }}"
fi
echo "BUILD_ARGS=${BUILD_ARGS}" >> $GITHUB_ENV
- name: Install Rust stable
run: rustup toolchain install stable
- name: Rust cache
uses: swatinem/rust-cache@v2
@@ -139,8 +90,8 @@ jobs:
- name: Install app dependencies and build web
run: pnpm install --frozen-lockfile
- name: Build the coco at ${{ matrix.platform}} for ${{ matrix.target }} @ ${{ env.APP_VERSION }}
- name: Build the app
uses: tauri-apps/tauri-action@v0
env:
CI: false
@@ -156,8 +107,8 @@ jobs:
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
with:
tagName: ${{ github.ref_name }}
releaseName: Coco ${{ env.APP_VERSION }}
releaseBody: "${{ needs.create-release.outputs.RELEASE_BODY }}"
releaseName: Coco ${{ needs.create-release.outputs.APP_VERSION }}
releaseBody: ""
releaseDraft: true
prerelease: false
args: ${{ env.BUILD_ARGS }}
args: --target ${{ matrix.target }}

View File

@@ -1,54 +0,0 @@
name: Rust Code Compile Check
on:
pull_request:
# Only run it when Rust code changes
paths:
- 'src-tauri/**'
jobs:
compile-check:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Checkout dependency (pizza-engine) repository
uses: actions/checkout@v4
with:
repository: 'infinilabs/pizza'
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
submodules: recursive
ref: main
path: pizza
- name: Install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-latest')
run: |
sudo apt-get update
sudo apt-get install -y libwebkit2gtk-4.1-dev libappindicator3-dev librsvg2-dev patchelf xdg-utils
- name: Add pizza engine as a dependency
working-directory: src-tauri
shell: bash
run: cargo add --path ../pizza/lib/engine --features query_string_parser,persistence
- name: Check compilation (Without Pizza engine enabled)
working-directory: ./src-tauri
run: cargo check
- name: Check compilation (With Pizza engine enabled)
working-directory: ./src-tauri
run: cargo check --features use_pizza_engine
- name: Run tests (Without Pizza engine)
working-directory: ./src-tauri
run: cargo test
- name: Run tests (With Pizza engine)
working-directory: ./src-tauri
run: cargo test --features use_pizza_engine

View File

@@ -8,14 +8,11 @@
"clsx",
"codegen",
"dataurl",
"deeplink",
"deepthink",
"dtolnay",
"dyld",
"elif",
"errmsg",
"fullscreen",
"fulltext",
"headlessui",
"Icdbb",
"icns",
@@ -32,8 +29,6 @@
"localstorage",
"lucide",
"maximizable",
"mdast",
"meval",
"Minimizable",
"msvc",
"nord",
@@ -43,11 +38,9 @@
"overscan",
"partialize",
"patchelf",
"Quicklink",
"Raycast",
"rehype",
"reqwest",
"rerank",
"rgba",
"rustup",
"screenshotable",

View File

@@ -78,8 +78,4 @@ clean-rebuild:
$(MAKE) dev-build
add-dep-pizza-engine:
cd src-tauri && cargo add --git ssh://git@github.com/infinilabs/pizza.git pizza-engine --features query_string_parser,persistence
dev-build-with-pizza: add-dep-pizza-engine
@echo "Starting desktop development with Pizza Engine pulled in..."
RUST_BACKTRACE=1 pnpm tauri dev --features use_pizza_engine
cd src-tauri && cargo add --git ssh://git@github.com/infinilabs/pizza.git pizza-engine --features query_string_parser,persistence

View File

@@ -91,8 +91,6 @@ pnpm tauri build
- [Coco App Documentation](https://docs.infinilabs.com/coco-app/main/)
- [Coco Server Documentation](https://docs.infinilabs.com/coco-server/main/)
- [DeepWiki Coco App](https://deepwiki.com/infinilabs/coco-app)
- [DeepWiki Coco Server](https://deepwiki.com/infinilabs/coco-server)
- [Tauri Documentation](https://tauri.app/)
## Contributors

View File

@@ -1,35 +1,21 @@
---
weight: 10
title: "macOS"
title: "Mac OS"
asciinema: true
---
# macOS
# Mac OS
## Download Coco AI
Go to [coco.rs](https://coco.rs/) and download the package of your architecture:
Goto [https://coco.rs/](https://coco.rs/)
{{% load-img "/img/macos/mac-download-app.png" "" %}}
It should be placed in your "Downloads" folder:
{{% load-img "/img/macos/mac-zip-file.png" "" %}}
{{% load-img "/img/download-mac-app.png" "" %}}
## Unzip DMG file
Unzip the file:
{{% load-img "/img/macos/mac-unzip-zip-file.png" "" %}}
You will get a `dmg` file:
{{% load-img "/img/macos/mac-dmg.png" "" %}}
{{% load-img "/img/unzip-dmg-file.png" "" %}}
## Drag to Application Folder
Double click the `dmg` file, a window will pop up. Then drag the "Coco-AI" app to
your "Applications" folder:
{{% load-img "/img/macos/drag-to-app-folder.png" "" %}}
{{% load-img "/img/drag-to-application-folder.png" "" %}}

View File

@@ -14,9 +14,7 @@ asciinema: true
[if_x11]: https://unix.stackexchange.com/q/202891/498440
## Go to the download page
Download page: [link](https://coco.rs/#install)
## Goto [https://coco.rs/](https://coco.rs/)
## Download the package

View File

@@ -13,79 +13,6 @@ Information about release notes of Coco Server is provided here.
### 🚀 Features
- feat: file search using spotlight #705
- feat: voice input support in both search and chat modes #732
- feat: text to speech now powered by LLM #750
- feat: file search for Windows #762
### 🐛 Bug fix
- fix(file search): apply filters before from/size parameters #741
- fix(file search): searching by name&content does not search file name #743
- fix: prevent window from hiding when moved on Windows #748
- fix: unregister ext hotkey when it gets deleted #770
- fix: indexing apps does not respect search scope config #773
- fix: restore missing category titles on subpages #772
- fix: correct incorrect assistant display when quick ai access #779
- fix: resolved minor issues with voice playback #780
- fix: fixed incorrect taskbar icon display on linux #783
- fix: fix data inconsistency issue on secondary pages #784
### ✈️ Improvements
- refactor: prioritize stat(2) when checking if a file is dir #737
- refactor: change File Search ext type to extension #738
- refactor: create chat & send chat api #739
- chore: icon support for more file types #740
- chore: replace meval-rs with our fork to clear dep warning #745
- refactor: adjusted assistant, datasource, mcp_server interface parameters #746
- refactor: adjust extension code hierarchy #747
- chore: bump dep applications-rs #751
- chore: rename QuickLink/quick_link to Quicklink/quicklink #752
- chore: assistant params & styles #753
- chore: make optional fields optional #758
- chore: search-chat components add formatUrl & think data & icons url #765
- chore: Coco app http request headers #744
- refactor: do status code check before deserializing response #767
- style: splash adapts to the width of mobile phones #768
- chore: search-chat add language and formatUrl parameters #775
## 0.6.0 (2025-06-29)
### ❌ Breaking changes
### 🚀 Features
- feat: support `Tab` and `Enter` for delete dialog buttons #700
- feat: add check for updates #701
- feat: impl extension store #699
- feat: support back navigation via delete key #717
### 🐛 Bug fix
- fix: quick ai state synchronous #693
- fix: toggle extension should register/unregister hotkey #691
- fix: take coco server back on refresh #696
- fix: some input fields couldnt accept spaces #709
- fix: context menu search not working #713
- fix: open extension store display #724
### ✈️ Improvements
- refactor: use author/ext_id as extension unique identifier #643
- refactor: refactoring search api #679
- chore: continue to chat page display #690
- chore: improve server list selection with enter key #692
- chore: add message for latest version check #703
- chore: log command execution results #718
- chore: adjust styles and add button reindex #719
## 0.5.0 (2025-06-13)
### ❌ Breaking changes
### 🚀 Features
- feat: check or enter to close the list of assistants #469
- feat: add dimness settings for pinned window #470
- feat: supports Shift + Enter input box line feeds #472
@@ -97,59 +24,17 @@ Information about release notes of Coco Server is provided here.
- feat: the search input box supports multi-line input #501
- feat: websocket support self-signed TLS #504
- feat: add option to allow self-signed certificates #509
- feat: add AI summary component #518
- feat: dynamic log level via env var COCO_LOG #535
- feat: add quick AI access to search mode #556
- feat: rerank search results #561
- feat: ai overview support is enabled with shortcut #597
- feat: add key monitoring during reset #615
- feat: calculator extension add description #623
- feat: support right-click actions after text selection #624
- feat: add ai overview minimum number of search results configuration #625
- feat: add internationalized translations of AI-related extensions #632
- feat: context menu support for secondary pages #680
### 🐛 Bug fix
- fix: solve the problem of modifying the assistant in the chat #476
- fix: several issues around search #502
- fix: fixed the newly created session has no title when it is deleted #511
- fix: loading chat history for potential empty attachments
- fix: datasource & MCP list synchronization update #521
- fix: app icon & category icon #529
- fix: show only enabled datasource & MCP list
- fix: server image loading failure #534
- fix: panic when fetching app metadata on Windows #538
- fix: service switching error #539
- fix: switch server assistant and session unchanged #540
- fix: history list height #550
- fix: secondary page cannot be searched #551
- fix: the scroll button is not displayed by default #552
- fix: suggestion list position #553
- fix: independent chat window has no data #554
- fix: resolved navigation error on continue chat action #558
- fix: make extension search source respect parameter datasource #576
- fix: fixed issue with incorrect login status #600
- fix: new chat assistant id not found #603
- fix: resolve regex error on older macOS versions #605
- fix: fix chat log update and sorting issues #612
- fix: resolved an issue where number keys were not working on the web #616
- fix: do not panic when the datasource specified does not exist #618
- fix: fixed modifier keys not working with continue chat #619
- fix: invalid DSL error if input contains multiple lines #620
- fix: fix ai overview hidden height before message #622
- fix: tab key hides window in chat mode #641
- fix: arrow keys still navigated search when menu opened with Cmd+K #642
- fix: input lost when reopening dialog after search #644
- fix: web page unmount event #645
- fix: fix the problem of local path not opening #650
- fix: number keys not following settings #661
- fix: fix problem with up and down key indexing #676
- fix: arrow inserting escape sequences #683
### ✈️ Improvements
- chore: adjust list error message #475
- fix: solve the problem of modifying the assistant in the chat #476
- chore: refine wording on search failure
- choresearch and MCP show hidden logic #494
- chore: greetings show hidden logic #496
@@ -160,32 +45,6 @@ Information about release notes of Coco Server is provided here.
- refactor: optimized the modification operation of the numeric input box #508
- style: modify the style of the search input box #513
- style: chat input icons show #515
- refactor: refactoring icon component #514
- refactor: optimizing list styles in markdown content #520
- feat: add a component for text reading aloud #522
- style: history component styles #528
- style: search error styles #533
- chore: skip register server that not logged in #536
- refactor: service info related components #537
- chore: chat content can be copied #539
- refactor: refactoring search error #541
- chore: add assistant count #542
- chore: add global login judgment #544
- chore: mark server offline on user logout #546
- chore: logout update server profile #549
- chore: assistant keyboard events and mouse events #559
- chore: web component start page config #560
- chore: assistant chat placeholder & refactor input box components #566
- refactor: input box related components #568
- chore: mark unavailable server to offline on refresh info #569
- chore: only show available servers in chat #570
- refactor: search result related components #571
- chore: initialize current assistant from history #606
- chore: add onContextMenu event #629
- chore: more logs for the setup process #634
- chore: copy supports http protocol #639
- refactor: use author/ext_id as extension unique identifier #643
- chore: add special character filtering #668
## 0.4.0 (2025-04-27)
@@ -215,8 +74,6 @@ Information about release notes of Coco Server is provided here.
- feat: data sources support displaying customized icons #432
- feat: add shortcut key conflict hint and reset function #442
- feat: updated to include error message #465
- feat: support third party extensions #572
- feat: support ai overview #572
### Bug fix

BIN
docs/static/img/download-mac-app.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 586 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 299 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 650 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 441 KiB

BIN
docs/static/img/unzip-dmg-file.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

View File

@@ -1,7 +1,7 @@
{
"name": "coco",
"private": true,
"version": "0.6.0",
"version": "0.4.0",
"type": "module",
"scripts": {
"dev": "vite",
@@ -18,6 +18,7 @@
"release-beta": "release-it --preRelease=beta --preReleaseBase=1"
},
"dependencies": {
"@ant-design/icons": "^6.0.0",
"@headlessui/react": "^2.2.2",
"@tauri-apps/api": "^2.5.0",
"@tauri-apps/plugin-autostart": "~2.2.0",
@@ -26,7 +27,6 @@
"@tauri-apps/plugin-global-shortcut": "~2.0.0",
"@tauri-apps/plugin-http": "~2.0.2",
"@tauri-apps/plugin-log": "~2.4.0",
"@tauri-apps/plugin-opener": "^2.2.7",
"@tauri-apps/plugin-os": "^2.2.1",
"@tauri-apps/plugin-process": "^2.2.1",
"@tauri-apps/plugin-shell": "^2.2.1",
@@ -44,7 +44,6 @@
"i18next-browser-languagedetector": "^8.1.0",
"lodash-es": "^4.17.21",
"lucide-react": "^0.461.0",
"mdast-util-gfm-autolink-literal": "2.0.0",
"mermaid": "^11.6.0",
"nanoid": "^5.1.5",
"react": "^18.3.1",
@@ -59,12 +58,10 @@
"remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.1",
"remark-math": "^6.0.0",
"tailwind-merge": "^3.3.1",
"tauri-plugin-fs-pro-api": "^2.4.0",
"tauri-plugin-macos-permissions-api": "^2.3.0",
"tauri-plugin-screenshots-api": "^2.2.0",
"tauri-plugin-windows-version-api": "^2.0.0",
"type-fest": "^4.41.0",
"use-debounce": "^10.0.4",
"uuid": "^11.1.0",
"wavesurfer.js": "^7.9.5",
@@ -92,6 +89,5 @@
"tsx": "^4.19.4",
"typescript": "^5.8.3",
"vite": "^5.4.19"
},
"packageManager": "pnpm@10.11.0+sha512.6540583f41cc5f628eb3d9773ecee802f4f9ef9923cc45b69890fb47991d4b092964694ec3a4f738a420c918a333062c8b925d312f42e4f0c263eb603551f977"
}
}
}

145
pnpm-lock.yaml generated
View File

@@ -8,6 +8,9 @@ importers:
.:
dependencies:
'@ant-design/icons':
specifier: ^6.0.0
version: 6.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@headlessui/react':
specifier: ^2.2.2
version: 2.2.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -32,9 +35,6 @@ importers:
'@tauri-apps/plugin-log':
specifier: ~2.4.0
version: 2.4.0
'@tauri-apps/plugin-opener':
specifier: ^2.2.7
version: 2.2.7
'@tauri-apps/plugin-os':
specifier: ^2.2.1
version: 2.2.1
@@ -86,9 +86,6 @@ importers:
lucide-react:
specifier: ^0.461.0
version: 0.461.0(react@18.3.1)
mdast-util-gfm-autolink-literal:
specifier: 2.0.0
version: 2.0.0
mermaid:
specifier: ^11.6.0
version: 11.6.0
@@ -131,9 +128,6 @@ importers:
remark-math:
specifier: ^6.0.0
version: 6.0.0
tailwind-merge:
specifier: ^3.3.1
version: 3.3.1
tauri-plugin-fs-pro-api:
specifier: ^2.4.0
version: 2.4.0
@@ -146,9 +140,6 @@ importers:
tauri-plugin-windows-version-api:
specifier: ^2.0.0
version: 2.0.0
type-fest:
specifier: ^4.41.0
version: 4.41.0
use-debounce:
specifier: ^10.0.4
version: 10.0.4(react@18.3.1)
@@ -191,7 +182,7 @@ importers:
version: 1.8.8
'@vitejs/plugin-react':
specifier: ^4.4.1
version: 4.4.1(vite@5.4.19(@types/node@22.15.17)(sass@1.87.0)(terser@5.40.0))
version: 4.4.1(vite@5.4.19(@types/node@22.15.17)(sass@1.87.0))
autoprefixer:
specifier: ^10.4.21
version: 10.4.21(postcss@8.5.3)
@@ -224,7 +215,7 @@ importers:
version: 5.8.3
vite:
specifier: ^5.4.19
version: 5.4.19(@types/node@22.15.17)(sass@1.87.0)(terser@5.40.0)
version: 5.4.19(@types/node@22.15.17)(sass@1.87.0)
packages:
@@ -236,6 +227,23 @@ packages:
resolution: {integrity: sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==}
engines: {node: '>=6.0.0'}
'@ant-design/colors@8.0.0':
resolution: {integrity: sha512-6YzkKCw30EI/E9kHOIXsQDHmMvTllT8STzjMb4K2qzit33RW2pqCJP0sk+hidBntXxE+Vz4n1+RvCTfBw6OErw==}
'@ant-design/fast-color@3.0.0':
resolution: {integrity: sha512-eqvpP7xEDm2S7dUzl5srEQCBTXZMmY3ekf97zI+M2DHOYyKdJGH0qua0JACHTqbkRnD/KHFQP9J1uMJ/XWVzzA==}
engines: {node: '>=8.x'}
'@ant-design/icons-svg@4.4.2':
resolution: {integrity: sha512-vHbT+zJEVzllwP+CM+ul7reTEfBR0vgxFe7+lREAsAA7YGsYpboiq2sQNeQeRvh09GfQgs/GyFEvZpJ9cLXpXA==}
'@ant-design/icons@6.0.0':
resolution: {integrity: sha512-o0aCCAlHc1o4CQcapAwWzHeaW2x9F49g7P3IDtvtNXgHowtRWYb7kiubt8sQPFvfVIVU/jLw2hzeSlNt0FU+Uw==}
engines: {node: '>=8'}
peerDependencies:
react: '>=16.0.0'
react-dom: '>=16.0.0'
'@antfu/install-pkg@1.1.0':
resolution: {integrity: sha512-MGQsmw10ZyI+EJo45CdSER4zEb+p31LpDAFp2Z3gkSd1yqVZGi0Ebx++YTEMonJy4oChEMLsxZ64j8FH6sSqtQ==}
@@ -805,9 +813,6 @@ packages:
resolution: {integrity: sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==}
engines: {node: '>=6.0.0'}
'@jridgewell/source-map@0.3.6':
resolution: {integrity: sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==}
'@jridgewell/sourcemap-codec@1.5.0':
resolution: {integrity: sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==}
@@ -985,6 +990,12 @@ packages:
resolution: {integrity: sha512-c83qWb22rNRuB0UaVCI0uRPNRr8Z0FWnEIvT47jiHAmOIUHbBOg5XvV7pM5x+rKn9HRpjxquDbXYSXr3fAKFcw==}
engines: {node: '>=12'}
'@rc-component/util@1.2.1':
resolution: {integrity: sha512-AUVu6jO+lWjQnUOOECwu8iR0EdElQgWW5NBv5vP/Uf9dWbAX3udhMutRlkVXjuac2E40ghkFy+ve00mc/3Fymg==}
peerDependencies:
react: '>=18.0.0'
react-dom: '>=18.0.0'
'@react-aria/focus@3.20.2':
resolution: {integrity: sha512-Q3rouk/rzoF/3TuH6FzoAIKrl+kzZi9LHmr8S5EqLAOyP9TXIKG34x2j42dZsAhrw7TbF9gA8tBKwnCNH4ZV+Q==}
peerDependencies:
@@ -1245,9 +1256,6 @@ packages:
'@tauri-apps/plugin-log@2.4.0':
resolution: {integrity: sha512-j7yrDtLNmayCBOO2esl3aZv9jSXy2an8MDLry3Ys9ZXerwUg35n1Y2uD8HoCR+8Ng/EUgx215+qOUfJasjYrHw==}
'@tauri-apps/plugin-opener@2.2.7':
resolution: {integrity: sha512-uduEyvOdjpPOEeDRrhwlCspG/f9EQalHumWBtLBnp3fRp++fKGLqDOyUhSIn7PzX45b/rKep//ZQSAQoIxobLA==}
'@tauri-apps/plugin-os@2.2.1':
resolution: {integrity: sha512-cNYpNri2CCc6BaNeB6G/mOtLvg8dFyFQyCUdf2y0K8PIAKGEWdEcu8DECkydU2B+oj4OJihDPD2de5K6cbVl9A==}
@@ -1575,9 +1583,6 @@ packages:
engines: {node: ^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7}
hasBin: true
buffer-from@1.1.2:
resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==}
bundle-name@4.1.0:
resolution: {integrity: sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==}
engines: {node: '>=18'}
@@ -1653,6 +1658,9 @@ packages:
resolution: {integrity: sha512-cYY9mypksY8NRqgDB1XD1RiJL338v/551niynFTGkZOO2LHuB2OmOYxDIe/ttN9AHwrqdum1360G3ald0W9kCg==}
engines: {node: '>=8'}
classnames@2.5.1:
resolution: {integrity: sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow==}
cli-boxes@3.0.0:
resolution: {integrity: sha512-/lzGpEWL/8PfI0BmBOPRwp0c/wFNX1RdUML3jK/RcSBA9T8mZDdQpqYBKtCFTOfQbwPqWEOpjqW+Fnayc0969g==}
engines: {node: '>=10'}
@@ -1687,9 +1695,6 @@ packages:
comma-separated-tokens@2.0.3:
resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==}
commander@2.20.3:
resolution: {integrity: sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==}
commander@4.1.1:
resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==}
engines: {node: '>= 6'}
@@ -2635,8 +2640,8 @@ packages:
mdast-util-from-markdown@2.0.2:
resolution: {integrity: sha512-uZhTV/8NBuw0WHkPTrCqDOl0zVe1BIng5ZtHoDk49ME1qqcjYmmLmOf0gELgcRMxN4w2iuIeVso5/6QymSrgmA==}
mdast-util-gfm-autolink-literal@2.0.0:
resolution: {integrity: sha512-FyzMsduZZHSc3i0Px3PQcBT4WJY/X/RCtEJKuybiC6sjPqLv7h1yqAkmILZtuxMSsUyaLUWNp71+vQH2zqp5cg==}
mdast-util-gfm-autolink-literal@2.0.1:
resolution: {integrity: sha512-5HVP2MKaP6L+G6YaxPNjuL0BPrq9orG3TsrZ9YXbA3vDw/ACI4MEsnoDpn6ZNm7GnZgtAcONJyPhOP8tNJQavQ==}
mdast-util-gfm-footnote@2.1.0:
resolution: {integrity: sha512-sqpDWlsHn7Ac9GNZQMeUzPQSMzR6Wv0WKRNvQRg0KqHh02fpTz69Qc1QSseNX29bhz1ROIyNyxExfawVKTm1GQ==}
@@ -3132,6 +3137,9 @@ packages:
typescript:
optional: true
react-is@18.3.1:
resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==}
react-markdown@9.1.0:
resolution: {integrity: sha512-xaijuJB0kzGiUdG7nc2MOMDUDBWPyGAjZtUrow9XxUeua8IqeP+VlIfAZ3bphpcLTnSZXz6z9jcVC/TCwbfgdw==}
peerDependencies:
@@ -3338,9 +3346,6 @@ packages:
resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==}
engines: {node: '>=0.10.0'}
source-map-support@0.5.21:
resolution: {integrity: sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==}
source-map@0.6.1:
resolution: {integrity: sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==}
engines: {node: '>=0.10.0'}
@@ -3418,9 +3423,6 @@ packages:
tabbable@6.2.0:
resolution: {integrity: sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==}
tailwind-merge@3.3.1:
resolution: {integrity: sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==}
tailwindcss@3.4.17:
resolution: {integrity: sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==}
engines: {node: '>=14.0.0'}
@@ -3438,11 +3440,6 @@ packages:
tauri-plugin-windows-version-api@2.0.0:
resolution: {integrity: sha512-tty5n4ASYbXpnsD5ws2iTcTTpDCrSbzRTVp5Bo3UTpYGqlN1gBn2Zk8s3oO4w7VIM5WtJhDM9Jr/UgoTk7tFJQ==}
terser@5.40.0:
resolution: {integrity: sha512-cfeKl/jjwSR5ar7d0FGmave9hFGJT8obyo0z+CrQOylLDbk7X81nPU6vq9VORa5jU30SkDnT2FXjLbR8HLP+xA==}
engines: {node: '>=10'}
hasBin: true
thenify-all@1.6.0:
resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==}
engines: {node: '>=0.8'}
@@ -3777,6 +3774,23 @@ snapshots:
'@jridgewell/gen-mapping': 0.3.8
'@jridgewell/trace-mapping': 0.3.25
'@ant-design/colors@8.0.0':
dependencies:
'@ant-design/fast-color': 3.0.0
'@ant-design/fast-color@3.0.0': {}
'@ant-design/icons-svg@4.4.2': {}
'@ant-design/icons@6.0.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@ant-design/colors': 8.0.0
'@ant-design/icons-svg': 4.4.2
'@rc-component/util': 1.2.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
classnames: 2.5.1
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
'@antfu/install-pkg@1.1.0':
dependencies:
package-manager-detector: 1.3.0
@@ -4246,12 +4260,6 @@ snapshots:
'@jridgewell/set-array@1.2.1': {}
'@jridgewell/source-map@0.3.6':
dependencies:
'@jridgewell/gen-mapping': 0.3.8
'@jridgewell/trace-mapping': 0.3.25
optional: true
'@jridgewell/sourcemap-codec@1.5.0': {}
'@jridgewell/trace-mapping@0.3.25':
@@ -4419,6 +4427,12 @@ snapshots:
'@pnpm/network.ca-file': 1.0.2
config-chain: 1.1.13
'@rc-component/util@1.2.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
react-is: 18.3.1
'@react-aria/focus@3.20.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@react-aria/interactions': 3.25.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -4623,10 +4637,6 @@ snapshots:
dependencies:
'@tauri-apps/api': 2.5.0
'@tauri-apps/plugin-opener@2.2.7':
dependencies:
'@tauri-apps/api': 2.5.0
'@tauri-apps/plugin-os@2.2.1':
dependencies:
'@tauri-apps/api': 2.5.0
@@ -4868,14 +4878,14 @@ snapshots:
'@ungap/structured-clone@1.3.0': {}
'@vitejs/plugin-react@4.4.1(vite@5.4.19(@types/node@22.15.17)(sass@1.87.0)(terser@5.40.0))':
'@vitejs/plugin-react@4.4.1(vite@5.4.19(@types/node@22.15.17)(sass@1.87.0))':
dependencies:
'@babel/core': 7.27.1
'@babel/plugin-transform-react-jsx-self': 7.27.1(@babel/core@7.27.1)
'@babel/plugin-transform-react-jsx-source': 7.27.1(@babel/core@7.27.1)
'@types/babel__core': 7.20.5
react-refresh: 0.17.0
vite: 5.4.19(@types/node@22.15.17)(sass@1.87.0)(terser@5.40.0)
vite: 5.4.19(@types/node@22.15.17)(sass@1.87.0)
transitivePeerDependencies:
- supports-color
@@ -5004,9 +5014,6 @@ snapshots:
node-releases: 2.0.19
update-browserslist-db: 1.1.3(browserslist@4.24.5)
buffer-from@1.1.2:
optional: true
bundle-name@4.1.0:
dependencies:
run-applescript: 7.0.0
@@ -5077,6 +5084,8 @@ snapshots:
ci-info@4.2.0: {}
classnames@2.5.1: {}
cli-boxes@3.0.0: {}
cli-cursor@5.0.0:
@@ -5101,9 +5110,6 @@ snapshots:
comma-separated-tokens@2.0.3: {}
commander@2.20.3:
optional: true
commander@4.1.1: {}
commander@7.2.0: {}
@@ -6108,7 +6114,7 @@ snapshots:
transitivePeerDependencies:
- supports-color
mdast-util-gfm-autolink-literal@2.0.0:
mdast-util-gfm-autolink-literal@2.0.1:
dependencies:
'@types/mdast': 4.0.4
ccount: 2.0.1
@@ -6156,7 +6162,7 @@ snapshots:
mdast-util-gfm@3.1.0:
dependencies:
mdast-util-from-markdown: 2.0.2
mdast-util-gfm-autolink-literal: 2.0.0
mdast-util-gfm-autolink-literal: 2.0.1
mdast-util-gfm-footnote: 2.1.0
mdast-util-gfm-strikethrough: 2.0.0
mdast-util-gfm-table: 2.0.0
@@ -6824,6 +6830,8 @@ snapshots:
react-dom: 18.3.1(react@18.3.1)
typescript: 5.8.3
react-is@18.3.1: {}
react-markdown@9.1.0(@types/react@18.3.21)(react@18.3.1):
dependencies:
'@types/hast': 3.0.4
@@ -7113,12 +7121,6 @@ snapshots:
source-map-js@1.2.1: {}
source-map-support@0.5.21:
dependencies:
buffer-from: 1.1.2
source-map: 0.6.1
optional: true
source-map@0.6.1:
optional: true
@@ -7195,8 +7197,6 @@ snapshots:
tabbable@6.2.0: {}
tailwind-merge@3.3.1: {}
tailwindcss@3.4.17:
dependencies:
'@alloc/quick-lru': 5.2.0
@@ -7240,14 +7240,6 @@ snapshots:
dependencies:
'@tauri-apps/api': 2.5.0
terser@5.40.0:
dependencies:
'@jridgewell/source-map': 0.3.6
acorn: 8.14.1
commander: 2.20.3
source-map-support: 0.5.21
optional: true
thenify-all@1.6.0:
dependencies:
thenify: 3.3.1
@@ -7434,7 +7426,7 @@ snapshots:
'@types/unist': 3.0.3
vfile-message: 4.0.2
vite@5.4.19(@types/node@22.15.17)(sass@1.87.0)(terser@5.40.0):
vite@5.4.19(@types/node@22.15.17)(sass@1.87.0):
dependencies:
esbuild: 0.21.5
postcss: 8.5.3
@@ -7443,7 +7435,6 @@ snapshots:
'@types/node': 22.15.17
fsevents: 2.3.3
sass: 1.87.0
terser: 5.40.0
void-elements@3.1.0: {}

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
(() => {})();

674
src-tauri/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
[package]
name = "coco"
version = "0.6.0"
version = "0.4.0"
description = "Search, connect, collaborate all in one place."
authors = ["INFINI Labs"]
edition = "2024"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
@@ -44,12 +44,12 @@ use_pizza_engine = []
[dependencies]
pizza-common = { git = "https://github.com/infinilabs/pizza-common", branch = "main" }
tauri = { version = "2", features = ["protocol-asset", "macos-private-api", "tray-icon", "image-ico", "image-png"] }
tauri = { version = "2", features = ["protocol-asset", "macos-private-api", "tray-icon", "image-ico", "image-png", "unstable"] }
tauri-plugin-shell = "2"
serde = { version = "1", features = ["derive"] }
# Need `arbitrary_precision` feature to support storing u128
# see: https://docs.rs/serde_json/latest/serde_json/struct.Number.html#method.from_u128
serde_json = { version = "1", features = ["arbitrary_precision", "preserve_order"] }
serde_json = { version = "1", features = ["arbitrary_precision"] }
tauri-plugin-http = "2"
tauri-plugin-websocket = "2"
tauri-plugin-deep-link = "2.0.0"
@@ -62,7 +62,7 @@ tauri-plugin-drag = "2"
tauri-plugin-macos-permissions = "2"
tauri-plugin-fs-pro = "2"
tauri-plugin-screenshots = "2"
applications = { git = "https://github.com/infinilabs/applications-rs", rev = "31b0c030a0f3bc82275fe12debe526153978671d" }
applications = { git = "https://github.com/infinilabs/applications-rs", rev = "7bb507e6b12f73c96f3a52f0578d0246a689f381" }
tokio-native-tls = "0.3" # For wss connections
tokio = { version = "1", features = ["full"] }
tokio-tungstenite = { version = "0.20", features = ["native-tls"] }
@@ -81,31 +81,18 @@ plist = "1.7"
base64 = "0.13"
walkdir = "2"
log = "0.4"
strsim = "0.10"
futures-util = "0.3.31"
url = "2.5.2"
http = "1.1.0"
tungstenite = "0.24.0"
tokio-util = "0.7.14"
tauri-plugin-windows-version = "2"
meval = { git = "https://github.com/infinilabs/meval-rs" }
meval = "0.2"
chinese-number = "0.7"
num2words = "1"
tauri-plugin-log = "2"
chrono = "0.4.41"
serde_plain = "1.0.2"
derive_more = { version = "2.0.1", features = ["display"] }
anyhow = "1.0.98"
function_name = "0.3.0"
regex = "1.11.1"
borrowme = "0.0.15"
tauri-plugin-opener = "2"
async-recursion = "1.1.1"
zip = "4.0.0"
url = "2.5.2"
camino = "1.1.10"
tokio-stream = { version = "0.1.17", features = ["io-util"] }
cfg-if = "1.0.1"
sysinfo = "0.35.2"
[target."cfg(target_os = \"macos\")".dependencies]
tauri-nspanel = { git = "https://github.com/ahkohd/tauri-nspanel", branch = "v2" }
@@ -130,4 +117,3 @@ tauri-plugin-updater = { git = "https://github.com/infinilabs/plugins-workspace"
[target."cfg(target_os = \"windows\")".dependencies]
enigo="0.3"
windows = { version = "0.61.3", features = ["Win32_Foundation", "Win32_System_Com", "Win32_System_Ole", "Win32_System_Search", "Win32_UI_Shell_PropertiesSystem", "Win32_Data"] }

View File

@@ -1,14 +1,3 @@
fn main() {
tauri_build::build();
// If env var `GITHUB_ACTIONS` exists, we are running in CI, set up the `ci`
// attribute
if std::env::var("GITHUB_ACTIONS").is_ok() {
println!("cargo:rustc-cfg=ci");
}
// Notify `rustc` of this `cfg` attribute to suppress unknown attribute warnings.
//
// unexpected condition name: `ci`
println!("cargo::rustc-check-cfg=cfg(ci)");
tauri_build::build()
}

View File

@@ -2,7 +2,7 @@
"$schema": "../gen/schemas/desktop-schema.json",
"identifier": "default",
"description": "Capability for the main window",
"windows": ["main", "chat", "settings", "check"],
"windows": ["main", "chat", "settings"],
"permissions": [
"core:default",
"core:event:allow-emit",
@@ -71,7 +71,6 @@
"process:default",
"updater:default",
"windows-version:default",
"log:default",
"opener:default"
"log:default"
]
}

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "nightly-2025-06-26"
channel = "nightly-2024-10-29"

View File

@@ -1,16 +1,10 @@
use crate::common;
use crate::common::assistant::ChatRequestMessage;
use crate::common::http::{convert_query_params_to_strings, GetResponse};
use crate::common::register::SearchSourceRegistry;
use crate::common::http::GetResponse;
use crate::server::http_client::HttpClient;
use crate::{common, server::servers::COCO_SERVERS};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use futures_util::TryStreamExt;
use http::Method;
use serde_json::Value;
use std::collections::HashMap;
use tauri::{AppHandle, Emitter, Manager, Runtime};
use tokio::io::AsyncBufReadExt;
use tauri::{AppHandle, Runtime};
#[tauri::command]
pub async fn chat_history<R: Runtime>(
@@ -20,15 +14,17 @@ pub async fn chat_history<R: Runtime>(
size: u32,
query: Option<String>,
) -> Result<String, String> {
let mut query_params = Vec::new();
// Add from/size as number values
query_params.push(format!("from={}", from));
query_params.push(format!("size={}", size));
let mut query_params: HashMap<String, Value> = HashMap::new();
if from > 0 {
query_params.insert("from".to_string(), from.into());
}
if size > 0 {
query_params.insert("size".to_string(), size.into());
}
if let Some(query) = query {
if !query.is_empty() {
query_params.push(format!("query={}", query.to_string()));
query_params.insert("query".to_string(), query.into());
}
}
@@ -50,11 +46,13 @@ pub async fn session_chat_history<R: Runtime>(
from: u32,
size: u32,
) -> Result<String, String> {
let mut query_params = Vec::new();
// Add from/size as number values
query_params.push(format!("from={}", from));
query_params.push(format!("size={}", size));
let mut query_params: HashMap<String, Value> = HashMap::new();
if from > 0 {
query_params.insert("from".to_string(), from.into());
}
if size > 0 {
query_params.insert("size".to_string(), size.into());
}
let path = format!("/chat/{}/_history", session_id);
@@ -71,9 +69,10 @@ pub async fn open_session_chat<R: Runtime>(
server_id: String,
session_id: String,
) -> Result<String, String> {
let query_params = HashMap::new();
let path = format!("/chat/{}/_open", session_id);
let response = HttpClient::post(&server_id, path.as_str(), None, None)
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error open session: {}", e))?;
@@ -86,9 +85,10 @@ pub async fn close_session_chat<R: Runtime>(
server_id: String,
session_id: String,
) -> Result<String, String> {
let query_params = HashMap::new();
let path = format!("/chat/{}/_close", session_id);
let response = HttpClient::post(&server_id, path.as_str(), None, None)
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error close session: {}", e))?;
@@ -100,9 +100,10 @@ pub async fn cancel_session_chat<R: Runtime>(
server_id: String,
session_id: String,
) -> Result<String, String> {
let query_params = HashMap::new();
let path = format!("/chat/{}/_cancel", session_id);
let response = HttpClient::post(&server_id, path.as_str(), None, None)
let response = HttpClient::post(&server_id, path.as_str(), Some(query_params), None)
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
@@ -133,22 +134,15 @@ pub async fn new_chat<R: Runtime>(
let mut headers = HashMap::new();
headers.insert("WEBSOCKET-SESSION-ID".to_string(), websocket_id.into());
let response = HttpClient::advanced_post(
&server_id,
"/chat/_new",
Some(headers),
convert_query_params_to_strings(query_params),
body,
)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
let response =
HttpClient::advanced_post(&server_id, "/chat/_new", Some(headers), query_params, body)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
let body_text = common::http::get_response_body_text(response).await?;
log::debug!("New chat response: {}", &body_text);
let chat_response: GetResponse = serde_json::from_str(&body_text)
.map_err(|e| format!("Failed to parse response JSON: {}", e))?;
let chat_response: GetResponse =
serde_json::from_str(&body_text).map_err(|e| format!("Failed to parse response JSON: {}", e))?;
if chat_response.result != "created" {
return Err(format!("Unexpected result: {}", chat_response.result));
@@ -157,54 +151,6 @@ pub async fn new_chat<R: Runtime>(
Ok(chat_response)
}
#[tauri::command]
pub async fn chat_create<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
message: String,
query_params: Option<HashMap<String, Value>>,
) -> Result<(), String> {
let body = if !message.is_empty() {
let message = ChatRequestMessage {
message: Some(message),
};
Some(
serde_json::to_string(&message)
.map_err(|e| format!("Failed to serialize message: {}", e))?
.into(),
)
} else {
None
};
let response = HttpClient::advanced_post(
&server_id,
"/chat/_create",
None,
convert_query_params_to_strings(query_params),
body,
)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
if response.status() == 429 {
log::warn!("Rate limit exceeded for chat create");
return Err("Rate limited".to_string());
}
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
emit_json_stream_lines(
&app_handle,
"chat-create-stream",
"chat-create-error",
response,
)
.await
}
#[tauri::command]
pub async fn send_message<R: Runtime>(
_app_handle: AppHandle<R>,
@@ -227,115 +173,15 @@ pub async fn send_message<R: Runtime>(
&server_id,
path.as_str(),
Some(headers),
convert_query_params_to_strings(query_params),
query_params,
Some(body),
)
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
.await
.map_err(|e| format!("Error cancel session: {}", e))?;
common::http::get_response_body_text(response).await
}
#[tauri::command]
pub async fn chat_chat<R: Runtime>(
app_handle: AppHandle<R>,
server_id: String,
session_id: String,
message: String,
query_params: Option<HashMap<String, Value>>, //search,deep_thinking
) -> Result<(), String> {
let body = if !message.is_empty() {
let message = ChatRequestMessage {
message: Some(message),
};
Some(
serde_json::to_string(&message)
.map_err(|e| format!("Failed to serialize message: {}", e))?
.into(),
)
} else {
None
};
let path = format!("/chat/{}/_chat", session_id);
let response = HttpClient::advanced_post(
&server_id,
path.as_str(),
None,
convert_query_params_to_strings(query_params),
body,
)
.await
.map_err(|e| format!("Error sending message: {}", e))?;
if response.status() == 429 {
log::warn!("Rate limit exceeded for chat create");
return Err("Rate limited".to_string());
}
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
emit_json_stream_lines(
&app_handle,
"chat-create-stream",
"chat-create-error",
response,
)
.await
}
pub async fn emit_json_stream_lines<R: Runtime>(
app_handle: &AppHandle<R>,
event_name: &str,
error_event_name: &str,
response: reqwest::Response,
) -> Result<(), String> {
let stream = response.bytes_stream();
let reader = tokio_util::io::StreamReader::new(
stream.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
);
let mut lines = tokio::io::BufReader::new(reader).lines();
while let Ok(Some(line)) = lines.next_line().await {
log::debug!("Received stream line: {}", &line);
match serde_json::from_str::<Value>(&line) {
Ok(Value::Array(items)) => {
for item in items {
if let Ok(json_str) = serde_json::to_string(&item) {
if let Err(err) = app_handle.emit(event_name, json_str) {
log::error!("Emit failed: {:?}", err);
let _ = app_handle
.emit(error_event_name, format!("Emit failed: {:?}", err));
}
}
}
}
Ok(obj @ Value::Object(_)) => {
if let Ok(json_str) = serde_json::to_string(&obj) {
if let Err(err) = app_handle.emit(event_name, json_str) {
log::error!("Emit failed: {:?}", err);
let _ =
app_handle.emit(error_event_name, format!("Emit failed: {:?}", err));
}
}
}
Err(err) => {
log::warn!("Invalid JSON line: {} | Error: {}", line, err);
let _ = app_handle.emit(error_event_name, format!("Invalid JSON: {}", err));
}
_ => {
log::warn!("Unexpected JSON type: {}", line);
}
}
}
Ok(())
}
#[tauri::command]
pub async fn delete_session_chat(server_id: String, session_id: String) -> Result<bool, String> {
let response =
@@ -373,8 +219,8 @@ pub async fn update_session_chat(
None,
Some(reqwest::Body::from(serde_json::to_string(&body).unwrap())),
)
.await
.map_err(|e| format!("Error updating session: {}", e))?;
.await
.map_err(|e| format!("Error updating session: {}", e))?;
Ok(response.status().is_success())
}
@@ -383,184 +229,30 @@ pub async fn update_session_chat(
pub async fn assistant_search<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
query_params: Option<Vec<String>>,
from: u32,
size: u32,
query: Option<HashMap<String, Value>>,
) -> Result<Value, String> {
let response = HttpClient::post(&server_id, "/assistant/_search", query_params, None)
.await
.map_err(|e| format!("Error searching assistants: {}", e))?;
let mut body = serde_json::json!({
"from": from,
"size": size,
});
response
.json::<Value>()
.await
.map_err(|err| err.to_string())
}
if let Some(q) = query {
body["query"] = serde_json::to_value(q).map_err(|e| e.to_string())?;
}
#[tauri::command]
pub async fn assistant_get<R: Runtime>(
_app_handle: AppHandle<R>,
server_id: String,
assistant_id: String,
) -> Result<Value, String> {
let response = HttpClient::get(
let response = HttpClient::post(
&server_id,
&format!("/assistant/{}", assistant_id),
None, // headers
)
.await
.map_err(|e| format!("Error getting assistant: {}", e))?;
response
.json::<Value>()
.await
.map_err(|err| err.to_string())
}
/// Gets the information of the assistant specified by `assistant_id` by querying **all**
/// Coco servers.
///
/// Returns as soon as the assistant is found on any Coco server.
#[tauri::command]
pub async fn assistant_get_multi<R: Runtime>(
app_handle: AppHandle<R>,
assistant_id: String,
) -> Result<Value, String> {
let search_sources = app_handle.state::<SearchSourceRegistry>();
let sources_future = search_sources.get_sources();
let sources_list = sources_future.await;
let mut futures = FuturesUnordered::new();
for query_source in &sources_list {
let query_source_type = query_source.get_type();
if query_source_type.r#type != COCO_SERVERS {
// Assistants only exists on Coco servers.
continue;
}
let coco_server_id = query_source_type.id.clone();
let path = format!("/assistant/{}", assistant_id);
let fut = async move {
let res_response = HttpClient::get(
&coco_server_id,
&path,
None, // headers
)
.await;
match res_response {
Ok(response) => response
.json::<serde_json::Value>()
.await
.map_err(|e| e.to_string()),
Err(e) => Err(e),
}
};
futures.push(fut);
}
while let Some(res_response_json) = futures.next().await {
let response_json = match res_response_json {
Ok(json) => json,
Err(e) => return Err(e),
};
// Example response JSON
//
// When assistant is not found:
// ```json
// {
// "_id": "ID",
// "result": "not_found"
// }
// ```
//
// When assistant is found:
// ```json
// {
// "_id": "ID",
// "_source": {...}
// "found": true
// }
// ```
if let Some(found) = response_json.get("found") {
if found == true {
return Ok(response_json);
}
}
}
Err(format!(
"could not find Assistant [{}] on all the Coco servers",
assistant_id
))
}
use regex::Regex;
/// Remove all `"icon": "..."` fields from a JSON string
pub fn remove_icon_fields(json: &str) -> String {
// Regex to match `"icon": "..."` fields, including base64 or escaped strings
let re = Regex::new(r#""icon"\s*:\s*"[^"]*"(,?)"#).unwrap();
// Replace with empty string, or just remove trailing comma if needed
re.replace_all(json, |caps: &regex::Captures| {
if &caps[1] == "," {
"".to_string() // keep comma removal logic safe
} else {
"".to_string()
}
})
.to_string()
}
#[tauri::command]
pub async fn ask_ai<R: Runtime>(
app_handle: AppHandle<R>,
message: String,
server_id: String,
assistant_id: String,
client_id: String,
) -> Result<(), String> {
let cleaned = remove_icon_fields(message.as_str());
let body = serde_json::json!({ "message": cleaned });
let path = format!("/assistant/{}/_ask", assistant_id);
println!("Sending request to {}", &path);
let response = HttpClient::send_request(
server_id.as_str(),
Method::POST,
path.as_str(),
None,
"/assistant/_search",
None,
Some(reqwest::Body::from(body.to_string())),
)
.await?;
.await
.map_err(|e| format!("Error searching assistants: {}", e))?;
if response.status() == 429 {
log::warn!("Rate limit exceeded for assistant: {}", &assistant_id);
return Ok(());
}
if !response.status().is_success() {
return Err(format!("Request Failed: {}", response.status()));
}
let stream = response.bytes_stream();
let reader = tokio_util::io::StreamReader::new(
stream.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
);
let mut lines = tokio::io::BufReader::new(reader).lines();
while let Ok(Some(line)) = lines.next_line().await {
dbg!("Received line: {}", &line);
let _ = app_handle.emit(&client_id, line).map_err(|err| {
println!("Failed to emit: {:?}", err);
});
}
Ok(())
response
.json::<Value>()
.await
.map_err(|err| err.to_string())
}

View File

@@ -3,43 +3,38 @@ use std::{fs::create_dir, io::Read};
use tauri::{Manager, Runtime};
use tauri_plugin_autostart::ManagerExt;
/// If the state reported from the OS and the state stored by us differ, our state is
/// prioritized and seen as the correct one. Update the OS state to make them consistent.
pub fn ensure_autostart_state_consistent(app: &mut tauri::App) -> Result<(), String> {
// Start or stop according to configuration
pub fn enable_autostart(app: &mut tauri::App) {
use tauri_plugin_autostart::MacosLauncher;
use tauri_plugin_autostart::ManagerExt;
app.handle()
.plugin(tauri_plugin_autostart::init(
MacosLauncher::AppleScript,
None,
))
.unwrap();
let autostart_manager = app.autolaunch();
let os_state = autostart_manager.is_enabled().map_err(|e| e.to_string())?;
let coco_stored_state = current_autostart(app.app_handle()).map_err(|e| e.to_string())?;
// close autostart
// autostart_manager.disable().unwrap();
// return;
if os_state != coco_stored_state {
log::warn!(
"autostart inconsistent states, OS state [{}], Coco state [{}], config file could be deleted or corrupted",
os_state,
coco_stored_state
);
log::info!("trying to correct the inconsistent states");
let result = if coco_stored_state {
autostart_manager.enable()
} else {
autostart_manager.disable()
};
match result {
Ok(_) => {
log::info!("inconsistent autostart states fixed");
}
Err(e) => {
log::error!(
"failed to fix inconsistent autostart state due to error [{}]",
e
);
return Err(e.to_string());
}
}
match (
autostart_manager.is_enabled(),
current_autostart(app.app_handle()),
) {
(Ok(false), Ok(true)) => match autostart_manager.enable() {
Ok(_) => println!("Autostart enabled successfully."),
Err(err) => eprintln!("Failed to enable autostart: {}", err),
},
(Ok(true), Ok(false)) => match autostart_manager.disable() {
Ok(_) => println!("Autostart disable successfully."),
Err(err) => eprintln!("Failed to disable autostart: {}", err),
},
_ => (),
}
Ok(())
}
fn current_autostart<R: Runtime>(app: &tauri::AppHandle<R>) -> Result<bool, String> {

View File

@@ -9,13 +9,13 @@ pub struct ChatRequestMessage {
#[allow(dead_code)]
pub struct NewChatResponse {
pub _id: String,
pub _source: Session,
pub _source: Source,
pub result: String,
pub payload: Option<Value>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Session {
pub struct Source {
pub id: String,
pub created: String,
pub updated: String,
@@ -23,11 +23,4 @@ pub struct Session {
pub title: Option<String>,
pub summary: Option<String>,
pub manually_renamed_title: bool,
pub visible: Option<bool>,
pub context: Option<SessionContext>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionContext {
pub attachments: Option<Vec<String>>,
}

View File

@@ -29,89 +29,6 @@ pub struct EditorInfo {
pub timestamp: Option<String>,
}
/// Defines the action that would be performed when a document gets opened.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum OnOpened {
/// Launch the application
Application { app_path: String },
/// Open the URL.
Document { url: String },
/// Spawn a child process to run the `CommandAction`.
Command {
action: crate::extension::CommandAction,
},
}
impl OnOpened {
pub(crate) fn url(&self) -> String {
match self {
Self::Application { app_path } => app_path.clone(),
Self::Document { url } => url.clone(),
Self::Command { action } => {
const WHITESPACE: &str = " ";
let mut ret = action.exec.clone();
ret.push_str(WHITESPACE);
if let Some(ref args) = action.args {
ret.push_str(args.join(WHITESPACE).as_str());
}
ret
}
}
}
}
#[tauri::command]
pub(crate) async fn open(on_opened: OnOpened) -> Result<(), String> {
log::debug!("open({})", on_opened.url());
use crate::util::open as homemade_tauri_shell_open;
use crate::GLOBAL_TAURI_APP_HANDLE;
use std::process::Command;
let global_tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
match on_opened {
OnOpened::Application { app_path } => {
homemade_tauri_shell_open(global_tauri_app_handle.clone(), app_path).await?
}
OnOpened::Document { url } => {
homemade_tauri_shell_open(global_tauri_app_handle.clone(), url).await?
}
OnOpened::Command { action } => {
let mut cmd = Command::new(action.exec);
if let Some(args) = action.args {
cmd.args(args);
}
let output = cmd.output().map_err(|e| e.to_string())?;
// Sometimes, we wanna see the result in logs even though it doesn't fail.
log::debug!(
"executing open(Command) result, exit code: [{}], stdout: [{}], stderr: [{}]",
output.status,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
if !output.status.success() {
log::warn!(
"executing open(Command) failed, exit code: [{}], stdout: [{}], stderr: [{}]",
output.status,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
return Err(format!(
"Command failed, stderr [{}]",
String::from_utf8_lossy(&output.stderr)
));
}
}
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Document {
pub id: String,
@@ -131,8 +48,6 @@ pub struct Document {
pub thumbnail: Option<String>,
pub cover: Option<String>,
pub tags: Option<Vec<String>>,
/// What will happen if we open this document.
pub on_opened: Option<OnOpened>,
pub url: Option<String>,
pub size: Option<i64>,
pub metadata: Option<HashMap<String, serde_json::Value>>,

View File

@@ -2,52 +2,32 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct ErrorCause {
#[serde(default)]
pub r#type: Option<String>,
#[serde(default)]
pub reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct ErrorDetail {
#[serde(default)]
pub root_cause: Option<Vec<ErrorCause>>,
#[serde(default)]
pub r#type: Option<String>,
#[serde(default)]
pub reason: Option<String>,
#[serde(default)]
pub caused_by: Option<ErrorCause>,
pub reason: String,
pub status: u16,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct ErrorResponse {
#[serde(default)]
pub error: Option<ErrorDetail>,
#[serde(default)]
pub status: Option<u16>,
pub error: ErrorDetail,
}
#[derive(Debug, Error, Serialize)]
pub enum SearchError {
#[error("HttpError: {0}")]
#[error("HTTP request failed: {0}")]
HttpError(String),
#[error("ParseError: {0}")]
#[error("Invalid response format: {0}")]
ParseError(String),
#[error("Timeout occurred")]
Timeout,
#[error("UnknownError: {0}")]
#[error("Unknown error: {0}")]
#[allow(dead_code)]
Unknown(String),
#[error("InternalError: {0}")]
#[error("InternalError error: {0}")]
#[allow(dead_code)]
InternalError(String),
}
@@ -62,4 +42,4 @@ impl From<reqwest::Error> for SearchError {
SearchError::HttpError(err.to_string())
}
}
}
}

View File

@@ -2,8 +2,6 @@ use crate::common;
use reqwest::Response;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use tauri_plugin_store::JsonValue;
#[derive(Debug, Serialize, Deserialize)]
pub struct GetResponse {
@@ -38,43 +36,17 @@ pub async fn get_response_body_text(response: Response) -> Result<String, String
return Err(fallback_error);
}
match serde_json::from_str::<common::error::ErrorResponse>(&body) {
Ok(parsed_error) => {
dbg!(&parsed_error);
Err(format!(
"Server error ({}): {:?}",
status, parsed_error.error
"Server error ({}): {}",
parsed_error.error.status, parsed_error.error.reason
))
}
Err(_) => {
log::warn!("Failed to parse error response: {}", &body);
Err(fallback_error)
}
Err(_) => Err(fallback_error),
}
} else {
Ok(body)
}
}
pub fn convert_query_params_to_strings(
query_params: Option<HashMap<String, JsonValue>>,
) -> Option<Vec<String>> {
query_params.map(|map| {
map.into_iter()
.filter_map(|(k, v)| match v {
JsonValue::String(s) => Some(format!("{}={}", k, s)),
JsonValue::Number(n) => Some(format!("{}={}", k, n)),
JsonValue::Bool(b) => Some(format!("{}={}", k, b)),
_ => {
eprintln!(
"Skipping unsupported query value for key '{}': {:?}",
k, v
);
None
}
})
.collect()
})
}

View File

@@ -1,17 +1,16 @@
pub mod assistant;
pub mod auth;
pub mod connector;
pub mod datasource;
pub mod document;
pub mod error;
pub mod health;
pub mod http;
pub mod profile;
pub mod register;
pub mod search;
pub mod server;
pub mod auth;
pub mod datasource;
pub mod connector;
pub mod search;
pub mod document;
pub mod traits;
pub mod register;
pub mod assistant;
pub mod http;
pub mod error;
pub static MAIN_WINDOW_LABEL: &str = "main";
pub static SETTINGS_WINDOW_LABEL: &str = "settings";
pub static CHECK_WINDOW_LABEL: &str = "check";

View File

@@ -7,8 +7,8 @@ use std::error::Error;
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchResponse<T> {
pub took: Option<u64>,
pub timed_out: Option<bool>,
pub took: u64,
pub timed_out: bool,
pub _shards: Option<Shards>,
pub hits: Hits<T>,
}
@@ -25,7 +25,7 @@ pub struct Shards {
pub struct Hits<T> {
pub total: Total,
pub max_score: Option<f32>,
pub hits: Option<Vec<SearchHit<T>>>,
pub hits: Vec<SearchHit<T>>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -36,9 +36,9 @@ pub struct Total {
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchHit<T> {
pub _index: Option<String>,
pub _type: Option<String>,
pub _id: Option<String>,
pub _index: String,
pub _type: String,
pub _id: String,
pub _score: Option<f64>,
pub _source: T, // This will hold the type we pass in (e.g., DataSource)
}
@@ -58,18 +58,13 @@ where
Ok(search_response)
}
use serde::de::DeserializeOwned;
pub async fn parse_search_hits<T>(response: Response) -> Result<Vec<SearchHit<T>>, Box<dyn Error>>
where
T: DeserializeOwned + std::fmt::Debug,
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
let response = parse_search_response(response).await?;
match response.hits.hits {
Some(hits) => Ok(hits),
None => Ok(Vec::new()),
}
Ok(response.hits.hits)
}
pub async fn parse_search_results<T>(response: Response) -> Result<Vec<T>, Box<dyn Error>>

View File

@@ -1,8 +1,6 @@
use crate::common::health::Health;
use crate::common::profile::UserProfile;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -62,7 +60,6 @@ pub struct Server {
pub auth_provider: AuthProvider,
#[serde(default = "default_priority_type")]
pub priority: u32,
pub stats: Option<HashMap<String, Value>>,
}
impl PartialEq for Server {

View File

@@ -1,4 +1,5 @@
use crate::common::error::SearchError;
// use std::{future::Future, pin::Pin};
use crate::common::search::SearchQuery;
use crate::common::search::{QueryResponse, QuerySource};
use async_trait::async_trait;
@@ -9,3 +10,4 @@ pub trait SearchSource: Send + Sync {
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError>;
}

View File

@@ -1,13 +0,0 @@
pub(super) const EXTENSION_ID: &str = "AIOverview";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "AIOverview",
"name": "AI Overview",
"description": "...",
"icon": "font_a-AIOverview",
"type": "ai_extension",
"enabled": true
}
"#;

View File

@@ -1,213 +0,0 @@
//! File Search configuration entries definition and getter/setter functions.
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::sync::LazyLock;
use tauri_plugin_store::StoreExt;
// Tauri store keys for file system configuration
const TAURI_STORE_FILE_SYSTEM_CONFIG: &str = "file_system_config";
const TAURI_STORE_KEY_SEARCH_BY: &str = "search_by";
const TAURI_STORE_KEY_SEARCH_PATHS: &str = "search_paths";
const TAURI_STORE_KEY_EXCLUDE_PATHS: &str = "exclude_paths";
const TAURI_STORE_KEY_FILE_TYPES: &str = "file_types";
static HOME_DIR: LazyLock<String> = LazyLock::new(|| {
let os_string = dirs::home_dir()
.expect("$HOME should be set")
.into_os_string();
os_string
.into_string()
.expect("User home directory should be encoded with UTF-8")
});
#[derive(Debug, Clone, Serialize, Deserialize, Copy)]
pub enum SearchBy {
Name,
NameAndContents,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileSearchConfig {
pub search_paths: Vec<String>,
pub exclude_paths: Vec<String>,
pub file_types: Vec<String>,
pub search_by: SearchBy,
}
impl Default for FileSearchConfig {
fn default() -> Self {
Self {
search_paths: vec![
format!("{}/Documents", HOME_DIR.as_str()),
format!("{}/Desktop", HOME_DIR.as_str()),
format!("{}/Downloads", HOME_DIR.as_str()),
],
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
}
}
}
impl FileSearchConfig {
pub(crate) fn get() -> Self {
let tauri_app_handle = crate::GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
let store = tauri_app_handle
.store(TAURI_STORE_FILE_SYSTEM_CONFIG)
.unwrap_or_else(|e| {
panic!(
"store [{}] not found/loaded, error [{}]",
TAURI_STORE_FILE_SYSTEM_CONFIG, e
)
});
// Default value, will be used when specific config entries are not set
let default_config = FileSearchConfig::default();
let search_paths = {
if let Some(search_paths) = store.get(TAURI_STORE_KEY_SEARCH_PATHS) {
match search_paths {
Value::Array(arr) => {
let mut vec = Vec::with_capacity(arr.len());
for v in arr {
match v {
Value::String(s) => vec.push(s),
other => panic!(
"Expected all elements of 'search_paths' to be strings, but found: {:?}",
other
),
}
}
vec
}
other => panic!(
"Expected 'search_paths' to be an array of strings in the file system config store, but got: {:?}",
other
),
}
} else {
store.set(
TAURI_STORE_KEY_SEARCH_PATHS,
default_config.search_paths.as_slice(),
);
default_config.search_paths
}
};
let exclude_paths = {
if let Some(exclude_paths) = store.get(TAURI_STORE_KEY_EXCLUDE_PATHS) {
match exclude_paths {
Value::Array(arr) => {
let mut vec = Vec::with_capacity(arr.len());
for v in arr {
match v {
Value::String(s) => vec.push(s),
other => panic!(
"Expected all elements of 'exclude_paths' to be strings, but found: {:?}",
other
),
}
}
vec
}
other => panic!(
"Expected 'exclude_paths' to be an array of strings in the file system config store, but got: {:?}",
other
),
}
} else {
store.set(
TAURI_STORE_KEY_EXCLUDE_PATHS,
default_config.exclude_paths.as_slice(),
);
default_config.exclude_paths
}
};
let file_types = {
if let Some(file_types) = store.get(TAURI_STORE_KEY_FILE_TYPES) {
match file_types {
Value::Array(arr) => {
let mut vec = Vec::with_capacity(arr.len());
for v in arr {
match v {
Value::String(s) => vec.push(s),
other => panic!(
"Expected all elements of 'file_types' to be strings, but found: {:?}",
other
),
}
}
vec
}
other => panic!(
"Expected 'file_types' to be an array of strings in the file system config store, but got: {:?}",
other
),
}
} else {
store.set(
TAURI_STORE_KEY_FILE_TYPES,
default_config.file_types.as_slice(),
);
default_config.file_types
}
};
let search_by = {
if let Some(search_by) = store.get(TAURI_STORE_KEY_SEARCH_BY) {
serde_json::from_value(search_by.clone()).unwrap_or_else(|e| {
panic!(
"Failed to deserialize 'search_by' from file system config store. Invalid JSON: {:?}, error: {}",
search_by, e
)
})
} else {
store.set(
TAURI_STORE_KEY_SEARCH_BY,
serde_json::to_value(default_config.search_by).unwrap(),
);
default_config.search_by
}
};
Self {
search_by,
search_paths,
exclude_paths,
file_types,
}
}
}
// Tauri commands for managing file system configuration
#[tauri::command]
pub async fn get_file_system_config() -> FileSearchConfig {
FileSearchConfig::get()
}
#[tauri::command]
pub async fn set_file_system_config(config: FileSearchConfig) -> Result<(), String> {
let tauri_app_handle = crate::GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
let store = tauri_app_handle
.store(TAURI_STORE_FILE_SYSTEM_CONFIG)
.map_err(|e| e.to_string())?;
store.set(TAURI_STORE_KEY_SEARCH_PATHS, config.search_paths);
store.set(TAURI_STORE_KEY_EXCLUDE_PATHS, config.exclude_paths);
store.set(TAURI_STORE_KEY_FILE_TYPES, config.file_types);
store.set(
TAURI_STORE_KEY_SEARCH_BY,
serde_json::to_value(config.search_by).unwrap(),
);
Ok(())
}

View File

@@ -1,189 +0,0 @@
use super::super::config::FileSearchConfig;
use super::super::config::SearchBy;
use super::super::EXTENSION_ID;
use crate::common::{
document::{DataSourceReference, Document},
};
use crate::extension::OnOpened;
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::util::file::get_file_icon;
use futures::stream::Stream;
use futures::stream::StreamExt;
use std::os::fd::OwnedFd;
use std::path::Path;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Child;
use tokio::process::Command;
use tokio_stream::wrappers::LinesStream;
/// `mdfind` won't return scores, we use this score for all the documents.
const SCORE: f64 = 1.0;
pub(crate) async fn hits(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
let (mut iter, mut mdfind_child_process) =
execute_mdfind_query(&query_string, from, size, &config)?;
// Convert results to documents
let mut hits: Vec<(Document, f64)> = Vec::new();
while let Some(res_file_path) = iter.next().await {
let file_path =
res_file_path.map_err(|io_err| io_err.to_string())?;
let icon = get_file_icon(file_path.clone()).await;
let file_path_of_type_path = camino::Utf8Path::new(&file_path);
let r#where = file_path_of_type_path
.parent()
.unwrap_or_else(|| {
panic!(
"expect path [{}] to have a parent, but it does not",
file_path
);
})
.to_string();
let file_name = file_path_of_type_path.file_name().unwrap_or_else(|| {
panic!(
"expect path [{}] to have a file name, but it does not",
file_path
);
});
let on_opened = OnOpened::Document {
url: file_path.clone(),
};
let doc = Document {
id: file_path.clone(),
title: Some(file_name.to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(EXTENSION_ID.into()),
id: Some(EXTENSION_ID.into()),
icon: Some(String::from("font_Filesearch")),
}),
category: Some(r#where),
on_opened: Some(on_opened),
url: Some(file_path),
icon: Some(icon.to_string()),
..Default::default()
};
hits.push((doc, SCORE));
}
// Kill the mdfind process once we get the needed results to prevent zombie
// processes.
mdfind_child_process
.kill()
.await
.map_err(|e| format!("{:?}", e))?;
Ok(hits)
}
/// Return an array containing the `mdfind` command and its arguments.
fn build_mdfind_query(query_string: &str, config: &FileSearchConfig) -> Vec<String> {
let mut args = vec!["mdfind".to_string()];
match config.search_by {
SearchBy::Name => {
args.push(format!("kMDItemFSName == '*{}*'", query_string));
}
SearchBy::NameAndContents => {
args.push(format!(
"kMDItemFSName == '*{}*' || kMDItemTextContent == '{}'",
query_string, query_string
));
}
}
// Add search paths using -onlyin
for path in &config.search_paths {
if Path::new(path).exists() {
args.extend_from_slice(&["-onlyin".to_string(), path.to_string()]);
}
}
args
}
/// Spawn the `mdfind` child process and return an async iterator over its output,
/// allowing us to collect the results asynchronously.
///
/// # Return value:
///
/// * impl Stream: an async iterator that will yield the matched files
/// * Child: The handle to the mdfind process, we need to kill it once we
/// collect all the results to avoid zombie processes.
fn execute_mdfind_query(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<(impl Stream<Item = std::io::Result<String>>, Child), String> {
let args = build_mdfind_query(query_string, &config);
let (rx, tx) = std::io::pipe().unwrap();
let rx_owned = OwnedFd::from(rx);
let async_rx = tokio::net::unix::pipe::Receiver::from_owned_fd(rx_owned).unwrap();
let buffered_rx = BufReader::new(async_rx);
let lines = LinesStream::new(buffered_rx.lines());
let child = Command::new(&args[0])
.args(&args[1..])
.stdout(tx)
.stderr(std::process::Stdio::null())
.spawn()
.map_err(|e| format!("Failed to spawn mdfind: {}", e))?;
let config_clone = config.clone();
let iter = lines
.filter(move |res_path| {
std::future::ready({
match res_path {
Ok(path) => !should_be_filtered_out(&config_clone, path),
Err(_) => {
// Don't filter out Err() values
true
}
}
})
})
.skip(from)
.take(size);
Ok((iter, child))
}
/// If `file_path` should be removed from the search results given the filter
/// conditions specified in `config`.
fn should_be_filtered_out(config: &FileSearchConfig, file_path: &str) -> bool {
let is_excluded = config
.exclude_paths
.iter()
.any(|exclude_path| file_path.starts_with(exclude_path));
if is_excluded {
return true;
}
let matches_file_type = if config.file_types.is_empty() {
true
} else {
let path_obj = camino::Utf8Path::new(&file_path);
if let Some(extension) = path_obj.extension() {
config
.file_types
.iter()
.any(|file_type| file_type == extension)
} else {
// `config.file_types` is not empty, then the search results
// should have extensions.
false
}
};
!matches_file_type
}

View File

@@ -1,10 +0,0 @@
#[cfg(target_os = "macos")]
mod macos;
#[cfg(target_os = "windows")]
mod windows;
// `hits()` function is platform-specific, export the corresponding impl.
#[cfg(target_os = "macos")]
pub(crate) use macos::hits;
#[cfg(target_os = "windows")]
pub(crate) use windows::hits;

View File

@@ -1,630 +0,0 @@
//! # Credits
//!
//! https://github.com/IRONAGE-Park/rag-sample/blob/3f0ad8c8012026cd3a7e453d08f041609426cb91/src/native/windows.rs
//! is the starting point of this implementation.
use super::super::config::FileSearchConfig;
use super::super::config::SearchBy;
use super::super::EXTENSION_ID;
use crate::common::document::{DataSourceReference, Document};
use crate::extension::OnOpened;
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::util::file::get_file_icon;
use windows::{
core::{w, IUnknown, Interface, GUID, PWSTR},
Win32::System::{
Com::{CoCreateInstance, CLSCTX_INPROC_SERVER},
Ole::{OleInitialize, OleUninitialize},
Search::{
IAccessor, ICommand, ICommandText, IDBCreateCommand, IDBCreateSession, IDBInitialize,
IDataInitialize, IRowset, DBACCESSOR_ROWDATA, DBBINDING, DBMEMOWNER_CLIENTOWNED,
DBPARAMIO_NOTPARAM, DBPART_VALUE, DBTYPE_WSTR, DB_NULL_HCHAPTER, HACCESSOR,
MSDAINITIALIZE,
},
},
};
/// Owned version of `PWSTR` that holds the heap memory.
///
/// Use `as_pwstr()` to convert it to a raw pointer.
struct PwStrOwned(Vec<u16>);
impl PwStrOwned {
/// # SAFETY
///
/// The returned `PWSTR` is basically a raw pointer, it is only valid within the
/// lifetime of `PwStrOwned`.
unsafe fn as_pwstr(&mut self) -> PWSTR {
let raw_ptr = self.0.as_mut_ptr();
PWSTR::from_raw(raw_ptr)
}
}
/// Construct `PwStrOwned` from any `str`.
impl<S: AsRef<str> + ?Sized> From<&S> for PwStrOwned {
fn from(value: &S) -> Self {
let mut utf16_bytes = value.as_ref().encode_utf16().collect::<Vec<u16>>();
utf16_bytes.push(0); // the tailing NULL
PwStrOwned(utf16_bytes)
}
}
/// Helper function to construct the Windows Search SQL.
///
/// Paging is not natively supported by windows Search SQL, it only supports `size`
/// via the `TOP` keyword ("SELECT TOP {n} {columns}"). The SQL returned by this
/// function will have `{n}` set to `from + size`, then we will manually implement
/// paging.
fn query_sql(query_string: &str, from: usize, size: usize, config: &FileSearchConfig) -> String {
let top_n = from
.checked_add(size)
.expect("[from + size] cannot fit into an [usize]");
// System.ItemUrl is a column that contains the file path
// example: "file:C:/Users/desktop.ini"
//
// System.Search.Rank is the relevance score
let mut sql = format!(
"SELECT TOP {} System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE",
top_n
);
// Use debug print to escape the newline character, which cannot be handled by Windows Search.
let query_string_debug_print = format!("{:?}", query_string);
// Debug print will be double quoted, we need to trim them.
let query_string_debug_print_len = query_string_debug_print.len();
let query_string = &query_string_debug_print[1..(query_string_debug_print_len - 1)];
let search_by_predicate = match config.search_by {
SearchBy::Name => {
// `contains(System.FileName, '{query_string}')` would be faster
// because it uses inverted index, but that's not what we want
// due to the limitation of tokenization. For example, suppose "Coco AI.rs"
// will be tokenized to `["Coco", "AI", "rs"]`, then if users search
// via `Co`, this file won't be returned because term `Co` does not
// exist in the index.
//
// So we use wildcard instead even though it is slower.
format!("(System.FileName LIKE '%{query_string}%')")
}
SearchBy::NameAndContents => {
// Windows File Search does not support searching by file content.
//
// `CONTAINS('query_string')` would search all columns for `query_string`,
// this is the closest solution we have.
format!("((System.FileName LIKE '%{query_string}%') OR CONTAINS('{query_string}'))")
}
};
let search_paths_predicate: Option<String> = {
if config.search_paths.is_empty() {
None
} else {
let mut output = String::from("(");
for (idx, search_path) in config.search_paths.iter().enumerate() {
if idx != 0 {
output.push_str(" OR ");
}
output.push_str("SCOPE = 'file:");
output.push_str(&search_path);
output.push('\'');
}
output.push(')');
Some(output)
}
};
let exclude_paths_predicate: Option<String> = {
if config.exclude_paths.is_empty() {
None
} else {
let mut output = String::from("(");
for (idx, exclude_path) in config.exclude_paths.iter().enumerate() {
if idx != 0 {
output.push_str(" AND ");
}
output.push_str("(NOT SCOPE = 'file:");
output.push_str(&exclude_path);
output.push('\'');
output.push(')');
}
output.push(')');
Some(output)
}
};
let file_types_predicate: Option<String> = {
if config.file_types.is_empty() {
None
} else {
let mut output = String::from("(");
for (idx, file_type) in config.file_types.iter().enumerate() {
if idx != 0 {
output.push_str(" OR ");
}
// NOTE that this column contains a starting dot
output.push_str("System.FileExtension = '.");
output.push_str(&file_type);
output.push('\'');
}
output.push(')');
Some(output)
}
};
sql.push(' ');
sql.push_str(search_by_predicate.as_str());
if let Some(search_paths_predicate) = search_paths_predicate {
sql.push_str(" AND ");
sql.push_str(search_paths_predicate.as_str());
}
if let Some(exclude_paths_predicate) = exclude_paths_predicate {
sql.push_str(" AND ");
sql.push_str(exclude_paths_predicate.as_str());
}
if let Some(file_types_predicate) = file_types_predicate {
sql.push_str(" AND ");
sql.push_str(file_types_predicate.as_str());
}
sql
}
/// Default GUID for Search.CollatorDSO.1
const DBGUID_DEFAULT: GUID = GUID {
data1: 0xc8b521fb,
data2: 0x5cf3,
data3: 0x11ce,
data4: [0xad, 0xe5, 0x00, 0xaa, 0x00, 0x44, 0x77, 0x3d],
};
unsafe fn create_accessor_handle(accessor: &IAccessor, index: usize) -> Result<HACCESSOR, String> {
let bindings = DBBINDING {
iOrdinal: index,
obValue: 0,
obStatus: 0,
obLength: 0,
dwPart: DBPART_VALUE.0 as u32,
dwMemOwner: DBMEMOWNER_CLIENTOWNED.0 as u32,
eParamIO: DBPARAMIO_NOTPARAM.0 as u32,
cbMaxLen: 512,
dwFlags: 0,
wType: DBTYPE_WSTR.0 as u16,
bPrecision: 0,
bScale: 0,
..Default::default()
};
let mut status = 0;
let mut accessor_handle = HACCESSOR::default();
unsafe {
accessor
.CreateAccessor(
DBACCESSOR_ROWDATA.0 as u32,
1,
&bindings,
0,
&mut accessor_handle,
Some(&mut status),
)
.map_err(|e| e.to_string())?;
}
Ok(accessor_handle)
}
fn create_db_initialize() -> Result<IDBInitialize, String> {
unsafe {
let data_init: IDataInitialize =
CoCreateInstance(&MSDAINITIALIZE, None, CLSCTX_INPROC_SERVER)
.map_err(|e| e.to_string())?;
let mut unknown: Option<IUnknown> = None;
data_init
.GetDataSource(
None,
CLSCTX_INPROC_SERVER.0,
w!("provider=Search.CollatorDSO.1;EXTENDED PROPERTIES=\"Application=Windows\""),
&IDBInitialize::IID,
&mut unknown as *mut _ as *mut _,
)
.map_err(|e| e.to_string())?;
Ok(unknown.unwrap().cast().map_err(|e| e.to_string())?)
}
}
fn create_command(db_init: IDBInitialize) -> Result<ICommandText, String> {
unsafe {
let db_create_session: IDBCreateSession = db_init.cast().map_err(|e| e.to_string())?;
let session: IUnknown = db_create_session
.CreateSession(None, &IUnknown::IID)
.map_err(|e| e.to_string())?;
let db_create_command: IDBCreateCommand = session.cast().map_err(|e| e.to_string())?;
Ok(db_create_command
.CreateCommand(None, &ICommand::IID)
.map_err(|e| e.to_string())?
.cast()
.map_err(|e| e.to_string())?)
}
}
fn execute_windows_search_sql(sql_query: &str) -> Result<Vec<(String, String)>, String> {
unsafe {
let mut pwstr_owned_sql = PwStrOwned::from(sql_query);
// SAFETY: pwstr_owned_sql will live for the whole lifetime of this function.
let sql_query = pwstr_owned_sql.as_pwstr();
let db_init = create_db_initialize()?;
db_init.Initialize().map_err(|e| e.to_string())?;
let command = create_command(db_init)?;
// Set the command text
command
.SetCommandText(&DBGUID_DEFAULT, sql_query)
.map_err(|e| e.to_string())?;
// Execute the command
let mut rowset: Option<IRowset> = None;
command
.Execute(
None,
&IRowset::IID,
None,
None,
Some(&mut rowset as *mut _ as *mut _),
)
.map_err(|e| e.to_string())?;
let rowset = rowset.ok_or_else(|| {
format!(
"No rowset returned for query: {}",
// SAFETY: the raw pointer is not dangling
sql_query
.to_string()
.expect("the conversion should work as `sql_query` was created from a String",)
)
})?;
let accessor: IAccessor = rowset
.cast()
.map_err(|e| format!("Failed to cast to IAccessor: {}", e.to_string()))?;
let mut output = Vec::new();
let mut count = 0;
loop {
let mut rows_fetched = 0;
let mut row_handles = [std::ptr::null_mut(); 1];
let result = rowset.GetNextRows(
DB_NULL_HCHAPTER as usize,
0,
&mut rows_fetched,
&mut row_handles,
);
if result.is_err() {
break;
}
if rows_fetched == 0 {
break;
}
let mut data = Vec::new();
for i in 0..2 {
let mut item_name = [0u16; 512];
let accessor_handle = create_accessor_handle(&accessor, i + 1)?;
rowset
.GetData(
*row_handles[0],
accessor_handle,
item_name.as_mut_ptr() as *mut _,
)
.map_err(|e| {
format!(
"Failed to get data at count {}, index {}: {}",
count,
i,
e.to_string()
)
})?;
let name = String::from_utf16_lossy(&item_name);
// Remove null characters
data.push(name.trim_end_matches('\u{0000}').to_string());
accessor
.ReleaseAccessor(accessor_handle, None)
.map_err(|e| {
format!(
"Failed to release accessor at count {}, index {}: {}",
count,
i,
e.to_string()
)
})?;
}
output.push((data[0].clone(), data[1].clone()));
count += 1;
rowset
.ReleaseRows(
1,
row_handles[0],
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
)
.map_err(|e| {
format!(
"Failed to release rows at count {}: {}",
count,
e.to_string()
)
})?;
}
Ok(output)
}
}
pub(crate) async fn hits(
query_string: &str,
from: usize,
size: usize,
config: &FileSearchConfig,
) -> Result<Vec<(Document, f64)>, String> {
let sql = query_sql(query_string, from, size, config);
unsafe { OleInitialize(None).map_err(|e| e.to_string())? };
let result = execute_windows_search_sql(&sql)?;
unsafe { OleUninitialize() };
// .take(size) is not needed as `result` will contain `from+size` files at most
let result_with_paging = result.into_iter().skip(from);
// result_with_paging won't contain more than `size` entries
let mut hits = Vec::with_capacity(size);
const ITEM_URL_PREFIX: &str = "file:";
const ITEM_URL_PREFIX_LEN: usize = ITEM_URL_PREFIX.len();
for (item_url, score_str) in result_with_paging {
// path returned from Windows Search contains a prefix, we need to trim it.
//
// "file:C:/Users/desktop.ini" => "C:/Users/desktop.ini"
let file_path = &item_url[ITEM_URL_PREFIX_LEN..];
let icon = get_file_icon(file_path.to_string()).await;
let file_path_of_type_path = camino::Utf8Path::new(&file_path);
let r#where = file_path_of_type_path
.parent()
.unwrap_or_else(|| {
panic!(
"expect path [{}] to have a parent, but it does not",
file_path
);
})
.to_string();
let file_name = file_path_of_type_path.file_name().unwrap_or_else(|| {
panic!(
"expect path [{}] to have a file name, but it does not",
file_path
);
});
let on_opened = OnOpened::Document {
url: file_path.to_string(),
};
let doc = Document {
id: file_path.to_string(),
title: Some(file_name.to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(EXTENSION_ID.into()),
id: Some(EXTENSION_ID.into()),
icon: Some(String::from("font_Filesearch")),
}),
category: Some(r#where),
on_opened: Some(on_opened),
url: Some(file_path.into()),
icon: Some(icon.to_string()),
..Default::default()
};
let score: f64 = score_str.parse().expect(
"System.Search.Rank should be in range [0, 1000], which should be valid for [f64]",
);
hits.push((doc, score));
}
Ok(hits)
}
// Skip these tests in our CI, they fail with the following error
// "SQL is invalid: "0x80041820""
//
// I have no idea about the underlying root cause
#[cfg(all(test, not(ci)))]
mod test {
use super::*;
/// Helper function for ensuring `sql` is valid SQL by actually executing it.
fn ensure_it_is_valid_sql(sql: &str) {
unsafe { OleInitialize(None).unwrap() };
execute_windows_search_sql(&sql).expect("SQL is invalid");
unsafe { OleUninitialize() };
}
#[test]
fn test_query_sql_empty_config_search_by_name() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("coco", 0, 10, &config);
assert_eq!(
sql,
"SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%coco%')"
);
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_empty_config_search_by_name_and_content() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::NameAndContents,
};
let sql = query_sql("coco", 0, 10, &config);
assert_eq!(sql, "SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE ((System.FileName LIKE '%coco%') OR CONTAINS('coco'))");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_search_paths() {
let config = FileSearchConfig {
search_paths: vec!["C:/Users/".into()],
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("coco", 0, 10, &config);
assert_eq!(sql, "SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%coco%') AND (SCOPE = 'file:C:/Users/')");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_multiple_search_paths() {
let config = FileSearchConfig {
search_paths: vec![
"C:/Users/".into(),
"D:/Projects/".into(),
"E:/Documents/".into(),
],
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("test", 0, 5, &config);
assert_eq!(sql, "SELECT TOP 5 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%test%') AND (SCOPE = 'file:C:/Users/' OR SCOPE = 'file:D:/Projects/' OR SCOPE = 'file:E:/Documents/')");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_exclude_paths() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: vec!["C:/Windows/".into()],
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("file", 0, 20, &config);
assert_eq!(sql, "SELECT TOP 20 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%file%') AND ((NOT SCOPE = 'file:C:/Windows/'))");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_multiple_exclude_paths() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: vec!["C:/Windows/".into(), "C:/System/".into(), "C:/Temp/".into()],
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("data", 5, 15, &config);
assert_eq!(sql, "SELECT TOP 20 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%data%') AND ((NOT SCOPE = 'file:C:/Windows/') AND (NOT SCOPE = 'file:C:/System/') AND (NOT SCOPE = 'file:C:/Temp/'))");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_file_types() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: vec!["txt".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("readme", 0, 10, &config);
assert_eq!(sql, "SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%readme%') AND (System.FileExtension = '.txt')");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_multiple_file_types() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: vec!["rs".into(), "toml".into(), "md".into(), "json".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("config", 0, 50, &config);
assert_eq!(sql, "SELECT TOP 50 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%config%') AND (System.FileExtension = '.rs' OR System.FileExtension = '.toml' OR System.FileExtension = '.md' OR System.FileExtension = '.json')");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_all_fields_combined() {
let config = FileSearchConfig {
search_paths: vec!["C:/Projects/".into(), "D:/Code/".into()],
exclude_paths: vec!["C:/Projects/temp/".into()],
file_types: vec!["rs".into(), "ts".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("main", 10, 25, &config);
assert_eq!(sql, "SELECT TOP 35 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%main%') AND (SCOPE = 'file:C:/Projects/' OR SCOPE = 'file:D:/Code/') AND ((NOT SCOPE = 'file:C:/Projects/temp/')) AND (System.FileExtension = '.rs' OR System.FileExtension = '.ts')");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_with_special_characters() {
let config = FileSearchConfig {
search_paths: vec!["C:/Users/John Doe/".into()],
exclude_paths: Vec::new(),
file_types: vec!["c++".into()],
search_by: SearchBy::Name,
};
let sql = query_sql("hello-world", 0, 10, &config);
assert_eq!(sql, "SELECT TOP 10 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%hello-world%') AND (SCOPE = 'file:C:/Users/John Doe/') AND (System.FileExtension = '.c++')");
ensure_it_is_valid_sql(&sql);
}
#[test]
fn test_query_sql_edge_case_large_offset() {
let config = FileSearchConfig {
search_paths: Vec::new(),
exclude_paths: Vec::new(),
file_types: Vec::new(),
search_by: SearchBy::Name,
};
let sql = query_sql("test", 100, 50, &config);
assert_eq!(
sql,
"SELECT TOP 150 System.ItemUrl, System.Search.Rank FROM SystemIndex WHERE (System.FileName LIKE '%test%')"
);
ensure_it_is_valid_sql(&sql);
}
}

View File

@@ -1,90 +0,0 @@
pub(crate) mod config;
pub(crate) mod implementation;
use super::super::LOCAL_QUERY_SOURCE_TYPE;
use crate::common::{
error::SearchError,
search::{QueryResponse, QuerySource, SearchQuery},
traits::SearchSource,
};
use async_trait::async_trait;
use config::FileSearchConfig;
use hostname;
pub(crate) const EXTENSION_ID: &str = "File Search";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "File Search",
"name": "File Search",
"platforms": ["macos", "windows"],
"description": "Search files on your system",
"icon": "font_Filesearch",
"type": "extension"
}
"#;
pub struct FileSearchExtensionSearchSource;
#[async_trait]
impl SearchSource for FileSearchExtensionSearchSource {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(EXTENSION_ID.into())
.to_string_lossy()
.into(),
id: EXTENSION_ID.into(),
}
}
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError> {
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
let from = usize::try_from(query.from).expect("from too big");
let size = usize::try_from(query.size).expect("size too big");
let query_string = query_string.trim();
if query_string.is_empty() {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
// Get configuration from tauri store
let config = FileSearchConfig::get();
// If search paths are empty, then the hit should be empty.
//
// Without this, empty search paths will result in a mdfind that has no `-onlyin`
// option, which will in turn query the whole disk volume.
if config.search_paths.is_empty() {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
}
// Execute search in a blocking task
let query_source = self.get_type();
let hits = implementation::hits(&query_string, from, size, &config).await.map_err(SearchError::InternalError)?;
let total_hits = hits.len();
Ok(QueryResponse {
source: query_source,
hits,
total_hits,
})
}
}

View File

@@ -1,554 +0,0 @@
//! Built-in extensions and related stuff.
pub mod ai_overview;
pub mod application;
pub mod calculator;
#[cfg(any(target_os = "macos", target_os = "windows"))]
pub mod file_search;
pub mod pizza_engine_runtime;
pub mod quick_ai_access;
use super::Extension;
use crate::extension::built_in::application::{set_apps_hotkey, unset_apps_hotkey};
use crate::extension::{
alter_extension_json_file, ExtensionBundleIdBorrowed, PLUGIN_JSON_FILE_NAME,
};
use crate::{SearchSourceRegistry, GLOBAL_TAURI_APP_HANDLE};
use anyhow::Context;
use std::path::{Path, PathBuf};
use std::sync::LazyLock;
use tauri::{AppHandle, Manager, Runtime};
pub(crate) static BUILT_IN_EXTENSION_DIRECTORY: LazyLock<PathBuf> = LazyLock::new(|| {
let mut resource_dir = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set")
.path()
.app_data_dir()
.expect(
"User home directory not found, which should be impossible on desktop environments",
);
resource_dir.push("built_in_extensions");
resource_dir
});
/// Helper function to load the built-in extension specified by `extension_id`, used
/// in `list_built_in_extensions()`.
///
/// For built-in extensions, users are only allowed to edit these fields:
///
/// 1. alias (if this extension supports alias)
/// 2. hotkey (if this extension supports hotkey)
/// 3. enabled
///
/// If
///
/// 1. The above fields have invalid value
/// 2. Other fields are modified
///
/// we ignore and reset them to the default value.
async fn load_built_in_extension(
built_in_extensions_dir: &Path,
extension_id: &str,
default_plugin_json_file: &str,
) -> Result<Extension, String> {
let mut extension_dir = built_in_extensions_dir.join(extension_id);
let mut default_plugin_json = serde_json::from_str::<Extension>(&default_plugin_json_file).unwrap_or_else( |e| {
panic!("the default extension {} file of built-in extension [{}] cannot be parsed as a valid [struct Extension], error [{}]", PLUGIN_JSON_FILE_NAME, extension_id, e);
});
if !extension_dir.try_exists().map_err(|e| e.to_string())? {
tokio::fs::create_dir_all(extension_dir.as_path())
.await
.map_err(|e| e.to_string())?;
}
let plugin_json_file_path = {
extension_dir.push(PLUGIN_JSON_FILE_NAME);
extension_dir
};
// If the JSON file does not exist, create a file with the default template and return.
if !plugin_json_file_path
.try_exists()
.map_err(|e| e.to_string())?
{
tokio::fs::write(plugin_json_file_path, default_plugin_json_file)
.await
.map_err(|e| e.to_string())?;
return Ok(default_plugin_json);
}
let plugin_json_file_content = tokio::fs::read_to_string(plugin_json_file_path.as_path())
.await
.map_err(|e| e.to_string())?;
let res_plugin_json = serde_json::from_str::<Extension>(&plugin_json_file_content);
let Ok(plugin_json) = res_plugin_json else {
log::warn!("user invalidated built-in extension [{}] file, overwriting it with the default template", extension_id);
// If the JSON file cannot be parsed as `struct Extension`, overwrite it with the default template and return.
tokio::fs::write(plugin_json_file_path, default_plugin_json_file)
.await
.map_err(|e| e.to_string())?;
return Ok(default_plugin_json);
};
// Users are only allowed to edit the below fields
// 1. alias (if this extension supports alias)
// 2. hotkey (if this extension supports hotkey)
// 3. enabled
// so we ignore all other fields.
let alias = if default_plugin_json.supports_alias_hotkey() {
plugin_json.alias.clone()
} else {
None
};
let hotkey = if default_plugin_json.supports_alias_hotkey() {
plugin_json.hotkey.clone()
} else {
None
};
let enabled = plugin_json.enabled;
default_plugin_json.alias = alias;
default_plugin_json.hotkey = hotkey;
default_plugin_json.enabled = enabled;
let final_plugin_json_file_content = serde_json::to_string_pretty(&default_plugin_json)
.expect("failed to serialize `struct Extension`");
tokio::fs::write(plugin_json_file_path, final_plugin_json_file_content)
.await
.map_err(|e| e.to_string())?;
Ok(default_plugin_json)
}
/// Return the built-in extension list.
///
/// Will create extension files when they are not found.
///
/// Users may put extension files in the built-in extension directory, but
/// we do not care and will ignore them.
///
/// We only read alias/hotkey/enabled from the JSON file, we have ensured that if
/// alias/hotkey is not supported, then it will be `None`. Besides that, no further
/// validation is needed because nothing could go wrong.
pub(crate) async fn list_built_in_extensions() -> Result<Vec<Extension>, String> {
let dir = BUILT_IN_EXTENSION_DIRECTORY.as_path();
let mut built_in_extensions = Vec::new();
built_in_extensions.push(
load_built_in_extension(
dir,
application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME,
application::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
dir,
calculator::DATA_SOURCE_ID,
calculator::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
dir,
ai_overview::EXTENSION_ID,
ai_overview::PLUGIN_JSON_FILE,
)
.await?,
);
built_in_extensions.push(
load_built_in_extension(
dir,
quick_ai_access::EXTENSION_ID,
quick_ai_access::PLUGIN_JSON_FILE,
)
.await?,
);
cfg_if::cfg_if! {
if #[cfg(any(target_os = "macos", target_os = "windows"))] {
built_in_extensions.push(
load_built_in_extension(
dir,
file_search::EXTENSION_ID,
file_search::PLUGIN_JSON_FILE,
)
.await?,
);
}
}
Ok(built_in_extensions)
}
pub(super) async fn init_built_in_extension<R: Runtime>(
tauri_app_handle: &AppHandle<R>,
extension: &Extension,
search_source_registry: &SearchSourceRegistry,
) -> Result<(), String> {
log::trace!("initializing built-in extensions [{}]", extension.id);
if extension.id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
search_source_registry
.register_source(application::ApplicationSearchSource)
.await;
set_apps_hotkey(&tauri_app_handle)?;
log::debug!("built-in extension [{}] initialized", extension.id);
}
if extension.id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
search_source_registry
.register_source(calculator_search)
.await;
log::debug!("built-in extension [{}] initialized", extension.id);
}
cfg_if::cfg_if! {
if #[cfg(any(target_os = "macos", target_os = "windows"))] {
if extension.id == file_search::EXTENSION_ID {
let file_system_search = file_search::FileSearchExtensionSearchSource;
search_source_registry
.register_source(file_system_search)
.await;
log::debug!("built-in extension [{}] initialized", extension.id);
}
}
}
Ok(())
}
pub(crate) fn is_extension_built_in(bundle_id: &ExtensionBundleIdBorrowed<'_>) -> bool {
bundle_id.developer.is_none()
}
pub(crate) async fn enable_built_in_extension(
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<(), String> {
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
let update_extension = |extension: &mut Extension| -> Result<(), String> {
extension.enabled = true;
Ok(())
};
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_none()
{
search_source_registry_tauri_state
.register_source(application::ApplicationSearchSource)
.await;
set_apps_hotkey(tauri_app_handle)?;
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
// Check if this is an application
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_some()
{
let app_path = bundle_id.sub_extension_id.expect("just checked it is Some");
application::enable_app_search(tauri_app_handle, app_path)?;
return Ok(());
}
if bundle_id.extension_id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
search_source_registry_tauri_state
.register_source(calculator_search)
.await;
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == quick_ai_access::EXTENSION_ID {
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == ai_overview::EXTENSION_ID {
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
cfg_if::cfg_if! {
if #[cfg(any(target_os = "macos", target_os = "windows"))] {
if bundle_id.extension_id == file_search::EXTENSION_ID {
let file_system_search = file_search::FileSearchExtensionSearchSource;
search_source_registry_tauri_state
.register_source(file_system_search)
.await;
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
}
}
Ok(())
}
pub(crate) async fn disable_built_in_extension(
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<(), String> {
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
let update_extension = |extension: &mut Extension| -> Result<(), String> {
extension.enabled = false;
Ok(())
};
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_none()
{
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
unset_apps_hotkey(tauri_app_handle)?;
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
// Check if this is an application
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_some()
{
let app_path = bundle_id.sub_extension_id.expect("just checked it is Some");
application::disable_app_search(tauri_app_handle, app_path)?;
return Ok(());
}
if bundle_id.extension_id == calculator::DATA_SOURCE_ID {
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == quick_ai_access::EXTENSION_ID {
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
if bundle_id.extension_id == ai_overview::EXTENSION_ID {
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
cfg_if::cfg_if! {
if #[cfg(any(target_os = "macos", target_os = "windows"))] {
if bundle_id.extension_id == file_search::EXTENSION_ID {
search_source_registry_tauri_state
.remove_source(bundle_id.extension_id)
.await;
alter_extension_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id,
update_extension,
)?;
return Ok(());
}
}
}
Ok(())
}
pub(crate) fn set_built_in_extension_alias(bundle_id: &ExtensionBundleIdBorrowed<'_>, alias: &str) {
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
application::set_app_alias(tauri_app_handle, app_path, alias);
}
}
}
pub(crate) fn register_built_in_extension_hotkey(
bundle_id: &ExtensionBundleIdBorrowed<'_>,
hotkey: &str,
) -> Result<(), String> {
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
application::register_app_hotkey(&tauri_app_handle, app_path, hotkey)?;
}
}
Ok(())
}
pub(crate) fn unregister_built_in_extension_hotkey(
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<(), String> {
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
application::unregister_app_hotkey(&tauri_app_handle, app_path)?;
}
}
Ok(())
}
fn split_extension_id(extension_id: &str) -> (&str, Option<&str>) {
match extension_id.find('.') {
Some(idx) => (&extension_id[..idx], Some(&extension_id[idx + 1..])),
None => (extension_id, None),
}
}
fn load_extension_from_json_file(
extension_directory: &Path,
extension_id: &str,
) -> Result<Extension, String> {
let (parent_extension_id, _opt_sub_extension_id) = split_extension_id(extension_id);
let json_file_path = {
let mut extension_directory_path = extension_directory.join(parent_extension_id);
extension_directory_path.push(PLUGIN_JSON_FILE_NAME);
extension_directory_path
};
let mut extension = serde_json::from_reader::<_, Extension>(
std::fs::File::open(&json_file_path)
.with_context(|| {
format!(
"the [{}] file for extension [{}] is missing or broken",
PLUGIN_JSON_FILE_NAME, parent_extension_id
)
})
.map_err(|e| e.to_string())?,
)
.map_err(|e| e.to_string())?;
super::canonicalize_relative_icon_path(extension_directory, &mut extension)?;
Ok(extension)
}
pub(crate) async fn is_built_in_extension_enabled(
bundle_id: &ExtensionBundleIdBorrowed<'_>,
) -> Result<bool, String> {
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
&& bundle_id.sub_extension_id.is_none()
{
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
// Check if this is an application
if bundle_id.extension_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
if let Some(app_path) = bundle_id.sub_extension_id {
return Ok(application::is_app_search_enabled(app_path));
}
}
if bundle_id.extension_id == calculator::DATA_SOURCE_ID {
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
if bundle_id.extension_id == quick_ai_access::EXTENSION_ID {
let extension = load_extension_from_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id.extension_id,
)?;
return Ok(extension.enabled);
}
if bundle_id.extension_id == ai_overview::EXTENSION_ID {
let extension = load_extension_from_json_file(
&BUILT_IN_EXTENSION_DIRECTORY.as_path(),
bundle_id.extension_id,
)?;
return Ok(extension.enabled);
}
cfg_if::cfg_if! {
if #[cfg(any(target_os = "macos", target_os = "windows"))] {
if bundle_id.extension_id == file_search::EXTENSION_ID
&& bundle_id.sub_extension_id.is_none()
{
return Ok(search_source_registry_tauri_state
.get_source(bundle_id.extension_id)
.await
.is_some());
}
}
}
unreachable!("extension [{:?}] is not a built-in extension", bundle_id)
}

View File

@@ -1,76 +0,0 @@
//! We use Pizza Engine to index applications and local files. The engine will be
//! run in the thread/runtime defined in this file.
//!
//! # Why such a thread/runtime is needed
//!
//! Generally, Tokio async runtime requires all the async tasks running on it to be
//! `Send` and `Sync`, but the async tasks created by Pizza Engine are not,
//! which forces us to create a dedicated thread/runtime to execute them.
use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::OnceLock;
pub(crate) trait SearchSourceState {
#[cfg_attr(not(feature = "use_pizza_engine"), allow(unused))]
fn as_mut_any(&mut self) -> &mut dyn Any;
}
#[async_trait::async_trait(?Send)]
pub(crate) trait Task: Send + Sync {
fn search_source_id(&self) -> &'static str;
async fn exec(&mut self, state: &mut Option<Box<dyn SearchSourceState>>);
}
pub(crate) static RUNTIME_TX: OnceLock<tokio::sync::mpsc::UnboundedSender<Box<dyn Task>>> =
OnceLock::new();
/// This function blocks until the runtime thread is ready for accepting tasks.
pub(crate) async fn start_pizza_engine_runtime() {
const THREAD_NAME: &str = "Pizza engine runtime thread";
log::trace!("starting Pizza engine runtime");
let (engine_start_signal_tx, engine_start_signal_rx) = tokio::sync::oneshot::channel();
std::thread::Builder::new()
.name(THREAD_NAME.into())
.spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
let main = async {
let mut states: HashMap<String, Option<Box<dyn SearchSourceState>>> =
HashMap::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
RUNTIME_TX.set(tx).unwrap();
engine_start_signal_tx
.send(())
.expect("engine_start_signal_rx dropped");
while let Some(mut task) = rx.recv().await {
let opt_search_source_state = match states.entry(task.search_source_id().into())
{
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(None),
};
task.exec(opt_search_source_state).await;
}
};
rt.block_on(main);
})
.unwrap_or_else(|e| {
panic!(
"failed to start thread [{}] due to error [{}]",
THREAD_NAME, e
);
});
engine_start_signal_rx
.await
.expect("engine_start_signal_tx dropped, the runtime thread could be dead");
log::trace!("Pizza engine runtime started");
}

View File

@@ -1,12 +0,0 @@
pub(super) const EXTENSION_ID: &str = "QuickAIAccess";
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "QuickAIAccess",
"name": "Quick AI Access",
"description": "...",
"icon": "font_a-QuickAIAccess",
"type": "ai_extension",
"enabled": true
}
"#;

View File

@@ -1,760 +0,0 @@
pub(crate) mod built_in;
pub(crate) mod third_party;
use crate::common::document::OnOpened;
use crate::{common::register::SearchSourceRegistry, GLOBAL_TAURI_APP_HANDLE};
use anyhow::Context;
use borrowme::{Borrow, ToOwned};
use derive_more::Display;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value as Json;
use std::collections::HashSet;
use std::path::Path;
use tauri::Manager;
use third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE;
use crate::util::platform::Platform;
pub const LOCAL_QUERY_SOURCE_TYPE: &str = "local";
const PLUGIN_JSON_FILE_NAME: &str = "plugin.json";
const ASSETS_DIRECTORY_FILE_NAME: &str = "assets";
fn default_true() -> bool {
true
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Extension {
/// Extension ID.
///
/// The ID doesn't uniquely identifies an extension; Its bundle ID (ID & developer) does.
id: String,
/// Extension name.
name: String,
/// ID of the developer.
///
/// * For built-in extensions, this will always be None.
/// * For third-party first-layer extensions, the on-disk plugin.json file
/// won't contain this field, but we will set this field for them after reading them into the memory.
/// * For third-party sub extensions, this field will be None.
developer: Option<String>,
/// Platforms supported by this extension.
///
/// If `None`, then this extension can be used on all the platforms.
#[serde(skip_serializing_if = "Option::is_none")]
platforms: Option<HashSet<Platform>>,
/// Extension description.
description: String,
//// Specify the icon for this extension,
///
/// For the `plugin.json` file, this field can be specified in multi options:
///
/// 1. It can be a path to the icon file, the path can be
///
/// * relative (relative to the "assets" directory)
/// * absolute
/// 2. It can be a font class code, e.g., 'font_coco', if you want to use
/// Coco's built-in icons.
///
/// In cases where your icon file is named similarly to a font class code, Coco
/// will treat it as an icon file if it exists, i.e., if file `<extension>/assets/font_coco`
/// exists, then Coco will use this file rather than the built-in 'font_coco' icon.
///
/// For the `struct Extension` loaded into memory, this field should be:
///
/// 1. An absolute path
/// 2. A font code
icon: String,
r#type: ExtensionType,
/// If this is a Command extension, then action defines the operation to execute
/// when the it is triggered.
#[serde(skip_serializing_if = "Option::is_none")]
action: Option<CommandAction>,
/// The link to open if this is a Quicklink extension.
#[serde(skip_serializing_if = "Option::is_none")]
quicklink: Option<Quicklink>,
// If this extension is of type Group or Extension, then it behaves like a
// directory, i.e., it could contain sub items.
commands: Option<Vec<Extension>>,
scripts: Option<Vec<Extension>>,
quicklinks: Option<Vec<Extension>>,
/// The alias of the extension.
///
/// Extension of type Group and Extension cannot have alias.
#[serde(skip_serializing_if = "Option::is_none")]
alias: Option<String>,
/// The hotkey of the extension.
///
/// Extension of type Group and Extension cannot have hotkey.
#[serde(skip_serializing_if = "Option::is_none")]
hotkey: Option<String>,
/// Is this extension enabled.
#[serde(default = "default_true")]
enabled: bool,
/// Extension settings
#[serde(skip_serializing_if = "Option::is_none")]
settings: Option<Json>,
// We do not care about these fields, just take it regardless of what it is.
screenshots: Option<Json>,
url: Option<Json>,
version: Option<Json>,
}
/// Bundle ID uniquely identifies an extension.
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub(crate) struct ExtensionBundleId {
developer: Option<String>,
extension_id: String,
sub_extension_id: Option<String>,
}
impl Borrow for ExtensionBundleId {
type Target<'a> = ExtensionBundleIdBorrowed<'a>;
fn borrow(&self) -> Self::Target<'_> {
ExtensionBundleIdBorrowed {
developer: self.developer.as_deref(),
extension_id: &self.extension_id,
sub_extension_id: self.sub_extension_id.as_deref(),
}
}
}
/// Reference version of `ExtensionBundleId`.
#[derive(Debug, Serialize, PartialEq)]
pub(crate) struct ExtensionBundleIdBorrowed<'ext> {
developer: Option<&'ext str>,
extension_id: &'ext str,
sub_extension_id: Option<&'ext str>,
}
impl ToOwned for ExtensionBundleIdBorrowed<'_> {
type Owned = ExtensionBundleId;
fn to_owned(&self) -> Self::Owned {
ExtensionBundleId {
developer: self.developer.map(|s| s.to_string()),
extension_id: self.extension_id.to_string(),
sub_extension_id: self.sub_extension_id.map(|s| s.to_string()),
}
}
}
impl<'ext> PartialEq<ExtensionBundleIdBorrowed<'ext>> for ExtensionBundleId {
fn eq(&self, other: &ExtensionBundleIdBorrowed<'ext>) -> bool {
self.developer.as_deref() == other.developer
&& self.extension_id == other.extension_id
&& self.sub_extension_id.as_deref() == other.sub_extension_id
}
}
impl<'ext> PartialEq<ExtensionBundleId> for ExtensionBundleIdBorrowed<'ext> {
fn eq(&self, other: &ExtensionBundleId) -> bool {
self.developer == other.developer.as_deref()
&& self.extension_id == other.extension_id
&& self.sub_extension_id == other.sub_extension_id.as_deref()
}
}
impl Extension {
/// WARNING: the bundle ID returned from this function always has its `sub_extension_id`
/// set to `None`, this may not be what you want.
pub(crate) fn bundle_id_borrowed(&self) -> ExtensionBundleIdBorrowed<'_> {
ExtensionBundleIdBorrowed {
developer: self.developer.as_deref(),
extension_id: &self.id,
sub_extension_id: None,
}
}
/// Whether this extension could be searched.
pub(crate) fn searchable(&self) -> bool {
self.on_opened().is_some()
}
/// Return what will happen when we open this extension.
///
/// `None` if it cannot be opened.
pub(crate) fn on_opened(&self) -> Option<OnOpened> {
match self.r#type {
ExtensionType::Group => None,
ExtensionType::Extension => None,
ExtensionType::Command => Some(OnOpened::Command {
action: self.action.clone().unwrap_or_else(|| {
panic!(
"Command extension [{}]'s [action] field is not set, something wrong with your extension validity check", self.id
)
}),
}),
ExtensionType::Application => Some(OnOpened::Application {
app_path: self.id.clone(),
}),
ExtensionType::Script => todo!("not supported yet"),
ExtensionType::Quicklink => todo!("not supported yet"),
ExtensionType::Setting => todo!("not supported yet"),
ExtensionType::Calculator => None,
ExtensionType::AiExtension => None,
}
}
pub(crate) fn get_sub_extension(&self, sub_extension_id: &str) -> Option<&Self> {
if !self.r#type.contains_sub_items() {
return None;
}
if let Some(ref commands) = self.commands {
if let Some(sub_ext) = commands.iter().find(|cmd| cmd.id == sub_extension_id) {
return Some(sub_ext);
}
}
if let Some(ref scripts) = self.scripts {
if let Some(sub_ext) = scripts.iter().find(|script| script.id == sub_extension_id) {
return Some(sub_ext);
}
}
if let Some(ref quicklinks) = self.quicklinks {
if let Some(sub_ext) = quicklinks.iter().find(|link| link.id == sub_extension_id) {
return Some(sub_ext);
}
}
None
}
pub(crate) fn get_sub_extension_mut(&mut self, sub_extension_id: &str) -> Option<&mut Self> {
if !self.r#type.contains_sub_items() {
return None;
}
if let Some(ref mut commands) = self.commands {
if let Some(sub_ext) = commands.iter_mut().find(|cmd| cmd.id == sub_extension_id) {
return Some(sub_ext);
}
}
if let Some(ref mut scripts) = self.scripts {
if let Some(sub_ext) = scripts
.iter_mut()
.find(|script| script.id == sub_extension_id)
{
return Some(sub_ext);
}
}
if let Some(ref mut quicklinks) = self.quicklinks {
if let Some(sub_ext) = quicklinks
.iter_mut()
.find(|link| link.id == sub_extension_id)
{
return Some(sub_ext);
}
}
None
}
pub(crate) fn supports_alias_hotkey(&self) -> bool {
let ty = self.r#type;
ty != ExtensionType::Group && ty != ExtensionType::Extension
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub(crate) struct CommandAction {
pub(crate) exec: String,
pub(crate) args: Option<Vec<String>>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Quicklink {
link: String,
}
#[derive(Debug, PartialEq, Deserialize, Serialize, Clone, Display, Copy)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum ExtensionType {
#[display("Group")]
Group,
#[display("Extension")]
Extension,
#[display("Command")]
Command,
#[display("Application")]
Application,
#[display("Script")]
Script,
#[display("Quicklink")]
Quicklink,
#[display("Setting")]
Setting,
#[display("Calculator")]
Calculator,
#[display("AI Extension")]
AiExtension,
}
impl ExtensionType {
pub(crate) fn contains_sub_items(&self) -> bool {
self == &Self::Group || self == &Self::Extension
}
}
/// Helper function to filter out the extensions that do not satisfy the specified conditions.
///
/// used in `list_extensions()`
fn filter_out_extensions(
extensions: &mut Vec<Extension>,
query: Option<&str>,
extension_type: Option<ExtensionType>,
list_enabled: bool,
) {
// apply `list_enabled`
if list_enabled {
extensions.retain(|ext| ext.enabled);
for extension in extensions.iter_mut() {
if extension.r#type.contains_sub_items() {
if let Some(ref mut commands) = extension.commands {
commands.retain(|cmd| cmd.enabled);
}
if let Some(ref mut scripts) = extension.scripts {
scripts.retain(|script| script.enabled);
}
if let Some(ref mut quicklinks) = extension.quicklinks {
quicklinks.retain(|link| link.enabled);
}
}
}
}
// apply extension type filter to non-group/extension extensions
if let Some(extension_type) = extension_type {
assert!(
extension_type != ExtensionType::Group && extension_type != ExtensionType::Extension,
"filtering in folder extensions is pointless"
);
extensions.retain(|ext| {
let ty = ext.r#type;
ty == ExtensionType::Group || ty == ExtensionType::Extension || ty == extension_type
});
// Filter sub-extensions to only include the requested type
for extension in extensions.iter_mut() {
if extension.r#type.contains_sub_items() {
if let Some(ref mut commands) = extension.commands {
commands.retain(|cmd| cmd.r#type == extension_type);
}
if let Some(ref mut scripts) = extension.scripts {
scripts.retain(|script| script.r#type == extension_type);
}
if let Some(ref mut quicklinks) = extension.quicklinks {
quicklinks.retain(|link| link.r#type == extension_type);
}
}
}
// Application is special, technically, it should never be filtered out by
// this condition. But if our users will be surprising if they choose a
// non-Application type and see it in the results. So we do this to remedy the
// issue
if let Some(idx) = extensions.iter().position(|ext| {
ext.developer.is_none()
&& ext.id == built_in::application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
}) {
if extension_type != ExtensionType::Application {
extensions.remove(idx);
}
}
}
// apply query filter
if let Some(query) = query {
let match_closure = |ext: &Extension| {
let lowercase_title = ext.name.to_lowercase();
let lowercase_alias = ext.alias.as_ref().map(|alias| alias.to_lowercase());
let lowercase_query = query.to_lowercase();
lowercase_title.contains(&lowercase_query)
|| lowercase_alias.map_or(false, |alias| alias.contains(&lowercase_query))
};
extensions.retain(|ext| {
if ext.r#type.contains_sub_items() {
// Keep all group/extension types
true
} else {
// Apply filter to non-group/extension types
match_closure(ext)
}
});
// Filter sub-extensions in groups and extensions
for extension in extensions.iter_mut() {
if extension.r#type.contains_sub_items() {
if let Some(ref mut commands) = extension.commands {
commands.retain(&match_closure);
}
if let Some(ref mut scripts) = extension.scripts {
scripts.retain(&match_closure);
}
if let Some(ref mut quicklinks) = extension.quicklinks {
quicklinks.retain(&match_closure);
}
}
}
}
}
/// Return value:
///
/// * boolean: indicates if we found any invalid extensions
/// * Vec<Extension>: loaded extensions
#[tauri::command]
pub(crate) async fn list_extensions(
query: Option<String>,
extension_type: Option<ExtensionType>,
list_enabled: bool,
) -> Result<(bool, Vec<Extension>), String> {
log::trace!("loading extensions");
let third_party_dir = third_party::THIRD_PARTY_EXTENSIONS_DIRECTORY.as_path();
if !third_party_dir.try_exists().map_err(|e| e.to_string())? {
tokio::fs::create_dir_all(third_party_dir)
.await
.map_err(|e| e.to_string())?;
}
let (third_party_found_invalid_extension, mut third_party_extensions) =
third_party::list_third_party_extensions(third_party_dir).await?;
let built_in_extensions = built_in::list_built_in_extensions().await?;
let found_invalid_extension = third_party_found_invalid_extension;
let mut extensions = {
third_party_extensions.extend(built_in_extensions);
third_party_extensions
};
filter_out_extensions(
&mut extensions,
query.as_deref(),
extension_type,
list_enabled,
);
// Cleanup after filtering extensions, don't do it if filter is not performed.
//
// Remove parent extensions (Group/Extension types) that have no sub-items after filtering
let filter_performed = query.is_some() || extension_type.is_some() || list_enabled;
if filter_performed {
extensions.retain(|ext| {
if !ext.r#type.contains_sub_items() {
return true;
}
// We don't do this filter to applications since it is always empty, load at runtime.
if ext.developer.is_none()
&& ext.id == built_in::application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME
{
return true;
}
let has_commands = ext
.commands
.as_ref()
.map_or(false, |commands| !commands.is_empty());
let has_scripts = ext
.scripts
.as_ref()
.map_or(false, |scripts| !scripts.is_empty());
let has_quicklinks = ext
.quicklinks
.as_ref()
.map_or(false, |quicklinks| !quicklinks.is_empty());
has_commands || has_scripts || has_quicklinks
});
}
Ok((found_invalid_extension, extensions))
}
pub(crate) async fn init_extensions(mut extensions: Vec<Extension>) -> Result<(), String> {
log::trace!("initializing extensions");
let tauri_app_handle = GLOBAL_TAURI_APP_HANDLE
.get()
.expect("global tauri app handle not set");
let search_source_registry_tauri_state = tauri_app_handle.state::<SearchSourceRegistry>();
built_in::application::ApplicationSearchSource::prepare_index_and_store(
tauri_app_handle.clone(),
)
.await?;
// extension store
search_source_registry_tauri_state
.register_source(third_party::store::ExtensionStore)
.await;
// Init the built-in enabled extensions
for built_in_extension in extensions
.extract_if(.., |ext| {
built_in::is_extension_built_in(&ext.bundle_id_borrowed())
})
.filter(|ext| ext.enabled)
{
built_in::init_built_in_extension(
tauri_app_handle,
&built_in_extension,
&search_source_registry_tauri_state,
)
.await?;
}
// Now the third-party extensions
let third_party_search_source = third_party::ThirdPartyExtensionsSearchSource::new(extensions);
third_party_search_source.init().await?;
let third_party_search_source_clone = third_party_search_source.clone();
// Set the global search source so that we can access it in `#[tauri::command]`s
// ignore the result because this function will be invoked twice, which
// means this global variable will be set twice.
let _ = THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.set(third_party_search_source_clone);
search_source_registry_tauri_state
.register_source(third_party_search_source)
.await;
Ok(())
}
#[tauri::command]
pub(crate) async fn enable_extension(bundle_id: ExtensionBundleId) -> Result<(), String> {
let bundle_id_borrowed = bundle_id.borrow();
if built_in::is_extension_built_in(&bundle_id_borrowed) {
built_in::enable_built_in_extension(&bundle_id_borrowed).await?;
return Ok(());
}
third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").enable_extension(&bundle_id_borrowed).await
}
#[tauri::command]
pub(crate) async fn disable_extension(bundle_id: ExtensionBundleId) -> Result<(), String> {
let bundle_id_borrowed = bundle_id.borrow();
if built_in::is_extension_built_in(&bundle_id_borrowed) {
built_in::disable_built_in_extension(&bundle_id_borrowed).await?;
return Ok(());
}
third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").disable_extension(&bundle_id_borrowed).await
}
#[tauri::command]
pub(crate) async fn set_extension_alias(
bundle_id: ExtensionBundleId,
alias: String,
) -> Result<(), String> {
let bundle_id_borrowed = bundle_id.borrow();
if built_in::is_extension_built_in(&bundle_id_borrowed) {
built_in::set_built_in_extension_alias(&bundle_id_borrowed, &alias);
return Ok(());
}
third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").set_extension_alias(&bundle_id_borrowed, &alias).await
}
#[tauri::command]
pub(crate) async fn register_extension_hotkey(
bundle_id: ExtensionBundleId,
hotkey: String,
) -> Result<(), String> {
let bundle_id_borrowed = bundle_id.borrow();
if built_in::is_extension_built_in(&bundle_id_borrowed) {
built_in::register_built_in_extension_hotkey(&bundle_id_borrowed, &hotkey)?;
return Ok(());
}
third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").register_extension_hotkey(&bundle_id_borrowed, &hotkey).await
}
/// NOTE: this function won't error out if the extension specified by `extension_id`
/// has no hotkey set because we need it to behave like this.
#[tauri::command]
pub(crate) async fn unregister_extension_hotkey(
bundle_id: ExtensionBundleId,
) -> Result<(), String> {
let bundle_id_borrowed = bundle_id.borrow();
if built_in::is_extension_built_in(&bundle_id_borrowed) {
built_in::unregister_built_in_extension_hotkey(&bundle_id_borrowed)?;
return Ok(());
}
third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").unregister_extension_hotkey(&bundle_id_borrowed).await?;
Ok(())
}
#[tauri::command]
pub(crate) async fn is_extension_enabled(bundle_id: ExtensionBundleId) -> Result<bool, String> {
let bundle_id_borrowed = bundle_id.borrow();
if built_in::is_extension_built_in(&bundle_id_borrowed) {
return built_in::is_built_in_extension_enabled(&bundle_id_borrowed).await;
}
third_party::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE.get().expect("global third party search source not set, looks like init_extensions() has not been executed").is_extension_enabled(&bundle_id_borrowed).await
}
pub(crate) fn canonicalize_relative_icon_path(
extension_dir: &Path,
extension: &mut Extension,
) -> Result<(), String> {
fn _canonicalize_relative_icon_path(
extension_dir: &Path,
extension: &mut Extension,
) -> Result<(), String> {
let icon_str = &extension.icon;
let icon_path = Path::new(icon_str);
if icon_path.is_relative() {
let absolute_icon_path = {
let mut assets_directory = extension_dir.join(ASSETS_DIRECTORY_FILE_NAME);
assets_directory.push(icon_path);
assets_directory
};
if absolute_icon_path.try_exists().map_err(|e| e.to_string())? {
extension.icon = absolute_icon_path
.into_os_string()
.into_string()
.expect("path should be UTF-8 encoded");
}
}
Ok(())
}
_canonicalize_relative_icon_path(extension_dir, extension)?;
if let Some(commands) = &mut extension.commands {
for command in commands {
_canonicalize_relative_icon_path(extension_dir, command)?;
}
}
if let Some(scripts) = &mut extension.scripts {
for script in scripts {
_canonicalize_relative_icon_path(extension_dir, script)?;
}
}
if let Some(quicklinks) = &mut extension.quicklinks {
for quicklink in quicklinks {
_canonicalize_relative_icon_path(extension_dir, quicklink)?;
}
}
Ok(())
}
fn alter_extension_json_file(
extension_directory: &Path,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
how: impl Fn(&mut Extension) -> Result<(), String>,
) -> Result<(), String> {
/// Perform `how` against the extension specified by `extension_id`.
///
/// Please note that `bundle` could point to a sub extension if `sub_extension_id` is Some.
pub(crate) fn modify(
root_extension: &mut Extension,
bundle_id: &ExtensionBundleIdBorrowed<'_>,
how: impl FnOnce(&mut Extension) -> Result<(), String>,
) -> Result<(), String> {
let (parent_extension_id, opt_sub_extension_id) =
(bundle_id.extension_id, bundle_id.sub_extension_id);
assert_eq!(
parent_extension_id, root_extension.id,
"modify() should be invoked against a parent extension"
);
let Some(sub_extension_id) = opt_sub_extension_id else {
how(root_extension)?;
return Ok(());
};
// Search in commands
if let Some(ref mut commands) = root_extension.commands {
if let Some(command) = commands.iter_mut().find(|cmd| cmd.id == sub_extension_id) {
how(command)?;
return Ok(());
}
}
// Search in scripts
if let Some(ref mut scripts) = root_extension.scripts {
if let Some(script) = scripts.iter_mut().find(|scr| scr.id == sub_extension_id) {
how(script)?;
return Ok(());
}
}
// Search in quicklinks
if let Some(ref mut quicklinks) = root_extension.quicklinks {
if let Some(link) = quicklinks
.iter_mut()
.find(|lnk| lnk.id == sub_extension_id)
{
how(link)?;
return Ok(());
}
}
Err(format!(
"extension [{:?}] not found in {:?}",
bundle_id, root_extension
))
}
log::debug!(
"altering extension JSON file for extension [{:?}]",
bundle_id
);
let json_file_path = {
let mut path = extension_directory.to_path_buf();
if let Some(developer) = bundle_id.developer {
path.push(developer);
}
path.push(bundle_id.extension_id);
path.push(PLUGIN_JSON_FILE_NAME);
path
};
let mut extension = serde_json::from_reader::<_, Extension>(
std::fs::File::open(&json_file_path)
.with_context(|| {
format!(
"the [{}] file for extension [{:?}] is missing or broken",
PLUGIN_JSON_FILE_NAME, bundle_id
)
})
.map_err(|e| e.to_string())?,
)
.map_err(|e| e.to_string())?;
modify(&mut extension, bundle_id, how)?;
std::fs::write(
&json_file_path,
serde_json::to_string_pretty(&extension).map_err(|e| e.to_string())?,
)
.map_err(|e| e.to_string())?;
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,318 +0,0 @@
//! Extension store related stuff.
use super::LOCAL_QUERY_SOURCE_TYPE;
use crate::common::document::DataSourceReference;
use crate::common::document::Document;
use crate::common::error::SearchError;
use crate::common::search::QueryResponse;
use crate::common::search::QuerySource;
use crate::common::search::SearchQuery;
use crate::common::traits::SearchSource;
use crate::extension::canonicalize_relative_icon_path;
use crate::extension::third_party::THIRD_PARTY_EXTENSIONS_DIRECTORY;
use crate::extension::Extension;
use crate::extension::PLUGIN_JSON_FILE_NAME;
use crate::extension::THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE;
use crate::server::http_client::HttpClient;
use async_trait::async_trait;
use reqwest::StatusCode;
use serde_json::Map as JsonObject;
use serde_json::Value as Json;
use std::io::Read;
const DATA_SOURCE_ID: &str = "Extension Store";
pub(crate) struct ExtensionStore;
#[async_trait]
impl SearchSource for ExtensionStore {
fn get_type(&self) -> QuerySource {
QuerySource {
r#type: LOCAL_QUERY_SOURCE_TYPE.into(),
name: hostname::get()
.unwrap_or(DATA_SOURCE_ID.into())
.to_string_lossy()
.into(),
id: DATA_SOURCE_ID.into(),
}
}
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError> {
const SCORE: f64 = 2000.0;
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
let lowercase_query_string = query_string.to_lowercase();
let expected_str = "extension store";
if expected_str.contains(&lowercase_query_string) {
let doc = Document {
id: DATA_SOURCE_ID.to_string(),
category: Some(DATA_SOURCE_ID.to_string()),
title: Some(DATA_SOURCE_ID.to_string()),
icon: Some("font_Store".to_string()),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(DATA_SOURCE_ID.into()),
id: Some(DATA_SOURCE_ID.into()),
icon: Some("font_Store".to_string()),
}),
..Default::default()
};
Ok(QueryResponse {
source: self.get_type(),
hits: vec![(doc, SCORE)],
total_hits: 1,
})
} else {
Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
})
}
}
}
#[tauri::command]
pub(crate) async fn search_extension(
query_params: Option<Vec<String>>,
) -> Result<Vec<Json>, String> {
let response = HttpClient::get(
"default_coco_server",
"store/extension/_search",
query_params,
)
.await
.map_err(|e| format!("Failed to send request: {:?}", e))?;
// The response of a ES style search request
let mut response: JsonObject<String, Json> = response
.json()
.await
.map_err(|e| format!("Failed to parse response: {:?}", e))?;
let hits_json = response
.remove("hits")
.expect("the JSON response should contain field [hits]");
let mut hits = match hits_json {
Json::Object(obj) => obj,
_ => panic!(
"field [hits] should be a JSON object, but it is not, value: [{}]",
hits_json
),
};
let Some(hits_hits_json) = hits.remove("hits") else {
return Ok(Vec::new());
};
let hits_hits = match hits_hits_json {
Json::Array(arr) => arr,
_ => panic!(
"field [hits.hits] should be an array, but it is not, value: [{}]",
hits_hits_json
),
};
let mut extensions = Vec::with_capacity(hits_hits.len());
for hit in hits_hits {
let mut hit_obj = match hit {
Json::Object(obj) => obj,
_ => panic!(
"each hit in [hits.hits] should be a JSON object, but it is not, value: [{}]",
hit
),
};
let source = hit_obj
.remove("_source")
.expect("each hit should contain field [_source]");
let mut source_obj = match source {
Json::Object(obj) => obj,
_ => panic!(
"field [_source] should be a JSON object, but it is not, value: [{}]",
source
),
};
let developer_id = source_obj
.get("developer")
.and_then(|dev| dev.get("id"))
.and_then(|id| id.as_str())
.expect("developer.id should exist")
.to_string();
let extension_id = source_obj
.get("id")
.and_then(|id| id.as_str())
.expect("extension id should exist")
.to_string();
let installed = is_extension_installed(developer_id, extension_id).await;
source_obj.insert("installed".to_string(), Json::Bool(installed));
extensions.push(Json::Object(source_obj));
}
Ok(extensions)
}
async fn is_extension_installed(developer: String, extension_id: String) -> bool {
THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE
.get()
.unwrap()
.extension_exists(&developer, &extension_id)
.await
}
#[tauri::command]
pub(crate) async fn install_extension_from_store(id: String) -> Result<(), String> {
let path = format!("store/extension/{}/_download", id);
let response = HttpClient::get("default_coco_server", &path, None)
.await
.map_err(|e| format!("Failed to download extension: {}", e))?;
if response.status() == StatusCode::NOT_FOUND {
return Err(format!("extension [{}] not found", id));
}
let bytes = response
.bytes()
.await
.map_err(|e| format!("Failed to read response bytes: {}", e))?;
let cursor = std::io::Cursor::new(bytes);
let mut archive =
zip::ZipArchive::new(cursor).map_err(|e| format!("Failed to read zip archive: {}", e))?;
// The plugin.json sent from the server does not conform to our `struct Extension` definition:
//
// 1. Its `developer` field is a JSON object, but we need a string
// 2. sub-extensions won't have their `id` fields set
//
// we need to correct it
let mut plugin_json = archive.by_name(PLUGIN_JSON_FILE_NAME).map_err(|e| e.to_string())?;
let mut plugin_json_content = String::new();
std::io::Read::read_to_string(&mut plugin_json, &mut plugin_json_content)
.map_err(|e| e.to_string())?;
let mut extension: Json = serde_json::from_str(&plugin_json_content)
.map_err(|e| format!("Failed to parse plugin.json: {}", e))?;
let mut_ref_to_developer_object: &mut Json = extension
.as_object_mut()
.expect("plugin.json should be an object")
.get_mut("developer")
.expect("plugin.json should contain field [developer]");
let developer_id = mut_ref_to_developer_object
.get("id")
.expect("plugin.json should contain [developer.id]")
.as_str()
.expect("plugin.json field [developer.id] should be a string");
*mut_ref_to_developer_object = Json::String(developer_id.into());
// Set IDs for sub-extensions (commands, quicklinks, scripts)
let mut counter = 0;
// Helper function to set IDs for array fields
fn set_ids_for_field(extension: &mut Json, field_name: &str, counter: &mut i32) {
if let Some(field) = extension.as_object_mut().unwrap().get_mut(field_name) {
if let Some(array) = field.as_array_mut() {
for item in array {
if let Some(item_obj) = item.as_object_mut() {
if !item_obj.contains_key("id") {
item_obj.insert("id".to_string(), Json::String(counter.to_string()));
*counter += 1;
}
}
}
}
}
}
set_ids_for_field(&mut extension, "commands", &mut counter);
set_ids_for_field(&mut extension, "quicklinks", &mut counter);
set_ids_for_field(&mut extension, "scripts", &mut counter);
// Now the extension JSON is valid
let mut extension: Extension = serde_json::from_value(extension).unwrap_or_else(|e| {
panic!(
"cannot parse plugin.json as struct Extension, error [{:?}]",
e
);
});
drop(plugin_json);
// Write extension files to the extension directory
let developer = extension.developer.clone().unwrap_or_default();
let extension_id = extension.id.clone();
let extension_directory = {
let mut path = THIRD_PARTY_EXTENSIONS_DIRECTORY.to_path_buf();
path.push(developer);
path.push(extension_id.as_str());
path
};
tokio::fs::create_dir_all(extension_directory.as_path())
.await
.map_err(|e| e.to_string())?;
// Extract all files except plugin.json
for i in 0..archive.len() {
let mut zip_file = archive.by_index(i).map_err(|e| e.to_string())?;
// `.name()` is safe to use in our cases, the cases listed in the below
// page won't happen to us.
//
// https://docs.rs/zip/4.2.0/zip/read/struct.ZipFile.html#method.name
//
// Example names:
//
// * `assets/icon.png`
// * `assets/screenshot.png`
// * `plugin.json`
//
// Yes, the `assets` directory is not a part of it.
let zip_file_name = zip_file.name();
// Skip the plugin.json file as we'll create it from the extension variable
if zip_file_name == PLUGIN_JSON_FILE_NAME {
continue;
}
let dest_file_path = extension_directory.join(zip_file_name);
// For cases like `assets/xxx.png`
if let Some(parent_dir) = dest_file_path.parent() && !parent_dir.exists() {
tokio::fs::create_dir_all(parent_dir).await.map_err(|e| e.to_string())?;
}
let mut dest_file = tokio::fs::File::create(&dest_file_path) .await .map_err(|e| e.to_string())?;
let mut src_bytes = Vec::with_capacity(zip_file.size().try_into().expect("we won't have a extension file that is bigger than 4GiB"));
zip_file.read_to_end(&mut src_bytes).map_err(|e| e.to_string())?;
tokio::io::copy(&mut src_bytes.as_slice(), &mut dest_file).await.map_err(|e| e.to_string())?;
}
// Create plugin.json from the extension variable
let plugin_json_path = extension_directory.join(PLUGIN_JSON_FILE_NAME);
let extension_json = serde_json::to_string_pretty(&extension).map_err(|e| e.to_string())?;
tokio::fs::write(&plugin_json_path, extension_json)
.await
.map_err(|e| e.to_string())?;
// Turn it into an absolute path if it is a valid relative path because frontend code need this.
canonicalize_relative_icon_path(&extension_directory, &mut extension)?;
THIRD_PARTY_EXTENSIONS_SEARCH_SOURCE
.get()
.unwrap()
.add_extension(extension)
.await;
Ok(())
}

View File

@@ -1,7 +1,7 @@
mod assistant;
mod autostart;
mod common;
mod extension;
mod local;
mod search;
mod server;
mod settings;
@@ -11,15 +11,19 @@ mod util;
use crate::common::register::SearchSourceRegistry;
// use crate::common::traits::SearchSource;
use crate::common::{CHECK_WINDOW_LABEL, MAIN_WINDOW_LABEL, SETTINGS_WINDOW_LABEL};
use crate::common::{MAIN_WINDOW_LABEL, SETTINGS_WINDOW_LABEL};
use crate::server::servers::{load_or_insert_default_server, load_servers_token};
use autostart::{change_autostart, ensure_autostart_state_consistent};
use autostart::{change_autostart, enable_autostart};
use lazy_static::lazy_static;
use std::sync::Mutex;
use std::sync::OnceLock;
use tauri::async_runtime::block_on;
use tauri::plugin::TauriPlugin;
use tauri::{AppHandle, Emitter, Manager, PhysicalPosition, Runtime, WebviewWindow, WindowEvent};
#[cfg(target_os = "macos")]
use tauri::ActivationPolicy;
use tauri::{
AppHandle, Emitter, Manager, PhysicalPosition, Runtime, WebviewWindow, Window, WindowEvent,
};
use tauri_plugin_autostart::MacosLauncher;
/// Tauri store name
@@ -60,13 +64,11 @@ pub fn run() {
let ctx = tauri::generate_context!();
let mut app_builder = tauri::Builder::default();
// Set up logger first
app_builder = app_builder.plugin(set_up_tauri_logger());
#[cfg(desktop)]
{
app_builder = app_builder.plugin(tauri_plugin_single_instance::init(|_app, argv, _cwd| {
log::debug!("a new app instance was opened with {argv:?} and the deep link event was already triggered");
println!("a new app instance was opened with {argv:?} and the deep link event was already triggered");
// when defining deep link schemes at runtime, you must also check `argv` here
}));
}
@@ -75,7 +77,7 @@ pub fn run() {
.plugin(tauri_plugin_http::init())
.plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_autostart::init(
MacosLauncher::LaunchAgent,
MacosLauncher::AppleScript,
None,
))
.plugin(tauri_plugin_deep_link::init())
@@ -87,7 +89,7 @@ pub fn run() {
.plugin(tauri_plugin_process::init())
.plugin(tauri_plugin_updater::Builder::new().build())
.plugin(tauri_plugin_windows_version::init())
.plugin(tauri_plugin_opener::init());
.plugin(set_up_tauri_logger());
// Conditional compilation for macOS
#[cfg(target_os = "macos")]
@@ -105,8 +107,6 @@ pub fn run() {
show_coco,
hide_coco,
show_settings,
show_check,
hide_check,
server::servers::get_server_token,
server::servers::add_coco_server,
server::servers::remove_coco_server,
@@ -123,9 +123,7 @@ pub fn run() {
search::query_coco_fusion,
assistant::chat_history,
assistant::new_chat,
assistant::chat_create,
assistant::send_message,
assistant::chat_chat,
assistant::session_chat_history,
assistant::open_session_chat,
assistant::close_session_chat,
@@ -133,8 +131,6 @@ pub fn run() {
assistant::delete_session_chat,
assistant::update_session_chat,
assistant::assistant_search,
assistant::assistant_get,
assistant::assistant_get_multi,
// server::get_coco_server_datasources,
// server::get_coco_server_connectors,
server::websocket::connect_to_server,
@@ -144,48 +140,30 @@ pub fn run() {
server::attachment::get_attachment,
server::attachment::delete_attachment,
server::transcription::transcription,
util::open,
server::system_settings::get_system_settings,
extension::built_in::application::get_app_list,
extension::built_in::application::get_app_search_path,
extension::built_in::application::get_app_metadata,
extension::built_in::application::add_app_search_path,
extension::built_in::application::remove_app_search_path,
extension::built_in::application::reindex_applications,
extension::list_extensions,
extension::enable_extension,
extension::disable_extension,
extension::set_extension_alias,
extension::register_extension_hotkey,
extension::unregister_extension_hotkey,
extension::is_extension_enabled,
extension::third_party::store::search_extension,
extension::third_party::store::install_extension_from_store,
extension::third_party::uninstall_extension,
simulate_mouse_click,
local::get_disabled_local_query_sources,
local::enable_local_query_source,
local::disable_local_query_source,
local::application::get_app_list,
local::application::get_app_search_path,
local::application::get_app_metadata,
local::application::set_app_alias,
local::application::register_app_hotkey,
local::application::unregister_app_hotkey,
local::application::disable_app_search,
local::application::enable_app_search,
local::application::add_app_search_path,
local::application::remove_app_search_path,
settings::set_allow_self_signature,
settings::get_allow_self_signature,
assistant::ask_ai,
crate::common::document::open,
#[cfg(any(target_os = "macos", target_os = "windows"))]
extension::built_in::file_search::config::get_file_system_config,
#[cfg(any(target_os = "macos", target_os = "windows"))]
extension::built_in::file_search::config::set_file_system_config,
server::synthesize::synthesize,
util::file::get_file_icon,
util::app_lang::update_app_lang,
])
.setup(|app| {
#[cfg(target_os = "macos")]
{
log::trace!("hiding Dock icon on macOS");
app.set_activation_policy(tauri::ActivationPolicy::Accessory);
log::trace!("Dock icon should be hidden now");
}
let app_handle = app.handle().clone();
GLOBAL_TAURI_APP_HANDLE
.set(app_handle.clone())
.expect("variable already initialized");
log::trace!("global Tauri app handle set");
let registry = SearchSourceRegistry::default();
@@ -198,12 +176,15 @@ pub fn run() {
shortcut::enable_shortcut(app);
ensure_autostart_state_consistent(app)?;
enable_autostart(app);
#[cfg(target_os = "macos")]
app.set_activation_policy(ActivationPolicy::Accessory);
// app.listen("theme-changed", move |event| {
// if let Ok(payload) = serde_json::from_str::<ThemeChangedPayload>(event.payload()) {
// // switch_tray_icon(app.app_handle(), payload.is_dark_mode);
// log::debug!("Theme changed: is_dark_mode = {}", payload.is_dark_mode);
// println!("Theme changed: is_dark_mode = {}", payload.is_dark_mode);
// }
// });
@@ -223,19 +204,13 @@ pub fn run() {
let main_window = app.get_webview_window(MAIN_WINDOW_LABEL).unwrap();
let settings_window = app.get_webview_window(SETTINGS_WINDOW_LABEL).unwrap();
let check_window = app.get_webview_window(CHECK_WINDOW_LABEL).unwrap();
setup::default(
app,
main_window.clone(),
settings_window.clone(),
check_window.clone(),
);
setup::default(app, main_window.clone(), settings_window.clone());
Ok(())
})
.on_window_event(|window, event| match event {
WindowEvent::CloseRequested { api, .. } => {
//dbg!("Close requested event received");
dbg!("Close requested event received");
window.hide().unwrap();
api.prevent_close();
}
@@ -250,10 +225,10 @@ pub fn run() {
has_visible_windows,
..
} => {
// dbg!(
// "Reopen event received: has_visible_windows = {}",
// has_visible_windows
// );
dbg!(
"Reopen event received: has_visible_windows = {}",
has_visible_windows
);
if has_visible_windows {
return;
}
@@ -267,11 +242,11 @@ pub fn run() {
pub async fn init<R: Runtime>(app_handle: &AppHandle<R>) {
// Await the async functions to load the servers and tokens
if let Err(err) = load_or_insert_default_server(app_handle).await {
log::error!("Failed to load servers: {}", err);
eprintln!("Failed to load servers: {}", err);
}
if let Err(err) = load_servers_token(app_handle).await {
log::error!("Failed to load server tokens: {}", err);
eprintln!("Failed to load server tokens: {}", err);
}
let coco_servers = server::servers::get_all_servers();
@@ -284,12 +259,12 @@ pub async fn init<R: Runtime>(app_handle: &AppHandle<R>) {
.await;
}
extension::built_in::pizza_engine_runtime::start_pizza_engine_runtime().await;
local::start_pizza_engine_runtime();
}
#[tauri::command]
async fn show_coco<R: Runtime>(app_handle: AppHandle<R>) {
if let Some(window) = app_handle.get_webview_window(MAIN_WINDOW_LABEL) {
if let Some(window) = app_handle.get_window(MAIN_WINDOW_LABEL) {
move_window_to_active_monitor(&window);
let _ = window.show();
@@ -302,24 +277,24 @@ async fn show_coco<R: Runtime>(app_handle: AppHandle<R>) {
#[tauri::command]
async fn hide_coco<R: Runtime>(app: AppHandle<R>) {
if let Some(window) = app.get_webview_window(MAIN_WINDOW_LABEL) {
if let Some(window) = app.get_window(MAIN_WINDOW_LABEL) {
if let Err(err) = window.hide() {
log::error!("Failed to hide the window: {}", err);
eprintln!("Failed to hide the window: {}", err);
} else {
log::debug!("Window successfully hidden.");
println!("Window successfully hidden.");
}
} else {
log::error!("Main window not found.");
eprintln!("Main window not found.");
}
}
fn move_window_to_active_monitor<R: Runtime>(window: &WebviewWindow<R>) {
//dbg!("Moving window to active monitor");
fn move_window_to_active_monitor<R: Runtime>(window: &Window<R>) {
dbg!("Moving window to active monitor");
// Try to get the available monitors, handle failure gracefully
let available_monitors = match window.available_monitors() {
Ok(monitors) => monitors,
Err(e) => {
log::error!("Failed to get monitors: {}", e);
eprintln!("Failed to get monitors: {}", e);
return;
}
};
@@ -328,7 +303,7 @@ fn move_window_to_active_monitor<R: Runtime>(window: &WebviewWindow<R>) {
let cursor_position = match window.cursor_position() {
Ok(pos) => Some(pos),
Err(e) => {
log::error!("Failed to get cursor position: {}", e);
eprintln!("Failed to get cursor position: {}", e);
None
}
};
@@ -357,7 +332,7 @@ fn move_window_to_active_monitor<R: Runtime>(window: &WebviewWindow<R>) {
let monitor = match target_monitor.or_else(|| window.primary_monitor().ok().flatten()) {
Some(monitor) => monitor,
None => {
log::error!("No monitor found!");
eprintln!("No monitor found!");
return;
}
};
@@ -367,7 +342,7 @@ fn move_window_to_active_monitor<R: Runtime>(window: &WebviewWindow<R>) {
if let Some(ref prev_name) = *previous_monitor_name {
if name.to_string() == *prev_name {
log::debug!("Currently on the same monitor");
println!("Currently on the same monitor");
return;
}
@@ -381,7 +356,7 @@ fn move_window_to_active_monitor<R: Runtime>(window: &WebviewWindow<R>) {
let window_size = match window.inner_size() {
Ok(size) => size,
Err(e) => {
log::error!("Failed to get window size: {}", e);
eprintln!("Failed to get window size: {}", e);
return;
}
};
@@ -395,25 +370,52 @@ fn move_window_to_active_monitor<R: Runtime>(window: &WebviewWindow<R>) {
// Move the window to the new position
if let Err(e) = window.set_position(PhysicalPosition::new(window_x, window_y)) {
log::error!("Failed to move window: {}", e);
eprintln!("Failed to move window: {}", e);
}
if let Some(name) = monitor.name() {
log::debug!("Window moved to monitor: {}", name);
println!("Window moved to monitor: {}", name);
let mut previous_monitor = PREVIOUS_MONITOR_NAME.lock().unwrap();
*previous_monitor = Some(name.to_string());
}
}
#[allow(dead_code)]
fn open_settings(app: &tauri::AppHandle) {
use tauri::webview::WebviewBuilder;
println!("settings menu item was clicked");
let window = app.get_webview_window("settings");
if let Some(window) = window {
let _ = window.show();
let _ = window.unminimize();
let _ = window.set_focus();
} else {
let window = tauri::window::WindowBuilder::new(app, "settings")
.title("Settings Window")
.fullscreen(false)
.resizable(false)
.minimizable(false)
.maximizable(false)
.inner_size(800.0, 600.0)
.build()
.unwrap();
let webview_builder =
WebviewBuilder::new("settings", tauri::WebviewUrl::App("/ui/settings".into()));
let _webview = window
.add_child(
webview_builder,
tauri::LogicalPosition::new(0, 0),
window.inner_size().unwrap(),
)
.unwrap();
}
}
#[tauri::command]
async fn get_app_search_source<R: Runtime>(app_handle: AppHandle<R>) -> Result<(), String> {
// We want all the extensions here, so no filter condition specified.
let (_found_invalid_extensions, extensions) = extension::list_extensions(None, None, false)
.await
.map_err(|e| e.to_string())?;
extension::init_extensions(extensions).await?;
local::init_local_search_source(&app_handle).await?;
let _ = server::connector::refresh_all_connectors(&app_handle).await;
let _ = server::datasource::refresh_all_datasources(&app_handle).await;
@@ -422,36 +424,53 @@ async fn get_app_search_source<R: Runtime>(app_handle: AppHandle<R>) -> Result<(
#[tauri::command]
async fn show_settings(app_handle: AppHandle) {
log::debug!("settings menu item was clicked");
let window = app_handle
.get_webview_window(SETTINGS_WINDOW_LABEL)
.expect("we have a settings window");
window.show().unwrap();
window.unminimize().unwrap();
window.set_focus().unwrap();
open_settings(&app_handle);
}
#[tauri::command]
async fn show_check(app_handle: AppHandle) {
log::debug!("check menu item was clicked");
let window = app_handle
.get_webview_window(CHECK_WINDOW_LABEL)
.expect("we have a check window");
async fn simulate_mouse_click<R: Runtime>(window: WebviewWindow<R>, is_chat_mode: bool) {
#[cfg(target_os = "windows")]
{
use enigo::{Button, Coordinate, Direction, Enigo, Mouse, Settings};
use std::{thread, time::Duration};
window.show().unwrap();
window.unminimize().unwrap();
window.set_focus().unwrap();
}
if let Ok(mut enigo) = Enigo::new(&Settings::default()) {
// Save the current mouse position
if let Ok((original_x, original_y)) = enigo.location() {
// Retrieve the window's outer position (top-left corner)
if let Ok(position) = window.outer_position() {
// Retrieve the window's inner size (client area)
if let Ok(size) = window.inner_size() {
// Calculate the center position of the title bar
let x = position.x + (size.width as i32 / 2);
let y = if is_chat_mode {
position.y + size.height as i32 - 50
} else {
position.y + 30
};
#[tauri::command]
async fn hide_check(app_handle: AppHandle) {
log::debug!("check window was closed");
let window = &app_handle
.get_webview_window(CHECK_WINDOW_LABEL)
.expect("we have a check window");
// Move the mouse cursor to the calculated position
if enigo.move_mouse(x, y, Coordinate::Abs).is_ok() {
// // Simulate a left mouse click
let _ = enigo.button(Button::Left, Direction::Click);
// let _ = enigo.button(Button::Left, Direction::Release);
window.hide().unwrap();
thread::sleep(Duration::from_millis(100));
// Move the mouse cursor back to the original position
let _ = enigo.move_mouse(original_x, original_y, Coordinate::Abs);
}
}
}
}
}
}
#[cfg(not(target_os = "windows"))]
{
let _ = window;
let _ = is_chat_mode;
}
}
/// Log format:
@@ -468,12 +487,6 @@ async fn hide_check(app_handle: AppHandle) {
/// ```
fn set_up_tauri_logger() -> TauriPlugin<tauri::Wry> {
use log::Level;
use log::LevelFilter;
use tauri_plugin_log::Builder;
/// Coco-AI app's default log level.
const DEFAULT_LOG_LEVEL: LevelFilter = LevelFilter::Info;
const LOG_LEVEL_ENV_VAR: &str = "COCO_LOG";
fn format_log_level(level: Level) -> &'static str {
match level {
@@ -495,93 +508,16 @@ fn set_up_tauri_logger() -> TauriPlugin<tauri::Wry> {
str
}
/// Allow us to configure dynamic log levels via environment variable `COCO_LOG`.
///
/// Generally, it mirros the behavior of `env_logger`. Syntax: `COCO_LOG=[target][=][level][,...]`
///
/// * If this environment variable is not set, use the default log level.
/// * If it is set, respect it:
///
/// * `COCO_LOG=coco_lib` turns on all logging for the `coco_lib` module, which is
/// equivalent to `COCO_LOG=coco_lib=trace`
/// * `COCO_LOG=trace` turns on all logging for the application, regardless of its name
/// * `COCO_LOG=TRACE` turns on all logging for the application, regardless of its name (same as previous)
/// * `COCO_LOG=reqwest=debug` turns on debug logging for `reqwest`
/// * `COCO_LOG=trace,tauri=off` turns on all the logging except for the logs come from `tauri`
/// * `COCO_LOG=off` turns off all logging for the application
/// * `COCO_LOG=` Since the value is empty, turns off all logging for the application as well
fn dynamic_log_level(mut builder: Builder) -> Builder {
let Some(log_levels) = std::env::var_os(LOG_LEVEL_ENV_VAR) else {
return builder.level(DEFAULT_LOG_LEVEL);
};
builder = builder.level(LevelFilter::Off);
let log_levels = log_levels.into_string().unwrap_or_else(|e| {
panic!(
"The value '{}' set in environment varaible '{}' is not UTF-8 encoded",
// Cannot use `.display()` here becuase that requires MSRV 1.87.0
e.to_string_lossy(),
LOG_LEVEL_ENV_VAR
)
});
// COCO_LOG=[target][=][level][,...]
let target_log_levels = log_levels.split(',');
for target_log_level in target_log_levels {
#[allow(clippy::collapsible_else_if)]
if let Some(char_index) = target_log_level.chars().position(|c| c == '=') {
let (target, equal_sign_and_level) = target_log_level.split_at(char_index);
// Remove the equal sign, we know it takes 1 byte
let level = &equal_sign_and_level[1..];
if let Ok(level) = level.parse::<LevelFilter>() {
// Here we have to call `.to_string()` because `Cow<'static, str>` requires `&'static str`
builder = builder.level_for(target.to_string(), level);
} else {
panic!(
"log level '{}' set in '{}={}' is invalid",
level, target, level
);
}
} else {
if let Ok(level) = target_log_level.parse::<LevelFilter>() {
// This is a level
builder = builder.level(level);
} else {
// This is a target, enable all the logging
//
// Here we have to call `.to_string()` because `Cow<'static, str>` requires `&'static str`
builder = builder.level_for(target_log_level.to_string(), LevelFilter::Trace);
}
}
}
builder
}
// When running the built binary, set `COCO_LOG` to `coco_lib=trace` to capture all logs
// that come from Coco in the log file, which helps with debugging.
if !tauri::is_dev() {
// We have absolutely no guarantee that we (We have control over the Rust
// code, but definitely no idea about the libc C code, all the shared objects
// that we will link) will not concurrently read/write `envp`, so just use unsafe.
unsafe {
std::env::set_var("COCO_LOG", "coco_lib=trace");
}
}
let mut builder = tauri_plugin_log::Builder::new();
builder = builder.format(|out, message, record| {
let now = chrono::Local::now().format("%m-%d %H:%M:%S");
let level = format_log_level(record.level());
let target_and_line = format_target_and_line(record);
out.finish(format_args!(
"[{}] [{}] [{}] {}",
now, level, target_and_line, message
));
});
builder = dynamic_log_level(builder);
builder.build()
tauri_plugin_log::Builder::new()
.format(|out, message, record| {
let now = chrono::Local::now().format("%m-%d %H:%M:%S");
let level = format_log_level(record.level());
let target_and_line = format_target_and_line(record);
out.finish(format_args!(
"[{}] [{}] [{}] {}",
now, level, target_and_line, message
));
})
.level(log::LevelFilter::Debug)
.build()
}

View File

@@ -12,10 +12,9 @@ pub use with_feature::*;
#[cfg(not(feature = "use_pizza_engine"))]
pub use without_feature::*;
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)]
pub struct AppEntry {
path: String,
name: String,
@@ -25,26 +24,15 @@ pub struct AppEntry {
is_disabled: bool,
}
#[derive(serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AppMetadata {
name: String,
r#where: String,
size: u64,
icon: String,
created: u128,
modified: u128,
last_opened: u128,
}
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "Applications",
"platforms": ["macos", "linux", "windows"],
"name": "Applications",
"description": "Application search",
"icon": "font_Application",
"type": "group",
"enabled": true
}
"#;
}

View File

@@ -1,20 +1,18 @@
use super::super::Extension;
use super::AppMetadata;
use crate::common::error::SearchError;
use crate::common::search::{QueryResponse, QuerySource, SearchQuery};
use crate::common::traits::SearchSource;
use crate::extension::LOCAL_QUERY_SOURCE_TYPE;
use crate::local::LOCAL_QUERY_SOURCE_TYPE;
use async_trait::async_trait;
use tauri::{AppHandle, Runtime};
use super::AppEntry;
use super::AppMetadata;
pub(crate) const QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME: &str = "Applications";
pub struct ApplicationSearchSource;
impl ApplicationSearchSource {
pub async fn prepare_index_and_store<R: Runtime>(
_app_handle: AppHandle<R>,
) -> Result<(), String> {
pub async fn init<R: Runtime>(_app_handle: AppHandle<R>) -> Result<(), String> {
Ok(())
}
}
@@ -41,45 +39,46 @@ impl SearchSource for ApplicationSearchSource {
}
}
pub fn set_app_alias<R: Runtime>(_tauri_app_handle: &AppHandle<R>, _app_path: &str, _alias: &str) {
#[tauri::command]
pub async fn set_app_alias(_app_path: String, _alias: String) -> Result<(), String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub fn register_app_hotkey<R: Runtime>(
_tauri_app_handle: &AppHandle<R>,
_app_path: &str,
_hotkey: &str,
#[tauri::command]
pub async fn register_app_hotkey<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
_hotkey: String,
) -> Result<(), String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub fn unregister_app_hotkey<R: Runtime>(
_tauri_app_handle: &AppHandle<R>,
_app_path: &str,
#[tauri::command]
pub async fn unregister_app_hotkey<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<(), String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub fn disable_app_search<R: Runtime>(
_tauri_app_handle: &AppHandle<R>,
_app_path: &str,
#[tauri::command]
pub async fn disable_app_search<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<(), String> {
// no-op
Ok(())
}
pub fn enable_app_search<R: Runtime>(
_tauri_app_handle: &AppHandle<R>,
_app_path: &str,
#[tauri::command]
pub async fn enable_app_search<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
_app_path: String,
) -> Result<(), String> {
// no-op
Ok(())
}
pub fn is_app_search_enabled(_app_path: &str) -> bool {
false
}
#[tauri::command]
pub async fn add_app_search_path<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
@@ -104,10 +103,11 @@ pub async fn get_app_search_path<R: Runtime>(_tauri_app_handle: AppHandle<R>) ->
Vec::new()
}
#[tauri::command]
pub async fn get_app_list<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
) -> Result<Vec<Extension>, String> {
) -> Result<Vec<AppEntry>, String> {
// Return an empty list
Ok(Vec::new())
}
@@ -119,23 +119,3 @@ pub async fn get_app_metadata<R: Runtime>(
) -> Result<AppMetadata, String> {
unreachable!("app list should be empty, there is no way this can be invoked")
}
pub(crate) fn set_apps_hotkey<R: Runtime>(_tauri_app_handle: &AppHandle<R>) -> Result<(), String> {
// no-op
Ok(())
}
pub(crate) fn unset_apps_hotkey<R: Runtime>(
_tauri_app_handle: &AppHandle<R>,
) -> Result<(), String> {
// no-op
Ok(())
}
#[tauri::command]
pub async fn reindex_applications<R: Runtime>(
_tauri_app_handle: AppHandle<R>,
) -> Result<(), String> {
// no-op
Ok(())
}

View File

@@ -1,4 +1,4 @@
use super::super::LOCAL_QUERY_SOURCE_TYPE;
use super::LOCAL_QUERY_SOURCE_TYPE;
use crate::common::{
document::{DataSourceReference, Document},
error::SearchError,
@@ -13,19 +13,6 @@ use std::collections::HashMap;
pub(crate) const DATA_SOURCE_ID: &str = "Calculator";
/// JSON file for this extension.
pub(crate) const PLUGIN_JSON_FILE: &str = r#"
{
"id": "Calculator",
"name": "Calculator",
"platforms": ["macos", "linux", "windows"],
"description": "...",
"icon": "font_Calculator",
"type": "calculator",
"enabled": true
}
"#;
pub struct CalculatorSource {
base_score: f64,
}
@@ -36,7 +23,7 @@ impl CalculatorSource {
}
}
fn parse_query(query: &str) -> Value {
fn parse_query(query: String) -> Value {
let mut query_json = serde_json::Map::new();
let operators = ["+", "-", "*", "/", "%"];
@@ -61,7 +48,7 @@ fn parse_query(query: &str) -> Value {
query_json.insert("type".to_string(), Value::String("expression".to_string()));
}
query_json.insert("value".to_string(), Value::String(query.to_string()));
query_json.insert("value".to_string(), Value::String(query));
Value::Object(query_json)
}
@@ -121,17 +108,11 @@ impl SearchSource for CalculatorSource {
}
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError> {
let Some(query_string) = query.query_strings.get("query") else {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
});
};
// Trim the leading and tailing whitespace so that our later if condition
// will only be evaluated against non-whitespace characters.
let query_string = query_string.trim();
let query_string = query
.query_strings
.get("query")
.unwrap_or(&"".to_string())
.to_string();
if query_string.is_empty() || query_string.len() == 1 {
return Ok(QueryResponse {
@@ -141,54 +122,42 @@ impl SearchSource for CalculatorSource {
});
}
let query_string_clone = query_string.to_string();
let query_source = self.get_type();
let base_score = self.base_score;
let closure = move || -> QueryResponse {
let res_num = meval::eval_str(&query_string_clone);
match meval::eval_str(&query_string) {
Ok(num) => {
let mut payload: HashMap<String, Value> = HashMap::new();
match res_num {
Ok(num) => {
let mut payload: HashMap<String, Value> = HashMap::new();
let payload_query = parse_query(query_string);
let payload_result = parse_result(num);
let payload_query = parse_query(&query_string_clone);
let payload_result = parse_result(num);
payload.insert("query".to_string(), payload_query);
payload.insert("result".to_string(), payload_result);
payload.insert("query".to_string(), payload_query);
payload.insert("result".to_string(), payload_result);
let doc = Document {
id: DATA_SOURCE_ID.to_string(),
category: Some(DATA_SOURCE_ID.to_string()),
payload: Some(payload),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(DATA_SOURCE_ID.into()),
id: Some(DATA_SOURCE_ID.into()),
icon: None,
}),
..Default::default()
};
let doc = Document {
id: DATA_SOURCE_ID.to_string(),
category: Some(DATA_SOURCE_ID.to_string()),
payload: Some(payload),
source: Some(DataSourceReference {
r#type: Some(LOCAL_QUERY_SOURCE_TYPE.into()),
name: Some(DATA_SOURCE_ID.into()),
id: Some(DATA_SOURCE_ID.into()),
icon: Some(String::from("font_Calculator")),
}),
..Default::default()
};
QueryResponse {
source: query_source,
hits: vec![(doc, base_score)],
total_hits: 1,
}
}
Err(_) => QueryResponse {
source: query_source,
return Ok(QueryResponse {
source: self.get_type(),
hits: vec![(doc, self.base_score)],
total_hits: 1,
});
}
Err(_) => {
return Ok(QueryResponse {
source: self.get_type(),
hits: Vec::new(),
total_hits: 0,
},
});
}
};
let spawn_result = tokio::task::spawn_blocking(closure).await;
match spawn_result {
Ok(response) => Ok(response),
Err(e) => std::panic::resume_unwind(e.into_panic()),
}
}
}

164
src-tauri/src/local/mod.rs Normal file
View File

@@ -0,0 +1,164 @@
pub mod application;
pub mod calculator;
pub mod file_system;
use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::OnceLock;
use crate::common::register::SearchSourceRegistry;
use serde_json::Value as Json;
use tauri::{AppHandle, Manager, Runtime};
use tauri_plugin_store::StoreExt;
pub const LOCAL_QUERY_SOURCE_TYPE: &str = "local";
pub const TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE: &str = "local_query_source_enabled_state";
trait SearchSourceState {
#[cfg_attr(not(feature = "use_pizza_engine"), allow(unused))]
fn as_mut_any(&mut self) -> &mut dyn Any;
}
#[async_trait::async_trait(?Send)]
trait Task: Send + Sync {
fn search_source_id(&self) -> &'static str;
async fn exec(&mut self, state: &mut Option<Box<dyn SearchSourceState>>);
}
static RUNTIME_TX: OnceLock<tokio::sync::mpsc::UnboundedSender<Box<dyn Task>>> = OnceLock::new();
pub(crate) fn start_pizza_engine_runtime() {
std::thread::spawn(|| {
let rt = tokio::runtime::Runtime::new().unwrap();
let main = async {
let mut states: HashMap<String, Option<Box<dyn SearchSourceState>>> = HashMap::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
RUNTIME_TX.set(tx).unwrap();
while let Some(mut task) = rx.recv().await {
let opt_search_source_state = match states.entry(task.search_source_id().into()) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(None),
};
task.exec(opt_search_source_state).await;
}
};
rt.block_on(main);
});
}
pub(crate) async fn init_local_search_source<R: Runtime>(
app_handle: &AppHandle<R>,
) -> Result<(), String> {
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.map_err(|e| e.to_string())?;
if enabled_status_store.is_empty() {
enabled_status_store.set(
application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME,
Json::Bool(true),
);
enabled_status_store.set(calculator::DATA_SOURCE_ID, Json::Bool(true));
}
let registry = app_handle.state::<SearchSourceRegistry>();
application::ApplicationSearchSource::init(app_handle.clone()).await?;
for (id, enabled) in enabled_status_store.entries() {
let enabled = match enabled {
Json::Bool(b) => b,
_ => unreachable!("enabled state should be stored as a boolean"),
};
if enabled {
if id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
registry
.register_source(application::ApplicationSearchSource)
.await;
}
if id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
registry.register_source(calculator_search).await;
}
}
}
Ok(())
}
#[tauri::command]
pub async fn get_disabled_local_query_sources<R: Runtime>(app_handle: AppHandle<R>) -> Vec<String> {
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.unwrap_or_else(|e| {
panic!(
"tauri store [{}] should exist and be loaded, but that's not true due to error [{}]",
TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e
)
});
let mut disabled_local_query_sources = Vec::new();
for (id, enabled) in enabled_status_store.entries() {
let enabled = match enabled {
Json::Bool(b) => b,
_ => unreachable!("enabled state should be stored as a boolean"),
};
if !enabled {
disabled_local_query_sources.push(id);
}
}
disabled_local_query_sources
}
#[tauri::command]
pub async fn enable_local_query_source<R: Runtime>(
app_handle: AppHandle<R>,
query_source_id: String,
) {
let registry = app_handle.state::<SearchSourceRegistry>();
if query_source_id == application::QUERYSOURCE_ID_DATASOURCE_ID_DATASOURCE_NAME {
let application_search = application::ApplicationSearchSource;
registry.register_source(application_search).await;
}
if query_source_id == calculator::DATA_SOURCE_ID {
let calculator_search = calculator::CalculatorSource::new(2000f64);
registry.register_source(calculator_search).await;
}
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.unwrap_or_else(|e| {
panic!(
"tauri store [{}] should exist and be loaded, but that's not true due to error [{}]",
TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e
)
});
enabled_status_store.set(query_source_id, Json::Bool(true));
}
#[tauri::command]
pub async fn disable_local_query_source<R: Runtime>(
app_handle: AppHandle<R>,
query_source_id: String,
) {
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(&query_source_id).await;
let enabled_status_store = app_handle
.store(TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE)
.unwrap_or_else(|e| {
panic!(
"tauri store [{}] should exist and be loaded, but that's not true due to error [{}]",
TAURI_STORE_LOCAL_QUERY_SOURCE_ENABLED_STATE, e
)
});
enabled_status_store.set(query_source_id, Json::Bool(false));
}

View File

@@ -1,55 +1,15 @@
use crate::common::error::SearchError;
use crate::common::register::SearchSourceRegistry;
use crate::common::search::{
FailedRequest, MultiSourceQueryResponse, QueryHits, QueryResponse, QuerySource, SearchQuery,
FailedRequest, MultiSourceQueryResponse, QueryHits, QuerySource, SearchQuery,
};
use crate::common::traits::SearchSource;
use function_name::named;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use std::cmp::Reverse;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future::Future;
use std::sync::Arc;
use tauri::{AppHandle, Manager, Runtime};
use tokio::time::error::Elapsed;
use tokio::time::{timeout, Duration};
/// Helper function to return the Future used for querying querysources.
///
/// It is a workaround for the limitations:
///
/// 1. 2 async blocks have different types in Rust's type system even though
/// they are literally same
/// 2. `futures::stream::FuturesUnordered` needs the `Futures` pushed to it to
/// have only 1 type
///
/// Putting the async block in a function to unify the types.
fn same_type_futures(
query_source: QuerySource,
query_source_trait_object: Arc<dyn SearchSource>,
timeout_duration: Duration,
search_query: SearchQuery,
) -> impl Future<
Output = (
QuerySource,
Result<Result<QueryResponse, SearchError>, Elapsed>,
),
> + 'static {
async move {
(
// Store `query_source` as part of future for debugging purposes.
query_source,
timeout(timeout_duration, async {
query_source_trait_object.search(search_query).await
})
.await,
)
}
}
#[named]
#[tauri::command]
pub async fn query_coco_fusion<R: Runtime>(
app_handle: AppHandle<R>,
@@ -58,153 +18,113 @@ pub async fn query_coco_fusion<R: Runtime>(
query_strings: HashMap<String, String>,
query_timeout: u64,
) -> Result<MultiSourceQueryResponse, SearchError> {
let query_keyword = query_strings
.get("query")
.unwrap_or(&"".to_string())
.clone();
let opt_query_source_id = query_strings.get("querysource");
let query_source_to_search = query_strings.get("querysource");
let search_sources = app_handle.state::<SearchSourceRegistry>();
let sources_future = search_sources.get_sources();
let mut futures = FuturesUnordered::new();
let mut sources = HashMap::new();
let mut sources_list = sources_future.await;
let sources_list_len = sources_list.len();
let sources_list = sources_future.await;
// Time limit for each query
let timeout_duration = Duration::from_millis(query_timeout);
log::debug!(
"{}(): {:?}, timeout: {:?}",
function_name!(),
query_strings,
timeout_duration
);
// Push all queries into futures
for query_source in sources_list {
let query_source_type = query_source.get_type().clone();
let search_query = SearchQuery::new(from, size, query_strings.clone());
if let Some(query_source_id) = opt_query_source_id {
// If this query source ID is specified, we only query this query source.
log::debug!(
"parameter [querysource={}] specified, will only query this querysource",
query_source_id
);
let opt_query_source_trait_object_index = sources_list
.iter()
.position(|query_source| &query_source.get_type().id == query_source_id);
let Some(query_source_trait_object_index) = opt_query_source_trait_object_index else {
// It is possible (an edge case) that the frontend invokes `query_coco_fusion()` with a
// datasource that does not exist in the source list:
//
// 1. Search applications
// 2. Navigate to the application sub page
// 3. Disable the application extension in settings
// 4. hide the search window
// 5. Re-open the search window and search for something
//
// The application search source is not in the source list because the extension
// has been disabled, but the last search is indeed invoked with parameter
// `datasource=application`.
return Ok(MultiSourceQueryResponse {
failed: Vec::new(),
hits: Vec::new(),
total_hits: 0,
});
};
let query_source_trait_object = sources_list.remove(query_source_trait_object_index);
let query_source = query_source_trait_object.get_type();
futures.push(same_type_futures(
query_source,
query_source_trait_object,
timeout_duration,
search_query,
));
} else {
for query_source_trait_object in sources_list {
let query_source = query_source_trait_object.get_type().clone();
log::debug!("will query querysource [{}]", query_source.id);
futures.push(same_type_futures(
query_source,
query_source_trait_object,
timeout_duration,
search_query.clone(),
));
if let Some(query_source_to_search) = query_source_to_search {
// We should not search this data source
if &query_source_type.id != query_source_to_search {
continue;
}
}
sources.insert(query_source_type.id.clone(), query_source_type);
let query = SearchQuery::new(from, size, query_strings.clone());
let query_source_clone = query_source.clone(); // Clone Arc to avoid ownership issues
futures.push(tokio::spawn(async move {
// Timeout each query execution
timeout(timeout_duration, async {
query_source_clone.search(query).await
})
.await
}));
}
let mut total_hits = 0;
let mut need_rerank = true; //TODO set default to false when boost supported in Pizza
let mut failed_requests = Vec::new();
let mut all_hits: Vec<(String, QueryHits, f64)> = Vec::new();
let mut hits_per_source: HashMap<String, Vec<(QueryHits, f64)>> = HashMap::new();
if sources_list_len > 1 {
need_rerank = true; // If we have more than one source, we need to rerank the hits
}
while let Some(result) = futures.next().await {
match result {
Ok(Ok(Ok(response))) => {
total_hits += response.total_hits;
let source_id = response.source.id.clone();
while let Some((query_source, timeout_result)) = futures.next().await {
match timeout_result {
// Ignore the `_timeout` variable as it won't provide any useful debugging information.
Err(_timeout) => {
log::warn!(
"searching query source [{}] timed out, skip this request",
query_source.id
);
// failed_requests.push(FailedRequest {
// source: query_source,
// status: 0,
// error: Some("querying timed out".into()),
// reason: None,
// });
for (doc, score) in response.hits {
let query_hit = QueryHits {
source: Some(response.source.clone()),
score,
document: doc,
};
all_hits.push((source_id.clone(), query_hit.clone(), score));
hits_per_source
.entry(source_id.clone())
.or_insert_with(Vec::new)
.push((query_hit, score));
}
}
Ok(Ok(Err(err))) => {
failed_requests.push(FailedRequest {
source: QuerySource {
r#type: "N/A".into(),
name: "N/A".into(),
id: "N/A".into(),
},
status: 0,
error: Some(err.to_string()),
reason: None,
});
}
Ok(Err(err)) => {
failed_requests.push(FailedRequest {
source: QuerySource {
r#type: "N/A".into(),
name: "N/A".into(),
id: "N/A".into(),
},
status: 0,
error: Some(err.to_string()),
reason: None,
});
}
// Timeout reached, skip this request
_ => {
failed_requests.push(FailedRequest {
source: QuerySource {
r#type: "N/A".into(),
name: "N/A".into(),
id: "N/A".into(),
},
status: 0,
error: Some(format!("{:?}", &result)),
reason: None,
});
}
Ok(query_result) => match query_result {
Ok(response) => {
total_hits += response.total_hits;
let source_id = response.source.id.clone();
for (doc, score) in response.hits {
log::debug!("doc: {}, {:?}, {}", doc.id, doc.title, score);
let query_hit = QueryHits {
source: Some(response.source.clone()),
score,
document: doc,
};
all_hits.push((source_id.clone(), query_hit.clone(), score));
hits_per_source
.entry(source_id.clone())
.or_insert_with(Vec::new)
.push((query_hit, score));
}
}
Err(search_error) => {
log::error!(
"searching query source [{}] failed, error [{}]",
query_source.id,
search_error
);
failed_requests.push(FailedRequest {
source: query_source,
status: 0,
error: Some(search_error.to_string()),
reason: None,
});
}
},
}
}
// Sort hits within each source by score (descending)
for hits in hits_per_source.values_mut() {
hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Greater));
hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
let total_sources = hits_per_source.len();
@@ -220,71 +140,16 @@ pub async fn query_coco_fusion<R: Runtime>(
// Distribute hits fairly across sources
for (_source_id, hits) in &mut hits_per_source {
let take_count = hits.len().min(max_hits_per_source);
for (doc, score) in hits.drain(0..take_count) {
for (doc, _) in hits.drain(0..take_count) {
if !seen_docs.contains(&doc.document.id) {
seen_docs.insert(doc.document.id.clone());
log::debug!(
"collect doc: {}, {:?}, {}",
doc.document.id,
doc.document.title,
score
);
final_hits.push(doc);
}
}
}
log::debug!("final hits: {:?}", final_hits.len());
let mut unique_sources = HashSet::new();
for hit in &final_hits {
if let Some(source) = &hit.source {
if source.id != crate::extension::built_in::calculator::DATA_SOURCE_ID {
unique_sources.insert(&source.id);
}
}
}
log::debug!(
"Multiple sources found: {:?}, no rerank needed",
unique_sources
);
if unique_sources.len() < 1 {
need_rerank = false; // If we have hits from multiple sources, we don't need to rerank
}
if need_rerank && final_hits.len() > 1 {
// Precollect (index, title)
let titles_to_score: Vec<(usize, &str)> = final_hits
.iter()
.enumerate()
.filter_map(|(idx, hit)| {
let source = hit.source.as_ref()?;
let title = hit.document.title.as_deref()?;
if source.id != crate::extension::built_in::calculator::DATA_SOURCE_ID {
Some((idx, title))
} else {
None
}
})
.collect();
// Score them
let scored_hits = boosted_levenshtein_rerank(query_keyword.as_str(), titles_to_score);
// Sort descending by score
let mut scored_hits = scored_hits;
scored_hits.sort_by_key(|&(_, score)| Reverse((score * 1000.0) as u64));
// Apply new scores to final_hits
for (idx, score) in scored_hits.into_iter().take(size as usize) {
final_hits[idx].score = score;
}
} else if final_hits.len() < size as usize {
// If we still need more hits, take the highest-scoring remaining ones
// If we still need more hits, take the highest-scoring remaining ones
if final_hits.len() < size as usize {
let remaining_needed = size as usize - final_hits.len();
// Sort all hits by score descending, removing duplicates by document ID
@@ -314,45 +179,9 @@ pub async fn query_coco_fusion<R: Runtime>(
.unwrap_or(std::cmp::Ordering::Equal)
});
if final_hits.len() < 5 {
//TODO: Add a recommendation system to suggest more sources
log::info!(
"Less than 5 hits found, consider using recommendation to find more suggestions."
);
//local: recent history, local extensions
//remote: ai agents, quick links, other tasks, managed by server
}
Ok(MultiSourceQueryResponse {
failed: failed_requests,
hits: final_hits,
total_hits,
})
}
fn boosted_levenshtein_rerank(query: &str, titles: Vec<(usize, &str)>) -> Vec<(usize, f64)> {
use strsim::levenshtein;
let query_lower = query.to_lowercase();
titles
.into_iter()
.map(|(idx, title)| {
let mut score = 0.0;
if title.contains(query) {
score += 0.4;
} else if title.to_lowercase().contains(&query_lower) {
score += 0.2;
}
let dist = levenshtein(&query_lower, &title.to_lowercase());
let max_len = query_lower.len().max(title.len());
if max_len > 0 {
score += (1.0 - (dist as f64 / max_len as f64)) as f32;
}
(idx, score.min(1.0) as f64)
})
.collect()
}

View File

@@ -15,6 +15,42 @@ pub struct UploadAttachmentResponse {
pub attachments: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttachmentSource {
pub id: String,
pub created: String,
pub updated: String,
pub session: String,
pub name: String,
pub icon: String,
pub url: String,
pub size: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttachmentHit {
pub _index: String,
pub _type: Option<String>,
pub _id: String,
pub _score: Option<f64>,
pub _source: AttachmentSource,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AttachmentHits {
pub total: Value,
pub max_score: Option<f64>,
pub hits: Option<Vec<AttachmentHit>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetAttachmentResponse {
pub took: u32,
pub timed_out: bool,
pub _shards: Option<Value>,
pub hits: AttachmentHits,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeleteAttachmentResponse {
pub _id: String,
@@ -24,6 +60,7 @@ pub struct DeleteAttachmentResponse {
#[command]
pub async fn upload_attachment(
server_id: String,
session_id: String,
file_paths: Vec<PathBuf>,
) -> Result<UploadAttachmentResponse, String> {
let mut form = Form::new();
@@ -46,7 +83,7 @@ pub async fn upload_attachment(
}
let server = get_server_by_id(&server_id).ok_or("Server not found")?;
let url = HttpClient::join_url(&server.endpoint, &format!("attachment/_upload"));
let url = HttpClient::join_url(&server.endpoint, &format!("chat/{}/_upload", session_id));
let token = get_server_token(&server_id).await?;
let mut headers = HashMap::new();
@@ -70,17 +107,20 @@ pub async fn upload_attachment(
}
#[command]
pub async fn get_attachment(server_id: String, session_id: String) -> Result<Value, String> {
let mut query_params = Vec::new();
query_params.push(format!("session={}", session_id));
pub async fn get_attachment(
server_id: String,
session_id: String,
) -> Result<GetAttachmentResponse, String> {
let mut query_params = HashMap::new();
query_params.insert("session".to_string(), serde_json::Value::String(session_id));
let response = HttpClient::get(&server_id, "/attachment/_search", Some(query_params))
.await
.map_err(|e| format!("Request error: {}", e))?;
let body = get_response_body_text(response).await?;
serde_json::from_str::<Value>(&body)
serde_json::from_str::<GetAttachmentResponse>(&body)
.map_err(|e| format!("Failed to parse attachment response: {}", e))
}

View File

@@ -1,8 +1,7 @@
use crate::common::connector::Connector;
use crate::common::search::parse_search_results;
use crate::server::http_client::{HttpClient, status_code_check};
use crate::server::http_client::HttpClient;
use crate::server::servers::get_all_servers;
use http::StatusCode;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
@@ -108,7 +107,6 @@ pub async fn fetch_connectors_by_server(id: &str) -> Result<Vec<Connector>, Stri
// dbg!("Error fetching connector for id {}: {}", &id, &e);
format!("Error fetching connector: {}", e)
})?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results directly from the response body
let datasource: Vec<Connector> = parse_search_results(resp)

View File

@@ -1,14 +1,20 @@
use crate::common::datasource::DataSource;
use crate::common::search::parse_search_results;
use crate::server::connector::get_connector_by_id;
use crate::server::http_client::{HttpClient, status_code_check};
use crate::server::http_client::HttpClient;
use crate::server::servers::get_all_servers;
use http::StatusCode;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tauri::{AppHandle, Runtime};
#[derive(serde::Deserialize, Debug)]
pub struct GetDatasourcesByServerOptions {
pub from: Option<u32>,
pub size: Option<u32>,
pub query: Option<String>,
}
lazy_static! {
static ref DATASOURCE_CACHE: Arc<RwLock<HashMap<String, HashMap<String, DataSource>>>> =
Arc::new(RwLock::new(HashMap::new()));
@@ -90,17 +96,50 @@ pub async fn refresh_all_datasources<R: Runtime>(_app_handle: &AppHandle<R>) ->
#[tauri::command]
pub async fn datasource_search(
id: &str,
query_params: Option<Vec<String>>, //["query=abc", "filter=er", "filter=efg", "from=0", "size=5"],
options: Option<GetDatasourcesByServerOptions>,
) -> Result<Vec<DataSource>, String> {
let from = options.as_ref().and_then(|opt| opt.from).unwrap_or(0);
let size = options.as_ref().and_then(|opt| opt.size).unwrap_or(10000);
let query = options
.and_then(|opt| opt.query)
.unwrap_or(String::default());
let mut body = serde_json::json!({
"from": from,
"size": size,
});
if !query.is_empty() {
body["query"] = serde_json::json!({
"bool": {
"must": [{
"query_string": {
"fields": ["combined_fulltext"],
"query": query,
"fuzziness": "AUTO",
"fuzzy_prefix_length": 2,
"fuzzy_max_expansions": 10,
"fuzzy_transpositions": true,
"allow_leading_wildcard": false
}
}]
}
});
}
// Perform the async HTTP request outside the cache lock
let resp = HttpClient::post(id, "/datasource/_search", query_params, None)
let resp = HttpClient::post(
id,
"/datasource/_search",
None,
Some(reqwest::Body::from(body.to_string())),
)
.await
.map_err(|e| format!("Error fetching datasource: {}", e))?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results from the response
let datasources: Vec<DataSource> = parse_search_results(resp).await.map_err(|e| {
//dbg!("Error parsing search results: {}", &e);
dbg!("Error parsing search results: {}", &e);
e.to_string()
})?;
@@ -113,17 +152,50 @@ pub async fn datasource_search(
#[tauri::command]
pub async fn mcp_server_search(
id: &str,
query_params: Option<Vec<String>>,
options: Option<GetDatasourcesByServerOptions>,
) -> Result<Vec<DataSource>, String> {
let from = options.as_ref().and_then(|opt| opt.from).unwrap_or(0);
let size = options.as_ref().and_then(|opt| opt.size).unwrap_or(10000);
let query = options
.and_then(|opt| opt.query)
.unwrap_or(String::default());
let mut body = serde_json::json!({
"from": from,
"size": size,
});
if !query.is_empty() {
body["query"] = serde_json::json!({
"bool": {
"must": [{
"query_string": {
"fields": ["combined_fulltext"],
"query": query,
"fuzziness": "AUTO",
"fuzzy_prefix_length": 2,
"fuzzy_max_expansions": 10,
"fuzzy_transpositions": true,
"allow_leading_wildcard": false
}
}]
}
});
}
// Perform the async HTTP request outside the cache lock
let resp = HttpClient::post(id, "/mcp_server/_search", query_params, None)
let resp = HttpClient::post(
id,
"/mcp_server/_search",
None,
Some(reqwest::Body::from(body.to_string())),
)
.await
.map_err(|e| format!("Error fetching datasource: {}", e))?;
status_code_check(&resp, &[StatusCode::OK, StatusCode::CREATED])?;
// Parse the search results from the response
let mcp_server: Vec<DataSource> = parse_search_results(resp).await.map_err(|e| {
//dbg!("Error parsing search results: {}", &e);
dbg!("Error parsing search results: {}", &e);
e.to_string()
})?;

View File

@@ -1,19 +1,17 @@
use crate::server::servers::{get_server_by_id, get_server_token};
use crate::util::app_lang::get_app_lang;
use crate::util::platform::Platform;
use http::{HeaderName, HeaderValue, StatusCode};
use http::{HeaderName, HeaderValue};
use once_cell::sync::Lazy;
use reqwest::{Client, Method, RequestBuilder};
use std::collections::HashMap;
use std::sync::LazyLock;
use std::time::Duration;
use tauri_plugin_store::JsonValue;
use tokio::sync::Mutex;
pub(crate) fn new_reqwest_http_client(accept_invalid_certs: bool) -> Client {
Client::builder()
.read_timeout(Duration::from_secs(3)) // Set a timeout of 3 second
.connect_timeout(Duration::from_secs(3)) // Set a timeout of 3 second
.timeout(Duration::from_secs(5 * 60)) // Set a timeout of 5 minute
.timeout(Duration::from_secs(10)) // Set a timeout of 10 seconds
.danger_accept_invalid_certs(accept_invalid_certs) // allow self-signed certificates
.build()
.expect("Failed to build client")
@@ -29,26 +27,6 @@ pub static HTTP_CLIENT: Lazy<Mutex<Client>> = Lazy::new(|| {
Mutex::new(new_reqwest_http_client(allow_self_signature))
});
/// These header values won't change during a process's lifetime.
static STATIC_HEADERS: LazyLock<HashMap<String, String>> = LazyLock::new(|| {
HashMap::from([
(
"X-OS-NAME".into(),
Platform::current()
.to_os_name_http_header_str()
.into_owned(),
),
(
"X-OS-VER".into(),
sysinfo::System::os_version()
.expect("sysinfo::System::os_version() should be Some on major systems"),
),
("X-OS-ARCH".into(), sysinfo::System::cpu_arch()),
("X-APP-NAME".into(), "coco-app".into()),
("X-APP-VER".into(), env!("CARGO_PKG_VERSION").into()),
])
});
pub struct HttpClient;
impl HttpClient {
@@ -62,7 +40,7 @@ impl HttpClient {
pub async fn send_raw_request(
method: Method,
url: &str,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>,
headers: Option<HashMap<String, String>>,
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
@@ -78,7 +56,7 @@ impl HttpClient {
Self::get_request_builder(method, url, headers, query_params, body).await;
let response = request_builder.send().await.map_err(|e| {
//dbg!("Failed to send request: {}", &e);
dbg!("Failed to send request: {}", &e);
format!("Failed to send request: {}", e)
})?;
@@ -96,7 +74,7 @@ impl HttpClient {
method: Method,
url: &str,
headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>, // Add query parameters
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> RequestBuilder {
let client = HTTP_CLIENT.lock().await; // Acquire the lock on HTTP_CLIENT
@@ -104,32 +82,8 @@ impl HttpClient {
// Build the request
let mut request_builder = client.request(method.clone(), url);
// Populate the headers defined by us
let mut req_headers = reqwest::header::HeaderMap::new();
for (key, value) in STATIC_HEADERS.iter() {
let key = HeaderName::from_bytes(key.as_bytes())
.expect("headers defined by us should be valid");
let value = HeaderValue::from_str(value.trim()).unwrap_or_else(|e| {
panic!(
"header value [{}] is invalid, error [{}], this should be unreachable",
value, e
);
});
req_headers.insert(key, value);
}
let app_lang = get_app_lang().await.to_string();
req_headers.insert(
"X-APP-LANG",
HeaderValue::from_str(&app_lang).unwrap_or_else(|e| {
panic!(
"header value [{}] is invalid, error [{}], this should be unreachable",
app_lang, e
);
}),
);
// Headers from the function parameter
if let Some(h) = headers {
let mut req_headers = reqwest::header::HeaderMap::new();
for (key, value) in h.into_iter() {
match (
HeaderName::from_bytes(key.as_bytes()),
@@ -152,9 +106,24 @@ impl HttpClient {
request_builder = request_builder.headers(req_headers);
}
if let Some(params) = query_params {
let query: Vec<(&str, &str)> =
params.iter().filter_map(|s| s.split_once('=')).collect();
if let Some(query) = query_params {
// Convert only supported value types into strings
let query: HashMap<String, String> = query
.into_iter()
.filter_map(|(k, v)| {
match v {
JsonValue::String(s) => Some((k, s)),
JsonValue::Number(n) => Some((k, n.to_string())),
JsonValue::Bool(b) => Some((k, b.to_string())),
_ => {
dbg!(
"Unsupported query parameter type. Only strings, numbers, and booleans are supported.",k,v,
);
None
} // skip arrays, objects, nulls
}
})
.collect();
request_builder = request_builder.query(&query);
}
@@ -171,7 +140,7 @@ impl HttpClient {
method: Method,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>,
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
// Fetch the server using the server_id
@@ -196,12 +165,12 @@ impl HttpClient {
headers.insert("X-API-TOKEN".to_string(), t);
}
// log::debug!(
// "Sending request to server: {}, url: {}, headers: {:?}",
// &server_id,
// &url,
// &headers
// );
log::debug!(
"Sending request to server: {}, url: {}, headers: {:?}",
&server_id,
&url,
&headers
);
Self::send_raw_request(method, &url, query_params, Some(headers), body).await
} else {
@@ -213,7 +182,7 @@ impl HttpClient {
pub async fn get(
server_id: &str,
path: &str,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::GET, path, None, query_params, None).await
}
@@ -222,7 +191,7 @@ impl HttpClient {
pub async fn post(
server_id: &str,
path: &str,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(server_id, Method::POST, path, None, query_params, body).await
@@ -232,7 +201,7 @@ impl HttpClient {
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(
@@ -252,7 +221,7 @@ impl HttpClient {
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
body: Option<reqwest::Body>,
) -> Result<reqwest::Response, String> {
HttpClient::send_request(
@@ -272,7 +241,7 @@ impl HttpClient {
server_id: &str,
path: &str,
custom_headers: Option<HashMap<String, String>>,
query_params: Option<Vec<String>>,
query_params: Option<HashMap<String, JsonValue>>, // Add query parameters
) -> Result<reqwest::Response, String> {
HttpClient::send_request(
server_id,
@@ -285,30 +254,3 @@ impl HttpClient {
.await
}
}
/// Helper function to check status code.
///
/// If the status code is not in the `allowed_status_codes` list, return an error.
pub(crate) fn status_code_check(
response: &reqwest::Response,
allowed_status_codes: &[StatusCode],
) -> Result<(), String> {
let status_code = response.status();
if !allowed_status_codes.contains(&status_code) {
let msg = format!(
"Response of request [{}] status code failed: status code [{}], which is not in the 'allow' list {:?}",
response.url(),
status_code,
allowed_status_codes
.iter()
.map(|status| status.to_string())
.collect::<Vec<String>>()
);
log::warn!("{}", msg);
Err(msg)
} else {
Ok(())
}
}

View File

@@ -8,7 +8,6 @@ pub mod http_client;
pub mod profile;
pub mod search;
pub mod servers;
pub mod synthesize;
pub mod system_settings;
pub mod transcription;
pub mod websocket;

View File

@@ -1,4 +1,4 @@
use crate::common::document::{Document, OnOpened};
use crate::common::document::Document;
use crate::common::error::SearchError;
use crate::common::http::get_response_body_text;
use crate::common::search::{QueryHits, QueryResponse, QuerySource, SearchQuery, SearchResponse};
@@ -9,6 +9,7 @@ use async_trait::async_trait;
// use futures::stream::StreamExt;
use ordered_float::OrderedFloat;
use std::collections::HashMap;
use tauri_plugin_store::JsonValue;
// use std::hash::Hash;
#[allow(dead_code)]
@@ -92,58 +93,39 @@ impl SearchSource for CocoSearchSource {
async fn search(&self, query: SearchQuery) -> Result<QueryResponse, SearchError> {
let url = "/query/_search";
let mut total_hits = 0;
let mut hits: Vec<(Document, f64)> = Vec::new();
let mut query_params = Vec::new();
// Add from/size as number values
query_params.push(format!("from={}", query.from));
query_params.push(format!("size={}", query.size));
// Add query strings
let mut query_args: HashMap<String, JsonValue> = HashMap::new();
query_args.insert("from".into(), JsonValue::Number(query.from.into()));
query_args.insert("size".into(), JsonValue::Number(query.size.into()));
for (key, value) in query.query_strings {
query_params.push(format!("{}={}", key, value));
query_args.insert(key, JsonValue::String(value));
}
let response = HttpClient::get(&self.server.id, &url, Some(query_params))
let response = HttpClient::get(
&self.server.id,
&url,
Some(query_args),
)
.await
.map_err(|e| SearchError::HttpError(format!("{}", e)))?;
.map_err(|e| SearchError::HttpError(format!("Error to send search request: {}", e)))?;
// Use the helper function to parse the response body
let response_body = get_response_body_text(response)
.await
.map_err(|e| SearchError::ParseError(e))?;
.map_err(|e| SearchError::ParseError(format!("Failed to read response body: {}", e)))?;
// Check if the response body is empty
if !response_body.is_empty() {
// log::info!("Search response body: {}", &response_body);
// Parse the search response from the body text
let parsed: SearchResponse<Document> = serde_json::from_str(&response_body)
.map_err(|e| SearchError::ParseError(format!("Failed to parse search response: {}", e)))?;
// Parse the search response from the body text
let parsed: SearchResponse<Document> = serde_json::from_str(&response_body)
.map_err(|e| SearchError::ParseError(format!("{}", e)))?;
// Process the parsed response
total_hits = parsed.hits.total.value as usize;
if let Some(items) = parsed.hits.hits {
for hit in items {
let mut document = hit._source;
// Default _score to 0.0 if None
let score = hit._score.unwrap_or(0.0);
let on_opened = document
.url
.as_ref()
.map(|url| OnOpened::Document { url: url.clone() });
// Set the `on_opened` field as it won't be returned from Coco server
document.on_opened = on_opened;
hits.push((document, score));
}
}
}
// Process the parsed response
let total_hits = parsed.hits.total.value as usize;
let hits: Vec<(Document, f64)> = parsed
.hits
.hits
.into_iter()
.map(|hit| (hit._source, hit._score.unwrap_or(0.0))) // Default _score to 0.0 if None
.collect();
// Return the final result
Ok(QueryResponse {

View File

@@ -59,7 +59,7 @@ pub fn save_server(server: &Server) -> bool {
}
fn remove_server_by_id(id: String) -> bool {
log::debug!("remove server by id: {}", &id);
dbg!("remove server by id:", &id);
let mut cache = SERVER_CACHE.write().unwrap();
let deleted = cache.remove(id.as_str());
deleted.is_some()
@@ -87,7 +87,7 @@ pub async fn persist_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<()
}
pub fn remove_server_token(id: &str) -> bool {
log::debug!("remove server token by id: {}", &id);
dbg!("remove server token by id:", &id);
let mut cache = SERVER_TOKEN.write().unwrap();
cache.remove(id).is_some()
}
@@ -104,7 +104,7 @@ pub fn persist_servers_token<R: Runtime>(app_handle: &AppHandle<R>) -> Result<()
.map(|server| serde_json::to_value(server).expect("Failed to serialize access_tokens")) // Automatically serialize all fields
.collect();
log::debug!("persist servers token: {:?}", &json_servers);
dbg!(format!("persist servers token: {:?}", &json_servers));
// Save the serialized servers to Tauri's store
app_handle
@@ -143,18 +143,17 @@ fn get_default_server() -> Server {
profile: None,
auth_provider: AuthProvider {
sso: Sso {
url: "https://coco.infini.cloud/sso/login/cloud?provider=coco-cloud&product=coco".to_string(),
url: "https://coco.infini.cloud/sso/login/".to_string(),
},
},
priority: 0,
stats: None,
}
}
pub async fn load_servers_token<R: Runtime>(
app_handle: &AppHandle<R>,
) -> Result<Vec<ServerAccessToken>, String> {
log::debug!("Attempting to load servers token");
dbg!("Attempting to load servers token");
let store = app_handle
.store(COCO_TAURI_STORE)
@@ -188,7 +187,10 @@ pub async fn load_servers_token<R: Runtime>(
save_access_token(server.id.clone(), server.clone());
}
log::debug!("loaded {:?} servers's token", &deserialized_tokens.len());
dbg!(format!(
"loaded {:?} servers's token",
&deserialized_tokens.len()
));
Ok(deserialized_tokens)
} else {
@@ -229,7 +231,7 @@ pub async fn load_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<Vec<S
save_server(&server);
}
log::debug!("load servers: {:?}", &deserialized_servers);
// dbg!(format!("load servers: {:?}", &deserialized_servers));
Ok(deserialized_servers)
} else {
@@ -241,18 +243,18 @@ pub async fn load_servers<R: Runtime>(app_handle: &AppHandle<R>) -> Result<Vec<S
pub async fn load_or_insert_default_server<R: Runtime>(
app_handle: &AppHandle<R>,
) -> Result<Vec<Server>, String> {
log::debug!("Attempting to load or insert default server");
dbg!("Attempting to load or insert default server");
let exists_servers = load_servers(&app_handle).await;
if exists_servers.is_ok() && !exists_servers.as_ref()?.is_empty() {
log::debug!("loaded {} servers", &exists_servers.clone()?.len());
dbg!(format!("loaded {} servers", &exists_servers.clone()?.len()));
return exists_servers;
}
let default = get_default_server();
save_server(&default);
log::debug!("loaded default servers");
dbg!("loaded default servers");
Ok(vec![default])
}
@@ -315,17 +317,10 @@ pub async fn refresh_coco_server_info<R: Runtime>(
// Send request to fetch updated server info
let response = HttpClient::get(&id, "/provider/_info", None)
.await
.map_err(|e| format!("Failed to contact the server: {}", e));
if response.is_err() {
let _ = mark_server_as_offline(app_handle, &id).await;
return Err(response.err().unwrap());
}
let response = response?;
.map_err(|e| format!("Failed to contact the server: {}", e))?;
if !response.status().is_success() {
let _ = mark_server_as_offline(app_handle, &id).await;
mark_server_as_offline(&id).await;
return Err(format!("Request failed with status: {}", response.status()));
}
@@ -336,9 +331,6 @@ pub async fn refresh_coco_server_info<R: Runtime>(
let mut updated_server: Server = serde_json::from_str(&body)
.map_err(|e| format!("Failed to deserialize the response: {}", e))?;
// Mark server as online
let _ = mark_server_as_online(app_handle.clone(), &id).await;
// Restore local state
updated_server.id = id.clone();
updated_server.builtin = is_builtin;
@@ -372,10 +364,10 @@ pub async fn add_coco_server<R: Runtime>(
let endpoint = endpoint.trim_end_matches('/');
if check_endpoint_exists(endpoint) {
log::debug!(
dbg!(format!(
"This Coco server has already been registered: {:?}",
&endpoint
);
));
return Err("This Coco server has already been registered.".into());
}
@@ -384,7 +376,7 @@ pub async fn add_coco_server<R: Runtime>(
.await
.map_err(|e| format!("Failed to send request to the server: {}", e))?;
log::debug!("Get provider info response: {:?}", &response);
dbg!(format!("Get provider info response: {:?}", &response));
let body = get_response_body_text(response).await?;
@@ -408,7 +400,7 @@ pub async fn add_coco_server<R: Runtime>(
.await
.map_err(|e| format!("Failed to persist Coco servers: {}", e))?;
log::debug!("Successfully registered server: {:?}", &endpoint);
dbg!(format!("Successfully registered server: {:?}", &endpoint));
Ok(server)
}
@@ -454,63 +446,26 @@ pub async fn try_register_server_to_search_source(
server: &Server,
) {
if server.enabled {
log::trace!(
"Server {} is public: {} and available: {}",
&server.name,
&server.public,
&server.available
);
if !server.public {
let token = get_server_token(&server.id).await;
if !token.is_ok() || token.is_ok() && token.unwrap().is_none() {
log::debug!("Server {} is not public and no token was found", &server.id);
return;
}
}
let registry = app_handle.state::<SearchSourceRegistry>();
let source = CocoSearchSource::new(server.clone());
registry.register_source(source).await;
}
}
#[tauri::command]
pub async fn mark_server_as_online<R: Runtime>(
app_handle: AppHandle<R>, id: &str) -> Result<(), ()> {
// println!("server_is_offline: {}", id);
let server = get_server_by_id(id);
if let Some(mut server) = server {
server.available = true;
server.health = None;
save_server(&server);
try_register_server_to_search_source(app_handle.clone(), &server).await;
}
Ok(())
}
#[tauri::command]
pub async fn mark_server_as_offline<R: Runtime>(
app_handle: AppHandle<R>,
id: &str,
) -> Result<(), ()> {
pub async fn mark_server_as_offline(id: &str) {
// println!("server_is_offline: {}", id);
let server = get_server_by_id(id);
if let Some(mut server) = server {
server.available = false;
server.health = None;
save_server(&server);
let registry = app_handle.state::<SearchSourceRegistry>();
registry.remove_source(id).await;
}
Ok(())
}
#[tauri::command]
pub async fn disable_server<R: Runtime>(app_handle: AppHandle<R>, id: String) -> Result<(), ()> {
println!("disable_server: {}", id);
let server = get_server_by_id(id.as_str());
if let Some(mut server) = server {
server.enabled = false;
@@ -531,48 +486,47 @@ pub async fn logout_coco_server<R: Runtime>(
app_handle: AppHandle<R>,
id: String,
) -> Result<(), String> {
log::debug!("Attempting to log out server by id: {}", &id);
dbg!("Attempting to log out server by id:", &id);
// Check if server token exists
if let Some(_token) = get_server_token(id.as_str()).await? {
log::debug!("Found server token for id: {}", &id);
dbg!("Found server token for id:", &id);
// Remove the server token from cache
remove_server_token(id.as_str());
// Persist the updated tokens
if let Err(e) = persist_servers_token(&app_handle) {
log::debug!("Failed to save tokens for id: {}. Error: {:?}", &id, &e);
dbg!("Failed to save tokens for id: {}. Error: {:?}", &id, &e);
return Err(format!("Failed to save tokens: {}", &e));
}
} else {
// Log the case where server token is not found
log::debug!("No server token found for id: {}", &id);
dbg!("No server token found for id: {}", &id);
}
// Check if the server exists
if let Some(mut server) = get_server_by_id(id.as_str()) {
log::debug!("Found server for id: {}", &id);
dbg!("Found server for id:", &id);
// Clear server profile
server.profile = None;
let _ = mark_server_as_offline(app_handle.clone(), id.as_str()).await;
// Save the updated server data
save_server(&server);
// Persist the updated server data
if let Err(e) = persist_servers(&app_handle).await {
log::debug!("Failed to save server for id: {}. Error: {:?}", &id, &e);
dbg!("Failed to save server for id: {}. Error: {:?}", &id, &e);
return Err(format!("Failed to save server: {}", &e));
}
} else {
// Log the case where server is not found
log::debug!("No server found for id: {}", &id);
dbg!("No server found for id: {}", &id);
return Err(format!("No server found for id: {}", id));
}
log::debug!("Successfully logged out server with id: {}", &id);
dbg!("Successfully logged out server with id:", &id);
Ok(())
}
@@ -623,7 +577,6 @@ fn test_trim_endpoint_last_forward_slash() {
},
},
priority: 0,
stats: None,
};
trim_endpoint_last_forward_slash(&mut server);

View File

@@ -1,57 +0,0 @@
use crate::server::http_client::HttpClient;
use futures_util::StreamExt;
use http::Method;
use serde_json::json;
use tauri::{command, AppHandle, Emitter, Runtime};
#[command]
pub async fn synthesize<R: Runtime>(
app_handle: AppHandle<R>,
client_id: String,
server_id: String,
voice: String,
content: String,
) -> Result<(), String> {
let body = json!({
"voice": voice,
"content": content,
})
.to_string();
let response = HttpClient::send_request(
server_id.as_str(),
Method::POST,
"/services/audio/synthesize",
None,
None,
Some(reqwest::Body::from(body.to_string())),
)
.await?;
log::info!("Synthesize response status: {}", response.status());
if response.status() == 429 {
return Ok(());
}
if !response.status().is_success() {
return Err(format!("Request Failed: {}", response.status()));
}
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
if let Err(err) = app_handle.emit(&client_id, bytes.to_vec()) {
log::error!("Emit error: {:?}", err);
}
}
Err(e) => {
log::error!("Stream error: {:?}", e);
break;
}
}
}
Ok(())
}

View File

@@ -1,96 +1,43 @@
use crate::common::http::get_response_body_text;
use crate::server::http_client::HttpClient;
use serde::{Deserialize, Serialize};
use serde_json::{from_str, Value};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use tauri::command;
#[derive(Debug, Serialize, Deserialize)]
pub struct TranscriptionResponse {
task_id: String,
results: Vec<Value>,
pub text: String,
}
#[command]
pub async fn transcription(
server_id: String,
audio_type: String,
audio_content: String,
) -> Result<TranscriptionResponse, String> {
// Send request to initiate transcription task
let init_response = HttpClient::post(
let mut query_params = HashMap::new();
query_params.insert("type".to_string(), JsonValue::String(audio_type));
query_params.insert("content".to_string(), JsonValue::String(audio_content));
// Send the HTTP POST request
let response = HttpClient::post(
&server_id,
"/services/audio/transcription",
Some(query_params),
None,
Some(audio_content.into()),
)
.await
.map_err(|e| format!("Failed to initiate transcription: {}", e))?;
// Extract response body as text
let init_response_text = get_response_body_text(init_response)
.await
.map_err(|e| format!("Failed to read initial response body: {}", e))?;
.map_err(|e| format!("Error sending transcription request: {}", e))?;
// Parse response JSON to extract task ID
let init_response_json: Value = from_str(&init_response_text).map_err(|e| {
format!(
"Failed to parse initial response JSON: {}. Raw response: {}",
e, init_response_text
)
})?;
let transcription_task_id = init_response_json["task_id"]
.as_str()
.ok_or_else(|| {
format!(
"Missing or invalid task_id in initial response: {}",
init_response_text
)
})?
.to_string();
// Set up polling with timeout
let polling_start = std::time::Instant::now();
const POLLING_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
const POLLING_INTERVAL: std::time::Duration = std::time::Duration::from_millis(200);
let mut transcription_response: TranscriptionResponse;
loop {
// Poll for transcription results
let poll_response = HttpClient::get(
&server_id,
&format!("/services/audio/task/{}", transcription_task_id),
None,
)
// Use get_response_body_text to extract the response body as text
let response_body = get_response_body_text(response)
.await
.map_err(|e| format!("Failed to poll transcription task: {}", e))?;
.map_err(|e| format!("Failed to read response body: {}", e))?;
// Extract poll response body
let poll_response_text = get_response_body_text(poll_response)
.await
.map_err(|e| format!("Failed to read poll response body: {}", e))?;
// Parse poll response JSON
transcription_response = from_str(&poll_response_text).map_err(|e| {
format!(
"Failed to parse poll response JSON: {}. Raw response: {}",
e, poll_response_text
)
})?;
// Check if transcription results are available
if !transcription_response.results.is_empty() {
break;
}
// Check for timeout
if polling_start.elapsed() >= POLLING_TIMEOUT {
return Err("Transcription task timed out after 30 seconds".to_string());
}
// Wait before next poll
tokio::time::sleep(POLLING_INTERVAL).await;
}
// Deserialize the response body into TranscriptionResponse
let transcription_response: TranscriptionResponse = serde_json::from_str(&response_body)
.map_err(|e| format!("Failed to parse transcription response: {}", e))?;
Ok(transcription_response)
}

View File

@@ -95,8 +95,8 @@ pub async fn connect_to_server<R: Runtime>(
true, // disable_nagle
Some(connector), // Connector
)
.await
.map_err(|e| format!("WebSocket TLS error: {:?}", e))?;
.await
.map_err(|e| format!("WebSocket TLS error: {:?}", e))?;
let (cancel_tx, mut cancel_rx) = mpsc::channel(1);
@@ -125,7 +125,6 @@ pub async fn connect_to_server<R: Runtime>(
let _ = app_handle_clone.emit(&format!("ws-message-{}", client_id_clone), text);
},
Some(Err(_)) | None => {
log::debug!("WebSocket connection closed or error");
let _ = app_handle_clone.emit(&format!("ws-error-{}", client_id_clone), id.clone());
break;
}
@@ -133,8 +132,7 @@ pub async fn connect_to_server<R: Runtime>(
}
}
_ = cancel_rx.recv() => {
log::debug!("WebSocket connection cancelled");
let _ = app_handle_clone.emit(&format!("ws-cancel-{}", client_id_clone), id.clone());
let _ = app_handle_clone.emit(&format!("ws-error-{}", client_id_clone), id.clone());
break;
}
}

View File

@@ -1,9 +1,3 @@
use tauri::{App, WebviewWindow};
pub fn platform(
_app: &mut App,
_main_window: WebviewWindow,
_settings_window: WebviewWindow,
_check_window: WebviewWindow,
) {
}
pub fn platform(_app: &mut App, _main_window: WebviewWindow, _settings_window: WebviewWindow) {}

View File

@@ -1,5 +1,5 @@
//credits to: https://github.com/ayangweb/ayangweb-EcoPaste/blob/169323dbe6365ffe4abb64d867439ed2ea84c6d1/src-tauri/src/core/setup/mac.rs
use tauri::{App, Emitter, EventTarget, WebviewWindow};
use tauri::{ActivationPolicy, App, Emitter, EventTarget, WebviewWindow};
use tauri_nspanel::{cocoa::appkit::NSWindowCollectionBehavior, panel_delegate, WebviewWindowExt};
use crate::common::MAIN_WINDOW_LABEL;
@@ -12,12 +12,9 @@ const WINDOW_BLUR_EVENT: &str = "tauri://blur";
const WINDOW_MOVED_EVENT: &str = "tauri://move";
const WINDOW_RESIZED_EVENT: &str = "tauri://resize";
pub fn platform(
_app: &mut App,
main_window: WebviewWindow,
_settings_window: WebviewWindow,
_check_window: WebviewWindow,
) {
pub fn platform(app: &mut App, main_window: WebviewWindow, _settings_window: WebviewWindow) {
app.set_activation_policy(ActivationPolicy::Accessory);
// Convert ns_window to ns_panel
let panel = main_window.to_panel().unwrap();

View File

@@ -18,20 +18,10 @@ pub use windows::*;
#[cfg(target_os = "linux")]
pub use linux::*;
pub fn default(
app: &mut App,
main_window: WebviewWindow,
settings_window: WebviewWindow,
check_window: WebviewWindow,
) {
pub fn default(app: &mut App, main_window: WebviewWindow, settings_window: WebviewWindow) {
// Development mode automatically opens the console: https://tauri.app/develop/debug
#[cfg(debug_assertions)]
#[cfg(all(dev, debug_assertions))]
main_window.open_devtools();
platform(
app,
main_window.clone(),
settings_window.clone(),
check_window.clone(),
);
platform(app, main_window.clone(), settings_window.clone());
}

View File

@@ -1,9 +1,3 @@
use tauri::{App, WebviewWindow};
pub fn platform(
_app: &mut App,
_main_window: WebviewWindow,
_settings_window: WebviewWindow,
_check_window: WebviewWindow,
) {
}
pub fn platform(_app: &mut App, _main_window: WebviewWindow, _settings_window: WebviewWindow) {}

View File

@@ -17,7 +17,6 @@ const DEFAULT_SHORTCUT: &str = "ctrl+shift+space";
/// Set up the shortcut upon app start.
pub fn enable_shortcut(app: &App) {
log::trace!("setting up Coco hotkey");
let store = app
.store(COCO_TAURI_STORE)
.expect("creating a store should not fail");
@@ -44,7 +43,6 @@ pub fn enable_shortcut(app: &App) {
.expect("default shortcut should never be invalid");
_register_shortcut_upon_start(app, default_shortcut);
}
log::trace!("Coco hotkey has been set");
}
/// Get the stored shortcut as a string, same as [`_get_shortcut()`], except that
@@ -99,7 +97,7 @@ fn _register_shortcut<R: Runtime>(app: &AppHandle<R>, shortcut: Shortcut) {
.on_shortcut(shortcut, move |app, scut, event| {
if scut == &shortcut {
dbg!("shortcut pressed");
let main_window = app.get_webview_window(MAIN_WINDOW_LABEL).unwrap();
let main_window = app.get_window(MAIN_WINDOW_LABEL).unwrap();
if let ShortcutState::Pressed = event.state() {
let app_handle = app.clone();
if main_window.is_visible().unwrap() {
@@ -128,7 +126,7 @@ fn _register_shortcut_upon_start(app: &App, shortcut: Shortcut) {
tauri_plugin_global_shortcut::Builder::new()
.with_handler(move |app, scut, event| {
if scut == &shortcut {
let window = app.get_webview_window(MAIN_WINDOW_LABEL).unwrap();
let window = app.get_window(MAIN_WINDOW_LABEL).unwrap();
if let ShortcutState::Pressed = event.state() {
let app_handle = app.clone();

View File

@@ -1,62 +0,0 @@
//! Configuration entry App language is persisted in the frontend code, but we
//! need to access it on the backend.
//!
//! So we duplicate it here **in the MEMORY** and expose a setter method to the
//! frontend so that the value can be updated and stay update-to-date.
use function_name::named;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(non_camel_case_types)]
pub(crate) enum Lang {
en_US,
zh_CN,
}
impl std::fmt::Display for Lang {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Lang::en_US => write!(f, "en_US"),
Lang::zh_CN => write!(f, "zh_CN"),
}
}
}
impl std::str::FromStr for Lang {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"en" => Ok(Lang::en_US),
"zh" => Ok(Lang::zh_CN),
_ => Err(format!("Invalid language: {}", s)),
}
}
}
/// Cache the language config in memory.
static APP_LANG: RwLock<Option<Lang>> = RwLock::const_new(None);
/// Frontend code uses this interface to update the in-memory cached `APP_LANG` config.
#[named]
#[tauri::command]
pub(crate) async fn update_app_lang(lang: String) {
let app_lang = lang.parse::<Lang>().unwrap_or_else(|e| {
panic!(
"frontend code passes an invalid argument [{}] to interface [{}], parsing error [{}]",
lang,
function_name!(),
e
)
});
let mut write_guard = APP_LANG.write().await;
*write_guard = Some(app_lang);
}
/// Helper getter method to handle the `None` case.
pub(crate) async fn get_app_lang() -> Lang {
let opt_lang = *APP_LANG.read().await;
opt_lang.expect("frontend code did not invoke [update_app_lang()] to set the APP_LANG")
}

View File

@@ -1,178 +0,0 @@
#[derive(Debug, Clone, PartialEq, Copy)]
pub(crate) enum FileType {
Folder,
JPEGImage,
PNGImage,
PDFDocument,
PlainTextDocument,
MicrosoftWordDocument,
MicrosoftExcelSpreadsheet,
AudioFile,
VideoFile,
CHeaderFile,
TOMLDocument,
RustScript,
CSourceCode,
MarkdownDocument,
TerminalSettings,
ZipArchive,
Dmg,
Html,
Json,
Xml,
Yaml,
Css,
Vue,
React,
Sql,
Csv,
Javascript,
Lnk,
Typescript,
Python,
Java,
Golang,
Ruby,
Php,
Sass,
Sketch,
AdobeAi,
AdobePsd,
AdobePr,
AdobeAu,
AdobeAe,
AdobeLr,
AdobeXd,
AdobeFl,
AdobeId,
Svg,
Epub,
Unknown,
}
async fn get_file_type(path: &str) -> FileType {
let path = camino::Utf8Path::new(path);
// stat() is more precise than file extension, use it if possible.
if path.is_dir() {
return FileType::Folder;
}
let Some(ext) = path.extension() else {
return FileType::Unknown;
};
let ext = ext.to_lowercase();
match ext.as_str() {
"pdf" => FileType::PDFDocument,
"txt" | "text" => FileType::PlainTextDocument,
"doc" | "docx" => FileType::MicrosoftWordDocument,
"xls" | "xlsx" => FileType::MicrosoftExcelSpreadsheet,
"jpg" | "jpeg" => FileType::JPEGImage,
"png" => FileType::PNGImage,
"mp3" | "wav" | "flac" | "aac" | "ogg" | "m4a" => FileType::AudioFile,
"mp4" | "avi" | "mov" | "mkv" | "wmv" | "flv" | "webm" => FileType::VideoFile,
"h" | "hpp" => FileType::CHeaderFile,
"c" | "cpp" | "cc" | "cxx" => FileType::CSourceCode,
"toml" => FileType::TOMLDocument,
"rs" => FileType::RustScript,
"md" | "markdown" => FileType::MarkdownDocument,
"terminal" => FileType::TerminalSettings,
"zip" | "rar" | "7z" | "tar" | "gz" | "bz2" => FileType::ZipArchive,
"dmg" => FileType::Dmg,
"html" | "htm" => FileType::Html,
"json" => FileType::Json,
"xml" => FileType::Xml,
"yaml" | "yml" => FileType::Yaml,
"css" => FileType::Css,
"vue" => FileType::Vue,
"jsx" | "tsx" => FileType::React,
"sql" => FileType::Sql,
"csv" => FileType::Csv,
"js" | "mjs" => FileType::Javascript,
"ts" => FileType::Typescript,
"py" | "pyw" => FileType::Python,
"java" => FileType::Java,
"go" => FileType::Golang,
"rb" => FileType::Ruby,
"php" => FileType::Php,
"sass" | "scss" => FileType::Sass,
"sketch" => FileType::Sketch,
"ai" => FileType::AdobeAi,
"psd" => FileType::AdobePsd,
"prproj" => FileType::AdobePr,
"aup" | "aup3" => FileType::AdobeAu,
"aep" => FileType::AdobeAe,
"lrcat" => FileType::AdobeLr,
"xd" => FileType::AdobeXd,
"fla" => FileType::AdobeFl,
"indd" => FileType::AdobeId,
"svg" => FileType::Svg,
"epub" => FileType::Epub,
"lnk" => FileType::Lnk,
_ => FileType::Unknown,
}
}
fn type_to_icon(ty: FileType) -> &'static str {
match ty {
FileType::Folder => "font_file_folder",
FileType::JPEGImage => "font_file_image",
FileType::PNGImage => "font_file_image",
FileType::PDFDocument => "font_file_document_pdf",
FileType::PlainTextDocument => "font_file_txt",
FileType::MicrosoftWordDocument => "font_file_document_word",
FileType::MicrosoftExcelSpreadsheet => "font_file_spreadsheet_excel",
FileType::AudioFile => "font_file_audio",
FileType::VideoFile => "font_file_video",
FileType::CHeaderFile => "font_file_csource",
FileType::TOMLDocument => "font_file_toml",
FileType::RustScript => "font_file_rustscript1",
FileType::CSourceCode => "font_file_csource",
FileType::MarkdownDocument => "font_file_markdown",
FileType::TerminalSettings => "font_file_terminal1",
FileType::ZipArchive => "font_file_zip",
FileType::Dmg => "font_file_dmg",
FileType::Html => "font_file_html",
FileType::Json => "font_file_json",
FileType::Xml => "font_file_xml",
FileType::Yaml => "font_file_yaml",
FileType::Css => "font_file_css",
FileType::Vue => "font_file_vue",
FileType::React => "font_file_react",
FileType::Sql => "font_file_sql",
FileType::Csv => "font_file_csv",
FileType::Javascript => "font_file_javascript",
FileType::Lnk => "font_file_lnk",
FileType::Typescript => "font_file_typescript",
FileType::Python => "font_file_python",
FileType::Java => "font_file_java",
FileType::Golang => "font_file_golang",
FileType::Ruby => "font_file_ruby",
FileType::Php => "font_file_php",
FileType::Sass => "font_file_sass",
FileType::Sketch => "font_file_sketch",
FileType::AdobeAi => "font_file_adobe_ai",
FileType::AdobePsd => "font_file_adobe_psd",
FileType::AdobePr => "font_file_adobe_pr",
FileType::AdobeAu => "font_file_adobe_au",
FileType::AdobeAe => "font_file_adobe_ae",
FileType::AdobeLr => "font_file_adobe_lr",
FileType::AdobeXd => "font_file_adobe_xd",
FileType::AdobeFl => "font_file_adobe_fl",
FileType::AdobeId => "font_file_adobe_id",
FileType::Svg => "font_file_svg",
FileType::Epub => "font_file_epub",
FileType::Unknown => "font_file_unknown",
}
}
#[tauri::command]
pub(crate) async fn get_file_icon(path: String) -> &'static str {
let ty = get_file_type(path.as_str()).await;
type_to_icon(ty)
}

View File

@@ -1,7 +1,3 @@
pub(crate) mod file;
pub(crate) mod platform;
pub(crate) mod app_lang;
use std::{path::Path, process::Command};
use tauri::{AppHandle, Runtime};
use tauri_plugin_shell::ShellExt;
@@ -71,6 +67,7 @@ fn get_linux_desktop_environment() -> Option<LinuxDesktopEnvironment> {
//
// tauri_plugin_shell::open() is deprecated, but we still use it.
#[allow(deprecated)]
#[tauri::command]
pub async fn open<R: Runtime>(app_handle: AppHandle<R>, path: String) -> Result<(), String> {
if cfg!(target_os = "linux") {
let borrowed_path = Path::new(&path);

View File

@@ -1,41 +0,0 @@
use serde::{Deserialize, Serialize};
use derive_more::Display;
use std::borrow::Cow;
#[derive(Debug, Deserialize, Serialize, Copy, Clone, Hash, PartialEq, Eq, Display)]
#[serde(rename_all(serialize = "lowercase", deserialize = "lowercase"))]
pub(crate) enum Platform {
#[display("macOS")]
Macos,
#[display("Linux")]
Linux,
#[display("windows")]
Windows,
}
impl Platform {
/// Helper function to determine the current platform.
pub(crate) fn current() -> Platform {
let os_str = std::env::consts::OS;
serde_plain::from_str(os_str).unwrap_or_else(|_e| {
panic!("std::env::consts::OS is [{}], which is not a valid value for [enum Platform], valid values: ['macos', 'linux', 'windows']", os_str)
})
}
/// Return the `X-OS-NAME` HTTP request header.
pub(crate) fn to_os_name_http_header_str(&self) -> Cow<'static, str> {
match self {
Self::Macos => {
Cow::Borrowed("macos")
}
Self::Windows => {
Cow::Borrowed("windows")
}
// For Linux, we need the actual distro `ID`, not just a "linux".
Self::Linux => {
Cow::Owned(sysinfo::System::distribution_id())
}
}
}
}

View File

@@ -41,9 +41,7 @@
"title": "Coco AI Settings",
"url": "/ui/settings",
"width": 1000,
"minWidth": 1000,
"height": 700,
"minHeight": 700,
"center": true,
"transparent": true,
"maximizable": false,
@@ -55,26 +53,6 @@
"effects": ["sidebar"],
"state": "active"
}
},
{
"label": "check",
"title": "Coco AI Update",
"url": "/ui/check",
"width": 340,
"minWidth": 340,
"height": 260,
"minHeight": 260,
"center": false,
"transparent": true,
"maximizable": false,
"skipTaskbar": false,
"dragDropEnabled": false,
"hiddenTitle": true,
"visible": false,
"windowEffects": {
"effects": ["sidebar"],
"state": "active"
}
}
],
"security": {
@@ -113,7 +91,21 @@
"icons/Square310x310Logo.png",
"icons/StoreLogo.png"
],
"resources": ["assets/**/*", "icons"]
"macOS": {
"minimumSystemVersion": "12.0",
"hardenedRuntime": true,
"dmg": {
"appPosition": {
"x": 180,
"y": 180
},
"applicationFolderPosition": {
"x": 480,
"y": 180
}
}
},
"resources": ["assets", "icons"]
},
"plugins": {
"features": {

View File

@@ -1,15 +0,0 @@
{
"identifier": "rs.coco.app",
"bundle": {
"linux": {
"deb": {
"depends": ["gstreamer1.0-plugins-good"],
"desktopTemplate": "./Coco.desktop"
},
"rpm": {
"depends": ["gstreamer1-plugins-good"],
"desktopTemplate": "./Coco.desktop"
}
}
}
}

View File

@@ -96,7 +96,7 @@ export const Get = <T>(
export const Post = <T>(
url: string,
data: IAnyObj | undefined,
data: IAnyObj,
params: IAnyObj = {},
headers: IAnyObj = {}
): Promise<[any, FcResponse<T> | undefined]> => {

View File

@@ -1,63 +0,0 @@
export async function streamPost({
url,
body,
queryParams,
headers,
onMessage,
onError,
}: {
url: string;
body: any;
queryParams?: Record<string, any>;
headers?: Record<string, string>;
onMessage: (chunk: string) => void;
onError?: (err: any) => void;
}) {
const appStore = JSON.parse(localStorage.getItem("app-store") || "{}");
let baseURL = appStore.state?.endpoint_http;
if (!baseURL || baseURL === "undefined") {
baseURL = "";
}
const headersStr = localStorage.getItem("headers") || "{}";
const headersStorage = JSON.parse(headersStr);
const query = new URLSearchParams(queryParams || {}).toString();
const fullUrl = `${baseURL}${url}?${query}`;
try {
const res = await fetch(fullUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
...(headersStorage),
...(headers || {}),
},
body: JSON.stringify(body),
});
if (!res.ok || !res.body) throw new Error("Stream failed");
const reader = res.body.getReader();
const decoder = new TextDecoder("utf-8");
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
for (let i = 0; i < lines.length - 1; i++) {
const line = lines[i].trim();
if (line) onMessage(line);
}
buffer = lines[lines.length - 1];
}
} catch (err) {
console.error("streamPost error:", err);
onError?.(err);
}
}

133
src/api/tauriFetchClient.ts Normal file
View File

@@ -0,0 +1,133 @@
import { fetch } from "@tauri-apps/plugin-http";
import { clientEnv } from "@/utils/env";
import { useLogStore } from "@/stores/logStore";
import { get_server_token } from "@/commands";
interface FetchRequestConfig {
url: string;
method?: "GET" | "POST" | "PUT" | "DELETE";
headers?: Record<string, string>;
body?: any;
timeout?: number;
parseAs?: "json" | "text" | "binary";
baseURL?: string;
}
interface FetchResponse<T = any> {
data: T;
status: number;
statusText: string;
headers: Headers;
}
const timeoutPromise = (ms: number) => {
return new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error(`Request timed out after ${ms} ms`)), ms)
);
};
export const tauriFetch = async <T = any>({
url,
method = "GET",
headers = {},
body,
timeout = 30,
parseAs = "json",
baseURL = clientEnv.COCO_SERVER_URL
}: FetchRequestConfig): Promise<FetchResponse<T>> => {
const addLog = useLogStore.getState().addLog;
try {
const appStore = JSON.parse(localStorage.getItem("app-store") || "{}");
const connectStore = JSON.parse(localStorage.getItem("connect-store") || "{}");
console.log("baseURL", appStore.state?.endpoint_http)
baseURL = appStore.state?.endpoint_http || baseURL;
const authStore = JSON.parse(localStorage.getItem("auth-store") || "{}")
const auth = authStore?.state?.auth
console.log("auth", auth)
if (baseURL.endsWith("/")) {
baseURL = baseURL.slice(0, -1);
}
if (!url.startsWith("http://") && !url.startsWith("https://")) {
// If not, prepend the defaultPrefix
url = baseURL + url;
}
if (method !== "GET") {
headers["Content-Type"] = "application/json";
}
const server_id = connectStore.state?.currentService?.id || "default_coco_server"
const res: any = await get_server_token(server_id);
headers["X-API-TOKEN"] = headers["X-API-TOKEN"] || res?.access_token || undefined;
// debug API
const requestInfo = {
url,
method,
headers,
body,
timeout,
parseAs,
};
const fetchPromise = fetch(url, {
method,
headers,
body,
});
const response = await Promise.race([
fetchPromise,
timeoutPromise(timeout * 1000),
]);
const statusText = response.ok ? "OK" : "Error";
let data: any;
if (parseAs === "json") {
data = await response.json();
} else if (parseAs === "text") {
data = await response.text();
} else {
data = await response.arrayBuffer();
}
// debug API
const log = {
request: requestInfo,
response: {
data,
status: response.status,
statusText,
headers: response.headers,
},
};
addLog(log);
return log.response;
} catch (error) {
console.error("Request failed:", error);
// debug API
const log = {
request: {
url,
method,
headers,
body,
timeout,
parseAs,
},
error,
};
addLog(log);
throw error;
}
};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 346 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 347 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 485 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 491 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 504 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 500 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 203 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 196 B

View File

@@ -16,32 +16,11 @@ import {
MultiSourceQueryResponse,
} from "@/types/commands";
import { useAppStore } from "@/stores/appStore";
import { useAuthStore } from "@/stores/authStore";
// Endpoints that don't require authentication
const WHITELIST_SERVERS = [
"list_coco_servers",
"add_coco_server",
"enable_server",
"disable_server",
"remove_coco_server",
"logout_coco_server",
"refresh_coco_server_info",
"handle_sso_callback",
"query_coco_fusion",
"open_session_chat", // TODO: quick ai access is a configured service, even if the current service is not logged in, it should not affect the configured service.
];
async function invokeWithErrorHandler<T>(
command: string,
args?: Record<string, any>
): Promise<T> {
const isCurrentLogin = useAuthStore.getState().isCurrentLogin;
if (!WHITELIST_SERVERS.includes(command) && !isCurrentLogin) {
console.error("This command requires authentication");
throw new Error("This command requires authentication");
}
//
const addError = useAppStore.getState().addError;
try {
const result = await invoke<T>(command, args);
@@ -51,7 +30,7 @@ async function invokeWithErrorHandler<T>(
const failedResult = result as any;
if (failedResult.failed?.length > 0 && failedResult?.hits?.length == 0) {
failedResult.failed.forEach((error: any) => {
addError(error.error, "error");
addError(error.error, 'error');
// console.error(error.error);
});
}
@@ -124,26 +103,12 @@ export function get_connectors_by_server(id: string): Promise<Connector[]> {
return invokeWithErrorHandler(`get_connectors_by_server`, { id });
}
export function datasource_search({
id,
queryParams,
}: {
id: string;
//["query=abc", "filter=er", "filter=efg", "from=0", "size=5"]
queryParams?: string[];
}): Promise<DataSource[]> {
return invokeWithErrorHandler(`datasource_search`, { id, queryParams });
export function datasource_search(id: string): Promise<DataSource[]> {
return invokeWithErrorHandler(`datasource_search`, { id });
}
export function mcp_server_search({
id,
queryParams,
}: {
id: string;
//["query=abc", "filter=er", "filter=efg", "from=0", "size=5"]
queryParams?: string[];
}): Promise<DataSource[]> {
return invokeWithErrorHandler(`mcp_server_search`, { id, queryParams });
export function mcp_server_search(id: string): Promise<DataSource[]> {
return invokeWithErrorHandler(`mcp_server_search`, { id });
}
export function connect_to_server(id: string, clientId: string): Promise<void> {
@@ -238,7 +203,7 @@ export function new_chat({
queryParams,
}: {
serverId: string;
websocketId: string;
websocketId?: string;
message: string;
queryParams?: Record<string, any>;
}): Promise<GetResponse> {
@@ -250,22 +215,6 @@ export function new_chat({
});
}
export function chat_create({
serverId,
message,
queryParams,
}: {
serverId: string;
message: string;
queryParams?: Record<string, any>;
}): Promise<GetResponse> {
return invokeWithErrorHandler(`chat_create`, {
serverId,
message,
queryParams,
});
}
export function send_message({
serverId,
websocketId,
@@ -274,7 +223,7 @@ export function send_message({
queryParams,
}: {
serverId: string;
websocketId: string;
websocketId?: string;
sessionId: string;
message: string;
queryParams?: Record<string, any>;
@@ -288,25 +237,6 @@ export function send_message({
});
}
export function chat_chat({
serverId,
sessionId,
message,
queryParams,
}: {
serverId: string;
sessionId: string;
message: string;
queryParams?: Record<string, any>;
}): Promise<string> {
return invokeWithErrorHandler(`chat_chat`, {
serverId,
sessionId,
message,
queryParams,
});
}
export const delete_session_chat = (serverId: string, sessionId: string) => {
return invokeWithErrorHandler<boolean>(`delete_session_chat`, {
serverId,
@@ -318,31 +248,19 @@ export const update_session_chat = (payload: {
serverId: string;
sessionId: string;
title?: string;
context?: Record<string, any>;
context?: {
attachments?: string[];
};
}): Promise<boolean> => {
return invokeWithErrorHandler<boolean>("update_session_chat", payload);
};
export const assistant_search = (payload: {
serverId: string;
queryParams?: string[];
}): Promise<boolean> => {
return invokeWithErrorHandler<boolean>("assistant_search", payload);
};
export const assistant_get = (payload: {
serverId: string;
assistantId: string;
}): Promise<boolean> => {
return invokeWithErrorHandler<boolean>("assistant_get", payload);
};
export const assistant_get_multi = (payload: {
assistantId: string;
}): Promise<boolean> => {
return invokeWithErrorHandler<boolean>("assistant_get_multi", payload);
};
export const upload_attachment = async (payload: UploadAttachmentPayload) => {
const response = await invokeWithErrorHandler<UploadAttachmentResponse>(
"upload_attachment",

View File

@@ -26,12 +26,4 @@ export function show_coco(): Promise<void> {
export function show_settings(): Promise<void> {
return invoke('show_settings');
}
export function show_check(): Promise<void> {
return invoke('show_check');
}
export function hide_check(): Promise<void> {
return invoke('hide_check');
}

View File

@@ -1,85 +0,0 @@
import { useRef } from "react";
import platformAdapter from "@/utils/platformAdapter";
import { useConnectStore } from "@/stores/connectStore";
import { parseSearchQuery, unrequitable } from "@/utils";
interface AssistantFetcherProps {
debounceKeyword?: string;
assistantIDs?: string[];
}
export const AssistantFetcher = ({
debounceKeyword = "",
assistantIDs = [],
}: AssistantFetcherProps) => {
const { currentService, currentAssistant, setCurrentAssistant } =
useConnectStore();
const lastServerId = useRef<string | null>(null);
const fetchAssistant = async (params: {
current: number;
pageSize: number;
serverId?: string;
query?: string;
}) => {
try {
if (unrequitable()) {
return {
total: 0,
list: [],
};
}
const {
pageSize,
current,
serverId = currentService?.id,
query,
} = params;
const queryParams = parseSearchQuery({
from: (current - 1) * pageSize,
size: pageSize,
query: query ?? debounceKeyword,
fuzziness: 5,
filters: {
enabled: true,
id: assistantIDs,
},
});
const response = await platformAdapter.fetchAssistant(
serverId,
queryParams
);
let assistantList = response?.hits?.hits ?? [];
console.log("assistantList", assistantList);
if (
!currentAssistant?._id ||
currentService?.id !== lastServerId.current
) {
setCurrentAssistant(assistantList[0]);
}
lastServerId.current = currentService?.id;
return {
total: response.hits.total.value,
list: assistantList,
};
} catch (error) {
setCurrentAssistant(null);
console.error("assistant_search", error);
return {
total: 0,
list: [],
};
}
};
return { fetchAssistant };
};

View File

@@ -1,70 +0,0 @@
import { memo } from "react";
import clsx from "clsx";
import { Check } from "lucide-react";
import VisibleKey from "@/components/Common/VisibleKey";
import FontIcon from "@/components/Common/Icons/FontIcon";
import logoImg from "@/assets/icon.svg";
interface AssistantItemProps {
_id: string;
_source?: {
icon?: string;
name?: string;
description?: string;
};
name?: string;
isActive: boolean;
isHighlight: boolean;
isKeyboardActive: boolean;
onClick: () => void;
}
const AssistantItem = memo(
({
_id,
_source,
name,
isActive,
isHighlight,
isKeyboardActive = false,
onClick,
}: AssistantItemProps) => (
<button
key={_id}
className={clsx(
"w-full flex items-center h-[50px] gap-2 rounded-lg p-2 mb-1 transition",
{
"hover:bg-[#E6E6E6] dark:hover:bg-[#1F2937]": !isKeyboardActive,
"bg-[#E6E6E6] dark:bg-[#1F2937]": isHighlight || isActive,
}
)}
onClick={onClick}
>
<div className="flex items-center justify-center size-6 bg-white border border-[#E6E6E6] rounded-full overflow-hidden">
{_source?.icon?.startsWith("font_") ? (
<FontIcon name={_source?.icon} className="size-4" />
) : (
<img src={logoImg} className="size-4" alt={name} />
)}
</div>
<div className="text-left flex-1 min-w-0">
<div className="font-medium text-gray-900 dark:text-white truncate">
{_source?.name || "-"}
</div>
<div className="text-xs text-gray-500 dark:text-gray-400 truncate">
{_source?.description || ""}
</div>
</div>
{isActive && (
<div className="flex items-center">
<VisibleKey shortcut="↓↑" shortcutClassName="w-6 -translate-x-4">
<Check className="w-4 h-4 text-gray-500 dark:text-gray-400" />
</VisibleKey>
</div>
)}
</button>
)
);
export default AssistantItem;

View File

@@ -1,30 +1,48 @@
import { useState, useRef, useCallback, useEffect } from "react";
import { ChevronDownIcon, RefreshCw } from "lucide-react";
import { useState, useRef, useCallback, useMemo } from "react";
import {
ChevronDownIcon,
RefreshCw,
Check,
ChevronLeft,
ChevronRight,
} from "lucide-react";
import { useTranslation } from "react-i18next";
import { isNil } from "lodash-es";
import { Popover, PopoverButton, PopoverPanel } from "@headlessui/react";
import { useDebounce, useKeyPress, usePagination } from "ahooks";
import clsx from "clsx";
import { useAppStore } from "@/stores/appStore";
import logoImg from "@/assets/icon.svg";
import platformAdapter from "@/utils/platformAdapter";
import VisibleKey from "@/components/Common/VisibleKey";
import { useConnectStore } from "@/stores/connectStore";
import FontIcon from "@/components/Common/Icons/FontIcon";
import { useChatStore } from "@/stores/chatStore";
import { useShortcutsStore } from "@/stores/shortcutsStore";
import NoDataImage from "@/components/Common/NoDataImage";
import PopoverInput from "@/components/Common/PopoverInput";
import { AssistantFetcher } from "./AssistantFetcher";
import AssistantItem from "./AssistantItem";
import Pagination from "@/components/Common/Pagination";
import { useSearchStore } from "@/stores/searchStore";
import { Post } from "@/api/axiosRequest";
import { Popover, PopoverButton, PopoverPanel } from "@headlessui/react";
import {
useAsyncEffect,
useDebounce,
useKeyPress,
usePagination,
useReactive,
} from "ahooks";
import clsx from "clsx";
import NoDataImage from "../Common/NoDataImage";
import PopoverInput from "../Common/PopoverInput";
import { isNil } from "lodash-es";
interface AssistantListProps {
assistantIDs?: string[];
}
interface State {
allAssistants: any[];
}
export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
const { t } = useTranslation();
const { connected } = useChatStore();
const isTauri = useAppStore((state) => state.isTauri);
const setAssistantList = useConnectStore((state) => state.setAssistantList);
const currentService = useConnectStore((state) => state.currentService);
const currentAssistant = useConnectStore((state) => state.currentAssistant);
const setCurrentAssistant = useConnectStore((state) => {
@@ -38,30 +56,135 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
const searchInputRef = useRef<HTMLInputElement>(null);
const [keyword, setKeyword] = useState("");
const debounceKeyword = useDebounce(keyword, { wait: 500 });
const askAiAssistantId = useSearchStore((state) => state.askAiAssistantId);
const setAskAiAssistantId = useSearchStore((state) => {
return state.setAskAiAssistantId;
});
const assistantList = useConnectStore((state) => state.assistantList);
const { fetchAssistant } = AssistantFetcher({
debounceKeyword,
assistantIDs,
const state = useReactive<State>({
allAssistants: [],
});
const getAssistants = (params: { current: number; pageSize: number }) => {
return fetchAssistant(params);
const currentServiceId = useMemo(() => {
return currentService?.id;
}, [connected, currentService?.id]);
const fetchAssistant = async (params: {
current: number;
pageSize: number;
}) => {
try {
const { pageSize, current } = params;
const from = (current - 1) * pageSize;
const size = pageSize;
let response: any;
const body: Record<string, any> = {
serverId: currentServiceId,
from,
size,
};
if (debounceKeyword || assistantIDs.length > 0) {
body.query = {
bool: {
must: [],
},
};
if (debounceKeyword) {
body.query.bool.must.push({
query_string: {
fields: ["combined_fulltext"],
query: debounceKeyword,
fuzziness: "AUTO",
fuzzy_prefix_length: 2,
fuzzy_max_expansions: 10,
fuzzy_transpositions: true,
allow_leading_wildcard: false,
},
});
}
if (assistantIDs.length > 0) {
body.query.bool.must.push({
terms: {
id: assistantIDs.map((id) => id),
},
});
}
}
if (isTauri) {
if (!currentServiceId) {
throw new Error("currentServiceId is undefined");
}
response = await platformAdapter.commands("assistant_search", body);
} else {
const [error, res] = await Post(`/assistant/_search`, body);
if (error) {
throw new Error(error);
}
response = res;
}
console.log("assistant_search", response);
let assistantList = response?.hits?.hits ?? [];
console.log("assistantList", assistantList);
for (const item of assistantList) {
const index = state.allAssistants.findIndex((allItem: any) => {
return item._id === allItem._id;
});
if (index === -1) {
state.allAssistants.push(item);
} else {
state.allAssistants[index] = item;
}
}
console.log("state.allAssistants", state.allAssistants);
const matched = state.allAssistants.find((item: any) => {
return item._id === currentAssistant?._id;
});
console.log("matched", matched);
if (matched) {
setCurrentAssistant(matched);
} else {
setCurrentAssistant(assistantList[0]);
}
return {
total: response.hits.total.value,
list: assistantList,
};
} catch (error) {
setCurrentAssistant(null);
console.error("assistant_search", error);
return {
total: 0,
list: [],
};
}
};
const { pagination, runAsync } = usePagination(getAssistants, {
useAsyncEffect(async () => {
const data = await fetchAssistant({ current: 1, pageSize: 1000 });
setAssistantList(data.list);
}, [currentServiceId]);
const { pagination, runAsync } = usePagination(fetchAssistant, {
defaultPageSize: 5,
refreshDeps: [currentService?.id, debounceKeyword, currentService?.enabled],
refreshDeps: [currentServiceId, debounceKeyword],
onSuccess(data) {
setAssistants(data.list);
if (data.list.length === 0) {
setCurrentAssistant(void 0);
}
},
});
@@ -73,22 +196,6 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
setTimeout(() => setIsRefreshing(false), 1000);
};
const [highlightIndex, setHighlightIndex] = useState<number>(-1);
const [isKeyboardActive, setIsKeyboardActive] = useState(false);
useEffect(() => {
if (!askAiAssistantId || assistantList.length === 0) return;
const matched = assistantList.find((item) => {
return item._id === askAiAssistantId;
});
if (!matched) return;
setCurrentAssistant(matched);
setAskAiAssistantId(void 0);
}, [assistantList, askAiAssistantId]);
useKeyPress(
["uparrow", "downarrow", "enter"],
(event, key) => {
@@ -99,7 +206,9 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
event.stopPropagation();
event.preventDefault();
setIsKeyboardActive(true);
if (key === "enter") {
return popoverButtonRef.current?.click();
}
const index = assistants.findIndex(
(item) => item._id === currentAssistant?._id
@@ -108,20 +217,15 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
if (length <= 1) return;
let nextIndex = highlightIndex === -1 ? index : highlightIndex;
let nextIndex = index;
if (key === "uparrow") {
nextIndex = nextIndex > 0 ? nextIndex - 1 : length - 1;
} else if (key === "downarrow") {
nextIndex = nextIndex < length - 1 ? nextIndex + 1 : 0;
nextIndex = index > 0 ? index - 1 : length - 1;
} else {
nextIndex = index < length - 1 ? index + 1 : 0;
}
if (key === "enter") {
setCurrentAssistant(assistants[nextIndex]);
return popoverButtonRef.current?.click();
}
setHighlightIndex(nextIndex);
setCurrentAssistant(assistants[nextIndex]);
},
{
target: popoverRef,
@@ -142,11 +246,6 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
pagination.changeCurrent(pagination.current + 1);
}, [pagination]);
const handleMouseMove = useCallback(() => {
setHighlightIndex(-1);
setIsKeyboardActive(false);
}, []);
return (
<div className="relative">
<Popover ref={popoverRef}>
@@ -181,10 +280,7 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
</VisibleKey>
</PopoverButton>
<PopoverPanel
className="absolute z-50 top-full mt-1 left-0 w-60 rounded-xl bg-white dark:bg-[#202126] p-3 text-sm/6 text-[#333] dark:text-[#D8D8D8] shadow-lg border dark:border-white/10 focus:outline-none max-h-[calc(100vh-150px)] overflow-y-auto"
onMouseMove={handleMouseMove}
>
<PopoverPanel className="absolute z-50 top-full mt-1 left-0 w-60 rounded-xl bg-white dark:bg-[#202126] p-3 text-sm/6 text-[#333] dark:text-[#D8D8D8] shadow-lg border dark:border-white/10 focus:outline-none max-h-[calc(100vh-80px)] overflow-y-auto">
<div className="flex items-center justify-between text-sm font-bold">
<div>
{t("assistant.popover.title")}{pagination.total}
@@ -223,36 +319,81 @@ export function AssistantList({ assistantIDs = [] }: AssistantListProps) {
placeholder={t("assistant.popover.search")}
className="w-full h-8 px-2 bg-transparent border rounded-md dark:border-white/10"
onChange={(event) => {
setKeyword(event.target.value);
console.log("onChange", event.target.value);
setKeyword(event.target.value.trim());
}}
/>
</VisibleKey>
{assistants.length > 0 ? (
<>
{assistants.map((assistant, index) => {
{assistants.map((assistant) => {
const { _id, _source, name } = assistant;
const isActive = currentAssistant?._id === _id;
return (
<AssistantItem
key={assistant._id}
{...assistant}
isActive={currentAssistant?._id === assistant._id}
isHighlight={highlightIndex === index}
isKeyboardActive={isKeyboardActive}
<button
key={_id}
className={clsx(
"w-full flex items-center h-[50px] gap-2 rounded-lg p-2 mb-1 hover:bg-[#E6E6E6] dark:hover:bg-[#1F2937] transition",
{
"bg-[#E6E6E6] dark:bg-[#1F2937]": isActive,
}
)}
onClick={() => {
setCurrentAssistant(assistant);
popoverButtonRef.current?.click();
}}
/>
>
<div className="flex items-center justify-center size-6 bg-white border border-[#E6E6E6] rounded-full overflow-hidden">
{_source?.icon?.startsWith("font_") ? (
<FontIcon name={_source?.icon} className="size-4" />
) : (
<img src={logoImg} className="size-4" alt={name} />
)}
</div>
<div className="text-left flex-1 min-w-0">
<div className="font-medium text-gray-900 dark:text-white truncate">
{_source?.name || "-"}
</div>
<div className="text-xs text-gray-500 dark:text-gray-400 truncate">
{_source?.description || ""}
</div>
</div>
{isActive && (
<div className="flex items-center">
<VisibleKey
shortcut="↓↑"
shortcutClassName="w-6 -translate-x-4"
>
<Check className="w-4 h-4 text-gray-500 dark:text-gray-400" />
</VisibleKey>
</div>
)}
</button>
);
})}
<Pagination
current={pagination.current}
totalPage={pagination.totalPage}
onPrev={handlePrev}
onNext={handleNext}
className="-mx-3 -mb-3"
/>
<div className="flex items-center justify-between h-8 -mx-3 -mb-3 px-3 text-[#999] border-t dark:border-t-white/10">
<VisibleKey shortcut="leftarrow" onKeyPress={handlePrev}>
<ChevronLeft
className="size-4 cursor-pointer"
onClick={handlePrev}
/>
</VisibleKey>
<div className="text-xs">
{pagination.current}/{pagination.totalPage}
</div>
<VisibleKey shortcut="rightarrow" onKeyPress={handleNext}>
<ChevronRight
className="size-4 cursor-pointer"
onClick={handleNext}
/>
</VisibleKey>
</div>
</>
) : (
<div className="flex justify-center items-center py-2">

View File

@@ -12,18 +12,16 @@ import { useChatStore } from "@/stores/chatStore";
import { useConnectStore } from "@/stores/connectStore";
import { useWindows } from "@/hooks/useWindows";
import useMessageChunkData from "@/hooks/useMessageChunkData";
import useWebSocket from "@/hooks/useWebSocket";
import { useChatActions } from "@/hooks/useChatActions";
import { useMessageHandler } from "@/hooks/useMessageHandler";
import { ChatSidebar } from "./ChatSidebar";
import { ChatHeader } from "./ChatHeader";
import { ChatContent } from "./ChatContent";
import ConnectPrompt from "./ConnectPrompt";
import type { Chat, StartPage } from "@/types/chat";
import type { Chat } from "./types";
import PrevSuggestion from "@/components/ChatMessage/PrevSuggestion";
import { useAppStore } from "@/stores/appStore";
import { useSearchStore } from "@/stores/searchStore";
import { useAuthStore } from "@/stores/authStore";
import Splash from "./Splash";
interface ChatAIProps {
isSearchActive?: boolean;
@@ -38,13 +36,12 @@ interface ChatAIProps {
getFileUrl: (path: string) => string;
showChatHistory?: boolean;
assistantIDs?: string[];
startPage?: StartPage;
formatUrl?: (data: any) => string;
}
export interface ChatAIRef {
init: (value: string) => void;
cancelChat: () => void;
reconnect: () => void;
clearChat: () => void;
}
@@ -64,72 +61,47 @@ const ChatAI = memo(
getFileUrl,
showChatHistory,
assistantIDs,
startPage,
formatUrl,
},
ref
) => {
useImperativeHandle(ref, () => ({
init: init,
cancelChat: () => cancelChat(activeChat),
reconnect: reconnect,
clearChat: clearChat,
}));
const { curChatEnd, setCurChatEnd } = useChatStore();
const { curChatEnd, setCurChatEnd, connected, setConnected } =
useChatStore();
const isTauri = useAppStore((state) => state.isTauri);
const isCurrentLogin = useAuthStore((state) => state.isCurrentLogin);
const setIsCurrentLogin = useAuthStore((state) => {
return state.setIsCurrentLogin;
const currentService = useConnectStore((state) => state.currentService);
const visibleStartPage = useConnectStore((state) => {
return state.visibleStartPage;
});
const { currentService, visibleStartPage } = useConnectStore();
const addError = useAppStore.getState().addError;
const [activeChat, setActiveChat] = useState<Chat>();
const [timedoutShow, setTimedoutShow] = useState(false);
const [isLogin, setIsLogin] = useState(true);
const curIdRef = useRef("");
const [isSidebarOpenChat, setIsSidebarOpenChat] = useState(isSidebarOpen);
const [chats, setChats] = useState<Chat[]>([]);
const askAiSessionId = useSearchStore((state) => state.askAiSessionId);
const setAskAiSessionId = useSearchStore(
(state) => state.setAskAiSessionId
);
const askAiServerId = useSearchStore((state) => {
return state.askAiServerId;
});
useEffect(() => {
activeChatProp && setActiveChat(activeChatProp);
}, [activeChatProp]);
useEffect(() => {
if (!isTauri) return;
if (!currentService?.enabled) {
setActiveChat(void 0);
setIsCurrentLogin(false);
}
if (showChatHistory) {
getChatHistory();
}
}, [currentService?.enabled, showChatHistory]);
useEffect(() => {
if (askAiServerId || !askAiSessionId) return;
onSelectChat({ _id: askAiSessionId });
setAskAiSessionId(void 0);
}, [askAiSessionId, askAiServerId]);
const [Question, setQuestion] = useState<string>("");
const [websocketSessionId, setWebsocketSessionId] = useState("");
const onWebsocketSessionId = useCallback((sessionId: string) => {
setWebsocketSessionId(sessionId);
}, []);
const {
data: {
query_intent,
@@ -156,6 +128,16 @@ const ChatAI = memo(
const dealMsgRef = useRef<((msg: string) => void) | null>(null);
const clientId = isChatPage ? "standalone" : "popup";
const { reconnect, updateDealMsg } = useWebSocket({
clientId,
connected,
setConnected,
currentService,
dealMsgRef,
onWebsocketSessionId,
});
const {
chatClose,
cancelChat,
@@ -169,6 +151,7 @@ const ChatAI = memo(
handleRename,
handleDelete,
} = useChatActions(
currentService?.id,
setActiveChat,
setCurChatEnd,
setTimedoutShow,
@@ -176,11 +159,11 @@ const ChatAI = memo(
setQuestion,
curIdRef,
setChats,
dealMsgRef,
isSearchActive,
isDeepThinkActive,
isMCPActive,
changeInput,
websocketSessionId,
showChatHistory
);
@@ -193,13 +176,6 @@ const ChatAI = memo(
handlers
);
const updateDealMsg = useCallback(
(newDealMsg: (msg: string) => void) => {
dealMsgRef.current = newDealMsg;
},
[dealMsgRef]
);
useEffect(() => {
if (dealMsg) {
dealMsgRef.current = dealMsg;
@@ -219,8 +195,8 @@ const ChatAI = memo(
const init = useCallback(
async (value: string) => {
try {
//console.log("init", curChatEnd, activeChat?._id);
if (!isCurrentLogin) {
//console.log("init", isLogin, curChatEnd, activeChat?._id);
if (!isLogin) {
addError("Please login to continue chatting");
return;
}
@@ -229,20 +205,21 @@ const ChatAI = memo(
return;
}
if (!activeChat?._id) {
await createNewChat(value, activeChat);
await createNewChat(value, activeChat, websocketSessionId);
} else {
await handleSendMessage(value, activeChat);
await handleSendMessage(value, activeChat, websocketSessionId);
}
} catch (error) {
console.error("Failed to initialize chat:", error);
}
},
[
isCurrentLogin,
isLogin,
curChatEnd,
activeChat?._id,
createNewChat,
handleSendMessage,
websocketSessionId,
]
);
@@ -318,7 +295,6 @@ const ChatAI = memo(
(chatId: string, title: string) => {
setChats((prev) => {
const chatIndex = prev.findIndex((chat) => chat._id === chatId);
if (chatIndex === -1) return prev;
const modifiedChat = {
@@ -327,8 +303,8 @@ const ChatAI = memo(
};
const result = [...prev];
result.splice(chatIndex, 1, modifiedChat);
return result;
result.splice(chatIndex, 1);
return [modifiedChat, ...result];
});
if (activeChat?._id === chatId) {
@@ -344,12 +320,16 @@ const ChatAI = memo(
);
return (
<>
<div
data-tauri-drag-region
className={`flex flex-col rounded-md relative h-full overflow-hidden`}
>
{showChatHistory && !setIsSidebarOpen && (
<ChatSidebar
isSidebarOpen={isSidebarOpenChat}
chats={chats}
activeChat={activeChat}
// onNewChat={clearChat}
onSelectChat={onSelectChat}
onDeleteChat={deleteChat}
fetchChatHistory={getChatHistory}
@@ -357,53 +337,48 @@ const ChatAI = memo(
onRename={renameChat}
/>
)}
<div
data-tauri-drag-region
className={`flex flex-col rounded-md h-full overflow-hidden relative`}
>
<ChatHeader
clearChat={clearChat}
onOpenChatAI={openChatAI}
setIsSidebarOpen={toggleSidebar}
isSidebarOpen={isSidebarOpenChat}
<ChatHeader
onCreateNewChat={clearChat}
onOpenChatAI={openChatAI}
setIsSidebarOpen={toggleSidebar}
isSidebarOpen={isSidebarOpenChat}
activeChat={activeChat}
reconnect={reconnect}
isChatPage={isChatPage}
isLogin={isLogin}
setIsLogin={setIsLogin}
showChatHistory={showChatHistory}
assistantIDs={assistantIDs}
/>
{isLogin ? (
<ChatContent
activeChat={activeChat}
isChatPage={isChatPage}
showChatHistory={showChatHistory}
assistantIDs={assistantIDs}
curChatEnd={curChatEnd}
query_intent={query_intent}
tools={tools}
fetch_source={fetch_source}
pick_source={pick_source}
deep_read={deep_read}
think={think}
response={response}
loadingStep={loadingStep}
timedoutShow={timedoutShow}
Question={Question}
handleSendMessage={(value) =>
handleSendMessage(value, activeChat)
}
getFileUrl={getFileUrl}
/>
) : (
<ConnectPrompt />
)}
{isCurrentLogin ? (
<>
<ChatContent
activeChat={activeChat}
curChatEnd={curChatEnd}
query_intent={query_intent}
tools={tools}
fetch_source={fetch_source}
pick_source={pick_source}
deep_read={deep_read}
think={think}
response={response}
loadingStep={loadingStep}
timedoutShow={timedoutShow}
Question={Question}
handleSendMessage={(value) =>
handleSendMessage(value, activeChat)
}
getFileUrl={getFileUrl}
formatUrl={formatUrl}
/>
<Splash assistantIDs={assistantIDs} startPage={startPage} />
</>
) : (
<ConnectPrompt />
)}
{!activeChat?._id && !visibleStartPage && (
<PrevSuggestion sendMessage={init} />
)}
</div>
</>
{!activeChat?._id && !visibleStartPage && (
<PrevSuggestion sendMessage={init} />
)}
</div>
);
}
)

Some files were not shown because too many files have changed in this diff Show More