0.32.0 (#208)
* separate other tts apis and improve chunking * move old tts config to voice agent config and implement config widget ux elements for table editing * elevenlabs updated to use their client and expose model selection * linting * separate character class into character.pt and start on voice routing * linting * tts hot swapping and chunking improvements * linting * add support for piper-tts * update gitignore * linting * support google tts fix issue where quick_toggle agent config didnt work on standard config items * linting * only show agent quick toggles if the agent is enabled * change elevenlabs to use a locally maintained voice list * tts generate before / after events * voice library refactor * linting * update openai model and voices * tweak configs * voice library ux * linting * add support for kokoro tts * fix add / remove voice * voice library tags * linting * linting * tts api status * api infos and add more kokoro voices * allow voice testing before saving a new voice * tweaks to voice library ux and some api info text * linting * voice mixer * polish * voice files go into /tts instead of templates/voice * change default narrator voice * xtts confirmation note * character voice select * koboldai format template * polish * skip empty chunks * change default voice * replace em-dash with normal dash * adjust limit * replace libebreaks * chunk cleanup for whitespace * info updated * remove invalid endif tag * sort voices by ready api * Character hashable type * clarify set_simulated_environment use to avoid unwanted character deactivated * allow manual generation of tts and fix assorted issues with tts * tts websocket handler router renamed * voice mixer: when there are only 2 voices auto adjust the other weight as needed * separate persist character functions into own mixin * auto assign voices * fix chara load and auto assign voice during chara load * smart speaker separation * tts speaker separation config * generate tts for intro text * fix prompting issues with anthropic, google and openrouter clients * decensor flag off again * only to ai assisted voice markup on narrator messages * openrouter provider configuration * linting * improved sound controls * add support for chatterbox * fix info * chatterbox dependencies * remove piper and xtts2 * linting * voice params * linting * tts model overrides and move tts info to tab * reorg toolbar * allow overriding of test text * more tts fixes, apply intensity, chatterbox voices * confirm voice delete * lintinG * groq updates * reorg decorators * tts fixes * cancelable audio queue * voice library uploads * scene voice library * Config refactor (#13) * config refactor progres * config nuke continues * fix system prompts * linting * client fun * client config refactor * fix kcpp auto embedding selection * linting * fix proxy config * remove cruft * fix remaining client bugs from config refactor always use get_config(), dont keep an instance reference * support for reasoning models * more reasoning tweaks * only allow one frontend to connect at a time * fix tests * relock * relock * more client adjustments * pattern prefill * some tts agent fixes * fix ai assist cond * tts nodes * fix config retrieval * assign voice node and fixes * sim suite char gen assign voice * fix voice assign template to consider used voices * get rid of auto break repetition which wasn't working right for a while anyhow * linting * generate tts node as string node * linting * voice change on character event * tweak chatterbox max length * koboldai default template * linting * fix saving of existing voice * relock * adjust params of eva default voice * f5tts support * f5tts samples * f5tts support * f5tts tweaks * chunk size per tts api and reorg defaul f5tts voices * chatterbox default voice reog to match f5-tts default voices * voice library ux polish pass * cleanup * f5-tts tweaks * missing samples * get rid of old save cmd * add chatterbox and f5tts * housekeeping * fix some issues with world entry editing * remove cruft * replace exclamation marks * fix save immutable check * fix replace_exclamation_marks * better error handling in websocket plugins and fix issue with saves * agent config save on dialog close * ctrl click to disable / enable agents * fix quick config * allow modifying response size of focal requests * sim suite set goal always sets story intent, encourage calling of set goal during simulation start * allow setting of model * voice param tweaks * tts tweaks * fix character card load * fix note_on_value * add mixed speaker_separation mode * indicate which message the audio is for and provide way to stop audio from the message * fix issue with some tts generation failing * linting * fix speaker separate modes * bad idea * linting * refactor speaker separation prompt * add kimi think pattern * fix issue with unwanted cover image replacemenT * no scene analysis for visual promp generation (for now) * linting * tts for context investigation messages * prompt tweaks * tweak intro * fix intro text tts not auto playing sometimes * consider narrator voice when assigning voice tro a character * allow director log messages to go only into the director console * linting * startup performance fixes * init time * linting * only show audio control for messagews taht can have it * always create story intent and dont override existing saves during character card load * fix history check in dynamic story line node add HasHistory node * linting * fix intro message not having speaker separation * voice library character manager * sequantial and cancelable auto assign all * linting * fix generation cancel handling * tooltips * fix auto assign voice from scene voices * polish * kokoro does not like lazy import * update info text * complete scene export / import * linting * wording * remove cruft * fix story intent generation during character card import * fix generation cancelled emit status inf loop * prompt tweak * reasoning quick toggle, reasoning token slider, tooltips * improved reasoning pattern handling * fix indirect coercion response parsing * fix streaming issue * response length instructions * more robust streaming * adjust default * adjust formatting * litning * remove debug output * director console log function calls * install cuda script updated * linting * add another step * adjust default * update dialogue examples * fix voice selection issues * what's happening here * third time's the charm? * Vite migration (#207) * add vite config * replace babel, webpack, vue-cli deps with vite, switch to esm modules, separate eslint config * change process.env to import.meta.env * update index.html for vite and move to root * update docs for vite * remove vue cli config * update example env with vite * bump frontend deps after rebase to 32.0 --------- Co-authored-by: pax-co <Pax_801@proton.me> * properly referencer data type * what's new * better indication of dialogue example supporting multiple lines, improve dialogue example display * fix potential issue with cached scene anlysis being reused when it shouldn't * fix character creation issues with player character toggle * fix issue where editing a message would sometimes lose parts of the message * fix slider ux thumb labels (vuetify update) * relock * narrative conversation format * remove planning step * linting * tweaks * don't overthink * update dialogue examples and intro * dont dictate response length instructions when data structures are expected * prompt tweaks * prompt tweaks * linting * fix edit message not handling : well * prompt tweaks * fix tests * fix manual revision when character message was generated in new narrative mode * fix issue with message editing * Docker packages relese (#204) * add CI workflow for Docker image build and MkDocs deployment * rename CI workflow from 'ci' to 'package' * refactor CI workflow: consolidate container build and documentation deployment into a single file * fix: correct indentation for permissions in CI workflow * fix: correct indentation for steps in deploy-docs job in CI workflow * build both cpu and cuda image * docs * docs * expose writing style during state reinforcement * prompt tweaks * test container build * test container image * update docker compose * docs * test-container-build * test container build * test container build * update docker build workflows * fix guidance prompt prefix not being dropped * mount tts dir * add gpt-5 * remove debug output * docs * openai auto toggle reasoning based on model selection * linting --------- Co-authored-by: pax-co <123330830+pax-co@users.noreply.github.com> Co-authored-by: pax-co <Pax_801@proton.me> Co-authored-by: Luis Alexandre Deschamps Brandão <brandao_luis@yahoo.com>
53
.github/workflows/ci.yml
vendored
@@ -1,30 +1,57 @@
|
||||
name: ci
|
||||
name: ci
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
- prep-0.26.0
|
||||
- master
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
container-build:
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Configure Git Credentials
|
||||
run: |
|
||||
git config user.name github-actions[bot]
|
||||
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
||||
- uses: actions/setup-python@v5
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build & push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}:latest
|
||||
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
|
||||
|
||||
deploy-docs:
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Configure Git credentials
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
- uses: actions/setup-python@v5
|
||||
with: { python-version: '3.x' }
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
restore-keys: mkdocs-material-
|
||||
- run: pip install mkdocs-material mkdocs-awesome-pages-plugin mkdocs-glightbox
|
||||
- run: mkdocs gh-deploy --force
|
||||
32
.github/workflows/test-container-build.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
name: test-container-build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ 'prep-*' ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
container-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build & push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
# Tag with prep suffix to avoid conflicts with production
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
|
||||
12
.gitignore
vendored
@@ -8,11 +8,20 @@
|
||||
talemate_env
|
||||
chroma
|
||||
config.yaml
|
||||
.cursor
|
||||
.claude
|
||||
|
||||
# uv
|
||||
.venv/
|
||||
templates/llm-prompt/user/*.jinja2
|
||||
templates/world-state/*.yaml
|
||||
tts/voice/piper/*.onnx
|
||||
tts/voice/piper/*.json
|
||||
tts/voice/kokoro/*.pt
|
||||
tts/voice/xtts2/*.wav
|
||||
tts/voice/chatterbox/*.wav
|
||||
tts/voice/f5tts/*.wav
|
||||
tts/voice/voice-library.json
|
||||
scenes/
|
||||
!scenes/infinity-quest-dynamic-scenario/
|
||||
!scenes/infinity-quest-dynamic-scenario/assets/
|
||||
@@ -21,4 +30,5 @@ scenes/
|
||||
!scenes/infinity-quest/assets/
|
||||
!scenes/infinity-quest/infinity-quest.json
|
||||
tts_voice_samples/*.wav
|
||||
third-party-docs/
|
||||
third-party-docs/
|
||||
legacy-state-reinforcements.yaml
|
||||
11
Dockerfile
@@ -35,18 +35,9 @@ COPY pyproject.toml uv.lock /app/
|
||||
# Copy the Python source code (needed for editable install)
|
||||
COPY ./src /app/src
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
# Create virtual environment and install dependencies (includes CUDA support via pyproject.toml)
|
||||
RUN uv sync
|
||||
|
||||
# Conditional PyTorch+CUDA install
|
||||
ARG CUDA_AVAILABLE=false
|
||||
RUN . /app/.venv/bin/activate && \
|
||||
if [ "$CUDA_AVAILABLE" = "true" ]; then \
|
||||
echo "Installing PyTorch with CUDA support..." && \
|
||||
uv pip uninstall torch torchaudio && \
|
||||
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
|
||||
fi
|
||||
|
||||
# Stage 3: Final image
|
||||
FROM python:3.11-slim
|
||||
|
||||
|
||||
20
docker-compose.manual.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
talemate:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-8080}:8080"
|
||||
- "${BACKEND_PORT:-5050}:5050"
|
||||
volumes:
|
||||
- ./config.yaml:/app/config.yaml
|
||||
- ./scenes:/app/scenes
|
||||
- ./templates:/app/templates
|
||||
- ./chroma:/app/chroma
|
||||
- ./tts:/app/tts
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONPATH=/app/src:$PYTHONPATH
|
||||
command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]
|
||||
@@ -2,11 +2,7 @@ version: '3.8'
|
||||
|
||||
services:
|
||||
talemate:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- CUDA_AVAILABLE=${CUDA_AVAILABLE:-false}
|
||||
image: ghcr.io/vegu-ai/talemate:latest
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-8080}:8080"
|
||||
- "${BACKEND_PORT:-5050}:5050"
|
||||
@@ -15,6 +11,7 @@ services:
|
||||
- ./scenes:/app/scenes
|
||||
- ./templates:/app/templates
|
||||
- ./chroma:/app/chroma
|
||||
- ./tts:/app/tts
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONPATH=/app/src:$PYTHONPATH
|
||||
|
||||
@@ -27,10 +27,10 @@ uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
|
||||
|
||||
### Letting the frontend know about the new host and port
|
||||
|
||||
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL`.
|
||||
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VITE_TALEMATE_BACKEND_WEBSOCKET_URL`.
|
||||
|
||||
```env
|
||||
VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
|
||||
VITE_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
|
||||
```
|
||||
|
||||
Next rebuild the frontend.
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
!!! example "Experimental"
|
||||
Talemate through docker has not received a lot of testing from me, so please let me know if you encounter any issues.
|
||||
|
||||
You can do so by creating an issue on the [:material-github: GitHub repository](https://github.com/vegu-ai/talemate)
|
||||
|
||||
## Quick install instructions
|
||||
|
||||
1. `git clone https://github.com/vegu-ai/talemate.git`
|
||||
1. `cd talemate`
|
||||
1. copy config file
|
||||
1. linux: `cp config.example.yaml config.yaml`
|
||||
1. windows: `copy config.example.yaml config.yaml`
|
||||
1. If your host has a CUDA compatible Nvidia GPU
|
||||
1. Windows (via PowerShell): `$env:CUDA_AVAILABLE="true"; docker compose up`
|
||||
1. Linux: `CUDA_AVAILABLE=true docker compose up`
|
||||
1. If your host does **NOT** have a CUDA compatible Nvidia GPU
|
||||
1. Windows: `docker compose up`
|
||||
1. Linux: `docker compose up`
|
||||
1. windows: `copy config.example.yaml config.yaml` (or just copy the file and rename it via the file explorer)
|
||||
1. `docker compose up`
|
||||
1. Navigate your browser to http://localhost:8080
|
||||
|
||||
!!! info "Pre-built Images"
|
||||
The default setup uses pre-built images from GitHub Container Registry that include CUDA support by default. To manually build the container instead, use `docker compose -f docker-compose.manual.yml up --build`.
|
||||
|
||||
!!! note
|
||||
When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.
|
||||
|
||||
BIN
docs/img/0.32.0/add-chatterbox-voice.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/img/0.32.0/add-elevenlabs-voice.png
Normal file
|
After Width: | Height: | Size: 29 KiB |
BIN
docs/img/0.32.0/add-f5tts-voice.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.32.0/character-voice-assignment.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
BIN
docs/img/0.32.0/chatterbox-api-settings.png
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
docs/img/0.32.0/chatterbox-parameters.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
docs/img/0.32.0/client-reasoning-2.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
docs/img/0.32.0/client-reasoning.png
Normal file
|
After Width: | Height: | Size: 75 KiB |
BIN
docs/img/0.32.0/elevenlabs-api-settings.png
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
docs/img/0.32.0/elevenlabs-copy-voice-id.png
Normal file
|
After Width: | Height: | Size: 9.6 KiB |
BIN
docs/img/0.32.0/f5tts-api-settings.png
Normal file
|
After Width: | Height: | Size: 72 KiB |
BIN
docs/img/0.32.0/f5tts-parameters.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
docs/img/0.32.0/google-tts-api-settings.png
Normal file
|
After Width: | Height: | Size: 63 KiB |
BIN
docs/img/0.32.0/kokoro-mixer.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
docs/img/0.32.0/openai-tts-api-settings.png
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
docs/img/0.32.0/voice-agent-settings.png
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
docs/img/0.32.0/voice-agent-status-characters.png
Normal file
|
After Width: | Height: | Size: 3.0 KiB |
BIN
docs/img/0.32.0/voice-library-access.png
Normal file
|
After Width: | Height: | Size: 9.3 KiB |
BIN
docs/img/0.32.0/voice-library-api-status.png
Normal file
|
After Width: | Height: | Size: 6.6 KiB |
BIN
docs/img/0.32.0/voice-library-interface.png
Normal file
|
After Width: | Height: | Size: 142 KiB |
58
docs/user-guide/agents/voice/chatterbox.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Chatterbox
|
||||
|
||||
Local zero shot voice cloning from .wav files.
|
||||
|
||||

|
||||
|
||||
##### Device
|
||||
|
||||
Auto-detects best available option
|
||||
|
||||
##### Model
|
||||
|
||||
Default Chatterbox model optimized for speed
|
||||
|
||||
##### Chunk size
|
||||
|
||||
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
|
||||
|
||||
## Adding Chatterbox Voices
|
||||
|
||||
### Voice Requirements
|
||||
|
||||
Chatterbox voices require:
|
||||
|
||||
- Reference audio file (.wav format, 5-15 seconds optimal)
|
||||
- Clear speech with minimal background noise
|
||||
- Single speaker throughout the sample
|
||||
|
||||
### Creating a Voice
|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click **:material-plus: New**
|
||||
3. Select "Chatterbox" as the provider
|
||||
4. Configure the voice:
|
||||
|
||||

|
||||
|
||||
**Label:** Descriptive name (e.g., "Marcus - Deep Male")
|
||||
|
||||
**Voice ID / Upload File** Upload a .wav file containing the voice sample. The uploaded reference audio will also be the voice ID.
|
||||
|
||||
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
|
||||
|
||||
**Tags:** Add descriptive tags for organization
|
||||
|
||||
**Extra voice parameters**
|
||||
|
||||
There exist some optional parameters that can be set here on a per voice level.
|
||||
|
||||

|
||||
|
||||
##### Exaggeration Level
|
||||
|
||||
Exaggeration (Neutral = 0.5, extreme values can be unstable). Higher exaggeration tends to speed up speech; reducing cfg helps compensate with slower, more deliberate pacing.
|
||||
|
||||
##### CFG / Pace
|
||||
|
||||
If the reference speaker has a fast speaking style, lowering cfg to around 0.3 can improve pacing.
|
||||
@@ -1,7 +1,41 @@
|
||||
# ElevenLabs
|
||||
|
||||
If you have not configured the ElevenLabs TTS API, the voice agent will show that the API key is missing.
|
||||
Professional voice synthesis with voice cloning capabilities using ElevenLabs API.
|
||||
|
||||

|
||||

|
||||
|
||||
See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on how to set up the API key.
|
||||
## API Setup
|
||||
|
||||
ElevenLabs requires an API key. See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on obtaining and setting an API key.
|
||||
|
||||
## Configuration
|
||||
|
||||
**Model:** Select from available ElevenLabs models
|
||||
|
||||
!!! warning "Voice Limits"
|
||||
Your ElevenLabs subscription allows you to maintain a set number of voices (10 for the cheapest plan). Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button. It is recommended to test voices via their voice library instead.
|
||||
|
||||
## Adding ElevenLabs Voices
|
||||
|
||||
### Getting Voice IDs
|
||||
|
||||
1. Go to [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab) to view your voices
|
||||
2. Find or create the voice you want to use
|
||||
3. Click "More Actions" -> "Copy Voice ID" for the desired voice
|
||||
|
||||

|
||||
|
||||
### Creating a Voice in Talemate
|
||||
|
||||

|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click "Add Voice"
|
||||
3. Select "ElevenLabs" as the provider
|
||||
4. Configure the voice:
|
||||
|
||||
**Label:** Descriptive name for the voice
|
||||
|
||||
**Provider ID:** Paste the ElevenLabs voice ID you copied
|
||||
|
||||
**Tags:** Add descriptive tags for organization
|
||||
78
docs/user-guide/agents/voice/f5tts.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# F5-TTS
|
||||
|
||||
Local zero shot voice cloning from .wav files.
|
||||
|
||||

|
||||
|
||||
##### Device
|
||||
Auto-detects best available option (GPU preferred)
|
||||
|
||||
##### Model
|
||||
|
||||
- F5TTS_v1_Base (default, most recent model)
|
||||
- F5TTS_Base
|
||||
- E2TTS_Base
|
||||
|
||||
##### NFE Step
|
||||
|
||||
Number of steps to generate the voice. Higher values result in more detailed voices.
|
||||
|
||||
##### Chunk size
|
||||
|
||||
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
|
||||
|
||||
##### Replace exclamation marks
|
||||
|
||||
If checked, exclamation marks will be replaced with periods. This is recommended for `F5TTS_v1_Base` since it seems to over exaggerate exclamation marks.
|
||||
|
||||
## Adding F5-TTS Voices
|
||||
|
||||
### Voice Requirements
|
||||
|
||||
F5-TTS voices require:
|
||||
|
||||
- Reference audio file (.wav format, 10-30 seconds)
|
||||
- Clear speech with minimal background noise
|
||||
- Single speaker throughout the sample
|
||||
- Reference text (optional but recommended)
|
||||
|
||||
### Creating a Voice
|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click "Add Voice"
|
||||
3. Select "F5-TTS" as the provider
|
||||
4. Configure the voice:
|
||||
|
||||

|
||||
|
||||
**Label:** Descriptive name (e.g., "Emma - Calm Female")
|
||||
|
||||
**Voice ID / Upload File** Upload a .wav file containing the **reference audio** voice sample. The uploaded reference audio will also be the voice ID.
|
||||
|
||||
- Use 6-10 second samples (longer doesn't improve quality)
|
||||
- Ensure clear speech with minimal background noise
|
||||
- Record at natural speaking pace
|
||||
|
||||
**Reference Text:** Enter the exact text spoken in the reference audio for improved quality
|
||||
|
||||
- Enter exactly what is spoken in the reference audio
|
||||
- Include proper punctuation and capitalization
|
||||
- Improves voice cloning accuracy significantly
|
||||
|
||||
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
|
||||
|
||||
**Tags:** Add descriptive tags (gender, age, style) for organization
|
||||
|
||||
**Extra voice parameters**
|
||||
|
||||
There exist some optional parameters that can be set here on a per voice level.
|
||||
|
||||

|
||||
|
||||
##### Speed
|
||||
|
||||
Allows you to adjust the speed of the voice.
|
||||
|
||||
##### CFG Strength
|
||||
|
||||
A higher CFG strength generally leads to more faithful reproduction of the input text, while a lower CFG strength can result in more varied or creative speech output, potentially at the cost of text-to-speech accuracy.
|
||||
15
docs/user-guide/agents/voice/google.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Google Gemini-TTS
|
||||
|
||||
Google Gemini-TTS provides access to Google's text-to-speech service.
|
||||
|
||||
## API Setup
|
||||
|
||||
Google Gemini-TTS requires a Google Cloud API key.
|
||||
|
||||
See the [Google Cloud API setup](/talemate/user-guide/apis/google/) for instructions on obtaining an API key.
|
||||
|
||||
## Configuration
|
||||
|
||||

|
||||
|
||||
**Model:** Select from available Google TTS models
|
||||
@@ -1,6 +1,26 @@
|
||||
# Overview
|
||||
|
||||
Talemate supports Text-to-Speech (TTS) functionality, allowing users to convert text into spoken audio. This document outlines the steps required to configure TTS for Talemate using different providers, including ElevenLabs and a local TTS API.
|
||||
In 0.32.0 Talemate's TTS (Text-to-Speech) agent has been completely refactored to provide advanced voice capabilities including per-character voice assignment, speaker separation, and support for multiple local and remote APIs. The voice system now includes a comprehensive voice library for managing and organizing voices across all supported providers.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Per-character voice assignment** - Each character can have their own unique voice
|
||||
- **Speaker separation** - Automatic detection and separation of dialogue from narration
|
||||
- **Voice library management** - Centralized management of all voices across providers
|
||||
- **Multiple API support** - Support for both local and remote TTS providers
|
||||
- **Director integration** - Automatic voice assignment for new characters
|
||||
|
||||
## Supported APIs
|
||||
|
||||
### Local APIs
|
||||
- **Kokoro** - Fastest generation with predefined voice models and mixing
|
||||
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
|
||||
- **Chatterbox** - High-quality voice cloning (slower generation)
|
||||
|
||||
### Remote APIs
|
||||
- **ElevenLabs** - Professional voice synthesis with voice cloning
|
||||
- **Google Gemini-TTS** - Google's text-to-speech service
|
||||
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
|
||||
|
||||
## Enable the Voice agent
|
||||
|
||||
@@ -12,28 +32,30 @@ If your voice agent is disabled - indicated by the grey dot next to the agent -
|
||||
|
||||
 
|
||||
|
||||
!!! note "Ctrl click to toggle agent"
|
||||
You can use Ctrl click to toggle the agent on and off.
|
||||
|
||||
!!! abstract "Next: Connect to a TTS api"
|
||||
Next you need to decide which service / api to use for audio generation and configure the voice agent accordingly.
|
||||
## Voice Library Management
|
||||
|
||||
- [OpenAI](openai.md)
|
||||
- [ElevenLabs](elevenlabs.md)
|
||||
- [Local TTS](local_tts.md)
|
||||
Voices are managed through the Voice Library, accessible from the main application bar. The Voice Library allows you to:
|
||||
|
||||
You can also find more information about the various settings [here](settings.md).
|
||||
- Add and organize voices from all supported providers
|
||||
- Assign voices to specific characters
|
||||
- Create mixed voices (Kokoro)
|
||||
- Manage both global and scene-specific voice libraries
|
||||
|
||||
## Select a voice
|
||||
See the [Voice Library Guide](voice-library.md) for detailed instructions.
|
||||
|
||||

|
||||
## Character Voice Assignment
|
||||
|
||||
Click on the agent to open the agent settings.
|
||||

|
||||
|
||||
Then click on the `Narrator Voice` dropdown and select a voice.
|
||||
Characters can have individual voices assigned through the Voice Library. When a character has a voice assigned:
|
||||
|
||||

|
||||
1. Their dialogue will use their specific voice
|
||||
2. The narrator voice is used for exposition in their messages (with speaker separation enabled)
|
||||
3. If their assigned voice's API is not available, it falls back to the narrator voice
|
||||
|
||||
The selection is saved automatically, click anywhere outside the agent window to close it.
|
||||
The Voice agent status will show all assigned character voices and their current status.
|
||||
|
||||
The Voice agent should now show that the voice is selected and be ready to use.
|
||||
|
||||

|
||||

|
||||
55
docs/user-guide/agents/voice/kokoro.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Kokoro
|
||||
|
||||
Kokoro provides predefined voice models and voice mixing capabilities for creating custom voices.
|
||||
|
||||
## Using Predefined Voices
|
||||
|
||||
Kokoro comes with built-in voice models that are ready to use immediately
|
||||
|
||||
Available predefined voices include various male and female voices with different characteristics.
|
||||
|
||||
## Creating Mixed Voices
|
||||
|
||||
Kokor allows you to mix voices together to create a new voice.
|
||||
|
||||
### Voice Mixing Interface
|
||||
|
||||
|
||||
To create a mixed voice:
|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click ":material-plus: New"
|
||||
3. Select "Kokoro" as the provider
|
||||
4. Choose ":material-tune:Mixer" option
|
||||
5. Configure the mixed voice:
|
||||
|
||||

|
||||
|
||||
|
||||
**Label:** Descriptive name for the mixed voice
|
||||
|
||||
**Base Voices:** Select 2-4 existing Kokoro voices to combine
|
||||
|
||||
**Weights:** Set the influence of each voice (0.1 to 1.0)
|
||||
|
||||
**Tags:** Descriptive tags for organization
|
||||
|
||||
### Weight Configuration
|
||||
|
||||
Each selected voice can have its weight adjusted:
|
||||
|
||||
- Higher weights make that voice more prominent in the mix
|
||||
- Lower weights make that voice more subtle
|
||||
- Total weights need to sum to 1.0
|
||||
- Experiment with different combinations to achieve desired results
|
||||
|
||||
### Saving Mixed Voices
|
||||
|
||||
Once configured click "Add Voice", mixed voices are saved to your voice library and can be:
|
||||
|
||||
- Assigned to characters
|
||||
- Used as narrator voices
|
||||
|
||||
just like any other voice.
|
||||
|
||||
Saving a mixed cvoice may take a moment to complete.
|
||||
@@ -1,53 +0,0 @@
|
||||
# Local TTS
|
||||
|
||||
!!! warning
|
||||
This has not been tested in a while and may not work as expected. It will likely be replaced with something different in the future. If this approach is currently broken its likely to remain so until it is replaced.
|
||||
|
||||
For running a local TTS API, Talemate requires specific dependencies to be installed.
|
||||
|
||||
### Windows Installation
|
||||
|
||||
Run `install-local-tts.bat` to install the necessary requirements.
|
||||
|
||||
### Linux Installation
|
||||
|
||||
Execute the following command:
|
||||
|
||||
```bash
|
||||
pip install TTS
|
||||
```
|
||||
|
||||
### Model and Device Configuration
|
||||
|
||||
1. Choose a TTS model from the [Coqui TTS model list](https://github.com/coqui-ai/TTS).
|
||||
2. Decide whether to use `cuda` or `cpu` for the device setting.
|
||||
3. The first time you run TTS through the local API, it will download the specified model. Please note that this may take some time, and the download progress will be visible in the Talemate backend output.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
tts:
|
||||
device: cuda # or 'cpu'
|
||||
model: tts_models/multilingual/multi-dataset/xtts_v2
|
||||
```
|
||||
|
||||
### Voice Samples Configuration
|
||||
|
||||
Configure voice samples by setting the `value` field to the path of a .wav file voice sample. Official samples can be downloaded from [Coqui XTTS-v2 samples](https://huggingface.co/coqui/XTTS-v2/tree/main/samples).
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
tts:
|
||||
voices:
|
||||
- label: English Male
|
||||
value: path/to/english_male.wav
|
||||
- label: English Female
|
||||
value: path/to/english_female.wav
|
||||
```
|
||||
|
||||
## Saving the Configuration
|
||||
|
||||
After configuring the `config.yaml` file, save your changes. Talemate will use the updated settings the next time it starts.
|
||||
|
||||
For more detailed information on configuring Talemate, refer to the `config.py` file in the Talemate source code and the `config.example.yaml` file for a barebone configuration example.
|
||||
@@ -8,16 +8,12 @@ See the [OpenAI API setup](/apis/openai.md) for instructions on how to set up th
|
||||
|
||||
## Settings
|
||||
|
||||

|
||||

|
||||
|
||||
##### Model
|
||||
|
||||
Which model to use for generation.
|
||||
|
||||
- GPT-4o Mini TTS
|
||||
- TTS-1
|
||||
- TTS-1 HD
|
||||
|
||||
!!! quote "OpenAI API documentation on quality"
|
||||
For real-time applications, the standard tts-1 model provides the lowest latency but at a lower quality than the tts-1-hd model. Due to the way the audio is generated, tts-1 is likely to generate content that has more static in certain situations than tts-1-hd. In some cases, the audio may not have noticeable differences depending on your listening device and the individual person.
|
||||
|
||||
Generally i have found that HD is fast enough for talemate, so this is the default.
|
||||
- TTS-1 HD
|
||||
@@ -1,36 +1,65 @@
|
||||
# Settings
|
||||
|
||||

|
||||

|
||||
|
||||
##### API
|
||||
##### Enabled APIs
|
||||
|
||||
The TTS API to use for voice generation.
|
||||
Select which TTS APIs to enable. You can enable multiple APIs simultaneously:
|
||||
|
||||
- OpenAI
|
||||
- ElevenLabs
|
||||
- Local TTS
|
||||
- **Kokoro** - Fastest generation with predefined voice models and mixing
|
||||
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
|
||||
- **Chatterbox** - High-quality voice cloning (slower generation)
|
||||
- **ElevenLabs** - Professional voice synthesis with voice cloning
|
||||
- **Google Gemini-TTS** - Google's text-to-speech service
|
||||
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
|
||||
|
||||
!!! note "Multi-API Support"
|
||||
You can enable multiple APIs and assign different voices from different providers to different characters. The system will automatically route voice generation to the appropriate API based on the voice assignment.
|
||||
|
||||
##### Narrator Voice
|
||||
|
||||
The voice to use for narration. Each API will come with its own set of voices.
|
||||
The default voice used for narration and as a fallback for characters without assigned voices.
|
||||
|
||||

|
||||
The dropdown shows all available voices from all enabled APIs, with the format: "Voice Name (Provider)"
|
||||
|
||||
!!! note "Local TTS"
|
||||
For local TTS, you will have to provide voice samples yourself. See [Local TTS Instructions](local_tts.md) for more information.
|
||||
!!! info "Voice Management"
|
||||
Voices are managed through the Voice Library, accessible from the main application bar. Adding, removing, or modifying voices should be done through the Voice Library interface.
|
||||
|
||||
##### Generate for player
|
||||
##### Speaker Separation
|
||||
|
||||
Whether to generate voice for the player. If enabled, whenever the player speaks, the voice agent will generate audio for them.
|
||||
Controls how dialogue is separated from exposition in messages:
|
||||
|
||||
##### Generate for NPCs
|
||||
- **No separation** - Character messages use character voice entirely, narrator messages use narrator voice
|
||||
- **Simple** - Basic separation of dialogue from exposition using punctuation analysis, with exposition being read by the narrator voice
|
||||
- **Mixed** - Enables AI assisted separation for narrator messages and simple separation for character messages
|
||||
- **AI assisted** - AI assisted separation for both narrator and character messages
|
||||
|
||||
Whether to generate voice for NPCs. If enabled, whenever a non player character speaks, the voice agent will generate audio for them.
|
||||
!!! warning "AI Assisted Performance"
|
||||
AI-assisted speaker separation sends additional prompts to your LLM, which may impact response time and API costs.
|
||||
|
||||
##### Generate for narration
|
||||
##### Auto-generate for player
|
||||
|
||||
Whether to generate voice for narration. If enabled, whenever the narrator speaks, the voice agent will generate audio for them.
|
||||
Generate voice automatically for player messages
|
||||
|
||||
##### Split generation
|
||||
##### Auto-generate for AI characters
|
||||
|
||||
If enabled, the voice agent will generate audio in chunks, allowing for faster generation. This does however cause it lose context between chunks, and inflection may not be as good.
|
||||
Generate voice automatically for NPC/AI character messages
|
||||
|
||||
##### Auto-generate for narration
|
||||
|
||||
Generate voice automatically for narrator messages
|
||||
|
||||
##### Auto-generate for context investigation
|
||||
|
||||
Generate voice automatically for context investigation messages
|
||||
|
||||
## Advanced Settings
|
||||
|
||||
Advanced settings are configured per-API and can be found in the respective API configuration sections:
|
||||
|
||||
- **Chunk size** - Maximum text length per generation request
|
||||
- **Model selection** - Choose specific models for each API
|
||||
- **Voice parameters** - Provider-specific voice settings
|
||||
|
||||
!!! tip "Performance Optimization"
|
||||
Each API has different optimal chunk sizes and parameters. The system automatically handles chunking and queuing for optimal performance across all enabled APIs.
|
||||
156
docs/user-guide/agents/voice/voice-library.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Voice Library
|
||||
|
||||
The Voice Library is the central hub for managing all voices across all TTS providers in Talemate. It provides a unified interface for organizing, creating, and assigning voices to characters.
|
||||
|
||||
## Accessing the Voice Library
|
||||
|
||||
The Voice Library can be accessed from the main application bar at the top of the Talemate interface.
|
||||
|
||||

|
||||
|
||||
Click the voice icon to open the Voice Library dialog.
|
||||
|
||||
!!! note "Voice agent needs to be enabled"
|
||||
The Voice agent needs to be enabled for the voice library to be available.
|
||||
|
||||
## Voice Library Interface
|
||||
|
||||

|
||||
|
||||
The Voice Library interface consists of:
|
||||
|
||||
### Scope Tabs
|
||||
|
||||
- **Global** - Voices available across all scenes
|
||||
- **Scene** - Voices specific to the current scene (only visible when a scene is loaded)
|
||||
- **Characters** - Character voice assignments for the current scene (only visible when a scene is loaded)
|
||||
|
||||
### API Status
|
||||
|
||||
The toolbar shows the status of all TTS APIs:
|
||||
|
||||
- **Green** - API is enabled and ready
|
||||
- **Orange** - API is enabled but not configured
|
||||
- **Red** - API has configuration issues
|
||||
- **Gray** - API is disabled
|
||||
|
||||

|
||||
|
||||
## Managing Voices
|
||||
|
||||
### Global Voice Library
|
||||
|
||||
The global voice library contains voices that are available across all scenes. These include:
|
||||
|
||||
- Default voices provided by each TTS provider
|
||||
- Custom voices you've added
|
||||
|
||||
#### Adding New Voices
|
||||
|
||||
To add a new voice:
|
||||
|
||||
1. Click the "+ New" button
|
||||
2. Select the TTS provider
|
||||
3. Configure the voice parameters:
|
||||
- **Label** - Display name for the voice
|
||||
- **Provider ID** - Provider-specific identifier
|
||||
- **Tags** - Free-form descriptive tags you define (gender, age, style, etc.)
|
||||
- **Parameters** - Provider-specific settings
|
||||
|
||||
Check the provider specific documentation for more information on how to configure the voice.
|
||||
|
||||
#### Voice Types by Provider
|
||||
|
||||
**F5-TTS & Chatterbox:**
|
||||
|
||||
- Upload .wav reference files for voice cloning
|
||||
- Specify reference text for better quality
|
||||
- Adjust speed and other parameters
|
||||
|
||||
**Kokoro:**
|
||||
|
||||
- Select from predefined voice models
|
||||
- Create mixed voices by combining multiple models
|
||||
- Adjust voice mixing weights
|
||||
|
||||
**ElevenLabs:**
|
||||
|
||||
- Select from available ElevenLabs voices
|
||||
- Configure voice settings and stability
|
||||
- Use custom cloned voices from your ElevenLabs account
|
||||
|
||||
**OpenAI:**
|
||||
|
||||
- Choose from available OpenAI voice models
|
||||
- Configure model (GPT-4o Mini TTS, TTS-1, TTS-1-HD)
|
||||
|
||||
**Google Gemini-TTS:**
|
||||
|
||||
- Select from Google's voice models
|
||||
- Configure language and gender settings
|
||||
|
||||
### Scene Voice Library
|
||||
|
||||
Scene-specific voices are only available within the current scene. This is useful for:
|
||||
|
||||
- Scene-specific characters
|
||||
- Temporary voice experiments
|
||||
- Custom voices for specific scenarios
|
||||
|
||||
Scene voices are saved with the scene and will be available when the scene is loaded.
|
||||
|
||||
## Character Voice Assignment
|
||||
|
||||
### Automatic Assignment
|
||||
|
||||
The Director agent can automatically assign voices to new characters based on:
|
||||
|
||||
- Character tags and attributes
|
||||
- Voice tags matching character personality
|
||||
- Available voices in the voice library
|
||||
|
||||
This feature can be enabled in the Director agent settings.
|
||||
|
||||
### Manual Assignment
|
||||
|
||||

|
||||
|
||||
To manually assign a voice to a character:
|
||||
|
||||
1. Go to the "Characters" tab in the Voice Library
|
||||
2. Find the character in the list
|
||||
3. Click the voice dropdown for that character
|
||||
4. Select a voice from the available options
|
||||
5. The assignment is saved automatically
|
||||
|
||||
### Character Voice Status
|
||||
|
||||
The character list shows:
|
||||
|
||||
- **Character name**
|
||||
- **Currently assigned voice** (if any)
|
||||
- **Voice status** - whether the voice's API is available
|
||||
- **Quick assignment controls**
|
||||
|
||||
## Voice Tags and Organization
|
||||
|
||||
### Tagging System
|
||||
|
||||
Voices can be tagged with any descriptive attributes you choose. Tags are completely free-form and user-defined. Common examples include:
|
||||
|
||||
- **Gender**: male, female, neutral
|
||||
- **Age**: young, mature, elderly
|
||||
- **Style**: calm, energetic, dramatic, mysterious
|
||||
- **Quality**: deep, high, raspy, smooth
|
||||
- **Character types**: narrator, villain, hero, comic relief
|
||||
- **Custom tags**: You can create any tags that help you organize your voices
|
||||
|
||||
### Filtering and Search
|
||||
|
||||
Use the search bar to filter voices by:
|
||||
- Voice label/name
|
||||
- Provider
|
||||
- Tags
|
||||
- Character assignments
|
||||
|
||||
This makes it easy to find the right voice for specific characters or situations.
|
||||
82
docs/user-guide/clients/reasoning.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# Reasoning Model Support
|
||||
|
||||
Talemate supports reasoning models that can perform step-by-step thinking before generating their final response. This feature allows models to work through complex problems internally before providing an answer.
|
||||
|
||||
## Enabling Reasoning Support
|
||||
|
||||
To enable reasoning support for a client:
|
||||
|
||||
1. Open the **Clients** dialog from the main toolbar
|
||||
2. Select the client you want to configure
|
||||
3. Navigate to the **Reasoning** tab in the client configuration
|
||||
|
||||

|
||||
|
||||
4. Check the **Enable Reasoning** checkbox
|
||||
|
||||
## Configuring Reasoning Tokens
|
||||
|
||||
Once reasoning is enabled, you can configure the **Reasoning Tokens** setting using the slider:
|
||||
|
||||

|
||||
|
||||
### Recommended Token Amounts
|
||||
|
||||
**For local reasoning models:** Use a high token allocation (recommended: 4096 tokens) to give the model sufficient space for complex reasoning.
|
||||
|
||||
**For remote APIs:** Start with lower amounts (512-1024 tokens) and adjust based on your needs and token costs.
|
||||
|
||||
### Token Allocation Behavior
|
||||
|
||||
The behavior of the reasoning tokens setting depends on your API provider:
|
||||
|
||||
**For APIs that support direct reasoning token specification:**
|
||||
|
||||
- The specified tokens will be allocated specifically for reasoning
|
||||
- The model will use these tokens for internal thinking before generating the response
|
||||
|
||||
**For APIs that do NOT support reasoning token specification:**
|
||||
|
||||
- The tokens are added as extra allowance to the response token limit for ALL requests
|
||||
- This may lead to more verbose responses than usual since Talemate normally uses response token limits to control verbosity
|
||||
|
||||
!!! warning "Increased Verbosity"
|
||||
For providers without direct reasoning token support, enabling reasoning may result in more verbose responses since the extra tokens are added to all requests.
|
||||
|
||||
## Response Pattern Configuration
|
||||
|
||||
When reasoning is enabled, you may need to configure a **Pattern to strip from the response** to remove the thinking process from the final output.
|
||||
|
||||
### Default Patterns
|
||||
|
||||
Talemate provides quick-access buttons for common reasoning patterns:
|
||||
|
||||
- **Default** - Uses the built-in pattern: `.*?</think>`
|
||||
- **`.*?◁/think▷`** - For models using arrow-style thinking delimiters
|
||||
- **`.*?</think>`** - For models using XML-style think tags
|
||||
|
||||
### Custom Patterns
|
||||
|
||||
You can also specify a custom regular expression pattern that matches your model's reasoning format. This pattern will be used to strip the thinking tokens from the response before displaying it to the user.
|
||||
|
||||
## Model Compatibility
|
||||
|
||||
Not all models support reasoning. This feature works best with:
|
||||
|
||||
- Models specifically trained for chain-of-thought reasoning
|
||||
- Models that support structured thinking patterns
|
||||
- APIs that provide reasoning token specification
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Coercion Disabled**: When reasoning is enabled, LLM coercion (pre-filling responses) is automatically disabled since reasoning models need to generate their complete thought process
|
||||
- **Response Time**: Reasoning models may take longer to respond as they work through their thinking process
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Pattern Not Working
|
||||
If the reasoning pattern isn't properly stripping the thinking process:
|
||||
|
||||
1. Check your model's actual reasoning output format
|
||||
2. Adjust the regular expression pattern to match your model's specific format
|
||||
3. Test with the default pattern first to see if it works
|
||||
@@ -35,4 +35,19 @@ A unique name for the client that makes sense to you.
|
||||
Which model to use. Currently defaults to `gpt-4o`.
|
||||
|
||||
!!! note "Talemate lags behind OpenAI"
|
||||
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
|
||||
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
|
||||
|
||||
##### Reasoning models (o1, o3, gpt-5)
|
||||
|
||||
!!! important "Enable reasoning and allocate tokens"
|
||||
The `o1`, `o3`, and `gpt-5` families are reasoning models. They always perform internal thinking before producing the final answer. To use them effectively in Talemate:
|
||||
|
||||
- Enable the **Reasoning** option in the client configuration.
|
||||
- Set **Reasoning Tokens** to a sufficiently high value to make room for the model's thinking process.
|
||||
|
||||
A good starting range is 512–1024 tokens. Increase if your tasks are complex. Without enabling reasoning and allocating tokens, these models may return minimal or empty visible content because the token budget is consumed by internal reasoning.
|
||||
|
||||
See the detailed guide: [Reasoning Model Support](/talemate/user-guide/clients/reasoning/).
|
||||
|
||||
!!! tip "Getting empty responses?"
|
||||
If these models return empty or very short answers, it usually means the reasoning budget was exhausted. Increase **Reasoning Tokens** and try again.
|
||||
@@ -4,4 +4,4 @@
|
||||
uv pip uninstall torch torchaudio
|
||||
|
||||
# install torch and torchaudio with CUDA support
|
||||
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
uv pip install torch~=2.7.1 torchaudio~=2.7.1 --index-url https://download.pytorch.org/whl/cu128
|
||||
@@ -1,11 +1,12 @@
|
||||
[project]
|
||||
name = "talemate"
|
||||
version = "0.31.0"
|
||||
version = "0.32.0"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = [{name = "VeguAITools"}]
|
||||
license = {text = "GNU Affero General Public License v3.0"}
|
||||
requires-python = ">=3.10,<3.14"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
dependencies = [
|
||||
"pip",
|
||||
"astroid>=2.8",
|
||||
"jedi>=0.18",
|
||||
"black",
|
||||
@@ -50,11 +51,20 @@ dependencies = [
|
||||
# ChromaDB
|
||||
"chromadb>=1.0.12",
|
||||
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
|
||||
"torch>=2.7.0",
|
||||
"torchaudio>=2.7.0",
|
||||
# locked for instructor embeddings
|
||||
#sentence-transformers==2.2.2
|
||||
"torch>=2.7.1",
|
||||
"torchaudio>=2.7.1",
|
||||
"sentence_transformers>=2.7.0",
|
||||
# TTS
|
||||
"elevenlabs>=2.7.1",
|
||||
# Local TTS
|
||||
# Chatterbox TTS
|
||||
#"chatterbox-tts @ https://github.com/rsxdalv/chatterbox/archive/refs/heads/fast.zip",
|
||||
"chatterbox-tts==0.1.2",
|
||||
# kokoro TTS
|
||||
"kokoro>=0.9.4",
|
||||
"soundfile>=0.13.1",
|
||||
# F5-TTS
|
||||
"f5-tts>=1.1.7",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -105,3 +115,25 @@ force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
|
||||
[tool.uv]
|
||||
override-dependencies = [
|
||||
# chatterbox wants torch 2.6.0, but is confirmed working with 2.7.1
|
||||
"torchaudio>=2.7.1",
|
||||
"torch>=2.7.1",
|
||||
"numpy>=2",
|
||||
"pydantic>=2.11",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"description": "Captain Elmer Farstield and his trusty first officer, Kaira, embark upon a daring mission into uncharted space. Their small but mighty exploration vessel, the Starlight Nomad, is equipped with state-of-the-art technology and crewed by an elite team of scientists, engineers, and pilots. Together they brave the vast cosmos seeking answers to humanity's most pressing questions about life beyond our solar system.",
|
||||
"intro": "You awaken aboard your ship, the Starlight Nomad, surrounded by darkness. A soft hum resonates throughout the vessel indicating its systems are online. Your mind struggles to recall what brought you here - where 'here' actually is. You remember nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there was a bright light followed by this endless void.\n\nGingerly, you make your way through the dimly lit corridors of the ship. It seems smaller than you expected given the magnitude of the mission ahead. However, each room reveals intricate technology designed specifically for long-term space travel and exploration. There appears to be no other living soul besides yourself. An eerie silence fills every corner.",
|
||||
"intro": "Elmer awoke aboard his ship, the Starlight Nomad, surrounded by darkness. A soft hum resonated throughout the vessel indicating its systems were online. His mind struggled to recall what had brought him here - where here actually was. He remembered nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there had been a bright light followed by this endless void.\n\nGingerly, he made his way through the dimly lit corridors of the ship. It seemed smaller than he had expected given the magnitude of the mission ahead. However, each room revealed intricate technology designed specifically for long-term space travel and exploration. There appeared to be no other living soul besides himself. An eerie silence filled every corner.",
|
||||
"name": "Infinity Quest",
|
||||
"history": [],
|
||||
"environment": "scene",
|
||||
@@ -90,11 +90,11 @@
|
||||
"gender": "female",
|
||||
"color": "red",
|
||||
"example_dialogue": [
|
||||
"Kaira: \"Yes Captain, I believe that is the best course of action\" She nods slightly, as if to punctuate her approval of the decision*",
|
||||
"Kaira: \"This device appears to have multiple functions, Captain. Allow me to analyze its capabilities and determine if it could be useful in our exploration efforts.\"",
|
||||
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" Excitement bubbles beneath her calm exterior as she shares the news",
|
||||
"Kaira: \"Captain, I understand why you would want us to pursue this course of action based on our current data, but I cannot shake the feeling that there might be unforeseen consequences if we proceed without further investigation into potential hazards.\"",
|
||||
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world... But then again, perhaps it was fate that led me here, onto this ship bound for destinations unknown...\""
|
||||
"Kaira: \"Yes Captain, I believe that is the best course of action.\" Kaira glanced at the navigation display, then back at him with a slight nod. \"The numbers check out on my end too. If we adjust our trajectory by 2.7 degrees and increase thrust by fifteen percent, we should reach the nebula's outer edge within six hours.\" Her violet fingers moved efficiently across the controls as she pulled up the gravitational readings.",
|
||||
"Kaira: The scanner hummed as it analyzed the alien artifact. Kaira knelt beside the strange object, frowning in concentration at the shifting symbols on its warm metal surface. \"This device appears to have multiple functions, Captain,\" she said, adjusting her scanner's settings. \"Give me a few minutes to run a full analysis and I'll know what we're dealing with. The material composition is fascinating - it's responding to our ship's systems in ways that shouldn't be possible.\"",
|
||||
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" The excitement in her voice was unmistakable as Kaira looked up from her console. \"These readings are incredible - I've never seen anything like this before. There are structures beneath the surface that predate our oldest sites by millions of years.\" She paused, processing the implications. \"If these readings are accurate, we may have found something truly significant.\"",
|
||||
"Kaira: Something felt off about the proposed course of action. Kaira moved from her station to stand beside the Captain, organizing her thoughts carefully. \"Captain, I understand why you would want us to pursue this based on our current data,\" she began respectfully, clasping her hands behind her back. \"But something feels wrong about this. The quantum signatures have subtle variations that remind me of Hegemony cloaking technology. Maybe we should run a few more scans before we commit to a full approach.\"",
|
||||
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world,\" she said softly, not turning from the observation deck viewport as footsteps approached. The stars wheeled slowly past in their eternal dance. \"Sometimes I dream of Altrusia's crystal gardens, the way our twin suns would set over the mountains.\" Kaira finally turned, her expression thoughtful. \"Then again, I suppose I was always meant to end up out here somehow. Perhaps this journey is exactly where I'm supposed to be.\""
|
||||
],
|
||||
"history_events": [],
|
||||
"is_player": false,
|
||||
|
||||
@@ -494,34 +494,6 @@
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942": {
|
||||
"title": "TRUE",
|
||||
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
|
||||
"properties": {
|
||||
"value": true
|
||||
},
|
||||
"x": 460,
|
||||
"y": 670,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/MakeBool",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
|
||||
"title": "Create Character",
|
||||
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
|
||||
"properties": {},
|
||||
"x": 580,
|
||||
"y": 550,
|
||||
"width": 245,
|
||||
"height": 186,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "agents/creator/CreateCharacter",
|
||||
"base_type": "core/Graph"
|
||||
},
|
||||
"970f12d0-330e-41b3-b025-9a53bcf2fc6f": {
|
||||
"title": "SET - created_character",
|
||||
"id": "970f12d0-330e-41b3-b025-9a53bcf2fc6f",
|
||||
@@ -628,6 +600,34 @@
|
||||
"inherited": false,
|
||||
"registry": "data/string/MakeText",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
|
||||
"title": "Create Character",
|
||||
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
|
||||
"properties": {},
|
||||
"x": 580,
|
||||
"y": 550,
|
||||
"width": 245,
|
||||
"height": 206,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "agents/creator/CreateCharacter",
|
||||
"base_type": "core/Graph"
|
||||
},
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942": {
|
||||
"title": "TRUE",
|
||||
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
|
||||
"properties": {
|
||||
"value": true
|
||||
},
|
||||
"x": 400,
|
||||
"y": 720,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/MakeBool",
|
||||
"base_type": "core/Node"
|
||||
}
|
||||
},
|
||||
"edges": {
|
||||
@@ -714,17 +714,6 @@
|
||||
"fd4cd318-121b-47de-84a0-b1ab62c5601b.value": [
|
||||
"41c0c2cd-d39b-4528-a5fe-459094393ba3.list"
|
||||
],
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
|
||||
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
|
||||
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
|
||||
],
|
||||
"a2457116-35cb-4571-ba1a-cbf63851544e.value": [
|
||||
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.state",
|
||||
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.character",
|
||||
@@ -738,6 +727,18 @@
|
||||
],
|
||||
"5041f12d-507f-4f5d-a26f-048625974602.value": [
|
||||
"b50e23d4-b456-4385-b80e-c2b6884c7855.template"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
|
||||
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
|
||||
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
|
||||
],
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.assign_voice"
|
||||
]
|
||||
},
|
||||
"groups": [
|
||||
@@ -808,13 +809,14 @@
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "f78fa84b-8b6f-4c8a-83c0-754537bb9060",
|
||||
"id": "92423e0a-89d8-4ad8-a42d-fdcba75b31d1",
|
||||
"name": "fn",
|
||||
"optional": false,
|
||||
"group": null,
|
||||
"socket_type": "function"
|
||||
}
|
||||
],
|
||||
"module_properties": {},
|
||||
"style": {
|
||||
"title_color": "#573a2e",
|
||||
"node_color": "#392f2c",
|
||||
|
||||
@@ -116,8 +116,8 @@
|
||||
"context_aware": true,
|
||||
"history_aware": true
|
||||
},
|
||||
"x": 568,
|
||||
"y": 861,
|
||||
"x": 372,
|
||||
"y": 857,
|
||||
"width": 249,
|
||||
"height": 406,
|
||||
"collapsed": false,
|
||||
@@ -130,7 +130,7 @@
|
||||
"id": "a8110d74-0fb5-4601-b883-6c63ceaa9d31",
|
||||
"properties": {},
|
||||
"x": 24,
|
||||
"y": 2084,
|
||||
"y": 2088,
|
||||
"width": 140,
|
||||
"height": 106,
|
||||
"collapsed": false,
|
||||
@@ -145,7 +145,7 @@
|
||||
"attribute": "title"
|
||||
},
|
||||
"x": 213,
|
||||
"y": 2094,
|
||||
"y": 2098,
|
||||
"width": 210,
|
||||
"height": 98,
|
||||
"collapsed": false,
|
||||
@@ -160,7 +160,7 @@
|
||||
"value": "The Simulation Suite"
|
||||
},
|
||||
"x": 213,
|
||||
"y": 2284,
|
||||
"y": 2288,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": false,
|
||||
@@ -177,7 +177,7 @@
|
||||
"case_sensitive": true
|
||||
},
|
||||
"x": 523,
|
||||
"y": 2184,
|
||||
"y": 2188,
|
||||
"width": 210,
|
||||
"height": 126,
|
||||
"collapsed": false,
|
||||
@@ -192,7 +192,7 @@
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 774,
|
||||
"y": 2185,
|
||||
"y": 2189,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -207,7 +207,7 @@
|
||||
"attribute": "0"
|
||||
},
|
||||
"x": 1629,
|
||||
"y": 2411,
|
||||
"y": 2415,
|
||||
"width": 210,
|
||||
"height": 98,
|
||||
"collapsed": false,
|
||||
@@ -222,7 +222,7 @@
|
||||
"attribute": "title"
|
||||
},
|
||||
"x": 2189,
|
||||
"y": 2321,
|
||||
"y": 2325,
|
||||
"width": 210,
|
||||
"height": 98,
|
||||
"collapsed": false,
|
||||
@@ -237,7 +237,7 @@
|
||||
"stage": 5
|
||||
},
|
||||
"x": 2459,
|
||||
"y": 2331,
|
||||
"y": 2335,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -250,7 +250,7 @@
|
||||
"id": "8208d05c-1822-4f4a-ba75-cfd18d2de8ca",
|
||||
"properties": {},
|
||||
"x": 1989,
|
||||
"y": 2331,
|
||||
"y": 2335,
|
||||
"width": 140,
|
||||
"height": 106,
|
||||
"collapsed": true,
|
||||
@@ -266,7 +266,7 @@
|
||||
"chars": "\"'*"
|
||||
},
|
||||
"x": 1899,
|
||||
"y": 2411,
|
||||
"y": 2415,
|
||||
"width": 210,
|
||||
"height": 102,
|
||||
"collapsed": false,
|
||||
@@ -282,7 +282,7 @@
|
||||
"max_splits": 1
|
||||
},
|
||||
"x": 1359,
|
||||
"y": 2411,
|
||||
"y": 2415,
|
||||
"width": 210,
|
||||
"height": 102,
|
||||
"collapsed": false,
|
||||
@@ -304,7 +304,7 @@
|
||||
"history_aware": true
|
||||
},
|
||||
"x": 1014,
|
||||
"y": 2185,
|
||||
"y": 2189,
|
||||
"width": 276,
|
||||
"height": 406,
|
||||
"collapsed": false,
|
||||
@@ -376,8 +376,8 @@
|
||||
"name": "arg_goal",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 228,
|
||||
"y": 1041,
|
||||
"x": 32,
|
||||
"y": 1037,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -393,7 +393,7 @@
|
||||
"max_scene_types": 2
|
||||
},
|
||||
"x": 671,
|
||||
"y": 1490,
|
||||
"y": 1494,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -409,7 +409,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 30,
|
||||
"y": 1551,
|
||||
"y": 1555,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -448,37 +448,6 @@
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"59d31050-f61d-4798-9790-e22d34ecbd4b": {
|
||||
"title": "GET local.auto_direct_enabled",
|
||||
"id": "59d31050-f61d-4798-9790-e22d34ecbd4b",
|
||||
"properties": {
|
||||
"name": "auto_direct_enabled",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 20,
|
||||
"y": 850,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8": {
|
||||
"title": "Switch",
|
||||
"id": "e8f19a05-43fe-4e4a-9cc4-bec0a29779d8",
|
||||
"properties": {
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 310,
|
||||
"y": 870,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Switch",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"b03fa942-c48e-4c04-b9ae-a009a7e0f947": {
|
||||
"title": "GET local.auto_direct_enabled",
|
||||
"id": "b03fa942-c48e-4c04-b9ae-a009a7e0f947",
|
||||
@@ -487,7 +456,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 24,
|
||||
"y": 1367,
|
||||
"y": 1371,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -502,7 +471,7 @@
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 360,
|
||||
"y": 1390,
|
||||
"y": 1394,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -518,7 +487,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 30,
|
||||
"y": 1821,
|
||||
"y": 1826,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -533,7 +502,7 @@
|
||||
"stage": 4
|
||||
},
|
||||
"x": 1100,
|
||||
"y": 1870,
|
||||
"y": 1875,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -546,7 +515,7 @@
|
||||
"id": "9db37d1e-3cf8-49bd-bdc5-8663494e5657",
|
||||
"properties": {},
|
||||
"x": 670,
|
||||
"y": 1840,
|
||||
"y": 1845,
|
||||
"width": 226,
|
||||
"height": 62,
|
||||
"collapsed": false,
|
||||
@@ -561,7 +530,7 @@
|
||||
"stage": 3
|
||||
},
|
||||
"x": 1080,
|
||||
"y": 1520,
|
||||
"y": 1524,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -576,7 +545,7 @@
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 370,
|
||||
"y": 1840,
|
||||
"y": 1845,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -590,8 +559,8 @@
|
||||
"properties": {
|
||||
"stage": 2
|
||||
},
|
||||
"x": 1280,
|
||||
"y": 890,
|
||||
"x": 1084,
|
||||
"y": 886,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -599,14 +568,29 @@
|
||||
"registry": "core/Stage",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"5559196c-f6b1-4223-8e13-2bf64e3cfef0": {
|
||||
"title": "true",
|
||||
"id": "5559196c-f6b1-4223-8e13-2bf64e3cfef0",
|
||||
"properties": {
|
||||
"value": true
|
||||
},
|
||||
"x": 24,
|
||||
"y": 876,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/MakeBool",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"6ef94917-f9b1-4c18-af15-617430e50cfe": {
|
||||
"title": "Set Scene Intent",
|
||||
"id": "6ef94917-f9b1-4c18-af15-617430e50cfe",
|
||||
"properties": {
|
||||
"intent": ""
|
||||
},
|
||||
"x": 930,
|
||||
"y": 860,
|
||||
"x": 731,
|
||||
"y": 853,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -693,14 +677,8 @@
|
||||
"e4cd1391-daed-4951-a6c6-438d993c07a9.state"
|
||||
],
|
||||
"c66bdaeb-4166-4835-9415-943af547c926.value": [
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description",
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state"
|
||||
],
|
||||
"59d31050-f61d-4798-9790-e22d34ecbd4b.value": [
|
||||
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.value"
|
||||
],
|
||||
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.yes": [
|
||||
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state",
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description"
|
||||
],
|
||||
"b03fa942-c48e-4c04-b9ae-a009a7e0f947.value": [
|
||||
"8ad7c42c-110e-46ae-b649-4a1e6d055e25.value"
|
||||
@@ -717,6 +695,9 @@
|
||||
"6a8762c4-16cf-4e8c-9d10-8af7597c4097.yes": [
|
||||
"9db37d1e-3cf8-49bd-bdc5-8663494e5657.state"
|
||||
],
|
||||
"5559196c-f6b1-4223-8e13-2bf64e3cfef0.value": [
|
||||
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
|
||||
],
|
||||
"6ef94917-f9b1-4c18-af15-617430e50cfe.state": [
|
||||
"f4cd34d9-0628-4145-a3da-ec1215cd356c.state"
|
||||
]
|
||||
@@ -745,7 +726,7 @@
|
||||
{
|
||||
"title": "Generate Scene Types",
|
||||
"x": -1,
|
||||
"y": 1287,
|
||||
"y": 1290,
|
||||
"width": 1298,
|
||||
"height": 408,
|
||||
"color": "#8AA",
|
||||
@@ -756,8 +737,8 @@
|
||||
"title": "Set story intention",
|
||||
"x": -1,
|
||||
"y": 773,
|
||||
"width": 1539,
|
||||
"height": 512,
|
||||
"width": 1320,
|
||||
"height": 514,
|
||||
"color": "#8AA",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
@@ -765,7 +746,7 @@
|
||||
{
|
||||
"title": "Evaluate Scene Intent",
|
||||
"x": -1,
|
||||
"y": 1697,
|
||||
"y": 1701,
|
||||
"width": 1293,
|
||||
"height": 302,
|
||||
"color": "#8AA",
|
||||
@@ -775,7 +756,7 @@
|
||||
{
|
||||
"title": "Set title",
|
||||
"x": -1,
|
||||
"y": 2003,
|
||||
"y": 2006,
|
||||
"width": 2618,
|
||||
"height": 637,
|
||||
"color": "#8AA",
|
||||
@@ -794,7 +775,7 @@
|
||||
{
|
||||
"text": "Some times the AI will produce more text after the title, we only care about the title on the first line.",
|
||||
"x": 1359,
|
||||
"y": 2311,
|
||||
"y": 2315,
|
||||
"width": 471,
|
||||
"inherited": false
|
||||
}
|
||||
@@ -804,13 +785,14 @@
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "dede1a38-2107-4475-9db5-358c09cb0d12",
|
||||
"id": "5c8dee64-5832-40ba-b1e2-2a411d913cc7",
|
||||
"name": "fn",
|
||||
"optional": false,
|
||||
"group": null,
|
||||
"socket_type": "function"
|
||||
}
|
||||
],
|
||||
"module_properties": {},
|
||||
"style": {
|
||||
"title_color": "#573a2e",
|
||||
"node_color": "#392f2c",
|
||||
|
||||
@@ -666,23 +666,6 @@
|
||||
"registry": "data/ListAppend",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66": {
|
||||
"title": "AI Function Calling",
|
||||
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
|
||||
"properties": {
|
||||
"template": "computer",
|
||||
"max_calls": 5,
|
||||
"retries": 1
|
||||
},
|
||||
"x": 1000,
|
||||
"y": 2581,
|
||||
"width": 212,
|
||||
"height": 206,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "focal/Focal",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"3da498c0-55de-4ec3-9943-6486279b9826": {
|
||||
"title": "GET scene loop.user_message",
|
||||
"id": "3da498c0-55de-4ec3-9943-6486279b9826",
|
||||
@@ -1167,6 +1150,24 @@
|
||||
"inherited": false,
|
||||
"registry": "data/MakeList",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66": {
|
||||
"title": "AI Function Calling",
|
||||
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
|
||||
"properties": {
|
||||
"template": "computer",
|
||||
"max_calls": 5,
|
||||
"retries": 1,
|
||||
"response_length": 1408
|
||||
},
|
||||
"x": 1000,
|
||||
"y": 2581,
|
||||
"width": 210,
|
||||
"height": 230,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "focal/Focal",
|
||||
"base_type": "core/Node"
|
||||
}
|
||||
},
|
||||
"edges": {
|
||||
@@ -1296,9 +1297,6 @@
|
||||
"d45f07ff-c3ce-4aac-bae6-b9c77089cb69.list": [
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66.callbacks"
|
||||
],
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
|
||||
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
|
||||
],
|
||||
"3da498c0-55de-4ec3-9943-6486279b9826.value": [
|
||||
"6d707f1c-af55-481a-970a-eb7a9f9c45dd.message"
|
||||
],
|
||||
@@ -1380,6 +1378,9 @@
|
||||
],
|
||||
"632d4a0e-3327-409b-aaa4-ed38b932286b.list": [
|
||||
"06175a92-abbe-4483-9ab2-abaee2104728.value"
|
||||
],
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
|
||||
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
|
||||
]
|
||||
},
|
||||
"groups": [
|
||||
|
||||
@@ -19,7 +19,7 @@ You have access to the following functions you must call to fulfill the user's r
|
||||
focal.callbacks.set_simulated_environment.render(
|
||||
"Create or change the simulated environment. This means the location, time, specific conditions, or any other aspect of the simulation that is not directly related to the characters.",
|
||||
instructions="Instructions on how to change the simulated environment. These will be given to the simulation computer to setup the new environment. REQUIRED.",
|
||||
reset="If true, the environment should be reset and all simulated characters are removed. If false, the environment should be changed but the characters should remain. REQUIRED.",
|
||||
reset="If true, the environment should be reset and ALL simulated characters are removed. IMPORTANT: If you set reset=true, this function MUST be the FIRST call in your stack; otherwise, set reset=false to avoid deactivating characters added earlier. REQUIRED.",
|
||||
examples=[
|
||||
{"instructions": "Change the location to a lush forest, with a river running through it.", "reset":true},
|
||||
{"instructions": "The simulation suite flickers and changes to a bustling city street.", "reset":true},
|
||||
@@ -123,7 +123,7 @@ You have access to the following functions you must call to fulfill the user's r
|
||||
|
||||
{{
|
||||
focal.callbacks.set_simulation_goal.render(
|
||||
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives.",
|
||||
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives. You MUST call this on new simulations or if the user has requested a change in the simulation's goal.",
|
||||
goal="The overall goal of the simulation. This should be a clear and concise statement that outlines the user's objective. REQUIRED.",
|
||||
examples=[
|
||||
{"goal": "The user is exploring a mysterious alien planet to uncover the secrets of an ancient civilization."},
|
||||
|
||||
@@ -18,10 +18,11 @@ from talemate.agents.context import ActiveAgent, active_agent
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopStartEvent
|
||||
from talemate.context import active_scene
|
||||
import talemate.config as config
|
||||
from talemate.ux.schema import Column
|
||||
from talemate.config import get_config, Config
|
||||
import talemate.config.schema as config_schema
|
||||
from talemate.client.context import (
|
||||
ClientContext,
|
||||
set_client_context_attribute,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -53,19 +54,20 @@ class AgentActionConfig(pydantic.BaseModel):
|
||||
type: str
|
||||
label: str
|
||||
description: str = ""
|
||||
value: Union[int, float, str, bool, list, None] = None
|
||||
default_value: Union[int, float, str, bool] = None
|
||||
max: Union[int, float, None] = None
|
||||
min: Union[int, float, None] = None
|
||||
step: Union[int, float, None] = None
|
||||
value: int | float | str | bool | list | None = None
|
||||
default_value: int | float | str | bool | None = None
|
||||
max: int | float | None = None
|
||||
min: int | float | None = None
|
||||
step: int | float | None = None
|
||||
scope: str = "global"
|
||||
choices: Union[list[dict[str, str]], None] = None
|
||||
choices: list[dict[str, str | int | float | bool]] | None = None
|
||||
note: Union[str, None] = None
|
||||
expensive: bool = False
|
||||
quick_toggle: bool = False
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
title: Union[str, None] = None
|
||||
value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True)
|
||||
condition: AgentActionConditional | None = None
|
||||
title: str | None = None
|
||||
value_migration: Callable | None = pydantic.Field(default=None, exclude=True)
|
||||
columns: list[Column] | None = None
|
||||
|
||||
note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict)
|
||||
|
||||
@@ -78,20 +80,21 @@ class AgentAction(pydantic.BaseModel):
|
||||
label: str
|
||||
description: str = ""
|
||||
warning: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
config: dict[str, AgentActionConfig] | None = None
|
||||
condition: AgentActionConditional | None = None
|
||||
container: bool = False
|
||||
icon: Union[str, None] = None
|
||||
icon: str | None = None
|
||||
can_be_disabled: bool = False
|
||||
quick_toggle: bool = False
|
||||
experimental: bool = False
|
||||
|
||||
|
||||
class AgentDetail(pydantic.BaseModel):
|
||||
value: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
icon: Union[str, None] = None
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
icon: str | None = None
|
||||
color: str = "grey"
|
||||
hidden: bool = False
|
||||
|
||||
|
||||
class DynamicInstruction(pydantic.BaseModel):
|
||||
@@ -172,11 +175,6 @@ def set_processing(fn):
|
||||
if scene:
|
||||
scene.continue_actions()
|
||||
|
||||
if getattr(scene, "config", None):
|
||||
set_client_context_attribute(
|
||||
"app_config_system_prompts", scene.config.get("system_prompts", {})
|
||||
)
|
||||
|
||||
with ActiveAgent(self, fn, args, kwargs) as active_agent_context:
|
||||
try:
|
||||
await self.emit_status(processing=True)
|
||||
@@ -221,7 +219,6 @@ class Agent(ABC):
|
||||
verbose_name = None
|
||||
set_processing = set_processing
|
||||
requires_llm_client = True
|
||||
auto_break_repetition = False
|
||||
websocket_handler = None
|
||||
essential = True
|
||||
ready_check_error = None
|
||||
@@ -235,6 +232,10 @@ class Agent(ABC):
|
||||
|
||||
return actions
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return get_config()
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
if hasattr(self, "client"):
|
||||
@@ -244,6 +245,12 @@ class Agent(ABC):
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if not self.requires_llm_client:
|
||||
return True
|
||||
|
||||
if not hasattr(self, "client"):
|
||||
return False
|
||||
|
||||
if not getattr(self.client, "enabled", True):
|
||||
return False
|
||||
|
||||
@@ -326,7 +333,7 @@ class Agent(ABC):
|
||||
|
||||
# scene state
|
||||
|
||||
def context_fingerpint(self, extra: list[str] = []) -> str | None:
|
||||
def context_fingerprint(self, extra: list[str] = None) -> str | None:
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
if not active_agent_context:
|
||||
@@ -337,8 +344,9 @@ class Agent(ABC):
|
||||
else:
|
||||
fingerprint = f"START-{active_agent_context.first.fingerprint}"
|
||||
|
||||
for extra_key in extra:
|
||||
fingerprint += f"-{hash(extra_key)}"
|
||||
if extra:
|
||||
for extra_key in extra:
|
||||
fingerprint += f"-{hash(extra_key)}"
|
||||
|
||||
return fingerprint
|
||||
|
||||
@@ -448,25 +456,26 @@ class Agent(ABC):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
async def save_config(self, app_config: config.Config | None = None):
|
||||
async def save_config(self):
|
||||
"""
|
||||
Saves the agent config to the config file.
|
||||
|
||||
If no config object is provided, the config is loaded from the config file.
|
||||
"""
|
||||
|
||||
if not app_config:
|
||||
app_config: config.Config = config.load_config(as_model=True)
|
||||
app_config: Config = get_config()
|
||||
|
||||
app_config.agents[self.agent_type] = config.Agent(
|
||||
app_config.agents[self.agent_type] = config_schema.Agent(
|
||||
name=self.agent_type,
|
||||
client=self.client.name if self.client else None,
|
||||
client=self.client.name if getattr(self, "client", None) else None,
|
||||
enabled=self.enabled,
|
||||
actions={
|
||||
action_key: config.AgentAction(
|
||||
action_key: config_schema.AgentAction(
|
||||
enabled=action.enabled,
|
||||
config={
|
||||
config_key: config.AgentActionConfig(value=config_obj.value)
|
||||
config_key: config_schema.AgentActionConfig(
|
||||
value=config_obj.value
|
||||
)
|
||||
for config_key, config_obj in action.config.items()
|
||||
},
|
||||
)
|
||||
@@ -478,7 +487,8 @@ class Agent(ABC):
|
||||
agent=self.agent_type,
|
||||
config=app_config.agents[self.agent_type],
|
||||
)
|
||||
config.save_config(app_config)
|
||||
|
||||
app_config.dirty = True
|
||||
|
||||
async def on_game_loop_start(self, event: GameLoopStartEvent):
|
||||
"""
|
||||
@@ -602,7 +612,7 @@ class Agent(ABC):
|
||||
exclude_fn: Callable = None,
|
||||
):
|
||||
current_memory_context = []
|
||||
memory_helper = self.scene.get_helper("memory")
|
||||
memory_helper = instance.get_agent("memory")
|
||||
if memory_helper:
|
||||
history_messages = "\n".join(
|
||||
self.scene.recent_history(memory_history_context_max)
|
||||
|
||||
@@ -23,6 +23,7 @@ from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentActionNote,
|
||||
AgentDetail,
|
||||
AgentEmission,
|
||||
DynamicInstruction,
|
||||
@@ -85,12 +86,22 @@ class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
"format": AgentActionConfig(
|
||||
type="text",
|
||||
label="Format",
|
||||
description="The generation format of the scene context, as seen by the AI.",
|
||||
description="The generation format of the scene progression, as seen by the AI. Has no direct effect on your view of the scene, but will affect the way the AI perceives the scene and its characters, leading to changes in the response, for better or worse.",
|
||||
choices=[
|
||||
{"label": "Screenplay", "value": "movie_script"},
|
||||
{"label": "Chat (legacy)", "value": "chat"},
|
||||
{
|
||||
"label": "Narrative (NEW, experimental)",
|
||||
"value": "narrative",
|
||||
},
|
||||
],
|
||||
value="movie_script",
|
||||
note_on_value={
|
||||
"narrative": AgentActionNote(
|
||||
type="primary",
|
||||
text="Will attempt to generate flowing, novel-like prose with scene intent awareness and character goal consideration. A reasoning model is STRONGLY recommended. Experimental and more prone to generate out of turn character actions and dialogue.",
|
||||
)
|
||||
},
|
||||
),
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
@@ -133,12 +144,6 @@ class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
),
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled=True,
|
||||
can_be_disabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"content": AgentAction(
|
||||
enabled=True,
|
||||
can_be_disabled=False,
|
||||
@@ -161,7 +166,7 @@ class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.TaleMateClient,
|
||||
client: client.ClientBase | None = None,
|
||||
kind: Optional[str] = "pygmalion",
|
||||
logging_enabled: Optional[bool] = True,
|
||||
**kwargs,
|
||||
@@ -453,21 +458,31 @@ class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
if total_result.startswith(":\n") or total_result.startswith(": "):
|
||||
total_result = total_result[2:]
|
||||
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", "")
|
||||
conversation_format = self.conversation_format
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
if conversation_format == "narrative":
|
||||
# For narrative format, the LLM generates pure prose without character name prefixes
|
||||
# We need to store it internally in the standard {name}: {text} format
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
# Only add character name if it's not already there
|
||||
if not total_result.startswith(character.name + ":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
else:
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", "")
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name + ":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name + ":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
|
||||
total_result = total_result.strip()
|
||||
|
||||
@@ -499,9 +514,6 @@ class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
return agent_function_name == "converse"
|
||||
|
||||
def inject_prompt_paramters(
|
||||
|
||||
@@ -39,7 +39,7 @@ class CreatorAgent(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.ClientBase,
|
||||
client: client.ClientBase | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
"num": 3
|
||||
},
|
||||
"x": 39,
|
||||
"y": 507,
|
||||
"y": 236,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -40,7 +40,7 @@
|
||||
"num": 1
|
||||
},
|
||||
"x": 38,
|
||||
"y": 107,
|
||||
"y": -164,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -59,7 +59,7 @@
|
||||
"num": 0
|
||||
},
|
||||
"x": 38,
|
||||
"y": -93,
|
||||
"y": -364,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -76,7 +76,7 @@
|
||||
"num": 0
|
||||
},
|
||||
"x": 288,
|
||||
"y": -63,
|
||||
"y": -334,
|
||||
"width": 210,
|
||||
"height": 106,
|
||||
"collapsed": true,
|
||||
@@ -89,7 +89,7 @@
|
||||
"id": "553125be-2c2b-4404-98b5-d6333a4f9655",
|
||||
"properties": {},
|
||||
"x": 348,
|
||||
"y": 507,
|
||||
"y": 236,
|
||||
"width": 140,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -102,7 +102,7 @@
|
||||
"id": "207e357e-5e83-4d40-a331-d0041b9dfa49",
|
||||
"properties": {},
|
||||
"x": 348,
|
||||
"y": 307,
|
||||
"y": 36,
|
||||
"width": 140,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -115,7 +115,7 @@
|
||||
"id": "bad7d2ed-f6fa-452b-8b47-d753ae7a45e0",
|
||||
"properties": {},
|
||||
"x": 348,
|
||||
"y": 107,
|
||||
"y": -164,
|
||||
"width": 140,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -131,7 +131,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 598,
|
||||
"y": 107,
|
||||
"y": -164,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -147,7 +147,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 598,
|
||||
"y": 307,
|
||||
"y": 36,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -163,7 +163,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 598,
|
||||
"y": 507,
|
||||
"y": 236,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -176,7 +176,7 @@
|
||||
"id": "1765a51c-82ac-4d96-9c60-d0adc7faaa68",
|
||||
"properties": {},
|
||||
"x": 498,
|
||||
"y": 707,
|
||||
"y": 436,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -192,7 +192,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 798,
|
||||
"y": 706,
|
||||
"y": 435,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -205,7 +205,7 @@
|
||||
"id": "f3dc37df-1748-4235-b575-60e55b8bec73",
|
||||
"properties": {},
|
||||
"x": 498,
|
||||
"y": 906,
|
||||
"y": 635,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -220,7 +220,7 @@
|
||||
"stage": 0
|
||||
},
|
||||
"x": 1208,
|
||||
"y": 366,
|
||||
"y": 95,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -238,7 +238,7 @@
|
||||
"icon": "F1719"
|
||||
},
|
||||
"x": 1178,
|
||||
"y": -144,
|
||||
"y": -415,
|
||||
"width": 210,
|
||||
"height": 130,
|
||||
"collapsed": false,
|
||||
@@ -253,7 +253,7 @@
|
||||
"default": true
|
||||
},
|
||||
"x": 298,
|
||||
"y": 736,
|
||||
"y": 465,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -268,7 +268,7 @@
|
||||
"default": false
|
||||
},
|
||||
"x": 298,
|
||||
"y": 936,
|
||||
"y": 665,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -287,7 +287,7 @@
|
||||
"num": 2
|
||||
},
|
||||
"x": 38,
|
||||
"y": 306,
|
||||
"y": 35,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -303,7 +303,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 798,
|
||||
"y": 906,
|
||||
"y": 635,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -322,7 +322,7 @@
|
||||
"num": 5
|
||||
},
|
||||
"x": 38,
|
||||
"y": 706,
|
||||
"y": 435,
|
||||
"width": 237,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -341,7 +341,7 @@
|
||||
"num": 7
|
||||
},
|
||||
"x": 38,
|
||||
"y": 906,
|
||||
"y": 635,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -360,7 +360,7 @@
|
||||
"num": 4
|
||||
},
|
||||
"x": 48,
|
||||
"y": 1326,
|
||||
"y": 1055,
|
||||
"width": 237,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -375,7 +375,7 @@
|
||||
"default": true
|
||||
},
|
||||
"x": 318,
|
||||
"y": 1356,
|
||||
"y": 1085,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -388,7 +388,7 @@
|
||||
"id": "6d387c67-6b32-4435-b984-4760f0f1f8d2",
|
||||
"properties": {},
|
||||
"x": 498,
|
||||
"y": 1336,
|
||||
"y": 1065,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -404,7 +404,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 798,
|
||||
"y": 1316,
|
||||
"y": 1045,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -455,7 +455,7 @@
|
||||
"num": 6
|
||||
},
|
||||
"x": 38,
|
||||
"y": 1106,
|
||||
"y": 835,
|
||||
"width": 252,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -471,7 +471,7 @@
|
||||
"apply_on_unresolved": true
|
||||
},
|
||||
"x": 518,
|
||||
"y": 1136,
|
||||
"y": 865,
|
||||
"width": 210,
|
||||
"height": 102,
|
||||
"collapsed": true,
|
||||
@@ -1101,7 +1101,7 @@
|
||||
"num": 8
|
||||
},
|
||||
"x": 49,
|
||||
"y": 1546,
|
||||
"y": 1275,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -1116,7 +1116,7 @@
|
||||
"default": false
|
||||
},
|
||||
"x": 329,
|
||||
"y": 1586,
|
||||
"y": 1315,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -1129,7 +1129,7 @@
|
||||
"id": "8acfe789-fbb5-4e29-8fd8-2217b987c086",
|
||||
"properties": {},
|
||||
"x": 499,
|
||||
"y": 1566,
|
||||
"y": 1295,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -1145,7 +1145,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 799,
|
||||
"y": 1526,
|
||||
"y": 1255,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1258,8 +1258,8 @@
|
||||
"output_name": "character",
|
||||
"num": 0
|
||||
},
|
||||
"x": 331,
|
||||
"y": 5739,
|
||||
"x": 332,
|
||||
"y": 6178,
|
||||
"width": 210,
|
||||
"height": 106,
|
||||
"collapsed": false,
|
||||
@@ -1274,8 +1274,8 @@
|
||||
"name": "character",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 51,
|
||||
"y": 5729,
|
||||
"x": 52,
|
||||
"y": 6168,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1291,8 +1291,8 @@
|
||||
"output_name": "actor",
|
||||
"num": 0
|
||||
},
|
||||
"x": 331,
|
||||
"y": 5949,
|
||||
"x": 332,
|
||||
"y": 6388,
|
||||
"width": 210,
|
||||
"height": 106,
|
||||
"collapsed": false,
|
||||
@@ -1307,8 +1307,8 @@
|
||||
"name": "actor",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 51,
|
||||
"y": 5949,
|
||||
"x": 52,
|
||||
"y": 6388,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"stage": 0
|
||||
},
|
||||
"x": 1320,
|
||||
"y": 1160,
|
||||
"y": 889,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -1414,7 +1414,7 @@
|
||||
"writing_style": null
|
||||
},
|
||||
"x": 320,
|
||||
"y": 1200,
|
||||
"y": 929,
|
||||
"width": 270,
|
||||
"height": 122,
|
||||
"collapsed": true,
|
||||
@@ -1430,7 +1430,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 790,
|
||||
"y": 1110,
|
||||
"y": 839,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1475,6 +1475,159 @@
|
||||
"inherited": false,
|
||||
"registry": "agents/creator/ContextualGenerate",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1": {
|
||||
"title": "assign_voice",
|
||||
"id": "d61de1ad-6f2a-447f-918a-dce7e76ea3a1",
|
||||
"properties": {},
|
||||
"x": 509,
|
||||
"y": 1544,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Watch",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"3d655827-b66b-4355-910d-96097e7f2f13": {
|
||||
"title": "SET local.assign_voice",
|
||||
"id": "3d655827-b66b-4355-910d-96097e7f2f13",
|
||||
"properties": {
|
||||
"name": "assign_voice",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 810,
|
||||
"y": 1506,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/SetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1": {
|
||||
"title": "Stage 0",
|
||||
"id": "6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1",
|
||||
"properties": {
|
||||
"stage": 0
|
||||
},
|
||||
"x": 1170,
|
||||
"y": 1546,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/Stage",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f": {
|
||||
"title": "IN assign_voice",
|
||||
"id": "9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f",
|
||||
"properties": {
|
||||
"input_type": "bool",
|
||||
"input_name": "assign_voice",
|
||||
"input_optional": true,
|
||||
"input_group": "",
|
||||
"num": 9
|
||||
},
|
||||
"x": 60,
|
||||
"y": 1527,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Input",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"de11206d-13db-44a5-befd-de559fb68d09": {
|
||||
"title": "GET local.assign_voice",
|
||||
"id": "de11206d-13db-44a5-befd-de559fb68d09",
|
||||
"properties": {
|
||||
"name": "assign_voice",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 25,
|
||||
"y": 5720,
|
||||
"width": 240,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"97b196a3-e7a6-4cfa-905e-686a744890b7": {
|
||||
"title": "Switch",
|
||||
"id": "97b196a3-e7a6-4cfa-905e-686a744890b7",
|
||||
"properties": {
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 355,
|
||||
"y": 5740,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Switch",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"3956711a-a3df-4213-b739-104cbe704964": {
|
||||
"title": "GET local.character",
|
||||
"id": "3956711a-a3df-4213-b739-104cbe704964",
|
||||
"properties": {
|
||||
"name": "character",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 25,
|
||||
"y": 5930,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"47b2b492-4178-4254-b468-3877a5341f66": {
|
||||
"title": "Assign Voice",
|
||||
"id": "47b2b492-4178-4254-b468-3877a5341f66",
|
||||
"properties": {},
|
||||
"x": 665,
|
||||
"y": 5830,
|
||||
"width": 161,
|
||||
"height": 66,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "agents/director/AssignVoice",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"44d78795-2d41-468b-94d5-399b9b655888": {
|
||||
"title": "Stage 6",
|
||||
"id": "44d78795-2d41-468b-94d5-399b9b655888",
|
||||
"properties": {
|
||||
"stage": 6
|
||||
},
|
||||
"x": 885,
|
||||
"y": 5860,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/Stage",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85": {
|
||||
"title": "As Bool",
|
||||
"id": "f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85",
|
||||
"properties": {
|
||||
"default": true
|
||||
},
|
||||
"x": 330,
|
||||
"y": 1560,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/AsBool",
|
||||
"base_type": "core/Node"
|
||||
}
|
||||
},
|
||||
"edges": {
|
||||
@@ -1726,15 +1879,39 @@
|
||||
],
|
||||
"4acb67ea-68ee-43ae-a8c6-98a2b0e0f053.text": [
|
||||
"d41f0d98-14d5-49dd-8e57-7812fb9fee94.value"
|
||||
],
|
||||
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value": [
|
||||
"3d655827-b66b-4355-910d-96097e7f2f13.value"
|
||||
],
|
||||
"3d655827-b66b-4355-910d-96097e7f2f13.value": [
|
||||
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1.state"
|
||||
],
|
||||
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f.value": [
|
||||
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value"
|
||||
],
|
||||
"de11206d-13db-44a5-befd-de559fb68d09.value": [
|
||||
"97b196a3-e7a6-4cfa-905e-686a744890b7.value"
|
||||
],
|
||||
"97b196a3-e7a6-4cfa-905e-686a744890b7.yes": [
|
||||
"47b2b492-4178-4254-b468-3877a5341f66.state"
|
||||
],
|
||||
"3956711a-a3df-4213-b739-104cbe704964.value": [
|
||||
"47b2b492-4178-4254-b468-3877a5341f66.character"
|
||||
],
|
||||
"47b2b492-4178-4254-b468-3877a5341f66.state": [
|
||||
"44d78795-2d41-468b-94d5-399b9b655888.state"
|
||||
],
|
||||
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value": [
|
||||
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value"
|
||||
]
|
||||
},
|
||||
"groups": [
|
||||
{
|
||||
"title": "Process Arguments - Stage 0",
|
||||
"x": 1,
|
||||
"y": -218,
|
||||
"width": 1446,
|
||||
"height": 1948,
|
||||
"y": -490,
|
||||
"width": 1432,
|
||||
"height": 2216,
|
||||
"color": "#3f789e",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
@@ -1792,12 +1969,22 @@
|
||||
{
|
||||
"title": "Outputs",
|
||||
"x": 0,
|
||||
"y": 5642,
|
||||
"y": 6080,
|
||||
"width": 595,
|
||||
"height": 472,
|
||||
"color": "#8A8",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
},
|
||||
{
|
||||
"title": "Assign Voice - Stage 6",
|
||||
"x": 0,
|
||||
"y": 5640,
|
||||
"width": 1120,
|
||||
"height": 437,
|
||||
"color": "#3f789e",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
}
|
||||
],
|
||||
"comments": [],
|
||||
@@ -1805,6 +1992,7 @@
|
||||
"base_type": "core/Graph",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"module_properties": {},
|
||||
"style": {
|
||||
"title_color": "#572e44",
|
||||
"node_color": "#392c34",
|
||||
|
||||
@@ -1,33 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import structlog
|
||||
import traceback
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate.emit import emit
|
||||
from talemate.scene_message import DirectorMessage
|
||||
from talemate.util import random_color
|
||||
from talemate.character import deactivate_character
|
||||
from talemate.status import LoadingStatus
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
from talemate.scene_message import DirectorMessage, Flags
|
||||
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig
|
||||
from talemate.agents.registry import register
|
||||
from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
from talemate.client import ClientBase
|
||||
from talemate.game.focal.schema import Call
|
||||
|
||||
from .guide import GuideSceneMixin
|
||||
from .generate_choices import GenerateChoicesMixin
|
||||
from .legacy_scene_instructions import LegacySceneInstructionsMixin
|
||||
from .auto_direct import AutoDirectMixin
|
||||
from .websocket_handler import DirectorWebsocketHandler
|
||||
|
||||
from .character_management import CharacterManagementMixin
|
||||
import talemate.agents.director.nodes # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Character, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agent.director")
|
||||
|
||||
|
||||
@@ -38,6 +29,7 @@ class DirectorAgent(
|
||||
GenerateChoicesMixin,
|
||||
AutoDirectMixin,
|
||||
LegacySceneInstructionsMixin,
|
||||
CharacterManagementMixin,
|
||||
Agent,
|
||||
):
|
||||
agent_type = "director"
|
||||
@@ -76,9 +68,10 @@ class DirectorAgent(
|
||||
GenerateChoicesMixin.add_actions(actions)
|
||||
GuideSceneMixin.add_actions(actions)
|
||||
AutoDirectMixin.add_actions(actions)
|
||||
CharacterManagementMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.is_enabled = True
|
||||
self.client = client
|
||||
self.next_direct_character = {}
|
||||
@@ -101,178 +94,24 @@ class DirectorAgent(
|
||||
def actor_direction_mode(self):
|
||||
return self.actions["direct"].config["actor_direction_mode"].value
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> List[Character]:
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
determine_name: bool = True,
|
||||
templates: list[str] = None,
|
||||
active: bool = True,
|
||||
narrate_entry: bool = False,
|
||||
narrate_entry_direction: str = "",
|
||||
augment_attributes: str = "",
|
||||
generate_attributes: bool = True,
|
||||
description: str = "",
|
||||
) -> Character:
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
narrator = instance.get_agent("narrator")
|
||||
memory = instance.get_agent("memory")
|
||||
scene: "Scene" = self.scene
|
||||
any_attribute_templates = False
|
||||
|
||||
loading_status = LoadingStatus(max_steps=None, cancellable=True)
|
||||
|
||||
# Start of character creation
|
||||
log.debug("persist_character", name=name)
|
||||
|
||||
# Determine the character's name (or clarify if it's already set)
|
||||
if determine_name:
|
||||
loading_status("Determining character name")
|
||||
name = await creator.determine_character_name(name, instructions=content)
|
||||
log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
# Create the blank character
|
||||
character: Character = self.scene.Character(name=name)
|
||||
|
||||
# Add the character to the scene
|
||||
character.color = random_color()
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
async def log_function_call(self, call: Call):
|
||||
log.debug("director.log_function_call", call=call)
|
||||
message = DirectorMessage(
|
||||
message=f"Called {call.name}",
|
||||
action=call.name,
|
||||
flags=Flags.HIDDEN,
|
||||
subtype="function_call",
|
||||
)
|
||||
await self.scene.add_actor(actor)
|
||||
emit("director", message, data={"function_call": call.model_dump()})
|
||||
|
||||
try:
|
||||
# Apply any character generation templates
|
||||
if templates:
|
||||
loading_status("Applying character generation templates")
|
||||
templates = scene.world_state_manager.template_collection.collect_all(
|
||||
templates
|
||||
)
|
||||
log.debug("persist_character", applying_templates=templates)
|
||||
await scene.world_state_manager.apply_templates(
|
||||
templates.values(),
|
||||
character_name=character.name,
|
||||
information=content,
|
||||
)
|
||||
|
||||
# if any of the templates are attribute templates, then we no longer need to
|
||||
# generate a character sheet
|
||||
any_attribute_templates = any(
|
||||
template.template_type == "character_attribute"
|
||||
for template in templates.values()
|
||||
)
|
||||
log.debug(
|
||||
"persist_character", any_attribute_templates=any_attribute_templates
|
||||
)
|
||||
|
||||
if (
|
||||
any_attribute_templates
|
||||
and augment_attributes
|
||||
and generate_attributes
|
||||
):
|
||||
log.debug(
|
||||
"persist_character", augmenting_attributes=augment_attributes
|
||||
)
|
||||
loading_status("Augmenting character attributes")
|
||||
additional_attributes = await world_state.extract_character_sheet(
|
||||
name=name,
|
||||
text=content,
|
||||
augmentation_instructions=augment_attributes,
|
||||
)
|
||||
character.base_attributes.update(additional_attributes)
|
||||
|
||||
# Generate a character sheet if there are no attribute templates
|
||||
if not any_attribute_templates and generate_attributes:
|
||||
loading_status("Generating character sheet")
|
||||
log.debug("persist_character", extracting_character_sheet=True)
|
||||
if not attributes:
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=content
|
||||
)
|
||||
else:
|
||||
attributes = world_state._parse_character_sheet(attributes)
|
||||
|
||||
log.debug("persist_character", attributes=attributes)
|
||||
character.base_attributes = attributes
|
||||
|
||||
# Generate a description for the character
|
||||
if not description:
|
||||
loading_status("Generating character description")
|
||||
description = await creator.determine_character_description(
|
||||
character, information=content
|
||||
)
|
||||
character.description = description
|
||||
log.debug("persist_character", description=description)
|
||||
|
||||
# Generate a dialogue instructions for the character
|
||||
loading_status("Generating acting instructions")
|
||||
dialogue_instructions = (
|
||||
await creator.determine_character_dialogue_instructions(
|
||||
character, information=content
|
||||
)
|
||||
)
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
|
||||
|
||||
# Narrate the character's entry if the option is selected
|
||||
if active and narrate_entry:
|
||||
loading_status("Narrating character entry")
|
||||
is_present = await world_state.is_character_present(name)
|
||||
if not is_present:
|
||||
await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
emit_message=True,
|
||||
character=character,
|
||||
narrative_direction=narrate_entry_direction,
|
||||
)
|
||||
|
||||
# Deactivate the character if not active
|
||||
if not active:
|
||||
await deactivate_character(scene, character)
|
||||
|
||||
# Commit the character's details to long term memory
|
||||
await character.commit_to_memory(memory)
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
loading_status.done(
|
||||
message=f"{character.name} added to scene", status="success"
|
||||
)
|
||||
return character
|
||||
except GenerationCancelled:
|
||||
loading_status.done(message="Character creation cancelled", status="idle")
|
||||
await scene.remove_actor(actor)
|
||||
except Exception:
|
||||
loading_status.done(message="Character creation failed", status="error")
|
||||
await scene.remove_actor(actor)
|
||||
log.error("Error persisting character", error=traceback.format_exc())
|
||||
|
||||
async def log_action(self, action: str, action_description: str):
|
||||
message = DirectorMessage(message=action_description, action=action)
|
||||
async def log_action(
|
||||
self, action: str, action_description: str, console_only: bool = False
|
||||
):
|
||||
message = DirectorMessage(
|
||||
message=action_description,
|
||||
action=action,
|
||||
flags=Flags.HIDDEN if console_only else Flags.NONE,
|
||||
)
|
||||
self.scene.push_history(message)
|
||||
emit("director", message)
|
||||
|
||||
|
||||
333
src/talemate/agents/director/character_management.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import traceback
|
||||
import structlog
|
||||
import talemate.instance as instance
|
||||
import talemate.agents.tts.voice_library as voice_library
|
||||
from talemate.agents.tts.schema import Voice
|
||||
from talemate.util import random_color
|
||||
from talemate.character import deactivate_character, set_voice
|
||||
from talemate.status import LoadingStatus
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
from talemate.agents.base import AgentAction, AgentActionConfig, set_processing
|
||||
import talemate.game.focal as focal
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CharacterManagementMixin",
|
||||
]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Character, Scene
|
||||
from talemate.agents.tts import TTSAgent
|
||||
|
||||
|
||||
class VoiceCandidate(Voice):
|
||||
used: bool = False
|
||||
|
||||
|
||||
class CharacterManagementMixin:
|
||||
"""
|
||||
Director agent mixin that provides functionality for automatically guiding
|
||||
the actors or the narrator during the scene progression.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["character_management"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
can_be_disabled=False,
|
||||
label="Character Management",
|
||||
icon="mdi-account",
|
||||
description="Configure how the director manages characters.",
|
||||
config={
|
||||
"assign_voice": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Assign Voice (TTS)",
|
||||
description="If enabled, the director is allowed to assign a text-to-speech voice when persisting a character.",
|
||||
value=True,
|
||||
title="Persisting Characters",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# config property helpers
|
||||
|
||||
@property
|
||||
def cm_assign_voice(self) -> bool:
|
||||
return self.actions["character_management"].config["assign_voice"].value
|
||||
|
||||
@property
|
||||
def cm_should_assign_voice(self) -> bool:
|
||||
if not self.cm_assign_voice:
|
||||
return False
|
||||
|
||||
tts_agent: "TTSAgent" = instance.get_agent("tts")
|
||||
if not tts_agent.enabled:
|
||||
return False
|
||||
|
||||
if not tts_agent.ready_apis:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# actions
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> list["Character"]:
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
determine_name: bool = True,
|
||||
templates: list[str] = None,
|
||||
active: bool = True,
|
||||
narrate_entry: bool = False,
|
||||
narrate_entry_direction: str = "",
|
||||
augment_attributes: str = "",
|
||||
generate_attributes: bool = True,
|
||||
description: str = "",
|
||||
assign_voice: bool = True,
|
||||
is_player: bool = False,
|
||||
) -> "Character":
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
narrator = instance.get_agent("narrator")
|
||||
memory = instance.get_agent("memory")
|
||||
scene: "Scene" = self.scene
|
||||
any_attribute_templates = False
|
||||
|
||||
loading_status = LoadingStatus(max_steps=None, cancellable=True)
|
||||
|
||||
# Start of character creation
|
||||
log.debug("persist_character", name=name)
|
||||
|
||||
# Determine the character's name (or clarify if it's already set)
|
||||
if determine_name:
|
||||
loading_status("Determining character name")
|
||||
name = await creator.determine_character_name(name, instructions=content)
|
||||
log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
# Create the blank character
|
||||
character: "Character" = self.scene.Character(name=name, is_player=is_player)
|
||||
|
||||
# Add the character to the scene
|
||||
character.color = random_color()
|
||||
|
||||
if is_player:
|
||||
actor = self.scene.Player(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
else:
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
|
||||
await self.scene.add_actor(actor)
|
||||
|
||||
try:
|
||||
# Apply any character generation templates
|
||||
if templates:
|
||||
loading_status("Applying character generation templates")
|
||||
templates = scene.world_state_manager.template_collection.collect_all(
|
||||
templates
|
||||
)
|
||||
log.debug("persist_character", applying_templates=templates)
|
||||
await scene.world_state_manager.apply_templates(
|
||||
templates.values(),
|
||||
character_name=character.name,
|
||||
information=content,
|
||||
)
|
||||
|
||||
# if any of the templates are attribute templates, then we no longer need to
|
||||
# generate a character sheet
|
||||
any_attribute_templates = any(
|
||||
template.template_type == "character_attribute"
|
||||
for template in templates.values()
|
||||
)
|
||||
log.debug(
|
||||
"persist_character", any_attribute_templates=any_attribute_templates
|
||||
)
|
||||
|
||||
if (
|
||||
any_attribute_templates
|
||||
and augment_attributes
|
||||
and generate_attributes
|
||||
):
|
||||
log.debug(
|
||||
"persist_character", augmenting_attributes=augment_attributes
|
||||
)
|
||||
loading_status("Augmenting character attributes")
|
||||
additional_attributes = await world_state.extract_character_sheet(
|
||||
name=name,
|
||||
text=content,
|
||||
augmentation_instructions=augment_attributes,
|
||||
)
|
||||
character.base_attributes.update(additional_attributes)
|
||||
|
||||
# Generate a character sheet if there are no attribute templates
|
||||
if not any_attribute_templates and generate_attributes:
|
||||
loading_status("Generating character sheet")
|
||||
log.debug("persist_character", extracting_character_sheet=True)
|
||||
if not attributes:
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=content
|
||||
)
|
||||
else:
|
||||
attributes = world_state._parse_character_sheet(attributes)
|
||||
|
||||
log.debug("persist_character", attributes=attributes)
|
||||
character.base_attributes = attributes
|
||||
|
||||
# Generate a description for the character
|
||||
if not description:
|
||||
loading_status("Generating character description")
|
||||
description = await creator.determine_character_description(
|
||||
character, information=content
|
||||
)
|
||||
character.description = description
|
||||
log.debug("persist_character", description=description)
|
||||
|
||||
# Generate a dialogue instructions for the character
|
||||
loading_status("Generating acting instructions")
|
||||
dialogue_instructions = (
|
||||
await creator.determine_character_dialogue_instructions(
|
||||
character, information=content
|
||||
)
|
||||
)
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
|
||||
|
||||
# Narrate the character's entry if the option is selected
|
||||
if active and narrate_entry:
|
||||
loading_status("Narrating character entry")
|
||||
is_present = await world_state.is_character_present(name)
|
||||
if not is_present:
|
||||
await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
emit_message=True,
|
||||
character=character,
|
||||
narrative_direction=narrate_entry_direction,
|
||||
)
|
||||
|
||||
if assign_voice:
|
||||
await self.assign_voice_to_character(character)
|
||||
|
||||
# Deactivate the character if not active
|
||||
if not active:
|
||||
await deactivate_character(scene, character)
|
||||
|
||||
# Commit the character's details to long term memory
|
||||
await character.commit_to_memory(memory)
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
loading_status.done(
|
||||
message=f"{character.name} added to scene", status="success"
|
||||
)
|
||||
return character
|
||||
except GenerationCancelled:
|
||||
loading_status.done(message="Character creation cancelled", status="idle")
|
||||
await scene.remove_actor(actor)
|
||||
except Exception:
|
||||
loading_status.done(message="Character creation failed", status="error")
|
||||
await scene.remove_actor(actor)
|
||||
log.error("Error persisting character", error=traceback.format_exc())
|
||||
|
||||
@set_processing
|
||||
async def assign_voice_to_character(
|
||||
self, character: "Character"
|
||||
) -> list[focal.Call]:
|
||||
tts_agent: "TTSAgent" = instance.get_agent("tts")
|
||||
if not self.cm_should_assign_voice:
|
||||
log.debug("assign_voice_to_character", skip=True, reason="not enabled")
|
||||
return
|
||||
|
||||
vl: voice_library.VoiceLibrary = voice_library.get_instance()
|
||||
|
||||
ready_tts_apis = tts_agent.ready_apis
|
||||
|
||||
voices_global = voice_library.voices_for_apis(ready_tts_apis, vl)
|
||||
voices_scene = voice_library.voices_for_apis(
|
||||
ready_tts_apis, self.scene.voice_library
|
||||
)
|
||||
|
||||
voices = voices_global + voices_scene
|
||||
|
||||
if not voices:
|
||||
log.debug(
|
||||
"assign_voice_to_character", skip=True, reason="no voices available"
|
||||
)
|
||||
return
|
||||
|
||||
voice_candidates = {
|
||||
voice.id: VoiceCandidate(**voice.model_dump()) for voice in voices
|
||||
}
|
||||
|
||||
for scene_character in self.scene.all_characters:
|
||||
if scene_character.voice:
|
||||
voice_candidates[scene_character.voice.id].used = True
|
||||
|
||||
async def assign_voice(voice_id: str):
|
||||
voice = vl.get_voice(voice_id) or self.scene.voice_library.get_voice(
|
||||
voice_id
|
||||
)
|
||||
if not voice:
|
||||
log.error(
|
||||
"assign_voice_to_character",
|
||||
skip=True,
|
||||
reason="voice not found",
|
||||
voice_id=voice_id,
|
||||
)
|
||||
return
|
||||
await set_voice(character, voice, auto=True)
|
||||
await self.log_action(
|
||||
f"Assigned voice `{voice.label}` to `{character.name}`",
|
||||
"Assigned voice",
|
||||
console_only=True,
|
||||
)
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name="assign_voice",
|
||||
arguments=[focal.Argument(name="voice_id", type="str")],
|
||||
fn=assign_voice,
|
||||
),
|
||||
],
|
||||
max_calls=1,
|
||||
character=character,
|
||||
voices=list(voice_candidates.values()),
|
||||
scene=self.scene,
|
||||
narrator_voice=tts_agent.narrator_voice,
|
||||
)
|
||||
|
||||
await focal_handler.request("director.cm-assign-voice")
|
||||
|
||||
log.debug("assign_voice_to_character", calls=focal_handler.state.calls)
|
||||
|
||||
return focal_handler.state.calls
|
||||
@@ -231,7 +231,9 @@ class GuideSceneMixin:
|
||||
if cached_guidance:
|
||||
if not analysis:
|
||||
return cached_guidance.get("guidance")
|
||||
elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]):
|
||||
elif cached_guidance.get("fp") == self.context_fingerprint(
|
||||
extra=[analysis]
|
||||
):
|
||||
return cached_guidance.get("guidance")
|
||||
|
||||
return None
|
||||
@@ -250,7 +252,7 @@ class GuideSceneMixin:
|
||||
self.set_scene_states(
|
||||
**{
|
||||
key: {
|
||||
"fp": self.context_fingerpint(extra=[analysis]),
|
||||
"fp": self.context_fingerprint(extra=[analysis]),
|
||||
"guidance": guidance,
|
||||
"analysis_type": analysis_type,
|
||||
"character": character.name if character else None,
|
||||
|
||||
@@ -7,7 +7,7 @@ from talemate.game.engine.nodes.core import (
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
from talemate.character import Character
|
||||
|
||||
TYPE_CHOICES.extend(
|
||||
[
|
||||
@@ -77,3 +77,90 @@ class PersistCharacter(AgentNode):
|
||||
)
|
||||
|
||||
self.set_output_values({"state": state, "character": character})
|
||||
|
||||
|
||||
@register("agents/director/AssignVoice")
|
||||
class AssignVoice(AgentNode):
|
||||
"""
|
||||
Assigns a voice to a character.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Assign Voice", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("character", socket_type="character")
|
||||
self.add_output("voice", socket_type="tts/voice")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character: "Character" = self.require_input("character")
|
||||
|
||||
await self.agent.assign_voice_to_character(character)
|
||||
|
||||
voice = character.voice
|
||||
|
||||
self.set_output_values({"state": state, "character": character, "voice": voice})
|
||||
|
||||
|
||||
@register("agents/director/LogAction")
|
||||
class LogAction(AgentNode):
|
||||
"""
|
||||
Logs an action to the console.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
action = PropertyField(
|
||||
name="action",
|
||||
type="str",
|
||||
description="The action to log",
|
||||
default="",
|
||||
)
|
||||
action_description = PropertyField(
|
||||
name="action_description",
|
||||
type="str",
|
||||
description="The description of the action",
|
||||
default="",
|
||||
)
|
||||
console_only = PropertyField(
|
||||
name="console_only",
|
||||
type="bool",
|
||||
description="Whether to log the action to the console only",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self, title="Log Director Action", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("action", socket_type="str")
|
||||
self.add_input("action_description", socket_type="str")
|
||||
self.add_input("console_only", socket_type="bool", optional=True)
|
||||
|
||||
self.set_property("action", "")
|
||||
self.set_property("action_description", "")
|
||||
self.set_property("console_only", False)
|
||||
|
||||
self.add_output("state")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
state = self.require_input("state")
|
||||
action = self.require_input("action")
|
||||
action_description = self.require_input("action_description")
|
||||
console_only = self.normalized_input_value("console_only") or False
|
||||
|
||||
await self.agent.log_action(
|
||||
action=action,
|
||||
action_description=action_description,
|
||||
console_only=console_only,
|
||||
)
|
||||
|
||||
self.set_output_values({"state": state})
|
||||
|
||||
@@ -5,8 +5,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
from talemate.context import interaction
|
||||
from talemate.context import interaction, handle_generation_cancelled
|
||||
from talemate.status import set_loading
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
@@ -45,6 +46,12 @@ class PersistCharacterPayload(pydantic.BaseModel):
|
||||
content: str = ""
|
||||
description: str = ""
|
||||
|
||||
is_player: bool = False
|
||||
|
||||
|
||||
class AssignVoiceToCharacterPayload(pydantic.BaseModel):
|
||||
character_name: str
|
||||
|
||||
|
||||
class DirectorWebsocketHandler(Plugin):
|
||||
"""
|
||||
@@ -105,7 +112,13 @@ class DirectorWebsocketHandler(Plugin):
|
||||
|
||||
async def handle_task_done(task):
|
||||
if task.exception():
|
||||
log.error("Error persisting character", error=task.exception())
|
||||
exc = task.exception()
|
||||
log.error("Error persisting character", error=exc)
|
||||
|
||||
# Handle GenerationCancelled properly to reset cancel_requested flag
|
||||
if isinstance(exc, GenerationCancelled):
|
||||
handle_generation_cancelled(exc)
|
||||
|
||||
await self.signal_operation_failed("Error persisting character")
|
||||
else:
|
||||
self.websocket_handler.queue_put(
|
||||
@@ -118,3 +131,63 @@ class DirectorWebsocketHandler(Plugin):
|
||||
await self.signal_operation_done()
|
||||
|
||||
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
|
||||
|
||||
async def handle_assign_voice_to_character(self, data: dict):
|
||||
"""
|
||||
Assign a voice to a character using the director agent
|
||||
"""
|
||||
try:
|
||||
payload = AssignVoiceToCharacterPayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
scene: "Scene" = self.scene
|
||||
if not scene:
|
||||
await self.signal_operation_failed("No scene active")
|
||||
return
|
||||
|
||||
character = scene.get_character(payload.character_name)
|
||||
if not character:
|
||||
await self.signal_operation_failed(
|
||||
f"Character '{payload.character_name}' not found"
|
||||
)
|
||||
return
|
||||
|
||||
character.voice = None
|
||||
|
||||
# Add as asyncio task
|
||||
task = asyncio.create_task(self.director.assign_voice_to_character(character))
|
||||
|
||||
async def handle_task_done(task):
|
||||
if task.exception():
|
||||
exc = task.exception()
|
||||
log.error("Error assigning voice to character", error=exc)
|
||||
|
||||
# Handle GenerationCancelled properly to reset cancel_requested flag
|
||||
if isinstance(exc, GenerationCancelled):
|
||||
handle_generation_cancelled(exc)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "assign_voice_to_character_failed",
|
||||
"character_name": payload.character_name,
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
await self.signal_operation_failed(
|
||||
f"Error assigning voice to character: {exc}"
|
||||
)
|
||||
else:
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "assign_voice_to_character_done",
|
||||
"character_name": payload.character_name,
|
||||
}
|
||||
)
|
||||
await self.signal_operation_done()
|
||||
self.scene.emit_status()
|
||||
|
||||
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
|
||||
|
||||
@@ -6,6 +6,7 @@ import structlog
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.client import ClientBase
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
@@ -86,7 +87,7 @@ class EditorAgent(
|
||||
RevisionMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = EditorAgent.init_actions()
|
||||
|
||||
@@ -60,5 +60,8 @@ class EditorWebsocketHandler(Plugin):
|
||||
character=character,
|
||||
)
|
||||
revised = await editor.revision_revise(info)
|
||||
if isinstance(message, CharacterMessage):
|
||||
if not revised.startswith(character.name + ":"):
|
||||
revised = f"{character.name}: {revised}"
|
||||
|
||||
scene.edit_message(message.id, revised)
|
||||
|
||||
@@ -23,10 +23,9 @@ from talemate.agents.base import (
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.config import load_config
|
||||
from talemate.config.schema import EmbeddingFunctionPreset
|
||||
from talemate.context import scene_is_loading, active_scene
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.agents.memory.context import memory_request, MemoryRequest
|
||||
from talemate.agents.memory.exceptions import (
|
||||
@@ -107,14 +106,13 @@ class MemoryAgent(Agent):
|
||||
}
|
||||
return actions
|
||||
|
||||
def __init__(self, scene, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
self.db = None
|
||||
self.scene = scene
|
||||
self.memory_tracker = {}
|
||||
self.config = load_config()
|
||||
self._ready_to_add = False
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
async_signals.get("config.changed").connect(self.on_config_changed)
|
||||
|
||||
async_signals.get("client.embeddings_available").connect(
|
||||
self.on_client_embeddings_available
|
||||
)
|
||||
@@ -136,28 +134,29 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def get_presets(self):
|
||||
def _label(embedding: dict):
|
||||
prefix = (
|
||||
embedding["client"] if embedding["client"] else embedding["embeddings"]
|
||||
)
|
||||
if embedding["model"]:
|
||||
return f"{prefix}: {embedding['model']}"
|
||||
def _label(embedding: EmbeddingFunctionPreset):
|
||||
prefix = embedding.client if embedding.client else embedding.embeddings
|
||||
if embedding.model:
|
||||
return f"{prefix}: {embedding.model}"
|
||||
else:
|
||||
return f"{prefix}"
|
||||
|
||||
return [
|
||||
{"value": k, "label": _label(v)}
|
||||
for k, v in self.config.get("presets", {}).get("embeddings", {}).items()
|
||||
for k, v in self.config.presets.embeddings.items()
|
||||
]
|
||||
|
||||
@property
|
||||
def embeddings_config(self):
|
||||
_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {})
|
||||
return self.config.presets.embeddings.get(_embeddings)
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
return self.embeddings_config.get("embeddings", "sentence-transformer")
|
||||
try:
|
||||
return self.embeddings_config.embeddings
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def using_openai_embeddings(self):
|
||||
@@ -181,22 +180,31 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def embeddings_client(self):
|
||||
return self.embeddings_config.get("client")
|
||||
try:
|
||||
return self.embeddings_config.client
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def max_distance(self) -> float:
|
||||
distance = float(self.embeddings_config.get("distance", 1.0))
|
||||
distance_mod = float(self.embeddings_config.get("distance_mod", 1.0))
|
||||
distance = float(self.embeddings_config.distance)
|
||||
distance_mod = float(self.embeddings_config.distance_mod)
|
||||
|
||||
return distance * distance_mod
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.embeddings_config.get("model")
|
||||
try:
|
||||
return self.embeddings_config.model
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def distance_function(self):
|
||||
return self.embeddings_config.get("distance_function", "l2")
|
||||
try:
|
||||
return self.embeddings_config.distance_function
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
@@ -204,7 +212,10 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def trust_remote_code(self) -> bool:
|
||||
return self.embeddings_config.get("trust_remote_code", False)
|
||||
try:
|
||||
return self.embeddings_config.trust_remote_code
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> str:
|
||||
@@ -241,7 +252,7 @@ class MemoryAgent(Agent):
|
||||
if self.using_sentence_transformer_embeddings and not self.model:
|
||||
self.actions["_config"].config["embeddings"].value = "default"
|
||||
|
||||
if not scene or not scene.get_helper("memory"):
|
||||
if not scene or not scene.active:
|
||||
return
|
||||
|
||||
self.close_db(scene)
|
||||
@@ -255,7 +266,14 @@ class MemoryAgent(Agent):
|
||||
self.actions["_config"].config["embeddings"].choices = self.get_presets
|
||||
return self.actions["_config"].config["embeddings"].choices
|
||||
|
||||
def on_config_saved(self, event):
|
||||
async def fix_broken_embeddings(self):
|
||||
if not self.embeddings_config:
|
||||
self.actions["_config"].config["embeddings"].value = "default"
|
||||
await self.emit_status()
|
||||
await self.handle_embeddings_change()
|
||||
await self.save_config()
|
||||
|
||||
async def on_config_changed(self, event):
|
||||
loop = asyncio.get_running_loop()
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
@@ -263,7 +281,6 @@ class MemoryAgent(Agent):
|
||||
|
||||
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
|
||||
|
||||
self.config = load_config()
|
||||
new_presets = self.sync_presets()
|
||||
if fingerprint != self.fingerprint:
|
||||
log.warning(
|
||||
@@ -285,10 +302,13 @@ class MemoryAgent(Agent):
|
||||
if emit_status:
|
||||
loop.run_until_complete(self.emit_status())
|
||||
|
||||
await self.fix_broken_embeddings()
|
||||
|
||||
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
|
||||
current_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
|
||||
if current_embeddings == event.client.embeddings_identifier:
|
||||
event.seen = True
|
||||
return
|
||||
|
||||
if not self.using_client_api_embeddings or not self.ready:
|
||||
@@ -304,6 +324,7 @@ class MemoryAgent(Agent):
|
||||
await self.emit_status()
|
||||
await self.handle_embeddings_change()
|
||||
await self.save_config()
|
||||
event.seen = True
|
||||
|
||||
@set_processing
|
||||
async def set_db(self):
|
||||
@@ -837,7 +858,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
return self.config.openai.api_key
|
||||
|
||||
@property
|
||||
def embedding_function(self) -> Callable:
|
||||
|
||||
@@ -138,11 +138,6 @@ class NarratorAgent(MemoryRAGMixin, Agent):
|
||||
),
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"content": AgentAction(
|
||||
enabled=True,
|
||||
can_be_disabled=False,
|
||||
@@ -210,7 +205,7 @@ class NarratorAgent(MemoryRAGMixin, Agent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.TaleMateClient,
|
||||
client: client.ClientBase | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
@@ -753,9 +748,6 @@ class NarratorAgent(MemoryRAGMixin, Agent):
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def set_generation_overrides(self, prompt_param: dict):
|
||||
|
||||
@@ -16,7 +16,7 @@ from talemate.scene_message import (
|
||||
ReinforcementMessage,
|
||||
)
|
||||
from talemate.world_state.templates import GenerationOptions
|
||||
|
||||
from talemate.client import ClientBase
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
@@ -34,6 +34,7 @@ from talemate.history import ArchiveEntry
|
||||
from .analyze_scene import SceneAnalyzationMixin
|
||||
from .context_investigation import ContextInvestigationMixin
|
||||
from .layered_history import LayeredHistoryMixin
|
||||
from .tts_utils import TTSUtilsMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
@@ -71,6 +72,7 @@ class SummarizeAgent(
|
||||
ContextInvestigationMixin,
|
||||
# Needs to be after ContextInvestigationMixin so signals are connected in the right order
|
||||
SceneAnalyzationMixin,
|
||||
TTSUtilsMixin,
|
||||
Agent,
|
||||
):
|
||||
"""
|
||||
@@ -129,7 +131,7 @@ class SummarizeAgent(
|
||||
ContextInvestigationMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.client = client
|
||||
|
||||
self.actions = SummarizeAgent.init_actions()
|
||||
|
||||
@@ -16,12 +16,29 @@ from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.agents.narrator import NarratorAgentEmission
|
||||
from talemate.agents.context import active_agent
|
||||
from talemate.agents.base import RagBuildSubInstructionEmission
|
||||
from contextvars import ContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
## CONTEXT
|
||||
|
||||
scene_analysis_disabled_context = ContextVar("scene_analysis_disabled", default=False)
|
||||
|
||||
|
||||
class SceneAnalysisDisabled:
|
||||
"""
|
||||
Context manager to disable scene analysis during specific agent actions.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.token = scene_analysis_disabled_context.set(True)
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
scene_analysis_disabled_context.reset(self.token)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.summarization.scene_analysis.before",
|
||||
@@ -184,6 +201,16 @@ class SceneAnalyzationMixin:
|
||||
if not self.analyze_scene:
|
||||
return
|
||||
|
||||
try:
|
||||
if scene_analysis_disabled_context.get():
|
||||
log.debug(
|
||||
"on_inject_instructions: scene analysis disabled through context",
|
||||
emission=emission,
|
||||
)
|
||||
return
|
||||
except LookupError:
|
||||
pass
|
||||
|
||||
analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}")
|
||||
|
||||
if not analyze_scene_for_type:
|
||||
@@ -259,7 +286,14 @@ class SceneAnalyzationMixin:
|
||||
if not cached_analysis:
|
||||
return None
|
||||
|
||||
fingerprint = self.context_fingerpint()
|
||||
fingerprint = self.context_fingerprint()
|
||||
|
||||
log.debug(
|
||||
"get_cached_analysis",
|
||||
fingerprint=fingerprint,
|
||||
cached_analysis_fp=cached_analysis.get("fp"),
|
||||
match=cached_analysis.get("fp") == fingerprint,
|
||||
)
|
||||
|
||||
if cached_analysis.get("fp") == fingerprint:
|
||||
return cached_analysis["guidance"]
|
||||
@@ -271,7 +305,7 @@ class SceneAnalyzationMixin:
|
||||
Sets the cached analysis for the given type.
|
||||
"""
|
||||
|
||||
fingerprint = self.context_fingerpint()
|
||||
fingerprint = self.context_fingerprint()
|
||||
|
||||
self.set_scene_states(
|
||||
**{
|
||||
|
||||
59
src/talemate/agents/summarize/tts_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import structlog
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
)
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.status import set_loading
|
||||
from talemate.util.dialogue import separate_dialogue_from_exposition
|
||||
|
||||
log = structlog.get_logger("talemate.agents.summarize.tts_utils")
|
||||
|
||||
|
||||
class TTSUtilsMixin:
|
||||
"""
|
||||
Summarizer Mixin for text-to-speech utilities.
|
||||
"""
|
||||
|
||||
@set_loading("Preparing TTS context")
|
||||
@set_processing
|
||||
async def markup_context_for_tts(self, text: str) -> str:
|
||||
"""
|
||||
Markup the context for text-to-speech.
|
||||
"""
|
||||
|
||||
original_text = text
|
||||
|
||||
log.debug("Markup context for TTS", text=text)
|
||||
|
||||
# if there are no quotes in the text, there is nothing to separate
|
||||
if '"' not in text:
|
||||
return original_text
|
||||
|
||||
# here we separate dialogue from exposition because into
|
||||
# obvious segments. It seems to have a positive effect on some
|
||||
# LLMs returning the complete text.
|
||||
separate_chunks = separate_dialogue_from_exposition(text)
|
||||
|
||||
numbered_chunks = []
|
||||
for i, chunk in enumerate(separate_chunks):
|
||||
numbered_chunks.append(f"[{i + 1}] {chunk.text.strip()}")
|
||||
|
||||
text = "\n".join(numbered_chunks)
|
||||
|
||||
response = await Prompt.request(
|
||||
"summarizer.markup-context-for-tts",
|
||||
self.client,
|
||||
"investigate_1024",
|
||||
vars={
|
||||
"text": text,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.split("<MARKUP>")[1].split("</MARKUP>")[0].strip()
|
||||
return response
|
||||
except IndexError:
|
||||
log.error("Failed to extract markup from response", response=response)
|
||||
return original_text
|
||||
@@ -1,670 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
import nltk
|
||||
import pydantic
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import talemate.config as config
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts") #
|
||||
|
||||
if not TTS:
|
||||
# TTS installation is massive and requires a lot of dependencies
|
||||
# so we don't want to require it unless the user wants to use it
|
||||
log.info(
|
||||
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
|
||||
)
|
||||
|
||||
|
||||
def parse_chunks(text: str) -> list[str]:
|
||||
"""
|
||||
Takes a string and splits it into chunks based on punctuation.
|
||||
|
||||
In case of an error it will return the original text as a single chunk and
|
||||
the error will be logged.
|
||||
"""
|
||||
|
||||
try:
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*", "")
|
||||
if not chunk:
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
return cleaned_chunks
|
||||
except Exception as e:
|
||||
log.error("chunking error", error=e, text=text)
|
||||
return [text.replace("__ellipsis__", "...").replace("*", "")]
|
||||
|
||||
|
||||
def clean_quotes(chunk: str):
|
||||
# if there is an uneven number of quotes, remove the last one if its
|
||||
# at the end of the chunk. If its in the middle, add a quote to the end
|
||||
if chunk.count('"') % 2 == 1:
|
||||
if chunk.endswith('"'):
|
||||
chunk = chunk[:-1]
|
||||
else:
|
||||
chunk += '"'
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
for chunk in chunks:
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
return joined_chunks
|
||||
|
||||
|
||||
class Voice(pydantic.BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
|
||||
|
||||
class VoiceLibrary(pydantic.BaseModel):
|
||||
api: str
|
||||
voices: list[Voice] = pydantic.Field(default_factory=list)
|
||||
last_synced: float = None
|
||||
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
essential = False
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
if agent:
|
||||
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
|
||||
voice.model_dump() for voice in agent.list_voices_sync()
|
||||
]
|
||||
|
||||
return config_options
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.is_enabled = False #
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download("punkt", quiet=True)
|
||||
except Exception as e:
|
||||
log.error("nltk download error", error=e)
|
||||
except Exception as e:
|
||||
log.error("nltk find error", error=e)
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"tts": VoiceLibrary(api="tts"),
|
||||
"openai": VoiceLibrary(api="openai"),
|
||||
}
|
||||
self.config = config.load_config()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.preselect_voice = None
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"api": AgentActionConfig(
|
||||
type="text",
|
||||
choices=[
|
||||
{"value": "tts", "label": "TTS (Local)"},
|
||||
{"value": "elevenlabs", "label": "Eleven Labs"},
|
||||
{"value": "openai", "label": "OpenAI"},
|
||||
],
|
||||
value="tts",
|
||||
label="API",
|
||||
description="Which TTS API to use",
|
||||
onchange="emit",
|
||||
),
|
||||
"voice_id": AgentActionConfig(
|
||||
type="text",
|
||||
value="default",
|
||||
label="Narrator Voice",
|
||||
description="Voice ID/Name to use for TTS",
|
||||
choices=[],
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Generate for player",
|
||||
description="Generate audio for player messages",
|
||||
),
|
||||
"generate_for_npc": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Generate for NPCs",
|
||||
description="Generate audio for NPC messages",
|
||||
),
|
||||
"generate_for_narration": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Generate for narration",
|
||||
description="Generate audio for narration messages",
|
||||
),
|
||||
"generate_chunks": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Split generation",
|
||||
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
|
||||
),
|
||||
},
|
||||
),
|
||||
"openai": AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.api", value="openai"
|
||||
),
|
||||
label="OpenAI",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="tts-1",
|
||||
choices=[
|
||||
{"value": "tts-1", "label": "TTS 1"},
|
||||
{"value": "tts-1-hd", "label": "TTS 1 HD"},
|
||||
],
|
||||
label="Model",
|
||||
description="TTS model to use",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def not_ready_reason(self) -> str:
|
||||
"""
|
||||
Returns a string explaining why the agent is not ready
|
||||
"""
|
||||
|
||||
if self.ready:
|
||||
return ""
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "TTS not installed"
|
||||
|
||||
elif self.requires_token and not self.token:
|
||||
return "No API token"
|
||||
|
||||
elif not self.default_voice_id:
|
||||
return "No voice selected"
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
details = {
|
||||
"api": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.api_label,
|
||||
description="The backend to use for TTS",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if self.ready and self.enabled:
|
||||
details["voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=self.voice_id_to_label(self.default_voice_id) or "",
|
||||
description="The voice to use for TTS",
|
||||
color="info",
|
||||
).model_dump()
|
||||
elif self.enabled:
|
||||
details["error"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=self.not_ready_reason,
|
||||
description=self.not_ready_reason,
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
@property
|
||||
def api_label(self):
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api = self.api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
return choice["label"]
|
||||
return api
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
return self.config.get(api, {}).get("api_key")
|
||||
|
||||
@property
|
||||
def default_voice_id(self):
|
||||
return self.actions["_config"].config["voice_id"].value
|
||||
|
||||
@property
|
||||
def requires_token(self):
|
||||
return self.api != "tts"
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return False
|
||||
return True
|
||||
|
||||
return (not self.requires_token or self.token) and self.default_voice_id
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
if self.requires_token and not self.token:
|
||||
return "error"
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "error"
|
||||
return "uninitialized"
|
||||
|
||||
@property
|
||||
def max_generation_length(self):
|
||||
if self.api == "elevenlabs":
|
||||
return 1024
|
||||
elif self.api == "coqui":
|
||||
return 250
|
||||
|
||||
return 250
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
api = self.api
|
||||
|
||||
api_changed = api != self.api
|
||||
|
||||
# log.debug(
|
||||
# "apply_config",
|
||||
# api=api,
|
||||
# api_changed=api != self.api,
|
||||
# current_api=self.api,
|
||||
# args=args,
|
||||
# kwargs=kwargs,
|
||||
# )
|
||||
|
||||
try:
|
||||
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
|
||||
"value"
|
||||
]
|
||||
except KeyError:
|
||||
self.preselect_voice = self.default_voice_id
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
self.actions["_config"].config["voice_id"].value = (
|
||||
self.voices[api].voices[0].value
|
||||
)
|
||||
except IndexError:
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(
|
||||
self.on_game_loop_new_message
|
||||
)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
instance.emit_agent_status(self.__class__, self)
|
||||
|
||||
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
if not self.enabled or not self.ready:
|
||||
return
|
||||
|
||||
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(emission.message, NarratorMessage)
|
||||
and not self.actions["_config"].config["generate_for_narration"].value
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
if (
|
||||
emission.message.source == "player"
|
||||
and not self.actions["_config"].config["generate_for_player"].value
|
||||
):
|
||||
return
|
||||
elif (
|
||||
emission.message.source == "ai"
|
||||
and not self.actions["_config"].config["generate_for_npc"].value
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
else:
|
||||
character_prefix = ""
|
||||
|
||||
log.info(
|
||||
"reactive tts", message=emission.message, character_prefix=character_prefix
|
||||
)
|
||||
|
||||
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
|
||||
|
||||
def voice(self, voice_id: str) -> Union[Voice, None]:
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice
|
||||
return None
|
||||
|
||||
def voice_id_to_label(self, voice_id: str):
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice.label
|
||||
return None
|
||||
|
||||
def list_voices_sync(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(self.list_voices())
|
||||
|
||||
async def list_voices(self):
|
||||
if self.requires_token and not self.token:
|
||||
return []
|
||||
|
||||
library = self.voices[self.api]
|
||||
|
||||
# TODO: allow re-syncing voices
|
||||
if library.last_synced:
|
||||
return library.voices
|
||||
|
||||
list_fn = getattr(self, f"_list_voices_{self.api}")
|
||||
log.info("Listing voices", api=self.api)
|
||||
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
if self.preselect_voice:
|
||||
if self.voice(self.preselect_voice):
|
||||
self.actions["_config"].config["voice_id"].value = self.preselect_voice
|
||||
self.preselect_voice = None
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
# set loading to false
|
||||
return library.voices
|
||||
|
||||
@set_processing
|
||||
async def generate(self, text: str):
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
generate_fn = getattr(self, f"_generate_{self.api}")
|
||||
|
||||
if self.actions["_config"].config["generate_chunks"].value:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks)
|
||||
else:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks, chunk_size=self.max_generation_length)
|
||||
|
||||
# Start generating audio chunks in the background
|
||||
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
|
||||
await self.set_background_processing(generation_task)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
# await asyncio.gather(generation_task)
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*", "").strip()
|
||||
log.info("Generating audio", api=self.api, chunk=chunk)
|
||||
audio_data = await generate_fn(chunk)
|
||||
self.play_audio(audio_data)
|
||||
|
||||
def play_audio(self, audio_data):
|
||||
# play audio through the python audio player
|
||||
# play(audio_data)
|
||||
|
||||
emit(
|
||||
"audio_queue",
|
||||
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
|
||||
)
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
# LOCAL
|
||||
|
||||
async def _generate_tts(self, text: str) -> Union[bytes, None]:
|
||||
if not TTS:
|
||||
return
|
||||
|
||||
tts_config = self.config.get("tts", {})
|
||||
model = tts_config.get("model")
|
||||
device = tts_config.get("device", "cpu")
|
||||
|
||||
log.debug("tts local", model=model, device=device)
|
||||
|
||||
if not hasattr(self, "tts_instance"):
|
||||
self.tts_instance = TTS(model).to(device)
|
||||
|
||||
tts = self.tts_instance
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
voice = self.voice(self.default_voice_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.tts_to_file,
|
||||
text=text,
|
||||
speaker_wav=voice.value,
|
||||
language="en",
|
||||
file_path=file_path,
|
||||
),
|
||||
)
|
||||
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
async def _list_voices_tts(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
|
||||
]
|
||||
|
||||
# ELEVENLABS
|
||||
|
||||
async def _generate_elevenlabs(
|
||||
self, text: str, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.default_voice_id}"
|
||||
headers = {
|
||||
"Accept": "audio/mpeg",
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": api_key,
|
||||
}
|
||||
data = {
|
||||
"text": text,
|
||||
"model_id": self.config.get("elevenlabs", {}).get("model"),
|
||||
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
||||
}
|
||||
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
|
||||
if response.status_code == 200:
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _list_voices_elevenlabs(self) -> dict[str, str]:
|
||||
url_voices = "https://api.elevenlabs.io/v1/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"xi-api-key": self.token,
|
||||
}
|
||||
response = await client.get(
|
||||
url_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["voices"]
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["voice_id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
|
||||
# OPENAI
|
||||
|
||||
async def _generate_openai(self, text: str, chunk_size: int = 1024):
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
model = self.actions["openai"].config["model"].value
|
||||
|
||||
response = await client.audio.speech.create(
|
||||
model=model, voice=self.default_voice_id, input=text
|
||||
)
|
||||
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
|
||||
async def _list_voices_openai(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(value="alloy", label="Alloy"),
|
||||
Voice(value="echo", label="Echo"),
|
||||
Voice(value="fable", label="Fable"),
|
||||
Voice(value="onyx", label="Onyx"),
|
||||
Voice(value="nova", label="Nova"),
|
||||
Voice(value="shimmer", label="Shimmer"),
|
||||
]
|
||||
995
src/talemate/agents/tts/__init__.py
Normal file
@@ -0,0 +1,995 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import uuid
|
||||
from collections import deque
|
||||
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
import talemate.util.dialogue as dialogue_utils
|
||||
import talemate.emit.async_signals as async_signals
|
||||
import talemate.instance as instance
|
||||
from talemate.ux.schema import Note
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import (
|
||||
CharacterMessage,
|
||||
NarratorMessage,
|
||||
ContextInvestigationMessage,
|
||||
)
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
AgentActionNote,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
|
||||
from .schema import (
|
||||
APIStatus,
|
||||
Voice,
|
||||
VoiceLibrary,
|
||||
GenerationContext,
|
||||
Chunk,
|
||||
VoiceGenerationEmission,
|
||||
)
|
||||
from .providers import provider
|
||||
|
||||
import talemate.agents.tts.voice_library as voice_library
|
||||
|
||||
from .elevenlabs import ElevenLabsMixin
|
||||
from .openai import OpenAIMixin
|
||||
from .google import GoogleMixin
|
||||
from .kokoro import KokoroMixin
|
||||
from .chatterbox import ChatterboxMixin
|
||||
from .websocket_handler import TTSWebsocketHandler
|
||||
from .f5tts import F5TTSMixin
|
||||
|
||||
import talemate.agents.tts.nodes as tts_nodes # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.character import Character, VoiceChangedEvent
|
||||
from talemate.agents.summarize import SummarizeAgent
|
||||
from talemate.game.engine.nodes.scene import SceneLoopEvent
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts")
|
||||
|
||||
HOT_SWAP_NOTIFICATION_TIME = 60
|
||||
|
||||
VOICE_LIBRARY_NOTE = "Voices are not managed here, but in the voice library which can be accessed through the Talemate application bar at the top. When disabling/enabling APIS, close and open this window to refresh the choices."
|
||||
|
||||
async_signals.register(
|
||||
"agent.tts.prepare.before",
|
||||
"agent.tts.prepare.after",
|
||||
"agent.tts.generate.before",
|
||||
"agent.tts.generate.after",
|
||||
)
|
||||
|
||||
|
||||
def parse_chunks(text: str) -> list[str]:
|
||||
"""
|
||||
Takes a string and splits it into chunks based on punctuation.
|
||||
|
||||
In case of an error it will return the original text as a single chunk and
|
||||
the error will be logged.
|
||||
"""
|
||||
|
||||
try:
|
||||
text = text.replace("*", "")
|
||||
|
||||
# ensure sentence terminators are before quotes
|
||||
# otherwise the beginning of dialog will bleed into narration
|
||||
text = re.sub(r'([^.?!]+) "', r'\1. "', text)
|
||||
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.strip():
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
return cleaned_chunks
|
||||
except Exception as e:
|
||||
log.error("chunking error", error=e, text=text)
|
||||
return [text.replace("__ellipsis__", "...").replace("*", "")]
|
||||
|
||||
|
||||
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
for chunk in chunks:
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(current_chunk)
|
||||
return joined_chunks
|
||||
|
||||
|
||||
@register()
|
||||
class TTSAgent(
|
||||
ElevenLabsMixin,
|
||||
OpenAIMixin,
|
||||
GoogleMixin,
|
||||
KokoroMixin,
|
||||
ChatterboxMixin,
|
||||
F5TTSMixin,
|
||||
Agent,
|
||||
):
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
essential = False
|
||||
|
||||
# websocket handler for frontend voice library management
|
||||
websocket_handler = TTSWebsocketHandler
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
if not agent:
|
||||
return config_options
|
||||
|
||||
narrator_voice_id = config_options["actions"]["_config"]["config"][
|
||||
"narrator_voice_id"
|
||||
]
|
||||
|
||||
narrator_voice_id["choices"] = cls.narrator_voice_id_choices(agent)
|
||||
|
||||
return config_options
|
||||
|
||||
@classmethod
|
||||
def narrator_voice_id_choices(cls, agent: "TTSAgent") -> list[dict[str, str]]:
|
||||
choices = voice_library.voices_for_apis(agent.ready_apis, agent.voice_library)
|
||||
choices.sort(key=lambda x: x.label)
|
||||
return [
|
||||
{
|
||||
"label": f"{voice.label} ({voice.provider})",
|
||||
"value": voice.id,
|
||||
}
|
||||
for voice in choices
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls) -> dict[str, AgentAction]:
|
||||
actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"apis": AgentActionConfig(
|
||||
type="flags",
|
||||
value=[
|
||||
"kokoro",
|
||||
],
|
||||
label="Enabled APIs",
|
||||
description="APIs to use for TTS",
|
||||
choices=[],
|
||||
),
|
||||
"narrator_voice_id": AgentActionConfig(
|
||||
type="autocomplete",
|
||||
value="kokoro:am_adam",
|
||||
label="Narrator Voice",
|
||||
description="Voice to use for narration",
|
||||
choices=[],
|
||||
note=VOICE_LIBRARY_NOTE,
|
||||
),
|
||||
"speaker_separation": AgentActionConfig(
|
||||
type="text",
|
||||
value="simple",
|
||||
label="Speaker separation",
|
||||
description="How to separate speaker dialogue from exposition",
|
||||
choices=[
|
||||
{"label": "No separation", "value": "none"},
|
||||
{"label": "Simple", "value": "simple"},
|
||||
{"label": "Mixed", "value": "mixed"},
|
||||
{"label": "AI assisted", "value": "ai_assisted"},
|
||||
],
|
||||
note_on_value={
|
||||
"none": AgentActionNote(
|
||||
type="primary",
|
||||
text="Character messages will be voiced entirely by the character's voice with a fallback to the narrator voice if the character has no voice selecte. Narrator messages will be voiced exclusively by the narrator voice.",
|
||||
),
|
||||
"simple": AgentActionNote(
|
||||
type="primary",
|
||||
text="Exposition and dialogue will be separated in character messages. Narrator messages will be voiced exclusively by the narrator voice. This means",
|
||||
),
|
||||
"mixed": AgentActionNote(
|
||||
type="primary",
|
||||
text="A mix of `simple` and `ai_assisted`. Character messages are separated into narrator and the character's voice. Narrator messages that have dialogue are analyzed by the Summarizer agent to determine the appropriate speaker(s).",
|
||||
),
|
||||
"ai_assisted": AgentActionNote(
|
||||
type="primary",
|
||||
text="Appropriate speaker separation will be attempted based on the content of the message with help from the Summarizer agent. This sends an extra prompt to the LLM to determine the appropriate speaker(s).",
|
||||
),
|
||||
},
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Auto-generate for player",
|
||||
description="Generate audio for player messages",
|
||||
),
|
||||
"generate_for_npc": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Auto-generate for AI characters",
|
||||
description="Generate audio for NPC messages",
|
||||
),
|
||||
"generate_for_narration": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Auto-generate for narration",
|
||||
description="Generate audio for narration messages",
|
||||
),
|
||||
"generate_for_context_investigation": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Auto-generate for context investigation",
|
||||
description="Generate audio for context investigation messages",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
KokoroMixin.add_actions(actions)
|
||||
ChatterboxMixin.add_actions(actions)
|
||||
GoogleMixin.add_actions(actions)
|
||||
ElevenLabsMixin.add_actions(actions)
|
||||
OpenAIMixin.add_actions(actions)
|
||||
F5TTSMixin.add_actions(actions)
|
||||
|
||||
return actions
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.is_enabled = False # tts agent is disabled by default
|
||||
self.actions = TTSAgent.init_actions()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
|
||||
# Queue management for voice generation
|
||||
# Each queue instance gets a unique id so it can later be referenced
|
||||
# (e.g. for cancellation of all remaining items).
|
||||
# Only one queue can be active at a time. New generation requests that
|
||||
# arrive while a queue is processing will be appended to the same
|
||||
# queue. Once the queue is fully processed it is discarded and a new
|
||||
# one will be created for subsequent generation requests.
|
||||
# Queue now holds individual (context, chunk) pairs so interruption can
|
||||
# happen between chunks even when a single context produced many.
|
||||
self._generation_queue: deque[tuple[GenerationContext, Chunk]] = deque()
|
||||
self._queue_id: str | None = None
|
||||
self._queue_task: asyncio.Task | None = None
|
||||
self._queue_lock = asyncio.Lock()
|
||||
|
||||
# general helpers
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def voice_library(self) -> VoiceLibrary:
|
||||
return voice_library.get_instance()
|
||||
|
||||
# config helpers
|
||||
|
||||
@property
|
||||
def narrator_voice_id(self) -> str:
|
||||
return self.actions["_config"].config["narrator_voice_id"].value
|
||||
|
||||
@property
|
||||
def generate_for_player(self) -> bool:
|
||||
return self.actions["_config"].config["generate_for_player"].value
|
||||
|
||||
@property
|
||||
def generate_for_npc(self) -> bool:
|
||||
return self.actions["_config"].config["generate_for_npc"].value
|
||||
|
||||
@property
|
||||
def generate_for_narration(self) -> bool:
|
||||
return self.actions["_config"].config["generate_for_narration"].value
|
||||
|
||||
@property
|
||||
def generate_for_context_investigation(self) -> bool:
|
||||
return (
|
||||
self.actions["_config"].config["generate_for_context_investigation"].value
|
||||
)
|
||||
|
||||
@property
|
||||
def speaker_separation(self) -> str:
|
||||
return self.actions["_config"].config["speaker_separation"].value
|
||||
|
||||
@property
|
||||
def apis(self) -> list[str]:
|
||||
return self.actions["_config"].config["apis"].value
|
||||
|
||||
@property
|
||||
def all_apis(self) -> list[str]:
|
||||
return [api["value"] for api in self.actions["_config"].config["apis"].choices]
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
details = {}
|
||||
|
||||
self.actions["_config"].config[
|
||||
"narrator_voice_id"
|
||||
].choices = self.narrator_voice_id_choices(self)
|
||||
|
||||
if not self.enabled:
|
||||
return details
|
||||
|
||||
used_apis: set[str] = set()
|
||||
|
||||
used_disabled_apis: set[str] = set()
|
||||
|
||||
if self.narrator_voice:
|
||||
#
|
||||
|
||||
label = self.narrator_voice.label
|
||||
color = "primary"
|
||||
used_apis.add(self.narrator_voice.provider)
|
||||
|
||||
if not self.api_enabled(self.narrator_voice.provider):
|
||||
used_disabled_apis.add(self.narrator_voice.provider)
|
||||
|
||||
if not self.api_ready(self.narrator_voice.provider):
|
||||
color = "error"
|
||||
|
||||
details["narrator_voice"] = AgentDetail(
|
||||
icon="mdi-script-text",
|
||||
value=label,
|
||||
description="Default voice",
|
||||
color=color,
|
||||
).model_dump()
|
||||
|
||||
scene = getattr(self, "scene", None)
|
||||
if scene:
|
||||
for character in scene.characters:
|
||||
if character.voice:
|
||||
label = character.voice.label
|
||||
color = "primary"
|
||||
used_apis.add(character.voice.provider)
|
||||
if not self.api_enabled(character.voice.provider):
|
||||
used_disabled_apis.add(character.voice.provider)
|
||||
if not self.api_ready(character.voice.provider):
|
||||
color = "error"
|
||||
|
||||
details[f"{character.name}_voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=f"{character.name}",
|
||||
description=f"{character.name}'s voice: {label} ({character.voice.provider})",
|
||||
color=color,
|
||||
).model_dump()
|
||||
|
||||
for api in used_disabled_apis:
|
||||
details[f"{api}_disabled"] = AgentDetail(
|
||||
icon="mdi-alert-circle",
|
||||
value=f"{api} disabled",
|
||||
description=f"{api} disabled - at least one voice is attempting to use this api but is not enabled",
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
for api in used_apis:
|
||||
fn = getattr(self, f"{api}_agent_details", None)
|
||||
if fn:
|
||||
details.update(fn)
|
||||
return details
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "idle" if not getattr(self, "processing", False) else "busy"
|
||||
return "uninitialized"
|
||||
|
||||
@property
|
||||
def narrator_voice(self) -> Voice | None:
|
||||
return self.voice_library.get_voice(self.narrator_voice_id)
|
||||
|
||||
@property
|
||||
def api_status(self) -> list[APIStatus]:
|
||||
api_status: list[APIStatus] = []
|
||||
|
||||
for api in self.all_apis:
|
||||
not_configured_reason = getattr(self, f"{api}_not_configured_reason", None)
|
||||
not_configured_action = getattr(self, f"{api}_not_configured_action", None)
|
||||
api_info: str | None = getattr(self, f"{api}_info", None)
|
||||
messages: list[Note] = []
|
||||
if not_configured_reason:
|
||||
messages.append(
|
||||
Note(
|
||||
text=not_configured_reason,
|
||||
color="error",
|
||||
icon="mdi-alert-circle-outline",
|
||||
actions=[not_configured_action]
|
||||
if not_configured_action
|
||||
else None,
|
||||
)
|
||||
)
|
||||
if api_info:
|
||||
messages.append(
|
||||
Note(
|
||||
text=api_info.strip(),
|
||||
color="muted",
|
||||
icon="mdi-information-outline",
|
||||
)
|
||||
)
|
||||
_status = APIStatus(
|
||||
api=api,
|
||||
enabled=self.api_enabled(api),
|
||||
ready=self.api_ready(api),
|
||||
configured=self.api_configured(api),
|
||||
messages=messages,
|
||||
supports_mixing=getattr(self, f"{api}_supports_mixing", False),
|
||||
provider=provider(api),
|
||||
default_model=getattr(self, f"{api}_model", None),
|
||||
model_choices=getattr(self, f"{api}_model_choices", []),
|
||||
)
|
||||
api_status.append(_status)
|
||||
|
||||
# order by api
|
||||
api_status.sort(key=lambda x: x.api)
|
||||
|
||||
return api_status
|
||||
|
||||
# events
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
async_signals.get("game_loop_new_message").connect(
|
||||
self.on_game_loop_new_message
|
||||
)
|
||||
async_signals.get("voice_library.update.after").connect(
|
||||
self.on_voice_library_update
|
||||
)
|
||||
async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init)
|
||||
async_signals.get("character.voice_changed").connect(
|
||||
self.on_character_voice_changed
|
||||
)
|
||||
|
||||
async def on_scene_loop_init(self, event: "SceneLoopEvent"):
|
||||
if not self.enabled or not self.ready or not self.generate_for_narration:
|
||||
return
|
||||
|
||||
if self.scene.environment == "creative":
|
||||
return
|
||||
|
||||
content_messages = self.scene.last_message_of_type(
|
||||
["character", "narrator", "context_investigation"]
|
||||
)
|
||||
|
||||
if content_messages:
|
||||
# we already have a history, so we don't need to generate TTS for the intro
|
||||
return
|
||||
|
||||
await self.generate(self.scene.get_intro(), character=None)
|
||||
|
||||
async def on_voice_library_update(self, voice_library: VoiceLibrary):
|
||||
log.debug("Voice library updated - refreshing narrator voice choices")
|
||||
self.actions["_config"].config[
|
||||
"narrator_voice_id"
|
||||
].choices = self.narrator_voice_id_choices(self)
|
||||
await self.emit_status()
|
||||
|
||||
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
if self.scene.environment == "creative":
|
||||
return
|
||||
|
||||
character: Character | None = None
|
||||
|
||||
if not self.enabled or not self.ready:
|
||||
return
|
||||
|
||||
if not isinstance(
|
||||
emission.message,
|
||||
(CharacterMessage, NarratorMessage, ContextInvestigationMessage),
|
||||
):
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(emission.message, NarratorMessage)
|
||||
and not self.generate_for_narration
|
||||
):
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(emission.message, ContextInvestigationMessage)
|
||||
and not self.generate_for_context_investigation
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
if emission.message.source == "player" and not self.generate_for_player:
|
||||
return
|
||||
elif emission.message.source == "ai" and not self.generate_for_npc:
|
||||
return
|
||||
|
||||
character = self.scene.get_character(emission.message.character_name)
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
text_to_generate = str(emission.message).replace(
|
||||
character_prefix + ": ", ""
|
||||
)
|
||||
elif isinstance(emission.message, ContextInvestigationMessage):
|
||||
character_prefix = ""
|
||||
text_to_generate = (
|
||||
emission.message.message
|
||||
) # Use just the message content, not the title prefix
|
||||
else:
|
||||
character_prefix = ""
|
||||
text_to_generate = str(emission.message)
|
||||
|
||||
log.info(
|
||||
"reactive tts", message=emission.message, character_prefix=character_prefix
|
||||
)
|
||||
|
||||
await self.generate(
|
||||
text_to_generate,
|
||||
character=character,
|
||||
message=emission.message,
|
||||
)
|
||||
|
||||
async def on_character_voice_changed(self, event: "VoiceChangedEvent"):
|
||||
log.debug(
|
||||
"Character voice changed", character=event.character, voice=event.voice
|
||||
)
|
||||
await self.emit_status()
|
||||
|
||||
# voice helpers
|
||||
|
||||
@property
|
||||
def ready_apis(self) -> list[str]:
|
||||
"""
|
||||
Returns a list of apis that are ready
|
||||
"""
|
||||
return [api for api in self.apis if self.api_ready(api)]
|
||||
|
||||
@property
|
||||
def used_apis(self) -> list[str]:
|
||||
"""
|
||||
Returns a list of apis that are in use
|
||||
|
||||
The api is in use if it is the narrator voice or if any of the active characters in the scene use a voice from the api.
|
||||
"""
|
||||
return [api for api in self.apis if self.api_used(api)]
|
||||
|
||||
def api_enabled(self, api: str) -> bool:
|
||||
"""
|
||||
Returns whether the api is currently in the .apis list, which means it is enabled.
|
||||
"""
|
||||
return api in self.apis
|
||||
|
||||
def api_ready(self, api: str) -> bool:
|
||||
"""
|
||||
Returns whether the api is ready.
|
||||
|
||||
The api must be enabled and configured.
|
||||
"""
|
||||
|
||||
if not self.api_enabled(api):
|
||||
return False
|
||||
|
||||
return self.api_configured(api)
|
||||
|
||||
def api_configured(self, api: str) -> bool:
|
||||
return getattr(self, f"{api}_configured", True)
|
||||
|
||||
def api_used(self, api: str) -> bool:
|
||||
"""
|
||||
Returns whether the narrator or any of the active characters in the scene
|
||||
use a voice from the given api
|
||||
|
||||
Args:
|
||||
api (str): The api to check
|
||||
|
||||
Returns:
|
||||
bool: Whether the api is in use
|
||||
"""
|
||||
|
||||
if self.narrator_voice and self.narrator_voice.provider == api:
|
||||
return True
|
||||
|
||||
if not getattr(self, "scene", None):
|
||||
return False
|
||||
|
||||
for character in self.scene.characters:
|
||||
if not character.voice:
|
||||
continue
|
||||
voice = self.voice_library.get_voice(character.voice.id)
|
||||
if voice and voice.provider == api:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def use_ai_assisted_speaker_separation(
|
||||
self,
|
||||
text: str,
|
||||
message: CharacterMessage
|
||||
| NarratorMessage
|
||||
| ContextInvestigationMessage
|
||||
| None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether the ai assisted speaker separation should be used for the given text.
|
||||
"""
|
||||
try:
|
||||
if not message and '"' not in text:
|
||||
return False
|
||||
|
||||
if not message and '"' in text:
|
||||
return self.speaker_separation in ["ai_assisted", "mixed"]
|
||||
|
||||
if message.source == "player":
|
||||
return False
|
||||
|
||||
if self.speaker_separation == "ai_assisted":
|
||||
return True
|
||||
|
||||
if (
|
||||
isinstance(message, NarratorMessage)
|
||||
and self.speaker_separation == "mixed"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Error using ai assisted speaker separation",
|
||||
error=e,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
return False
|
||||
|
||||
# tts markup cache
|
||||
|
||||
async def get_tts_markup_cache(self, text: str) -> str | None:
|
||||
"""
|
||||
Returns the cached tts markup for the given text.
|
||||
"""
|
||||
fp = hash(text)
|
||||
cached_markup = self.get_scene_state("tts_markup_cache")
|
||||
if cached_markup and cached_markup.get("fp") == fp:
|
||||
return cached_markup.get("markup")
|
||||
return None
|
||||
|
||||
async def set_tts_markup_cache(self, text: str, markup: str):
|
||||
fp = hash(text)
|
||||
self.set_scene_states(
|
||||
tts_markup_cache={
|
||||
"fp": fp,
|
||||
"markup": markup,
|
||||
}
|
||||
)
|
||||
|
||||
# generation
|
||||
|
||||
@set_processing
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
character: Character | None = None,
|
||||
force_voice: Voice | None = None,
|
||||
message: CharacterMessage | NarratorMessage | None = None,
|
||||
):
|
||||
"""
|
||||
Public entry-point for voice generation.
|
||||
|
||||
The actual audio generation happens sequentially inside a single
|
||||
background queue. If a queue is currently active, we simply append the
|
||||
new request to it; if not, we create a new queue (with its own unique
|
||||
id) and start processing.
|
||||
"""
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
summarizer: "SummarizeAgent" = instance.get_agent("summarizer")
|
||||
|
||||
context = GenerationContext(voice_id=self.narrator_voice_id)
|
||||
character_voice: Voice = force_voice or self.narrator_voice
|
||||
|
||||
if character and character.voice:
|
||||
voice = character.voice
|
||||
if voice and self.api_ready(voice.provider):
|
||||
character_voice = voice
|
||||
else:
|
||||
log.warning(
|
||||
"Character voice not available",
|
||||
character=character.name,
|
||||
voice=character.voice,
|
||||
)
|
||||
|
||||
log.debug("Voice routing", character=character, voice=character_voice)
|
||||
|
||||
# initial chunking by separating dialogue from exposition
|
||||
chunks: list[Chunk] = []
|
||||
if self.speaker_separation != "none":
|
||||
if self.use_ai_assisted_speaker_separation(text, message):
|
||||
markup = await self.get_tts_markup_cache(text)
|
||||
if not markup:
|
||||
log.debug("No markup cache found, generating markup")
|
||||
markup = await summarizer.markup_context_for_tts(text)
|
||||
await self.set_tts_markup_cache(text, markup)
|
||||
else:
|
||||
log.debug("Using markup cache")
|
||||
# Use the new markup parser for AI-assisted format
|
||||
dlg_chunks = dialogue_utils.parse_tts_markup(markup)
|
||||
else:
|
||||
# Use the original parser for non-AI-assisted format
|
||||
dlg_chunks = dialogue_utils.separate_dialogue_from_exposition(text)
|
||||
|
||||
for _dlg_chunk in dlg_chunks:
|
||||
_voice = (
|
||||
character_voice
|
||||
if _dlg_chunk.type == "dialogue"
|
||||
else self.narrator_voice
|
||||
)
|
||||
|
||||
if _dlg_chunk.speaker is not None:
|
||||
# speaker name has been identified
|
||||
_character = self.scene.get_character(_dlg_chunk.speaker)
|
||||
log.debug(
|
||||
"Identified speaker",
|
||||
speaker=_dlg_chunk.speaker,
|
||||
character=_character,
|
||||
)
|
||||
if (
|
||||
_character
|
||||
and _character.voice
|
||||
and self.api_ready(_character.voice.provider)
|
||||
):
|
||||
log.debug(
|
||||
"Using character voice",
|
||||
character=_character.name,
|
||||
voice=_character.voice,
|
||||
)
|
||||
_voice = _character.voice
|
||||
|
||||
_api: str = _voice.provider if _voice else self.api
|
||||
chunk = Chunk(
|
||||
api=_api,
|
||||
voice=Voice(**_voice.model_dump()),
|
||||
model=_voice.provider_model,
|
||||
generate_fn=getattr(self, f"{_api}_generate"),
|
||||
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
|
||||
character_name=character.name if character else None,
|
||||
text=[_dlg_chunk.text],
|
||||
type=_dlg_chunk.type,
|
||||
message_id=message.id if message else None,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
_voice = character_voice if character else self.narrator_voice
|
||||
_api: str = _voice.provider if _voice else self.api
|
||||
chunks = [
|
||||
Chunk(
|
||||
api=_api,
|
||||
voice=Voice(**_voice.model_dump()),
|
||||
model=_voice.provider_model,
|
||||
generate_fn=getattr(self, f"{_api}_generate"),
|
||||
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
|
||||
character_name=character.name if character else None,
|
||||
text=[text],
|
||||
type="dialogue" if character else "exposition",
|
||||
message_id=message.id if message else None,
|
||||
)
|
||||
]
|
||||
|
||||
# second chunking by splitting into chunks of max_generation_length
|
||||
|
||||
for chunk in chunks:
|
||||
api_chunk_size = getattr(self, f"{chunk.api}_chunk_size", 0)
|
||||
|
||||
log.debug("chunking", api=chunk.api, api_chunk_size=api_chunk_size)
|
||||
|
||||
_text = []
|
||||
|
||||
max_generation_length = getattr(self, f"{chunk.api}_max_generation_length")
|
||||
|
||||
if api_chunk_size > 0:
|
||||
max_generation_length = min(max_generation_length, api_chunk_size)
|
||||
|
||||
for _chunk_text in chunk.text:
|
||||
if len(_chunk_text) <= max_generation_length:
|
||||
_text.append(_chunk_text)
|
||||
continue
|
||||
|
||||
_parsed = parse_chunks(_chunk_text)
|
||||
_joined = rejoin_chunks(_parsed, chunk_size=max_generation_length)
|
||||
_text.extend(_joined)
|
||||
|
||||
log.debug("chunked for size", before=chunk.text, after=_text)
|
||||
|
||||
chunk.text = _text
|
||||
|
||||
context.chunks = chunks
|
||||
|
||||
# Enqueue each chunk individually for fine-grained interruptibility
|
||||
async with self._queue_lock:
|
||||
if self._queue_id is None:
|
||||
self._queue_id = str(uuid.uuid4())
|
||||
|
||||
for chunk in context.chunks:
|
||||
self._generation_queue.append((context, chunk))
|
||||
|
||||
# Start processing task if needed
|
||||
if self._queue_task is None or self._queue_task.done():
|
||||
self._queue_task = asyncio.create_task(
|
||||
self._process_queue(self._queue_id)
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"tts queue enqueue",
|
||||
queue_id=self._queue_id,
|
||||
total_items=len(self._generation_queue),
|
||||
)
|
||||
|
||||
# The caller doesn't need to wait for the queue to finish; it runs in
|
||||
# the background. We still register the task with Talemate's
|
||||
# background-processing tracking so that UI can reflect activity.
|
||||
await self.set_background_processing(self._queue_task)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Queue helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def _process_queue(self, queue_id: str):
|
||||
"""Sequentially processes all GenerationContext objects in the queue.
|
||||
|
||||
Once the last context has been processed the queue state is reset so a
|
||||
future generation call will create a new queue (and therefore a new
|
||||
id). The *queue_id* argument allows us to later add cancellation logic
|
||||
that can target a specific queue instance.
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
async with self._queue_lock:
|
||||
if not self._generation_queue:
|
||||
break
|
||||
|
||||
context, chunk = self._generation_queue.popleft()
|
||||
|
||||
log.debug(
|
||||
"tts queue dequeue",
|
||||
queue_id=queue_id,
|
||||
total_items=len(self._generation_queue),
|
||||
chunk_type=chunk.type,
|
||||
)
|
||||
|
||||
# Process outside lock so other coroutines can enqueue
|
||||
await self._generate_chunk(chunk, context)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Error processing queue", error=e, traceback=traceback.format_exc()
|
||||
)
|
||||
finally:
|
||||
# Clean up queue state after finishing (or on cancellation)
|
||||
async with self._queue_lock:
|
||||
if queue_id == self._queue_id:
|
||||
self._queue_id = None
|
||||
self._queue_task = None
|
||||
self._generation_queue.clear()
|
||||
|
||||
# Public helper so external code (e.g. later cancellation UI) can find the current queue id
|
||||
def current_queue_id(self) -> str | None:
|
||||
return self._queue_id
|
||||
|
||||
async def _generate_chunk(self, chunk: Chunk, context: GenerationContext):
|
||||
"""Generate audio for a single chunk (all its sub-chunks)."""
|
||||
|
||||
for _chunk in chunk.sub_chunks:
|
||||
if not _chunk.cleaned_text.strip():
|
||||
continue
|
||||
|
||||
emission: VoiceGenerationEmission = VoiceGenerationEmission(
|
||||
chunk=_chunk, context=context
|
||||
)
|
||||
|
||||
if _chunk.prepare_fn:
|
||||
await async_signals.get("agent.tts.prepare.before").send(emission)
|
||||
await _chunk.prepare_fn(_chunk)
|
||||
await async_signals.get("agent.tts.prepare.after").send(emission)
|
||||
|
||||
log.info(
|
||||
"Generating audio",
|
||||
api=chunk.api,
|
||||
text=_chunk.cleaned_text,
|
||||
parameters=_chunk.voice.parameters,
|
||||
prepare_fn=_chunk.prepare_fn,
|
||||
)
|
||||
|
||||
await async_signals.get("agent.tts.generate.before").send(emission)
|
||||
try:
|
||||
emission.wav_bytes = await _chunk.generate_fn(_chunk, context)
|
||||
except Exception as e:
|
||||
log.error("Error generating audio", error=e, chunk=_chunk)
|
||||
continue
|
||||
await async_signals.get("agent.tts.generate.after").send(emission)
|
||||
self.play_audio(emission.wav_bytes, chunk.message_id)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Deprecated: kept for backward compatibility but no longer used.
|
||||
async def generate_chunks(self, context: GenerationContext):
|
||||
for chunk in context.chunks:
|
||||
await self._generate_chunk(chunk, context)
|
||||
|
||||
def play_audio(self, audio_data, message_id: int | None = None):
|
||||
# play audio through the websocket (browser)
|
||||
|
||||
audio_data_encoded: str = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
emit(
|
||||
"audio_queue",
|
||||
data={"audio_data": audio_data_encoded, "message_id": message_id},
|
||||
)
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
async def stop_and_clear_queue(self):
|
||||
"""Cancel any ongoing generation and clear the pending queue.
|
||||
|
||||
This is triggered by UI actions that request immediate stop of TTS
|
||||
synthesis and playback. It cancels the background task (if still
|
||||
running) and clears all queued items in a thread-safe manner.
|
||||
"""
|
||||
async with self._queue_lock:
|
||||
# Clear all queued items
|
||||
self._generation_queue.clear()
|
||||
|
||||
# Cancel the background task if it is still running
|
||||
if self._queue_task and not self._queue_task.done():
|
||||
self._queue_task.cancel()
|
||||
|
||||
# Reset queue identifiers/state
|
||||
self._queue_id = None
|
||||
self._queue_task = None
|
||||
|
||||
# Ensure downstream components know playback is finished
|
||||
self.playback_done_event.set()
|
||||
317
src/talemate/agents/tts/chatterbox.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import os
|
||||
import functools
|
||||
import tempfile
|
||||
import uuid
|
||||
import asyncio
|
||||
import structlog
|
||||
import pydantic
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Lazy imports for heavy dependencies
|
||||
def _import_heavy_deps():
|
||||
global ta, ChatterboxTTS
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
from talemate.ux.schema import Field
|
||||
|
||||
from .schema import Voice, Chunk, GenerationContext, VoiceProvider, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
from .providers import register, provider
|
||||
from .util import voice_is_talemate_asset
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.chatterbox")
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Eva",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/eva.wav",
|
||||
tags=["female", "calm", "mature", "thoughtful"],
|
||||
),
|
||||
Voice(
|
||||
label="Lisa",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/lisa.wav",
|
||||
tags=["female", "energetic", "young"],
|
||||
),
|
||||
Voice(
|
||||
label="Adam",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/adam.wav",
|
||||
tags=["male", "calm", "mature", "thoughtful", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Bradford",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/bradford.wav",
|
||||
tags=["male", "calm", "mature", "thoughtful", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Julia",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/julia.wav",
|
||||
tags=["female", "calm", "mature"],
|
||||
),
|
||||
Voice(
|
||||
label="Zoe",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/zoe.wav",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(
|
||||
label="William",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/william.wav",
|
||||
tags=["male", "young"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
CHATTERBOX_INFO = """
|
||||
Chatterbox is a local text to speech model.
|
||||
|
||||
The voice id is the path to the .wav file for the voice.
|
||||
|
||||
The path can be relative to the talemate root directory, and you can put new *.wav samples
|
||||
in the `tts/voice/chatterbox` directory. It is also ok if you want to load the files from somewhere else as long as the filepath is available to the talemate backend.
|
||||
|
||||
First generation will download the models (2.13GB + 1.06GB).
|
||||
|
||||
Uses about 4GB of VRAM.
|
||||
"""
|
||||
|
||||
|
||||
@register()
|
||||
class ChatterboxProvider(VoiceProvider):
|
||||
name: str = "chatterbox"
|
||||
allow_model_override: bool = False
|
||||
allow_file_upload: bool = True
|
||||
upload_file_types: list[str] = ["audio/wav"]
|
||||
voice_parameters: list[Field] = [
|
||||
Field(
|
||||
name="exaggeration",
|
||||
type="number",
|
||||
label="Exaggeration level",
|
||||
value=0.5,
|
||||
min=0.25,
|
||||
max=2.0,
|
||||
step=0.05,
|
||||
),
|
||||
Field(
|
||||
name="cfg_weight",
|
||||
type="number",
|
||||
label="CFG/Pace",
|
||||
value=0.5,
|
||||
min=0.2,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
),
|
||||
Field(
|
||||
name="temperature",
|
||||
type="number",
|
||||
label="Temperature",
|
||||
value=0.8,
|
||||
min=0.05,
|
||||
max=5.0,
|
||||
step=0.05,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxInstance(pydantic.BaseModel):
|
||||
model: "ChatterboxTTS"
|
||||
device: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ChatterboxMixin:
|
||||
"""
|
||||
Chatterbox agent mixin for local text to speech.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "chatterbox",
|
||||
"label": "Chatterbox (Local)",
|
||||
"help": "Chatterbox is a local text to speech model.",
|
||||
}
|
||||
)
|
||||
|
||||
actions["chatterbox"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="Chatterbox",
|
||||
description="Chatterbox is a local text to speech model.",
|
||||
config={
|
||||
"device": AgentActionConfig(
|
||||
type="text",
|
||||
value="cuda" if CUDA_AVAILABLE else "cpu",
|
||||
label="Device",
|
||||
choices=[
|
||||
{"value": "cpu", "label": "CPU"},
|
||||
{"value": "cuda", "label": "CUDA"},
|
||||
],
|
||||
description="Device to use for TTS",
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=256,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
return actions
|
||||
|
||||
@property
|
||||
def chatterbox_configured(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def chatterbox_max_generation_length(self) -> int:
|
||||
return 512
|
||||
|
||||
@property
|
||||
def chatterbox_device(self) -> str:
|
||||
return self.actions["chatterbox"].config["device"].value
|
||||
|
||||
@property
|
||||
def chatterbox_chunk_size(self) -> int:
|
||||
return self.actions["chatterbox"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def chatterbox_info(self) -> str:
|
||||
return CHATTERBOX_INFO
|
||||
|
||||
@property
|
||||
def chatterbox_agent_details(self) -> dict:
|
||||
if not self.chatterbox_configured:
|
||||
return {}
|
||||
details = {}
|
||||
|
||||
details["chatterbox_device"] = AgentDetail(
|
||||
icon="mdi-memory",
|
||||
value=f"Chatterbox: {self.chatterbox_device}",
|
||||
description="The device to use for Chatterbox",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
def chatterbox_delete_voice(self, voice: Voice):
|
||||
"""
|
||||
Remove the voice from the file system.
|
||||
Only do this if the path is within TALEMATE_ROOT.
|
||||
"""
|
||||
|
||||
is_talemate_asset, resolved = voice_is_talemate_asset(
|
||||
voice, provider(voice.provider)
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"chatterbox_delete_voice",
|
||||
voice_id=voice.provider_id,
|
||||
is_talemate_asset=is_talemate_asset,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
if not is_talemate_asset:
|
||||
return
|
||||
|
||||
try:
|
||||
if resolved.exists() and resolved.is_file():
|
||||
resolved.unlink()
|
||||
log.debug("Deleted chatterbox voice file", path=str(resolved))
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Failed to delete chatterbox voice file", error=e, path=str(resolved)
|
||||
)
|
||||
|
||||
def _chatterbox_generate_file(
|
||||
self,
|
||||
model: "ChatterboxTTS",
|
||||
text: str,
|
||||
audio_prompt_path: str,
|
||||
output_path: str,
|
||||
**kwargs,
|
||||
):
|
||||
wav = model.generate(text=text, audio_prompt_path=audio_prompt_path, **kwargs)
|
||||
ta.save(output_path, wav, model.sr)
|
||||
return output_path
|
||||
|
||||
async def chatterbox_generate(
|
||||
self, chunk: Chunk, context: GenerationContext
|
||||
) -> bytes | None:
|
||||
chatterbox_instance: ChatterboxInstance | None = getattr(
|
||||
self, "chatterbox_instance", None
|
||||
)
|
||||
|
||||
reload: bool = False
|
||||
|
||||
if not chatterbox_instance:
|
||||
reload = True
|
||||
elif chatterbox_instance.device != self.chatterbox_device:
|
||||
reload = True
|
||||
|
||||
if reload:
|
||||
log.debug(
|
||||
"chatterbox - reinitializing tts instance",
|
||||
device=self.chatterbox_device,
|
||||
)
|
||||
# Lazy import heavy dependencies only when needed
|
||||
_import_heavy_deps()
|
||||
|
||||
self.chatterbox_instance = ChatterboxInstance(
|
||||
model=ChatterboxTTS.from_pretrained(device=self.chatterbox_device),
|
||||
device=self.chatterbox_device,
|
||||
)
|
||||
|
||||
model: "ChatterboxTTS" = self.chatterbox_instance.model
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
voice = chunk.voice
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._chatterbox_generate_file,
|
||||
model=model,
|
||||
text=chunk.cleaned_text,
|
||||
audio_prompt_path=voice.provider_id,
|
||||
output_path=file_path,
|
||||
**voice.parameters,
|
||||
),
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
async def chatterbox_prepare_chunk(self, chunk: Chunk):
|
||||
voice = chunk.voice
|
||||
P = provider(voice.provider)
|
||||
exaggeration = P.voice_parameter(voice, "exaggeration")
|
||||
|
||||
voice.parameters["exaggeration"] = exaggeration
|
||||
248
src/talemate/agents/tts/elevenlabs.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import io
|
||||
from typing import Union
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
# Lazy imports for heavy dependencies
|
||||
def _import_heavy_deps():
|
||||
global AsyncElevenLabs, ApiError
|
||||
from elevenlabs.client import AsyncElevenLabs
|
||||
|
||||
# Added explicit ApiError import for clearer error handling
|
||||
from elevenlabs.core.api_error import ApiError
|
||||
|
||||
|
||||
from talemate.ux.schema import Action
|
||||
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
from .schema import Voice, VoiceLibrary, GenerationContext, Chunk, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
|
||||
# emit helper to propagate status messages to the UX
|
||||
from talemate.emit import emit
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.elevenlabs")
|
||||
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Adam",
|
||||
provider="elevenlabs",
|
||||
provider_id="wBXNqKUATyqu0RtYt25i",
|
||||
tags=["male", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Amy",
|
||||
provider="elevenlabs",
|
||||
provider_id="oGn4Ha2pe2vSJkmIJgLQ",
|
||||
tags=["female"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ELEVENLABS_INFO = """
|
||||
ElevenLabs is a cloud-based text to speech API.
|
||||
|
||||
To add new voices, head to their voice library at [https://elevenlabs.io/app/voice-library](https://elevenlabs.io/app/voice-library) and note the voice id of the voice you want to use. (Click 'More Actions -> Copy Voice ID')
|
||||
|
||||
**About elevenlabs voices**
|
||||
Your elevenlabs subscription allows you to maintain a set number of voices (10 for cheapest plan).
|
||||
|
||||
Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button above. It is recommend testing via their voice library instead.
|
||||
"""
|
||||
|
||||
|
||||
class ElevenLabsMixin:
|
||||
"""
|
||||
ElevenLabs TTS agent mixin for cloud-based text to speech.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "elevenlabs",
|
||||
"label": "ElevenLabs",
|
||||
"help": "ElevenLabs is a cloud-based text to speech model that uses the ElevenLabs API. (API key required)",
|
||||
}
|
||||
)
|
||||
|
||||
actions["elevenlabs"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="ElevenLabs",
|
||||
description="ElevenLabs is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> ElevenLabs)",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="eleven_flash_v2_5",
|
||||
label="Model",
|
||||
description="Model to use for TTS",
|
||||
choices=[
|
||||
{
|
||||
"value": "eleven_multilingual_v2",
|
||||
"label": "Eleven Multilingual V2",
|
||||
},
|
||||
{"value": "eleven_flash_v2_5", "label": "Eleven Flash V2.5"},
|
||||
{"value": "eleven_turbo_v2_5", "label": "Eleven Turbo V2.5"},
|
||||
],
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=0,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
@classmethod
|
||||
def add_voices(cls, voices: dict[str, VoiceLibrary]):
|
||||
voices["elevenlabs"] = VoiceLibrary(api="elevenlabs", local=True)
|
||||
|
||||
@property
|
||||
def elevenlabs_chunk_size(self) -> int:
|
||||
return self.actions["elevenlabs"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def elevenlabs_configured(self) -> bool:
|
||||
api_key_set = bool(self.elevenlabs_api_key)
|
||||
model_set = bool(self.elevenlabs_model)
|
||||
return api_key_set and model_set
|
||||
|
||||
@property
|
||||
def elevenlabs_not_configured_reason(self) -> str | None:
|
||||
if not self.elevenlabs_api_key:
|
||||
return "ElevenLabs API key not set"
|
||||
if not self.elevenlabs_model:
|
||||
return "ElevenLabs model not set"
|
||||
return None
|
||||
|
||||
@property
|
||||
def elevenlabs_not_configured_action(self) -> Action | None:
|
||||
if not self.elevenlabs_api_key:
|
||||
return Action(
|
||||
action_name="openAppConfig",
|
||||
arguments=["application", "elevenlabs_api"],
|
||||
label="Set API Key",
|
||||
icon="mdi-key",
|
||||
)
|
||||
if not self.elevenlabs_model:
|
||||
return Action(
|
||||
action_name="openAgentSettings",
|
||||
arguments=["tts", "elevenlabs"],
|
||||
label="Set Model",
|
||||
icon="mdi-brain",
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def elevenlabs_max_generation_length(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def elevenlabs_model(self) -> str:
|
||||
return self.actions["elevenlabs"].config["model"].value
|
||||
|
||||
@property
|
||||
def elevenlabs_model_choices(self) -> list[str]:
|
||||
return [
|
||||
{"label": choice["label"], "value": choice["value"]}
|
||||
for choice in self.actions["elevenlabs"].config["model"].choices
|
||||
]
|
||||
|
||||
@property
|
||||
def elevenlabs_info(self) -> str:
|
||||
return ELEVENLABS_INFO
|
||||
|
||||
@property
|
||||
def elevenlabs_agent_details(self) -> dict:
|
||||
details = {}
|
||||
|
||||
if not self.elevenlabs_configured:
|
||||
details["elevenlabs_api_key"] = AgentDetail(
|
||||
icon="mdi-key",
|
||||
value="ElevenLabs API key not set",
|
||||
description="ElevenLabs API key not set. You can set it in the Talemate Settings -> Application -> ElevenLabs",
|
||||
color="error",
|
||||
).model_dump()
|
||||
else:
|
||||
details["elevenlabs_model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.elevenlabs_model,
|
||||
description="The model to use for ElevenLabs",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def elevenlabs_api_key(self) -> str:
|
||||
return self.config.elevenlabs.api_key
|
||||
|
||||
async def elevenlabs_generate(
|
||||
self, chunk: Chunk, context: GenerationContext, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
api_key = self.elevenlabs_api_key
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
# Lazy import heavy dependencies only when needed
|
||||
_import_heavy_deps()
|
||||
|
||||
client = AsyncElevenLabs(api_key=api_key)
|
||||
|
||||
try:
|
||||
response_async_iter = client.text_to_speech.convert(
|
||||
text=chunk.cleaned_text,
|
||||
voice_id=chunk.voice.provider_id,
|
||||
model_id=chunk.model or self.elevenlabs_model,
|
||||
)
|
||||
|
||||
bytes_io = io.BytesIO()
|
||||
|
||||
async for _chunk_bytes in response_async_iter:
|
||||
if _chunk_bytes:
|
||||
bytes_io.write(_chunk_bytes)
|
||||
|
||||
return bytes_io.getvalue()
|
||||
|
||||
except ApiError as e:
|
||||
# Emit detailed status message to the frontend UI
|
||||
error_message = "ElevenLabs API Error"
|
||||
try:
|
||||
# The ElevenLabs ApiError often contains a JSON body with details
|
||||
detail = e.body.get("detail", {}) if hasattr(e, "body") else {}
|
||||
error_message = detail.get("message", str(e)) or str(e)
|
||||
except Exception:
|
||||
error_message = str(e)
|
||||
|
||||
log.error("ElevenLabs API error", error=str(e))
|
||||
emit(
|
||||
"status",
|
||||
message=f"ElevenLabs TTS: {error_message}",
|
||||
status="error",
|
||||
)
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all to ensure the app does not crash on unexpected errors
|
||||
log.error("ElevenLabs TTS generation error", error=str(e))
|
||||
emit(
|
||||
"status",
|
||||
message=f"ElevenLabs TTS: Unexpected error – {str(e)}",
|
||||
status="error",
|
||||
)
|
||||
raise e
|
||||
436
src/talemate/agents/tts/f5tts.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import os
|
||||
import functools
|
||||
import tempfile
|
||||
import uuid
|
||||
import asyncio
|
||||
import structlog
|
||||
import pydantic
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Lazy imports for heavy dependencies
|
||||
def _import_heavy_deps():
|
||||
global F5TTS
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
from talemate.ux.schema import Field
|
||||
|
||||
from .schema import Voice, Chunk, GenerationContext, VoiceProvider, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
from .providers import register, provider
|
||||
from .util import voice_is_talemate_asset
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.f5tts")
|
||||
|
||||
REF_TEXT = "You awaken aboard your ship, the Starlight Nomad. A soft hum resonates throughout the vessel indicating its systems are online."
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Adam",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/adam.wav",
|
||||
tags=["male", "calm", "mature", "deep", "thoughtful"],
|
||||
parameters={
|
||||
"speed": 1.05,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
Voice(
|
||||
label="Bradford",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/bradford.wav",
|
||||
tags=["male", "calm", "mature"],
|
||||
parameters={
|
||||
"speed": 1,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
Voice(
|
||||
label="Julia",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/julia.wav",
|
||||
tags=["female", "calm", "mature"],
|
||||
parameters={
|
||||
"speed": 1.1,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
Voice(
|
||||
label="Lisa",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/lisa.wav",
|
||||
tags=["female", "young", "energetic"],
|
||||
parameters={
|
||||
"speed": 1.2,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
Voice(
|
||||
label="Eva",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/eva.wav",
|
||||
tags=["female", "mature", "thoughtful"],
|
||||
parameters={
|
||||
"speed": 1.15,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
Voice(
|
||||
label="Zoe",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/zoe.wav",
|
||||
tags=["female"],
|
||||
parameters={
|
||||
"speed": 1.15,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
Voice(
|
||||
label="William",
|
||||
provider="f5tts",
|
||||
provider_id="tts/voice/f5tts/william.wav",
|
||||
tags=["male", "young"],
|
||||
parameters={
|
||||
"speed": 1.15,
|
||||
"ref_text": REF_TEXT,
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
F5TTS_INFO = """
|
||||
F5-TTS is a local text-to-speech model.
|
||||
|
||||
The voice id is the path to the reference *.wav* file that contains a short
|
||||
voice sample (≈3-5 s). You can place new samples in the
|
||||
`tts/voice/f5tts` directory of your Talemate workspace or supply an absolute
|
||||
path that is accessible to the backend.
|
||||
|
||||
The first generation will download the model weights (~1.3 GB) if they are not
|
||||
cached yet.
|
||||
"""
|
||||
|
||||
|
||||
@register()
|
||||
class F5TTSProvider(VoiceProvider):
|
||||
"""Metadata for the F5-TTS provider."""
|
||||
|
||||
name: str = "f5tts"
|
||||
allow_model_override: bool = False
|
||||
allow_file_upload: bool = True
|
||||
upload_file_types: list[str] = ["audio/wav"]
|
||||
|
||||
# Provider-specific tunable parameters that can be stored per-voice
|
||||
voice_parameters: list[Field] = [
|
||||
Field(
|
||||
name="speed",
|
||||
type="number",
|
||||
label="Speed",
|
||||
value=1.0,
|
||||
min=0.25,
|
||||
max=2.0,
|
||||
step=0.05,
|
||||
description="If the speech is too fast or slow, adjust this value. 1.0 is normal speed.",
|
||||
),
|
||||
Field(
|
||||
name="ref_text",
|
||||
type="text",
|
||||
label="Reference text",
|
||||
value="",
|
||||
description="Text that matches the reference audio sample (improves synthesis quality).",
|
||||
required=True,
|
||||
),
|
||||
Field(
|
||||
name="cfg_strength",
|
||||
type="number",
|
||||
label="CFG Strength",
|
||||
value=2.0,
|
||||
min=0.1,
|
||||
step=0.1,
|
||||
max=10.0,
|
||||
description="CFG strength for the model.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class F5TTSInstance(pydantic.BaseModel):
|
||||
"""Holds a single F5-TTS model instance (lazy-initialised)."""
|
||||
|
||||
model: "F5TTS" # Forward reference for lazy loading
|
||||
model_name: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class F5TTSMixin:
|
||||
"""F5-TTS agent mixin for local text-to-speech generation."""
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# UI integration / configuration helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
"""Expose the F5-TTS backend in the global TTS agent settings."""
|
||||
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "f5tts",
|
||||
"label": "F5-TTS (Local)",
|
||||
"help": "F5-TTS is a local text-to-speech model.",
|
||||
}
|
||||
)
|
||||
|
||||
actions["f5tts"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="F5-TTS",
|
||||
description="F5-TTS is a local text-to-speech model.",
|
||||
config={
|
||||
"device": AgentActionConfig(
|
||||
type="text",
|
||||
value="cuda" if CUDA_AVAILABLE else "cpu",
|
||||
label="Device",
|
||||
choices=[
|
||||
{"value": "cpu", "label": "CPU"},
|
||||
{"value": "cuda", "label": "CUDA"},
|
||||
],
|
||||
description="Device to use for TTS",
|
||||
),
|
||||
"model_name": AgentActionConfig(
|
||||
type="text",
|
||||
value="F5TTS_v1_Base",
|
||||
label="Model",
|
||||
description="Model will be downloaded on first use.",
|
||||
choices=[
|
||||
{"value": "E2TTS_Base", "label": "E2TTS_Base"},
|
||||
{"value": "F5TTS_Base", "label": "F5TTS_Base"},
|
||||
{"value": "F5TTS_v1_Base", "label": "F5TTS_v1_Base"},
|
||||
],
|
||||
),
|
||||
"nfe_step": AgentActionConfig(
|
||||
type="number",
|
||||
label="NFE Step",
|
||||
value=32,
|
||||
min=32,
|
||||
step=16,
|
||||
max=64,
|
||||
description="Number of diffusion steps.",
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=32,
|
||||
max=1024,
|
||||
value=64,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
"replace_exclamation_marks": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Replace exclamation marks",
|
||||
description="Some models tend to over-emphasise exclamation marks, so this is a workaround to make the speech more natural.",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# No additional per-API settings (model/device) required for F5-TTS.
|
||||
return actions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience properties consumed by the core TTS agent
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def f5tts_configured(self) -> bool:
|
||||
# Local backend – always available once the model weights are present.
|
||||
return True
|
||||
|
||||
@property
|
||||
def f5tts_device(self) -> str:
|
||||
return self.actions["f5tts"].config["device"].value
|
||||
|
||||
@property
|
||||
def f5tts_chunk_size(self) -> int:
|
||||
return self.actions["f5tts"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def f5tts_replace_exclamation_marks(self) -> bool:
|
||||
return self.actions["f5tts"].config["replace_exclamation_marks"].value
|
||||
|
||||
@property
|
||||
def f5tts_model_name(self) -> str:
|
||||
return self.actions["f5tts"].config["model_name"].value
|
||||
|
||||
@property
|
||||
def f5tts_nfe_step(self) -> int:
|
||||
return self.actions["f5tts"].config["nfe_step"].value
|
||||
|
||||
@property
|
||||
def f5tts_max_generation_length(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def f5tts_info(self) -> str:
|
||||
return F5TTS_INFO
|
||||
|
||||
@property
|
||||
def f5tts_agent_details(self) -> dict:
|
||||
if not self.f5tts_configured:
|
||||
return {}
|
||||
details = {}
|
||||
|
||||
device = self.f5tts_device
|
||||
model_name = self.f5tts_model_name
|
||||
|
||||
details["f5tts_device"] = AgentDetail(
|
||||
icon="mdi-memory",
|
||||
value=f"{model_name}@{device}",
|
||||
description="The model and device to use for F5-TTS",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Voice housekeeping helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def f5tts_delete_voice(self, voice: Voice):
|
||||
"""Delete *voice* reference file if it is inside the Talemate workspace."""
|
||||
|
||||
is_talemate_asset, resolved = voice_is_talemate_asset(
|
||||
voice, provider(voice.provider)
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"f5tts_delete_voice",
|
||||
voice_id=voice.provider_id,
|
||||
is_talemate_asset=is_talemate_asset,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
if not is_talemate_asset:
|
||||
return
|
||||
|
||||
try:
|
||||
if resolved.exists() and resolved.is_file():
|
||||
resolved.unlink()
|
||||
log.debug("Deleted F5-TTS voice file", path=str(resolved))
|
||||
except Exception as e:
|
||||
log.error("Failed to delete F5-TTS voice file", error=e, path=str(resolved))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Generation helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _f5tts_generate_file(
|
||||
self,
|
||||
model: "F5TTS",
|
||||
chunk: Chunk,
|
||||
voice: Voice,
|
||||
output_path: str,
|
||||
) -> str:
|
||||
"""Blocking generation helper executed in a thread-pool."""
|
||||
|
||||
wav, sr, _ = model.infer(
|
||||
ref_file=voice.provider_id,
|
||||
ref_text=voice.parameters.get("ref_text", ""),
|
||||
gen_text=chunk.cleaned_text,
|
||||
file_wave=output_path,
|
||||
speed=voice.parameters.get("speed", 1.0),
|
||||
cfg_strength=voice.parameters.get("cfg_strength", 2.0),
|
||||
nfe_step=self.f5tts_nfe_step,
|
||||
)
|
||||
|
||||
# Some versions of F5-TTS don’t write *file_wave*. Drop-in save as fallback.
|
||||
# if not os.path.exists(output_path):
|
||||
# ta.save(output_path, wav, sr)
|
||||
|
||||
return output_path
|
||||
|
||||
async def f5tts_generate(
|
||||
self, chunk: Chunk, context: GenerationContext
|
||||
) -> bytes | None:
|
||||
"""Asynchronously synthesise *chunk* using F5-TTS."""
|
||||
|
||||
# Lazy initialisation & caching across invocations
|
||||
f5tts_instance: "F5TTSInstance | None" = getattr(self, "f5tts_instance", None)
|
||||
|
||||
device = self.f5tts_device
|
||||
model_name: str = self.f5tts_model_name
|
||||
|
||||
reload_model = (
|
||||
f5tts_instance is None
|
||||
or f5tts_instance.model.device != device
|
||||
or f5tts_instance.model_name != model_name
|
||||
)
|
||||
|
||||
if reload_model:
|
||||
if f5tts_instance is not None:
|
||||
log.debug(
|
||||
"Reloading F5-TTS backend", device=device, model_name=model_name
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
"Initialising F5-TTS backend", device=device, model_name=model_name
|
||||
)
|
||||
|
||||
# Lazy import heavy dependencies only when needed
|
||||
_import_heavy_deps()
|
||||
|
||||
f5tts_instance = F5TTSInstance(
|
||||
model=F5TTS(device=device, model=model_name),
|
||||
model_name=model_name,
|
||||
)
|
||||
self.f5tts_instance = f5tts_instance
|
||||
|
||||
model: "F5TTS" = f5tts_instance.model
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
voice = chunk.voice
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
# Delegate blocking work to the default ThreadPoolExecutor
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._f5tts_generate_file, model, chunk, voice, file_path
|
||||
),
|
||||
)
|
||||
|
||||
# Read the generated WAV and return bytes for websocket playback
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
async def f5tts_prepare_chunk(self, chunk: Chunk):
|
||||
text = chunk.text[0]
|
||||
|
||||
# f5-tts seems to have issues with ellipses
|
||||
text = text.replace("…", "...").replace("...", ".")
|
||||
|
||||
# hyphanated words also seem to be a problem
|
||||
text = re.sub(r"(\w)-(\w)", r"\1 \2", text)
|
||||
|
||||
if self.f5tts_replace_exclamation_marks:
|
||||
text = text.replace("!", ".")
|
||||
|
||||
chunk.text[0] = text
|
||||
|
||||
return chunk
|
||||
319
src/talemate/agents/tts/google.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import io
|
||||
import wave
|
||||
from typing import Union, Optional
|
||||
|
||||
import structlog
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from talemate.ux.schema import Action
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
from .schema import Voice, VoiceLibrary, Chunk, GenerationContext, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.google")
|
||||
|
||||
GOOGLE_INFO = """
|
||||
Google Gemini TTS is a cloud-based text to speech model.
|
||||
|
||||
A list of available voices can be found at [https://ai.google.dev/gemini-api/docs/speech-generation](https://ai.google.dev/gemini-api/docs/speech-generation).
|
||||
"""
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(label="Zephyr", provider="google", provider_id="Zephyr", tags=["female"]),
|
||||
Voice(label="Puck", provider="google", provider_id="Puck", tags=["male"]),
|
||||
Voice(label="Charon", provider="google", provider_id="Charon", tags=["male"]),
|
||||
Voice(label="Kore", provider="google", provider_id="Kore", tags=["female"]),
|
||||
Voice(label="Fenrir", provider="google", provider_id="Fenrir", tags=["male"]),
|
||||
Voice(label="Leda", provider="google", provider_id="Leda", tags=["female"]),
|
||||
Voice(label="Orus", provider="google", provider_id="Orus", tags=["male"]),
|
||||
Voice(label="Aoede", provider="google", provider_id="Aoede", tags=["female"]),
|
||||
Voice(
|
||||
label="Callirrhoe",
|
||||
provider="google",
|
||||
provider_id="Callirrhoe",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(
|
||||
label="Autonoe", provider="google", provider_id="Autonoe", tags=["female"]
|
||||
),
|
||||
Voice(
|
||||
label="Enceladus",
|
||||
provider="google",
|
||||
provider_id="Enceladus",
|
||||
tags=["male", "deep"],
|
||||
),
|
||||
Voice(label="Iapetus", provider="google", provider_id="Iapetus", tags=["male"]),
|
||||
Voice(label="Umbriel", provider="google", provider_id="Umbriel", tags=["male"]),
|
||||
Voice(
|
||||
label="Algieba",
|
||||
provider="google",
|
||||
provider_id="Algieba",
|
||||
tags=["male", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Despina",
|
||||
provider="google",
|
||||
provider_id="Despina",
|
||||
tags=["female", "young"],
|
||||
),
|
||||
Voice(
|
||||
label="Erinome", provider="google", provider_id="Erinome", tags=["female"]
|
||||
),
|
||||
Voice(label="Algenib", provider="google", provider_id="Algenib", tags=["male"]),
|
||||
Voice(
|
||||
label="Rasalgethi",
|
||||
provider="google",
|
||||
provider_id="Rasalgethi",
|
||||
tags=["male", "neutral"],
|
||||
),
|
||||
Voice(
|
||||
label="Laomedeia",
|
||||
provider="google",
|
||||
provider_id="Laomedeia",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(
|
||||
label="Achernar",
|
||||
provider="google",
|
||||
provider_id="Achernar",
|
||||
tags=["female", "young"],
|
||||
),
|
||||
Voice(label="Alnilam", provider="google", provider_id="Alnilam", tags=["male"]),
|
||||
Voice(label="Schedar", provider="google", provider_id="Schedar", tags=["male"]),
|
||||
Voice(
|
||||
label="Gacrux",
|
||||
provider="google",
|
||||
provider_id="Gacrux",
|
||||
tags=["female", "mature"],
|
||||
),
|
||||
Voice(
|
||||
label="Pulcherrima",
|
||||
provider="google",
|
||||
provider_id="Pulcherrima",
|
||||
tags=["female", "mature"],
|
||||
),
|
||||
Voice(
|
||||
label="Achird",
|
||||
provider="google",
|
||||
provider_id="Achird",
|
||||
tags=["male", "energetic"],
|
||||
),
|
||||
Voice(
|
||||
label="Zubenelgenubi",
|
||||
provider="google",
|
||||
provider_id="Zubenelgenubi",
|
||||
tags=["male"],
|
||||
),
|
||||
Voice(
|
||||
label="Vindemiatrix",
|
||||
provider="google",
|
||||
provider_id="Vindemiatrix",
|
||||
tags=["female", "mature"],
|
||||
),
|
||||
Voice(
|
||||
label="Sadachbia", provider="google", provider_id="Sadachbia", tags=["male"]
|
||||
),
|
||||
Voice(
|
||||
label="Sadaltager",
|
||||
provider="google",
|
||||
provider_id="Sadaltager",
|
||||
tags=["male"],
|
||||
),
|
||||
Voice(
|
||||
label="Sulafat",
|
||||
provider="google",
|
||||
provider_id="Sulafat",
|
||||
tags=["female", "young"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class GoogleMixin:
|
||||
"""Google Gemini TTS mixin (Flash/Pro preview models)."""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "google",
|
||||
"label": "Google Gemini",
|
||||
"help": "Google Gemini is a cloud-based text to speech model that uses the Google Gemini API. (API key required)",
|
||||
}
|
||||
)
|
||||
|
||||
actions["google"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="Google Gemini",
|
||||
description="Google Gemini is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> Google)",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="gemini-2.5-flash-preview-tts",
|
||||
choices=[
|
||||
{
|
||||
"value": "gemini-2.5-flash-preview-tts",
|
||||
"label": "Gemini 2.5 Flash TTS (Preview)",
|
||||
},
|
||||
{
|
||||
"value": "gemini-2.5-pro-preview-tts",
|
||||
"label": "Gemini 2.5 Pro TTS (Preview)",
|
||||
},
|
||||
],
|
||||
label="Model",
|
||||
description="Google TTS model to use",
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=0,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
@classmethod
|
||||
def add_voices(cls, voices: dict[str, VoiceLibrary]):
|
||||
voices["google"] = VoiceLibrary(api="google")
|
||||
|
||||
@property
|
||||
def google_configured(self) -> bool:
|
||||
return bool(self.google_api_key) and bool(self.google_model)
|
||||
|
||||
@property
|
||||
def google_chunk_size(self) -> int:
|
||||
return self.actions["google"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def google_not_configured_reason(self) -> str | None:
|
||||
if not self.google_api_key:
|
||||
return "Google API key not set"
|
||||
if not self.google_model:
|
||||
return "Google model not set"
|
||||
return None
|
||||
|
||||
@property
|
||||
def google_not_configured_action(self) -> Action | None:
|
||||
if not self.google_api_key:
|
||||
return Action(
|
||||
action_name="openAppConfig",
|
||||
arguments=["application", "google_api"],
|
||||
label="Set API Key",
|
||||
icon="mdi-key",
|
||||
)
|
||||
if not self.google_model:
|
||||
return Action(
|
||||
action_name="openAgentSettings",
|
||||
arguments=["tts", "google"],
|
||||
label="Set Model",
|
||||
icon="mdi-brain",
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def google_info(self) -> str:
|
||||
return GOOGLE_INFO
|
||||
|
||||
@property
|
||||
def google_max_generation_length(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def google_model(self) -> str:
|
||||
return self.actions["google"].config["model"].value
|
||||
|
||||
@property
|
||||
def google_model_choices(self) -> list[str]:
|
||||
return [
|
||||
{"label": choice["label"], "value": choice["value"]}
|
||||
for choice in self.actions["google"].config["model"].choices
|
||||
]
|
||||
|
||||
@property
|
||||
def google_api_key(self) -> Optional[str]:
|
||||
return self.config.google.api_key
|
||||
|
||||
@property
|
||||
def google_agent_details(self) -> dict:
|
||||
details = {}
|
||||
|
||||
if not self.google_configured:
|
||||
details["google_api_key"] = AgentDetail(
|
||||
icon="mdi-key",
|
||||
value="Google API key not set",
|
||||
description="Google API key not set. You can set it in the Talemate Settings -> Application -> Google",
|
||||
color="error",
|
||||
).model_dump()
|
||||
else:
|
||||
details["google_model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.google_model,
|
||||
description="The model to use for Google",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
def _make_google_client(self) -> genai.Client:
|
||||
"""Return a fresh genai.Client so updated creds propagate immediately."""
|
||||
return genai.Client(api_key=self.google_api_key or None)
|
||||
|
||||
async def google_generate(
|
||||
self,
|
||||
chunk: Chunk,
|
||||
context: GenerationContext,
|
||||
chunk_size: int = 1024, # kept for signature parity
|
||||
) -> Union[bytes, None]:
|
||||
"""Generate audio and wrap raw PCM into a playable WAV container."""
|
||||
|
||||
voice_name = chunk.voice.provider_id
|
||||
client = self._make_google_client()
|
||||
|
||||
try:
|
||||
response = await client.aio.models.generate_content(
|
||||
model=chunk.model or self.google_model,
|
||||
contents=chunk.cleaned_text,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=voice_name,
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Extract raw 24 kHz 16‑bit PCM (mono) bytes from first candidate
|
||||
part = response.candidates[0].content.parts[0].inline_data
|
||||
if not part or not part.data:
|
||||
return None
|
||||
pcm_bytes: bytes = part.data
|
||||
|
||||
# Wrap into a WAV container that browsers can decode
|
||||
wav_io = io.BytesIO()
|
||||
with wave.open(wav_io, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2) # 16‑bit
|
||||
wf.setframerate(24000) # Hz
|
||||
wf.writeframes(pcm_bytes)
|
||||
return wav_io.getvalue()
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
log.error("google_generate failed", error=str(e))
|
||||
return None
|
||||
324
src/talemate/agents/tts/kokoro.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import os
|
||||
import functools
|
||||
import tempfile
|
||||
import uuid
|
||||
import asyncio
|
||||
import structlog
|
||||
import pydantic
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from kokoro import KPipeline
|
||||
|
||||
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
)
|
||||
from .schema import (
|
||||
Voice,
|
||||
Chunk,
|
||||
GenerationContext,
|
||||
VoiceMixer,
|
||||
VoiceProvider,
|
||||
INFO_CHUNK_SIZE,
|
||||
)
|
||||
from .providers import register
|
||||
from .voice_library import add_default_voices
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.kokoro")
|
||||
|
||||
CUSTOM_VOICE_STORAGE = (
|
||||
Path(__file__).parent.parent.parent.parent.parent / "tts" / "voice" / "kokoro"
|
||||
)
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Alloy", provider="kokoro", provider_id="af_alloy", tags=["female"]
|
||||
),
|
||||
Voice(
|
||||
label="Aoede", provider="kokoro", provider_id="af_aoede", tags=["female"]
|
||||
),
|
||||
Voice(
|
||||
label="Bella", provider="kokoro", provider_id="af_bella", tags=["female"]
|
||||
),
|
||||
Voice(
|
||||
label="Heart", provider="kokoro", provider_id="af_heart", tags=["female"]
|
||||
),
|
||||
Voice(
|
||||
label="Jessica",
|
||||
provider="kokoro",
|
||||
provider_id="af_jessica",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(label="Kore", provider="kokoro", provider_id="af_kore", tags=["female"]),
|
||||
Voice(
|
||||
label="Nicole", provider="kokoro", provider_id="af_nicole", tags=["female"]
|
||||
),
|
||||
Voice(label="Nova", provider="kokoro", provider_id="af_nova", tags=["female"]),
|
||||
Voice(
|
||||
label="River", provider="kokoro", provider_id="af_river", tags=["female"]
|
||||
),
|
||||
Voice(
|
||||
label="Sarah", provider="kokoro", provider_id="af_sarah", tags=["female"]
|
||||
),
|
||||
Voice(label="Sky", provider="kokoro", provider_id="af_sky", tags=["female"]),
|
||||
Voice(label="Adam", provider="kokoro", provider_id="am_adam", tags=["male"]),
|
||||
Voice(label="Echo", provider="kokoro", provider_id="am_echo", tags=["male"]),
|
||||
Voice(label="Eric", provider="kokoro", provider_id="am_eric", tags=["male"]),
|
||||
Voice(
|
||||
label="Fenrir", provider="kokoro", provider_id="am_fenrir", tags=["male"]
|
||||
),
|
||||
Voice(label="Liam", provider="kokoro", provider_id="am_liam", tags=["male"]),
|
||||
Voice(
|
||||
label="Michael", provider="kokoro", provider_id="am_michael", tags=["male"]
|
||||
),
|
||||
Voice(label="Onyx", provider="kokoro", provider_id="am_onyx", tags=["male"]),
|
||||
Voice(label="Puck", provider="kokoro", provider_id="am_puck", tags=["male"]),
|
||||
Voice(label="Santa", provider="kokoro", provider_id="am_santa", tags=["male"]),
|
||||
Voice(
|
||||
label="Alice", provider="kokoro", provider_id="bf_alice", tags=["female"]
|
||||
),
|
||||
Voice(label="Emma", provider="kokoro", provider_id="bf_emma", tags=["female"]),
|
||||
Voice(
|
||||
label="Isabella",
|
||||
provider="kokoro",
|
||||
provider_id="bf_isabella",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(label="Lily", provider="kokoro", provider_id="bf_lily", tags=["female"]),
|
||||
Voice(
|
||||
label="Daniel", provider="kokoro", provider_id="bm_daniel", tags=["male"]
|
||||
),
|
||||
Voice(label="Fable", provider="kokoro", provider_id="bm_fable", tags=["male"]),
|
||||
Voice(
|
||||
label="George", provider="kokoro", provider_id="bm_george", tags=["male"]
|
||||
),
|
||||
Voice(label="Lewis", provider="kokoro", provider_id="bm_lewis", tags=["male"]),
|
||||
]
|
||||
)
|
||||
|
||||
KOKORO_INFO = """
|
||||
Kokoro is a local text to speech model.
|
||||
|
||||
**WILL DOWNLOAD**: Voices will be downloaded on first use, so the first generation will take longer to complete.
|
||||
"""
|
||||
|
||||
|
||||
@register()
|
||||
class KokoroProvider(VoiceProvider):
|
||||
name: str = "kokoro"
|
||||
allow_model_override: bool = False
|
||||
|
||||
|
||||
class KokoroInstance(pydantic.BaseModel):
|
||||
pipeline: "KPipeline" # Forward reference for lazy loading
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class KokoroMixin:
|
||||
"""
|
||||
Kokoro agent mixin for local text to speech.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "kokoro",
|
||||
"label": "Kokoro (Local)",
|
||||
"help": "Kokoro is a local text to speech model.",
|
||||
}
|
||||
)
|
||||
|
||||
actions["kokoro"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="Kokoro",
|
||||
description="Kokoro is a local text to speech model.",
|
||||
config={
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=512,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
return actions
|
||||
|
||||
@property
|
||||
def kokoro_configured(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def kokoro_chunk_size(self) -> int:
|
||||
return self.actions["kokoro"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def kokoro_max_generation_length(self) -> int:
|
||||
return 256
|
||||
|
||||
@property
|
||||
def kokoro_agent_details(self) -> dict:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def kokoro_supports_mixing(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def kokoro_info(self) -> str:
|
||||
return KOKORO_INFO
|
||||
|
||||
def kokoro_delete_voice(self, voice_id: str) -> None:
|
||||
"""
|
||||
If the voice_id is a file in the CUSTOM_VOICE_STORAGE directory, delete it.
|
||||
"""
|
||||
|
||||
# if voice id is a deletable file it'll be a relative or absolute path
|
||||
# to a file in the CUSTOM_VOICE_STORAGE directory
|
||||
|
||||
# we must verify that it is in the CUSTOM_VOICE_STORAGE directory
|
||||
voice_path = Path(voice_id).resolve()
|
||||
log.debug(
|
||||
"Kokoro - Checking if voice id is deletable",
|
||||
voice_id=voice_id,
|
||||
exists=voice_path.exists(),
|
||||
parent=voice_path.parent,
|
||||
is_custom_voice_storage=voice_path.parent == CUSTOM_VOICE_STORAGE,
|
||||
)
|
||||
if voice_path.exists() and voice_path.parent == CUSTOM_VOICE_STORAGE:
|
||||
log.debug("Kokoro - Deleting voice file", voice_id=voice_id)
|
||||
try:
|
||||
voice_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _kokoro_mix(self, mixer: VoiceMixer) -> "torch.Tensor":
|
||||
pipeline = KPipeline(lang_code="a")
|
||||
|
||||
packs = [
|
||||
{
|
||||
"voice_tensor": pipeline.load_single_voice(voice.id),
|
||||
"weight": voice.weight,
|
||||
}
|
||||
for voice in mixer.voices
|
||||
]
|
||||
|
||||
mixed_voice = None
|
||||
for pack in packs:
|
||||
if mixed_voice is None:
|
||||
mixed_voice = pack["voice_tensor"] * pack["weight"]
|
||||
else:
|
||||
mixed_voice += pack["voice_tensor"] * pack["weight"]
|
||||
|
||||
# TODO: ensure weights sum to 1
|
||||
|
||||
return mixed_voice
|
||||
|
||||
async def kokoro_test_mix(self, mixer: VoiceMixer):
|
||||
"""Test a mixed voice by generating a sample."""
|
||||
mixed_voice_tensor = self._kokoro_mix(mixer)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
pipeline = KPipeline(lang_code="a")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._kokoro_generate,
|
||||
pipeline,
|
||||
"This is a test of the mixed voice.",
|
||||
mixed_voice_tensor,
|
||||
file_path,
|
||||
),
|
||||
)
|
||||
|
||||
# Read and play the audio
|
||||
with open(file_path, "rb") as f:
|
||||
audio_data = f.read()
|
||||
self.play_audio(audio_data)
|
||||
|
||||
async def kokoro_save_mix(self, voice_id: str, mixer: VoiceMixer) -> Path:
|
||||
"""Save a voice tensor to disk."""
|
||||
# Ensure the directory exists
|
||||
CUSTOM_VOICE_STORAGE.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_to_path = CUSTOM_VOICE_STORAGE / f"{voice_id}.pt"
|
||||
voice_tensor = self._kokoro_mix(mixer)
|
||||
torch.save(voice_tensor, save_to_path)
|
||||
return save_to_path
|
||||
|
||||
def _kokoro_generate(
|
||||
self,
|
||||
pipeline: "KPipeline",
|
||||
text: str,
|
||||
voice: "str | torch.Tensor",
|
||||
file_path: str,
|
||||
) -> None:
|
||||
"""Generate audio from text using the given voice."""
|
||||
try:
|
||||
generator = pipeline(text, voice=voice)
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
sf.write(file_path, audio, 24000)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def kokoro_generate(
|
||||
self, chunk: Chunk, context: GenerationContext
|
||||
) -> bytes | None:
|
||||
kokoro_instance = getattr(self, "kokoro_instance", None)
|
||||
|
||||
reload: bool = False
|
||||
|
||||
if not kokoro_instance:
|
||||
reload = True
|
||||
|
||||
if reload:
|
||||
log.debug(
|
||||
"kokoro - reinitializing tts instance",
|
||||
)
|
||||
# Lazy import heavy dependencies only when needed
|
||||
|
||||
self.kokoro_instance = KokoroInstance(
|
||||
# a= American English
|
||||
# TODO: allow config of language???
|
||||
pipeline=KPipeline(lang_code="a")
|
||||
)
|
||||
|
||||
pipeline = self.kokoro_instance.pipeline
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._kokoro_generate,
|
||||
pipeline,
|
||||
chunk.cleaned_text,
|
||||
chunk.voice.provider_id,
|
||||
file_path,
|
||||
),
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
165
src/talemate/agents/tts/nodes.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import structlog
|
||||
from typing import ClassVar
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
TYPE_CHOICES,
|
||||
UNRESOLVED,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
from talemate.agents.tts.schema import Voice, VoiceLibrary
|
||||
|
||||
TYPE_CHOICES.extend(
|
||||
[
|
||||
"tts/voice",
|
||||
]
|
||||
)
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.tts")
|
||||
|
||||
|
||||
@register("agents/tts/Settings")
|
||||
class TTSAgentSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render TTS agent settings.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "tts"
|
||||
|
||||
def __init__(self, title="TTS Agent Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
@register("agents/tts/GetVoice")
|
||||
class GetVoice(AgentNode):
|
||||
"""
|
||||
Gets a voice from the TTS agent.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "tts"
|
||||
|
||||
class Fields:
|
||||
voice_id = PropertyField(
|
||||
name="voice_id",
|
||||
type="str",
|
||||
description="The ID of the voice to get",
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
def __init__(self, title="Get Voice", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
@property
|
||||
def voice_library(self) -> VoiceLibrary:
|
||||
return self.agent.voice_library
|
||||
|
||||
def setup(self):
|
||||
self.add_input("voice_id", socket_type="str", optional=True)
|
||||
self.set_property("voice_id", UNRESOLVED)
|
||||
|
||||
self.add_output("voice", socket_type="tts/voice")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
voice_id = self.require_input("voice_id")
|
||||
|
||||
voice = self.voice_library.get_voice(voice_id)
|
||||
|
||||
self.set_output_values({"voice": voice})
|
||||
|
||||
|
||||
@register("agents/tts/GetNarratorVoice")
|
||||
class GetNarratorVoice(AgentNode):
|
||||
"""
|
||||
Gets the narrator voice from the TTS agent.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "tts"
|
||||
|
||||
def __init__(self, title="Get Narrator Voice", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_output("voice", socket_type="tts/voice")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
voice = self.agent.narrator_voice
|
||||
|
||||
self.set_output_values({"voice": voice})
|
||||
|
||||
|
||||
@register("agents/tts/UnpackVoice")
|
||||
class UnpackVoice(AgentNode):
|
||||
"""
|
||||
Unpacks a voice from the TTS agent.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "tts"
|
||||
|
||||
def __init__(self, title="Unpack Voice", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("voice", socket_type="tts/voice")
|
||||
self.add_output("voice", socket_type="tts/voice")
|
||||
self.add_output("label", socket_type="str")
|
||||
self.add_output("provider", socket_type="str")
|
||||
self.add_output("provider_id", socket_type="str")
|
||||
self.add_output("provider_model", socket_type="str")
|
||||
self.add_output("tags", socket_type="list")
|
||||
self.add_output("parameters", socket_type="dict")
|
||||
self.add_output("is_scene_asset", socket_type="bool")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
voice: Voice = self.require_input("voice")
|
||||
|
||||
self.set_output_values(
|
||||
{
|
||||
"voice": voice,
|
||||
**voice.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@register("agents/tts/Generate")
|
||||
class Generate(AgentNode):
|
||||
"""
|
||||
Generates a voice from the TTS agent.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "tts"
|
||||
|
||||
class Fields:
|
||||
text = PropertyField(
|
||||
name="text",
|
||||
type="text",
|
||||
description="The text to generate",
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
def __init__(self, title="Generate TTS", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("text", socket_type="text", optional=True)
|
||||
self.add_input("voice", socket_type="tts/voice", optional=True)
|
||||
self.add_input("character", socket_type="character", optional=True)
|
||||
self.set_property("text", UNRESOLVED)
|
||||
self.add_output("state")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
text = self.require_input("text")
|
||||
voice = self.normalized_input_value("voice")
|
||||
character = self.normalized_input_value("character")
|
||||
|
||||
if not voice and not character:
|
||||
raise ValueError("Either voice or character must be provided")
|
||||
|
||||
await self.agent.generate(
|
||||
text=text,
|
||||
character=character,
|
||||
force_voice=voice,
|
||||
)
|
||||
|
||||
self.set_output_values({"state": state})
|
||||
230
src/talemate/agents/tts/openai.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import io
|
||||
from typing import Union
|
||||
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from talemate.ux.schema import Action
|
||||
from talemate.agents.base import AgentAction, AgentActionConfig, AgentDetail
|
||||
from .schema import Voice, VoiceLibrary, Chunk, GenerationContext, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.openai")
|
||||
|
||||
OPENAI_INFO = """
|
||||
OpenAI TTS is a cloud-based text to speech model.
|
||||
|
||||
A list of available voices can be found at [https://platform.openai.com/docs/guides/text-to-speech#voice-options](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
|
||||
"""
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Alloy",
|
||||
provider="openai",
|
||||
provider_id="alloy",
|
||||
tags=["neutral", "female"],
|
||||
),
|
||||
Voice(
|
||||
label="Ash",
|
||||
provider="openai",
|
||||
provider_id="ash",
|
||||
tags=["male"],
|
||||
),
|
||||
Voice(
|
||||
label="Ballad",
|
||||
provider="openai",
|
||||
provider_id="ballad",
|
||||
tags=["male", "energetic"],
|
||||
),
|
||||
Voice(
|
||||
label="Coral",
|
||||
provider="openai",
|
||||
provider_id="coral",
|
||||
tags=["female", "energetic"],
|
||||
),
|
||||
Voice(
|
||||
label="Echo",
|
||||
provider="openai",
|
||||
provider_id="echo",
|
||||
tags=["male", "neutral"],
|
||||
),
|
||||
Voice(
|
||||
label="Fable",
|
||||
provider="openai",
|
||||
provider_id="fable",
|
||||
tags=["neutral", "feminine"],
|
||||
),
|
||||
Voice(
|
||||
label="Onyx",
|
||||
provider="openai",
|
||||
provider_id="onyx",
|
||||
tags=["male"],
|
||||
),
|
||||
Voice(
|
||||
label="Nova",
|
||||
provider="openai",
|
||||
provider_id="nova",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(
|
||||
label="Sage",
|
||||
provider="openai",
|
||||
provider_id="sage",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(
|
||||
label="Shimmer",
|
||||
provider="openai",
|
||||
provider_id="shimmer",
|
||||
tags=["female"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class OpenAIMixin:
|
||||
"""
|
||||
OpenAI TTS agent mixin for cloud-based text to speech.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "openai",
|
||||
"label": "OpenAI",
|
||||
"help": "OpenAI is a cloud-based text to speech model that uses the OpenAI API. (API key required)",
|
||||
}
|
||||
)
|
||||
|
||||
actions["openai"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="OpenAI",
|
||||
description="OpenAI TTS is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> OpenAI)",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="gpt-4o-mini-tts",
|
||||
choices=[
|
||||
{"value": "gpt-4o-mini-tts", "label": "GPT-4o Mini TTS"},
|
||||
{"value": "tts-1", "label": "TTS 1"},
|
||||
{"value": "tts-1-hd", "label": "TTS 1 HD"},
|
||||
],
|
||||
label="Model",
|
||||
description="TTS model to use",
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=512,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
@classmethod
|
||||
def add_voices(cls, voices: dict[str, VoiceLibrary]):
|
||||
voices["openai"] = VoiceLibrary(api="openai")
|
||||
|
||||
@property
|
||||
def openai_chunk_size(self) -> int:
|
||||
return self.actions["openai"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def openai_max_generation_length(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def openai_model(self) -> str:
|
||||
return self.actions["openai"].config["model"].value
|
||||
|
||||
@property
|
||||
def openai_model_choices(self) -> list[str]:
|
||||
return [
|
||||
{"label": choice["label"], "value": choice["value"]}
|
||||
for choice in self.actions["openai"].config["model"].choices
|
||||
]
|
||||
|
||||
@property
|
||||
def openai_api_key(self) -> str:
|
||||
return self.config.openai.api_key
|
||||
|
||||
@property
|
||||
def openai_configured(self) -> bool:
|
||||
return bool(self.openai_api_key) and bool(self.openai_model)
|
||||
|
||||
@property
|
||||
def openai_info(self) -> str:
|
||||
return OPENAI_INFO
|
||||
|
||||
@property
|
||||
def openai_not_configured_reason(self) -> str | None:
|
||||
if not self.openai_api_key:
|
||||
return "OpenAI API key not set"
|
||||
if not self.openai_model:
|
||||
return "OpenAI model not set"
|
||||
return None
|
||||
|
||||
@property
|
||||
def openai_not_configured_action(self) -> Action | None:
|
||||
if not self.openai_api_key:
|
||||
return Action(
|
||||
action_name="openAppConfig",
|
||||
arguments=["application", "openai_api"],
|
||||
label="Set API Key",
|
||||
icon="mdi-key",
|
||||
)
|
||||
if not self.openai_model:
|
||||
return Action(
|
||||
action_name="openAgentSettings",
|
||||
arguments=["tts", "openai"],
|
||||
label="Set Model",
|
||||
icon="mdi-brain",
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def openai_agent_details(self) -> dict:
|
||||
details = {}
|
||||
|
||||
if not self.openai_configured:
|
||||
details["openai_api_key"] = AgentDetail(
|
||||
icon="mdi-key",
|
||||
value="OpenAI API key not set",
|
||||
description="OpenAI API key not set. You can set it in the Talemate Settings -> Application -> OpenAI",
|
||||
color="error",
|
||||
).model_dump()
|
||||
else:
|
||||
details["openai_model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.openai_model,
|
||||
description="The model to use for OpenAI",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
async def openai_generate(
|
||||
self, chunk: Chunk, context: GenerationContext, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
model = chunk.model or self.openai_model
|
||||
|
||||
response = await client.audio.speech.create(
|
||||
model=model, voice=chunk.voice.provider_id, input=chunk.cleaned_text
|
||||
)
|
||||
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
24
src/talemate/agents/tts/providers.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from .schema import VoiceProvider
|
||||
from typing import Generator
|
||||
|
||||
__all__ = ["register", "provider", "providers"]
|
||||
|
||||
PROVIDERS = {}
|
||||
|
||||
|
||||
class register:
|
||||
def __call__(self, cls: type[VoiceProvider]):
|
||||
PROVIDERS[cls().name] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def provider(name: str) -> VoiceProvider:
|
||||
cls = PROVIDERS.get(name)
|
||||
if not cls:
|
||||
return VoiceProvider(name=name)
|
||||
return cls()
|
||||
|
||||
|
||||
def providers() -> Generator[VoiceProvider, None, None]:
|
||||
for cls in PROVIDERS.values():
|
||||
yield cls()
|
||||
201
src/talemate/agents/tts/schema.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import pydantic
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Callable, Literal
|
||||
|
||||
from talemate.ux.schema import Note, Field
|
||||
from talemate.path import TALEMATE_ROOT
|
||||
|
||||
__all__ = [
|
||||
"APIStatus",
|
||||
"Chunk",
|
||||
"GenerationContext",
|
||||
"VoiceProvider",
|
||||
"Voice",
|
||||
"VoiceLibrary",
|
||||
"VoiceWeight",
|
||||
"VoiceMixer",
|
||||
"VoiceGenerationEmission",
|
||||
"INFO_CHUNK_SIZE",
|
||||
]
|
||||
|
||||
|
||||
MAX_TAG_LENGTH: int = 64 # Maximum number of characters per tag (configurable)
|
||||
MAX_TAGS_PER_VOICE: int = 10 # Maximum number of tags per voice (configurable)
|
||||
|
||||
DEFAULT_VOICE_DIR = TALEMATE_ROOT / "tts" / "voice"
|
||||
|
||||
INFO_CHUNK_SIZE = "Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking."
|
||||
|
||||
|
||||
class VoiceProvider(pydantic.BaseModel):
|
||||
name: str
|
||||
voice_parameters: list[Field] = pydantic.Field(default_factory=list)
|
||||
allow_model_override: bool = True
|
||||
allow_file_upload: bool = False
|
||||
upload_file_types: list[str] | None = None
|
||||
|
||||
@property
|
||||
def default_parameters(self) -> dict[str, str | float | int | bool]:
|
||||
return {param.name: param.value for param in self.voice_parameters}
|
||||
|
||||
@property
|
||||
def default_voice_dir(self) -> Path:
|
||||
return DEFAULT_VOICE_DIR / self.name
|
||||
|
||||
def voice_parameter(
|
||||
self, voice: "Voice", name: str
|
||||
) -> str | float | int | bool | None:
|
||||
"""
|
||||
Get a parameter from the voice.
|
||||
If the parameter is not set, return the default parameter from the provider.
|
||||
"""
|
||||
if name in voice.parameters:
|
||||
return voice.parameters[name]
|
||||
return self.default_parameters.get(name)
|
||||
|
||||
|
||||
class VoiceWeight(pydantic.BaseModel):
|
||||
id: str
|
||||
weight: float
|
||||
|
||||
|
||||
class VoiceMixer(pydantic.BaseModel):
|
||||
voices: list[VoiceWeight]
|
||||
|
||||
|
||||
class Voice(pydantic.BaseModel):
|
||||
# arbitrary voice label to allow a human to easily identify the voice
|
||||
label: str
|
||||
|
||||
# voice provider, this would be the TTS api in the voice
|
||||
provider: str
|
||||
|
||||
# voice id as known to the voice provider
|
||||
provider_id: str
|
||||
|
||||
# allows to also override to a specific model
|
||||
provider_model: str | None = None
|
||||
|
||||
# free-form tags for categorizing the voice (e.g. "male", "energetic")
|
||||
tags: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# provider specific parameters for the voice
|
||||
parameters: dict[str, str | float | int | bool] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
is_scene_asset: bool = False
|
||||
|
||||
@pydantic.field_validator("tags")
|
||||
@classmethod
|
||||
def _validate_tags(cls, v: list[str]):
|
||||
"""Validate tag list length and individual tag length."""
|
||||
if len(v) > MAX_TAGS_PER_VOICE:
|
||||
raise ValueError(
|
||||
f"Too many tags – maximum {MAX_TAGS_PER_VOICE} tags are allowed per voice"
|
||||
)
|
||||
for tag in v:
|
||||
if len(tag) > MAX_TAG_LENGTH:
|
||||
raise ValueError(
|
||||
f"Tag '{tag}' exceeds maximum length of {MAX_TAG_LENGTH} characters"
|
||||
)
|
||||
return v
|
||||
|
||||
model_config = pydantic.ConfigDict(validate_assignment=True, exclude_none=True)
|
||||
|
||||
@pydantic.computed_field(description="The unique identifier for the voice")
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return f"{self.provider}:{self.provider_id}"
|
||||
|
||||
|
||||
class VoiceLibrary(pydantic.BaseModel):
|
||||
version: int = 1
|
||||
voices: dict[str, Voice] = pydantic.Field(default_factory=dict)
|
||||
|
||||
def get_voice(self, voice_id: str) -> Voice | None:
|
||||
return self.voices.get(voice_id)
|
||||
|
||||
|
||||
class Chunk(pydantic.BaseModel):
|
||||
text: list[str] = pydantic.Field(default_factory=list)
|
||||
type: Literal["dialogue", "exposition"]
|
||||
character_name: str | None = None
|
||||
api: str | None = None
|
||||
voice: Voice | None = None
|
||||
model: str | None = None
|
||||
generate_fn: Callable | None = None
|
||||
prepare_fn: Callable | None = None
|
||||
message_id: int | None = None
|
||||
|
||||
@property
|
||||
def cleaned_text(self) -> str:
|
||||
cleaned: str = self.text[0].replace("*", "").replace('"', "").replace("`", "")
|
||||
|
||||
# troublemakers
|
||||
cleaned = cleaned.replace("—", " - ").replace("…", "...").replace(";", ",")
|
||||
|
||||
# replace any grouped up whitespace with a single space
|
||||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||
|
||||
# replace full uppercase word with lowercase
|
||||
# e.g. "HELLO" -> "hello"
|
||||
cleaned = re.sub(r"[A-Z]{2,}", lambda m: m.group(0).lower(), cleaned)
|
||||
|
||||
cleaned = cleaned.strip(",").strip()
|
||||
|
||||
# If there is no commong sentence ending punctuation, add a period
|
||||
if len(cleaned) > 0 and cleaned[-1] not in [".", "!", "?"]:
|
||||
cleaned += "."
|
||||
|
||||
return cleaned.strip().strip(",").strip()
|
||||
|
||||
@property
|
||||
def sub_chunks(self) -> list["Chunk"]:
|
||||
if len(self.text) == 1:
|
||||
return [self]
|
||||
|
||||
return [
|
||||
Chunk(
|
||||
text=[text],
|
||||
type=self.type,
|
||||
character_name=self.character_name,
|
||||
api=self.api,
|
||||
voice=Voice(**self.voice.model_dump()),
|
||||
model=self.model,
|
||||
generate_fn=self.generate_fn,
|
||||
prepare_fn=self.prepare_fn,
|
||||
)
|
||||
for text in self.text
|
||||
]
|
||||
|
||||
|
||||
class GenerationContext(pydantic.BaseModel):
|
||||
chunks: list[Chunk] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class VoiceGenerationEmission(pydantic.BaseModel):
|
||||
chunk: Chunk
|
||||
context: GenerationContext
|
||||
wav_bytes: bytes | None = None
|
||||
|
||||
|
||||
class ModelChoice(pydantic.BaseModel):
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
class APIStatus(pydantic.BaseModel):
|
||||
"""Status of an API."""
|
||||
|
||||
api: str
|
||||
enabled: bool
|
||||
ready: bool
|
||||
configured: bool
|
||||
provider: VoiceProvider
|
||||
messages: list[Note] = pydantic.Field(default_factory=list)
|
||||
supports_mixing: bool = False
|
||||
|
||||
default_model: str | None = None
|
||||
model_choices: list[ModelChoice] = pydantic.Field(default_factory=list)
|
||||
111
src/talemate/agents/tts/util.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
|
||||
from .schema import TALEMATE_ROOT, Voice, VoiceProvider
|
||||
|
||||
from .voice_library import get_instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.util")
|
||||
|
||||
__all__ = [
|
||||
"voice_parameter",
|
||||
"voice_is_talemate_asset",
|
||||
"voice_is_scene_asset",
|
||||
"get_voice",
|
||||
]
|
||||
|
||||
|
||||
def voice_parameter(
|
||||
voice: Voice, provider: VoiceProvider, name: str
|
||||
) -> str | float | int | bool | None:
|
||||
"""
|
||||
Get a parameter from the voice.
|
||||
"""
|
||||
if name in voice.parameters:
|
||||
return voice.parameters[name]
|
||||
return provider.default_parameters.get(name)
|
||||
|
||||
|
||||
def voice_is_talemate_asset(
|
||||
voice: Voice, provider: VoiceProvider
|
||||
) -> tuple[bool, Path | None]:
|
||||
"""
|
||||
Check if the voice is a Talemate asset.
|
||||
"""
|
||||
|
||||
if not provider.allow_file_upload:
|
||||
return False, None
|
||||
|
||||
path = Path(voice.provider_id)
|
||||
if not path.is_absolute():
|
||||
path = TALEMATE_ROOT / path
|
||||
try:
|
||||
resolved = path.resolve(strict=False)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"voice_is_talemate_asset - invalid path",
|
||||
error=e,
|
||||
voice_id=voice.provider_id,
|
||||
)
|
||||
return False, None
|
||||
|
||||
root = TALEMATE_ROOT.resolve()
|
||||
log.debug(
|
||||
"voice_is_talemate_asset - resolved", resolved=str(resolved), root=str(root)
|
||||
)
|
||||
if not str(resolved).startswith(str(root)):
|
||||
return False, None
|
||||
|
||||
return True, resolved
|
||||
|
||||
|
||||
def voice_is_scene_asset(voice: Voice, provider: VoiceProvider) -> bool:
|
||||
"""
|
||||
Check if the voice is a scene asset.
|
||||
|
||||
Scene assets are stored in the the scene's assets directory.
|
||||
|
||||
This function does NOT check .is_scene_asset but does path resolution to
|
||||
determine if the voice is a scene asset.
|
||||
"""
|
||||
|
||||
is_talemate_asset, resolved = voice_is_talemate_asset(voice, provider)
|
||||
if not is_talemate_asset:
|
||||
return False
|
||||
|
||||
SCENES_DIR = TALEMATE_ROOT / "scenes"
|
||||
|
||||
if str(resolved).startswith(str(SCENES_DIR.resolve())):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_voice(scene: "Scene", voice_id: str) -> Voice | None:
|
||||
"""Return a Voice by *voice_id* preferring the scene's library (if any).
|
||||
|
||||
Args:
|
||||
scene: Scene instance or ``None``.
|
||||
voice_id: The fully-qualified voice identifier (``provider:provider_id``).
|
||||
|
||||
The function first checks *scene.voice_library* (if present) and falls back
|
||||
to the global voice library instance.
|
||||
"""
|
||||
|
||||
try:
|
||||
if scene and getattr(scene, "voice_library", None):
|
||||
voice = scene.voice_library.get_voice(voice_id)
|
||||
if voice:
|
||||
return voice
|
||||
except Exception as e:
|
||||
log.error("get_voice - scene lookup failed", error=e)
|
||||
|
||||
try:
|
||||
return get_instance().get_voice(voice_id)
|
||||
except Exception as e:
|
||||
log.error("get_voice - global lookup failed", error=e)
|
||||
return None
|
||||
169
src/talemate/agents/tts/voice_library.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import structlog
|
||||
from pathlib import Path #
|
||||
import pydantic
|
||||
|
||||
import talemate.emit.async_signals as async_signals
|
||||
|
||||
from .schema import VoiceLibrary, Voice
|
||||
from typing import TYPE_CHECKING, Callable, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
__all__ = [
|
||||
"load_voice_library",
|
||||
"save_voice_library",
|
||||
"get_instance",
|
||||
"add_default_voices",
|
||||
"DEFAULT_VOICES",
|
||||
"VOICE_LIBRARY_PATH",
|
||||
"require_instance",
|
||||
"load_scene_voice_library",
|
||||
"save_scene_voice_library",
|
||||
"scoped_voice_library",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.voice_library")
|
||||
|
||||
async_signals.register(
|
||||
"voice_library.update.before",
|
||||
"voice_library.update.after",
|
||||
)
|
||||
|
||||
VOICE_LIBRARY_PATH = (
|
||||
Path(__file__).parent.parent.parent.parent.parent
|
||||
/ "tts"
|
||||
/ "voice"
|
||||
/ "voice-library.json"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_VOICES = {}
|
||||
|
||||
# TODO: does this need to be made thread safe?
|
||||
VOICE_LIBRARY = None
|
||||
|
||||
|
||||
class ScopedVoiceLibrary(pydantic.BaseModel):
|
||||
voice_library: VoiceLibrary
|
||||
fn_save: Callable[[VoiceLibrary], None]
|
||||
|
||||
async def save(self):
|
||||
await self.fn_save(self.voice_library)
|
||||
|
||||
|
||||
def scoped_voice_library(
|
||||
scope: Literal["global", "scene"], scene: "Scene | None" = None
|
||||
) -> ScopedVoiceLibrary:
|
||||
if scope == "global":
|
||||
return ScopedVoiceLibrary(
|
||||
voice_library=get_instance(), fn_save=save_voice_library
|
||||
)
|
||||
else:
|
||||
if not scene:
|
||||
raise ValueError("Scene is required for scoped voice library")
|
||||
|
||||
async def _save(library: VoiceLibrary):
|
||||
await save_scene_voice_library(scene, library)
|
||||
|
||||
return ScopedVoiceLibrary(voice_library=scene.voice_library, fn_save=_save)
|
||||
|
||||
|
||||
async def require_instance():
|
||||
global VOICE_LIBRARY
|
||||
if not VOICE_LIBRARY:
|
||||
VOICE_LIBRARY = await load_voice_library()
|
||||
return VOICE_LIBRARY
|
||||
|
||||
|
||||
async def load_voice_library() -> VoiceLibrary:
|
||||
"""
|
||||
Load the voice library from the file.
|
||||
"""
|
||||
try:
|
||||
with open(VOICE_LIBRARY_PATH, "r") as f:
|
||||
return VoiceLibrary.model_validate_json(f.read())
|
||||
except FileNotFoundError:
|
||||
library = VoiceLibrary(voices=DEFAULT_VOICES)
|
||||
await save_voice_library(library)
|
||||
return library
|
||||
finally:
|
||||
log.debug("loaded voice library", path=str(VOICE_LIBRARY_PATH))
|
||||
|
||||
|
||||
async def save_voice_library(voice_library: VoiceLibrary):
|
||||
"""
|
||||
Save the voice library to the file.
|
||||
"""
|
||||
await async_signals.get("voice_library.update.before").send(voice_library)
|
||||
with open(VOICE_LIBRARY_PATH, "w") as f:
|
||||
f.write(voice_library.model_dump_json(indent=2))
|
||||
await async_signals.get("voice_library.update.after").send(voice_library)
|
||||
|
||||
|
||||
def get_instance() -> VoiceLibrary:
|
||||
"""
|
||||
Get the shared voice library instance.
|
||||
"""
|
||||
if not VOICE_LIBRARY:
|
||||
raise RuntimeError("Voice library not loaded yet.")
|
||||
return VOICE_LIBRARY
|
||||
|
||||
|
||||
def add_default_voices(voices: list[Voice]):
|
||||
"""
|
||||
Add default voices to the voice library.
|
||||
"""
|
||||
global DEFAULT_VOICES
|
||||
for voice in voices:
|
||||
DEFAULT_VOICES[voice.id] = voice
|
||||
|
||||
|
||||
def voices_for_apis(apis: list[str], voice_library: VoiceLibrary) -> list[Voice]:
|
||||
"""
|
||||
Get the voices for the given apis.
|
||||
"""
|
||||
return [voice for voice in voice_library.voices.values() if voice.provider in apis]
|
||||
|
||||
|
||||
def _scene_library_path(scene: "Scene") -> Path:
|
||||
"""Return the path to the *scene* voice-library.json file."""
|
||||
|
||||
return Path(scene.info_dir) / "voice-library.json"
|
||||
|
||||
|
||||
async def load_scene_voice_library(scene: "Scene") -> VoiceLibrary:
|
||||
"""Load and return the voice library for *scene*.
|
||||
|
||||
If the file does not exist an empty ``VoiceLibrary`` instance is returned.
|
||||
The returned instance is *not* stored on the scene – caller decides.
|
||||
"""
|
||||
|
||||
path = _scene_library_path(scene)
|
||||
|
||||
try:
|
||||
if path.exists():
|
||||
with open(path, "r") as f:
|
||||
library = VoiceLibrary.model_validate_json(f.read())
|
||||
else:
|
||||
library = VoiceLibrary()
|
||||
except Exception as e:
|
||||
log.error("load_scene_voice_library", error=e, path=str(path))
|
||||
library = VoiceLibrary()
|
||||
|
||||
return library
|
||||
|
||||
|
||||
async def save_scene_voice_library(scene: "Scene", library: VoiceLibrary):
|
||||
"""Persist *library* to the scene's ``voice-library.json``.
|
||||
|
||||
The directory ``scene/{name}/info`` is created if necessary.
|
||||
"""
|
||||
|
||||
path = _scene_library_path(scene)
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
f.write(library.model_dump_json(indent=2))
|
||||
except Exception as e:
|
||||
log.error("save_scene_voice_library", error=e, path=str(path))
|
||||
674
src/talemate/agents/tts/websocket_handler.py
Normal file
@@ -0,0 +1,674 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
|
||||
import talemate.scene_message as scene_message
|
||||
|
||||
from .voice_library import (
|
||||
get_instance as get_voice_library,
|
||||
save_voice_library,
|
||||
scoped_voice_library,
|
||||
ScopedVoiceLibrary,
|
||||
)
|
||||
from .schema import (
|
||||
Voice,
|
||||
GenerationContext,
|
||||
Chunk,
|
||||
APIStatus,
|
||||
VoiceMixer,
|
||||
VoiceWeight,
|
||||
TALEMATE_ROOT,
|
||||
VoiceLibrary,
|
||||
)
|
||||
|
||||
from .util import voice_is_scene_asset
|
||||
from .providers import provider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.agents.tts import TTSAgent
|
||||
from talemate.tale_mate import Scene
|
||||
from talemate.character import Character
|
||||
|
||||
__all__ = [
|
||||
"TTSWebsocketHandler",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.server.voice_library")
|
||||
|
||||
|
||||
class EditVoicePayload(pydantic.BaseModel):
|
||||
"""Payload for editing an existing voice. Only specified fields are updated."""
|
||||
|
||||
voice_id: str
|
||||
scope: Literal["global", "scene"]
|
||||
|
||||
label: str
|
||||
provider: str
|
||||
provider_id: str
|
||||
provider_model: str | None = None
|
||||
tags: list[str] = pydantic.Field(default_factory=list)
|
||||
parameters: dict[str, int | float | str | bool] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class VoiceRefPayload(pydantic.BaseModel):
|
||||
"""Payload referencing an existing voice by its id (used for remove / test)."""
|
||||
|
||||
voice_id: str
|
||||
scope: Literal["global", "scene"]
|
||||
|
||||
|
||||
class TestVoicePayload(pydantic.BaseModel):
|
||||
"""Payload for testing a voice."""
|
||||
|
||||
provider: str
|
||||
provider_id: str
|
||||
provider_model: str | None = None
|
||||
text: str | None = None
|
||||
parameters: dict[str, int | float | str | bool] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class TestCharacterVoicePayload(pydantic.BaseModel):
|
||||
"""Payload for testing a character voice."""
|
||||
|
||||
character_name: str
|
||||
text: str | None = None
|
||||
|
||||
|
||||
class AddVoicePayload(Voice):
|
||||
"""Explicit payload for adding a new voice - identical fields to Voice."""
|
||||
|
||||
scope: Literal["global", "scene"]
|
||||
|
||||
|
||||
class TestMixedVoicePayload(pydantic.BaseModel):
|
||||
"""Payload for testing a mixed voice."""
|
||||
|
||||
provider: str
|
||||
voices: list[VoiceWeight]
|
||||
|
||||
|
||||
class SaveMixedVoicePayload(pydantic.BaseModel):
|
||||
"""Payload for saving a mixed voice."""
|
||||
|
||||
provider: str
|
||||
label: str
|
||||
voices: list[VoiceWeight]
|
||||
tags: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class UploadVoiceFilePayload(pydantic.BaseModel):
|
||||
"""Payload for uploading a new voice file for providers that support it."""
|
||||
|
||||
provider: str
|
||||
label: str
|
||||
content: str # Base64 data URL (e.g. data:audio/wav;base64,AAAB...)
|
||||
as_scene_asset: bool = False
|
||||
|
||||
@pydantic.field_validator("content")
|
||||
@classmethod
|
||||
def _validate_content(cls, v: str):
|
||||
if not v.startswith("data:") or ";base64," not in v:
|
||||
raise ValueError("Content must be a base64 data URL")
|
||||
return v
|
||||
|
||||
|
||||
class GenerateForSceneMessagePayload(pydantic.BaseModel):
|
||||
"""Payload for generating a voice for a scene message."""
|
||||
|
||||
message_id: int | Literal["intro"]
|
||||
|
||||
|
||||
class TTSWebsocketHandler(Plugin):
|
||||
"""Websocket plugin to manage the TTS voice library."""
|
||||
|
||||
router = "tts"
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
super().__init__(websocket_handler)
|
||||
# Immediately send current voice list to the frontend
|
||||
asyncio.create_task(self._send_voice_list())
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Events
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
def connect(self):
|
||||
# needs to be after config is saved so the TTS agent has already
|
||||
# refreshed to the latest config
|
||||
async_signals.get("config.changed.follow").connect(
|
||||
self.on_app_config_change_followup
|
||||
)
|
||||
|
||||
async def on_app_config_change_followup(self, event):
|
||||
self._send_api_status()
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Helper methods
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def _send_voice_list(self, select_voice_id: str | None = None):
|
||||
# global voice library
|
||||
voice_library = get_voice_library()
|
||||
voices = [v.model_dump() for v in voice_library.voices.values()]
|
||||
voices.sort(key=lambda x: x["label"])
|
||||
|
||||
# scene voice library
|
||||
if self.scene:
|
||||
scene_voice_library = self.scene.voice_library
|
||||
scene_voices = [v.model_dump() for v in scene_voice_library.voices.values()]
|
||||
scene_voices.sort(key=lambda x: x["label"])
|
||||
else:
|
||||
scene_voices = []
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "voices",
|
||||
"voices": voices,
|
||||
"scene_voices": scene_voices,
|
||||
"select_voice_id": select_voice_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _voice_exists(self, voice_library: VoiceLibrary, voice_id: str) -> bool:
|
||||
return voice_id in voice_library.voices
|
||||
|
||||
def _broadcast_update(self, select_voice_id: str | None = None):
|
||||
# After any mutation we broadcast the full list for simplicity
|
||||
asyncio.create_task(self._send_voice_list(select_voice_id))
|
||||
|
||||
def _send_api_status(self):
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
api_status: list[APIStatus] = tts_agent.api_status
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "api_status",
|
||||
"api_status": [s.model_dump() for s in api_status],
|
||||
}
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Handlers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def handle_list(self, data: dict):
|
||||
await self._send_voice_list()
|
||||
|
||||
async def handle_api_status(self, data: dict):
|
||||
self._send_api_status()
|
||||
|
||||
async def handle_add(self, data: dict):
|
||||
try:
|
||||
voice = AddVoicePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
if voice.scope == "scene" and not self.scene:
|
||||
await self.signal_operation_failed("No scene active")
|
||||
return
|
||||
|
||||
scoped: ScopedVoiceLibrary = scoped_voice_library(voice.scope, self.scene)
|
||||
voice.is_scene_asset = voice.scope == "scene"
|
||||
|
||||
if self._voice_exists(scoped.voice_library, voice.id):
|
||||
await self.signal_operation_failed("Voice already exists")
|
||||
return
|
||||
|
||||
scoped.voice_library.voices[voice.id] = voice
|
||||
|
||||
await scoped.save()
|
||||
|
||||
self._broadcast_update()
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_remove(self, data: dict):
|
||||
try:
|
||||
payload = VoiceRefPayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
|
||||
scoped: ScopedVoiceLibrary = scoped_voice_library(payload.scope, self.scene)
|
||||
|
||||
log.debug("Removing voice", voice_id=payload.voice_id, scope=payload.scope)
|
||||
|
||||
try:
|
||||
voice = scoped.voice_library.voices.pop(payload.voice_id)
|
||||
except KeyError:
|
||||
await self.signal_operation_failed("Voice not found (1)")
|
||||
return
|
||||
|
||||
provider = voice.provider
|
||||
# check if porivder has a delete method
|
||||
delete_method = getattr(tts_agent, f"{provider}_delete_voice", None)
|
||||
if delete_method:
|
||||
delete_method(voice)
|
||||
|
||||
await scoped.save()
|
||||
self._broadcast_update()
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_edit(self, data: dict):
|
||||
try:
|
||||
payload = EditVoicePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
scoped: ScopedVoiceLibrary = scoped_voice_library(payload.scope, self.scene)
|
||||
voice = scoped.voice_library.voices.get(payload.voice_id)
|
||||
if not voice:
|
||||
await self.signal_operation_failed("Voice not found")
|
||||
return
|
||||
|
||||
# all fields are always provided
|
||||
voice.label = payload.label
|
||||
voice.provider = payload.provider
|
||||
voice.provider_id = payload.provider_id
|
||||
voice.provider_model = payload.provider_model
|
||||
voice.tags = payload.tags
|
||||
voice.parameters = payload.parameters
|
||||
voice.is_scene_asset = voice_is_scene_asset(voice, provider(voice.provider))
|
||||
|
||||
# If provider or provider_id changed, id changes -> reinsert
|
||||
new_id = voice.id
|
||||
if new_id != payload.voice_id:
|
||||
# Remove old key, insert new
|
||||
del scoped.voice_library.voices[payload.voice_id]
|
||||
scoped.voice_library.voices[new_id] = voice
|
||||
|
||||
await scoped.save()
|
||||
self._broadcast_update()
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_test(self, data: dict):
|
||||
"""Handle a request to test a voice.
|
||||
|
||||
Supports two payload formats:
|
||||
|
||||
1. Existing voice - identified by ``voice_id`` (legacy behaviour)
|
||||
2. Unsaved voice - identified by at least ``provider`` and ``provider_id``.
|
||||
"""
|
||||
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
|
||||
try:
|
||||
payload = TestVoicePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
voice = Voice(
|
||||
label=f"{payload.provider_id} (test)",
|
||||
provider=payload.provider,
|
||||
provider_id=payload.provider_id,
|
||||
provider_model=payload.provider_model,
|
||||
parameters=payload.parameters,
|
||||
)
|
||||
|
||||
if not tts_agent or not tts_agent.api_ready(voice.provider):
|
||||
await self.signal_operation_failed(f"API '{voice.provider}' not ready")
|
||||
return
|
||||
|
||||
generate_fn = getattr(tts_agent, f"{voice.provider}_generate", None)
|
||||
if not generate_fn:
|
||||
await self.signal_operation_failed("Provider not supported by TTS agent")
|
||||
return
|
||||
|
||||
prepare_fn = getattr(tts_agent, f"{voice.provider}_prepare_chunk", None)
|
||||
|
||||
# Use provided text or default
|
||||
test_text = payload.text or "This is a test of the selected voice."
|
||||
|
||||
# Build minimal generation context
|
||||
context = GenerationContext()
|
||||
chunk = Chunk(
|
||||
text=[test_text],
|
||||
type="dialogue",
|
||||
api=voice.provider,
|
||||
voice=voice,
|
||||
model=voice.provider_model,
|
||||
generate_fn=generate_fn,
|
||||
prepare_fn=prepare_fn,
|
||||
character_name=None,
|
||||
)
|
||||
context.chunks.append(chunk)
|
||||
|
||||
# Run generation in background so we don't block the event loop
|
||||
async def _run_test():
|
||||
try:
|
||||
await tts_agent.generate_chunks(context)
|
||||
finally:
|
||||
await self.signal_operation_done(signal_only=True)
|
||||
|
||||
asyncio.create_task(_run_test())
|
||||
|
||||
async def handle_test_character_voice(self, data: dict):
|
||||
"""Handle a request to test a character voice."""
|
||||
|
||||
try:
|
||||
payload = TestCharacterVoicePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
character = self.scene.get_character(payload.character_name)
|
||||
if not character:
|
||||
await self.signal_operation_failed("Character not found")
|
||||
return
|
||||
|
||||
if not character.voice:
|
||||
await self.signal_operation_failed("Character has no voice")
|
||||
return
|
||||
|
||||
text: str = payload.text or "This is a test of the selected voice."
|
||||
|
||||
await self.handle_test(
|
||||
{
|
||||
"provider": character.voice.provider,
|
||||
"provider_id": character.voice.provider_id,
|
||||
"provider_model": character.voice.provider_model,
|
||||
"parameters": character.voice.parameters,
|
||||
"text": text,
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_test_mixed(self, data: dict):
|
||||
"""Handle a request to test a mixed voice."""
|
||||
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
|
||||
try:
|
||||
payload = TestMixedVoicePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
# Validate that weights sum to 1.0
|
||||
total_weight = sum(v.weight for v in payload.voices)
|
||||
if abs(total_weight - 1.0) > 0.001:
|
||||
await self.signal_operation_failed(
|
||||
f"Weights must sum to 1.0, got {total_weight}"
|
||||
)
|
||||
return
|
||||
|
||||
if not tts_agent or not tts_agent.api_ready(payload.provider):
|
||||
await self.signal_operation_failed(f"{payload.provider} API not ready")
|
||||
return
|
||||
|
||||
# Build mixer
|
||||
mixer = VoiceMixer(voices=payload.voices)
|
||||
|
||||
# Run test in background using the appropriate provider's test method
|
||||
test_method = getattr(tts_agent, f"{payload.provider}_test_mix", None)
|
||||
if not test_method:
|
||||
await self.signal_operation_failed(
|
||||
f"{payload.provider} does not implement voice mixing"
|
||||
)
|
||||
return
|
||||
|
||||
async def _run_test():
|
||||
try:
|
||||
await test_method(mixer)
|
||||
finally:
|
||||
await self.signal_operation_done(signal_only=True)
|
||||
|
||||
asyncio.create_task(_run_test())
|
||||
|
||||
async def handle_save_mixed(self, data: dict):
|
||||
"""Handle a request to save a mixed voice."""
|
||||
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
|
||||
try:
|
||||
payload = SaveMixedVoicePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
# Validate that weights sum to 1.0
|
||||
total_weight = sum(v.weight for v in payload.voices)
|
||||
if abs(total_weight - 1.0) > 0.001:
|
||||
await self.signal_operation_failed(
|
||||
f"Weights must sum to 1.0, got {total_weight}"
|
||||
)
|
||||
return
|
||||
|
||||
if not tts_agent or not tts_agent.api_ready(payload.provider):
|
||||
await self.signal_operation_failed(f"{payload.provider} API not ready")
|
||||
return
|
||||
|
||||
# Build mixer
|
||||
mixer = VoiceMixer(voices=payload.voices)
|
||||
|
||||
# Create a unique voice id for the mixed voice
|
||||
voice_id = f"{payload.label.lower().replace(' ', '-')}"
|
||||
|
||||
# Mix and save the voice using the appropriate provider's methods
|
||||
save_method = getattr(tts_agent, f"{payload.provider}_save_mix", None)
|
||||
|
||||
if not save_method:
|
||||
await self.signal_operation_failed(
|
||||
f"{payload.provider} does not implement voice mixing"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
saved_path = await save_method(voice_id, mixer)
|
||||
|
||||
# voice id is Path relative to talemate root
|
||||
voice_id = str(saved_path.relative_to(TALEMATE_ROOT))
|
||||
|
||||
# Add the voice to the library
|
||||
new_voice = Voice(
|
||||
label=payload.label,
|
||||
provider=payload.provider,
|
||||
provider_id=voice_id,
|
||||
tags=payload.tags,
|
||||
mix=mixer,
|
||||
)
|
||||
|
||||
voice_library = get_voice_library()
|
||||
voice_library.voices[new_voice.id] = new_voice
|
||||
await save_voice_library(voice_library)
|
||||
self._broadcast_update(new_voice.id)
|
||||
await self.signal_operation_done()
|
||||
|
||||
except Exception as e:
|
||||
log.error("Failed to save mixed voice", error=e)
|
||||
await self.signal_operation_failed(f"Failed to save mixed voice: {str(e)}")
|
||||
|
||||
async def handle_generate_for_scene_message(self, data: dict):
|
||||
"""Handle a request to generate a voice for a scene message."""
|
||||
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
log.debug("Generating TTS for scene message", data=data)
|
||||
|
||||
try:
|
||||
payload = GenerateForSceneMessagePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
log.debug("Payload", payload=payload)
|
||||
|
||||
character: "Character | None" = None
|
||||
text: str = ""
|
||||
message: scene_message.SceneMessage | None = None
|
||||
|
||||
if payload.message_id == "intro":
|
||||
text = scene.get_intro()
|
||||
else:
|
||||
message = scene.get_message(payload.message_id)
|
||||
|
||||
if not message:
|
||||
await self.signal_operation_failed("Message not found")
|
||||
return
|
||||
|
||||
if message.typ not in ["character", "narrator", "context_investigation"]:
|
||||
await self.signal_operation_failed(
|
||||
"Message is not a character, narrator, or context investigation message"
|
||||
)
|
||||
return
|
||||
|
||||
log.debug("Message type", message_type=message.typ)
|
||||
|
||||
if isinstance(message, scene_message.CharacterMessage):
|
||||
character = scene.get_character(message.character_name)
|
||||
|
||||
if not character:
|
||||
await self.signal_operation_failed("Character not found")
|
||||
return
|
||||
|
||||
text = message.without_name
|
||||
elif isinstance(message, scene_message.ContextInvestigationMessage):
|
||||
text = message.message
|
||||
else:
|
||||
text = message.message
|
||||
|
||||
if not text:
|
||||
await self.signal_operation_failed("No text to generate speech for.")
|
||||
return
|
||||
|
||||
await tts_agent.generate(text, character, message=message)
|
||||
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_stop_and_clear(self, data: dict):
|
||||
"""Handle a request from the frontend to stop and clear the current TTS queue."""
|
||||
|
||||
tts_agent: "TTSAgent" = get_agent("tts")
|
||||
|
||||
if not tts_agent:
|
||||
await self.signal_operation_failed("TTS agent not available")
|
||||
return
|
||||
|
||||
try:
|
||||
await tts_agent.stop_and_clear_queue()
|
||||
await self.signal_operation_done()
|
||||
except Exception as e:
|
||||
log.error("Failed to stop and clear TTS queue", error=e)
|
||||
await self.signal_operation_failed(str(e))
|
||||
|
||||
async def handle_upload_voice_file(self, data: dict):
|
||||
"""Handle uploading a new audio file for a voice.
|
||||
|
||||
The *provider* defines which MIME types it accepts via
|
||||
``VoiceProvider.upload_file_types``. This method therefore:
|
||||
|
||||
1. Parses the data-URL to obtain the raw bytes **and** MIME type.
|
||||
2. Verifies the MIME type against the provider's allowed list
|
||||
(if the provider restricts uploads).
|
||||
3. Stores the file under
|
||||
|
||||
``tts/voice/<provider>/<slug(label)>.<extension>``
|
||||
|
||||
where *extension* is derived from the MIME type (e.g. ``audio/wav`` → ``wav``).
|
||||
4. Returns the relative path ("provider_id") back to the frontend so
|
||||
it can populate the voice's ``provider_id`` field.
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = UploadVoiceFilePayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
# Check provider allows file uploads
|
||||
from .providers import provider as get_provider
|
||||
|
||||
P = get_provider(payload.provider)
|
||||
if not P.allow_file_upload:
|
||||
await self.signal_operation_failed(
|
||||
f"Provider '{payload.provider}' does not support file uploads"
|
||||
)
|
||||
return
|
||||
|
||||
# Build filename from label
|
||||
def slugify(text: str) -> str:
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r"[^a-z0-9]+", "-", text)
|
||||
return text.strip("-")
|
||||
|
||||
filename_no_ext = slugify(payload.label or "voice") or "voice"
|
||||
|
||||
# Determine media type and validate against provider
|
||||
try:
|
||||
header, b64data = payload.content.split(",", 1)
|
||||
media_type = header.split(":", 1)[1].split(";", 1)[0]
|
||||
except Exception:
|
||||
await self.signal_operation_failed("Invalid data URL format")
|
||||
return
|
||||
|
||||
if P.upload_file_types and media_type not in P.upload_file_types:
|
||||
await self.signal_operation_failed(
|
||||
f"File type '{media_type}' not allowed for provider '{payload.provider}'"
|
||||
)
|
||||
return
|
||||
|
||||
extension = media_type.split("/")[1]
|
||||
filename = f"{filename_no_ext}.{extension}"
|
||||
|
||||
# Determine target directory and path
|
||||
if not payload.as_scene_asset:
|
||||
target_dir = P.default_voice_dir
|
||||
else:
|
||||
target_dir = Path(self.scene.assets.asset_directory) / "tts"
|
||||
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
target_path = target_dir / filename
|
||||
|
||||
log.debug(
|
||||
"Target path",
|
||||
target_path=target_path,
|
||||
as_scene_asset=payload.as_scene_asset,
|
||||
)
|
||||
|
||||
# Decode base64 data URL
|
||||
try:
|
||||
file_bytes = base64.b64decode(b64data)
|
||||
except Exception as e:
|
||||
await self.signal_operation_failed(f"Invalid base64 data: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(target_path, "wb") as f:
|
||||
f.write(file_bytes)
|
||||
except Exception as e:
|
||||
await self.signal_operation_failed(f"Failed to save file: {e}")
|
||||
return
|
||||
|
||||
provider_id = str(target_path.relative_to(TALEMATE_ROOT))
|
||||
|
||||
# Send response back to frontend so it can set provider_id
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "voice_file_uploaded",
|
||||
"provider_id": provider_id,
|
||||
}
|
||||
)
|
||||
await self.signal_operation_done(signal_only=True)
|
||||
@@ -14,10 +14,9 @@ from talemate.agents.base import (
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
from talemate.agents.editor.revision import RevisionDisabled
|
||||
from talemate.agents.summarize.analyze_scene import SceneAnalysisDisabled
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers as signal_handlers
|
||||
from talemate.prompts.base import Prompt
|
||||
|
||||
from .commands import * # noqa
|
||||
@@ -152,16 +151,13 @@ class VisualBase(Agent):
|
||||
|
||||
return actions
|
||||
|
||||
def __init__(self, client: ClientBase, *kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = False
|
||||
self.backend_ready = False
|
||||
self.initialized = False
|
||||
self.config = load_config()
|
||||
self.actions = VisualBase.init_actions()
|
||||
|
||||
signal_handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
@@ -231,6 +227,10 @@ class VisualBase(Agent):
|
||||
or f"{self.backend_name} is not ready for processing",
|
||||
).model_dump()
|
||||
|
||||
backend_detail_fn = getattr(self, f"{self.backend.lower()}_agent_details", None)
|
||||
if backend_detail_fn:
|
||||
details.update(backend_detail_fn())
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
@@ -241,11 +241,6 @@ class VisualBase(Agent):
|
||||
def allow_automatic_generation(self):
|
||||
return self.actions["automatic_generation"].enabled
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
asyncio.create_task(self.emit_status())
|
||||
|
||||
async def on_ready_check_success(self):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = True
|
||||
@@ -406,7 +401,11 @@ class VisualBase(Agent):
|
||||
f"data:image/png;base64,{image}"
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
self.scene.assets.cover_image = asset.id
|
||||
|
||||
# Only set scene cover image if scene doesn't already have one
|
||||
if not self.scene.assets.cover_image:
|
||||
self.scene.assets.cover_image = asset.id
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
async def emit_image(self, image: str):
|
||||
@@ -538,7 +537,7 @@ class VisualBase(Agent):
|
||||
|
||||
@set_processing
|
||||
async def generate_environment_prompt(self, instructions: str = None):
|
||||
with RevisionDisabled():
|
||||
with RevisionDisabled(), SceneAnalysisDisabled():
|
||||
response = await Prompt.request(
|
||||
"visual.generate-environment-prompt",
|
||||
self.client,
|
||||
@@ -557,7 +556,7 @@ class VisualBase(Agent):
|
||||
):
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
with RevisionDisabled():
|
||||
with RevisionDisabled(), SceneAnalysisDisabled():
|
||||
response = await Prompt.request(
|
||||
"visual.generate-character-prompt",
|
||||
self.client,
|
||||
|
||||
@@ -10,7 +10,12 @@ import httpx
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
@@ -164,11 +169,16 @@ class ComfyUIMixin:
|
||||
label="Checkpoint",
|
||||
choices=[],
|
||||
description="The main checkpoint to use.",
|
||||
note="If the agent is enabled and connected, but the checkpoint list is empty, try closing this window and opening it again.",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def comfyui_checkpoint(self):
|
||||
return self.actions["comfyui"].config["checkpoint"].value
|
||||
|
||||
@property
|
||||
def comfyui_workflow_filename(self):
|
||||
base_name = self.actions["comfyui"].config["workflow"].value
|
||||
@@ -219,10 +229,27 @@ class ComfyUIMixin:
|
||||
async def comfyui_checkpoints(self):
|
||||
loader_node = (await self.comfyui_object_info)["CheckpointLoaderSimple"]
|
||||
_checkpoints = loader_node["input"]["required"]["ckpt_name"][0]
|
||||
log.debug("comfyui_checkpoints", _checkpoints=_checkpoints)
|
||||
return [
|
||||
{"label": checkpoint, "value": checkpoint} for checkpoint in _checkpoints
|
||||
]
|
||||
|
||||
def comfyui_agent_details(self):
|
||||
checkpoint: str = self.comfyui_checkpoint
|
||||
if not checkpoint:
|
||||
return {}
|
||||
|
||||
# remove .safetensors
|
||||
checkpoint = checkpoint.replace(".safetensors", "")
|
||||
|
||||
return {
|
||||
"checkpoint": AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=checkpoint,
|
||||
description="The checkpoint to use for comfyui",
|
||||
).model_dump()
|
||||
}
|
||||
|
||||
async def comfyui_get_image(self, filename: str, subfolder: str, folder_type: str):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
|
||||
@@ -55,7 +55,7 @@ class OpenAIImageMixin:
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
return self.config.openai.api_key
|
||||
|
||||
@property
|
||||
def openai_model_type(self):
|
||||
|
||||
@@ -91,7 +91,7 @@ class VisualWebsocketHandler(Plugin):
|
||||
await visual.generate_character_portrait(
|
||||
payload.context.character_name,
|
||||
payload.context.instructions,
|
||||
replace=True,
|
||||
replace=payload.context.replace,
|
||||
prompt_only=payload.context.prompt_only,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.instance import get_agent
|
||||
from talemate.client import ClientBase
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import (
|
||||
ReinforcementMessage,
|
||||
@@ -125,7 +126,7 @@ class WorldStateAgent(CharacterProgressionMixin, Agent):
|
||||
CharacterProgressionMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.next_update = 0
|
||||
|
||||
@@ -1,16 +1,539 @@
|
||||
from typing import TYPE_CHECKING, Union
|
||||
import pydantic
|
||||
import structlog
|
||||
import random
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from talemate.instance import get_agent
|
||||
import talemate.util as util
|
||||
import talemate.instance as instance
|
||||
import talemate.scene_message as scene_message
|
||||
import talemate.agents.base as agent_base
|
||||
from talemate.agents.tts.schema import Voice
|
||||
import talemate.emit.async_signals as async_signals
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene
|
||||
|
||||
from talemate.tale_mate import Scene, Actor
|
||||
|
||||
__all__ = [
|
||||
"Character",
|
||||
"VoiceChangedEvent",
|
||||
"deactivate_character",
|
||||
"activate_character",
|
||||
"set_voice",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.character")
|
||||
|
||||
async_signals.register("character.voice_changed")
|
||||
|
||||
|
||||
class Character(pydantic.BaseModel):
|
||||
# core character information
|
||||
name: str
|
||||
description: str = ""
|
||||
greeting_text: str = ""
|
||||
color: str = "#fff"
|
||||
is_player: bool = False
|
||||
memory_dirty: bool = False
|
||||
cover_image: str | None = None
|
||||
voice: Voice | None = None
|
||||
|
||||
# dialogue instructions and examples
|
||||
dialogue_instructions: str | None = None
|
||||
example_dialogue: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
# attribute and detail storage
|
||||
base_attributes: dict[str, str | int | float | bool] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
details: dict[str, str] = pydantic.Field(default_factory=dict)
|
||||
|
||||
# helpful references
|
||||
agent: agent_base.Agent | None = pydantic.Field(default=None, exclude=True)
|
||||
actor: "Actor | None" = pydantic.Field(default=None, exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def gender(self) -> str:
|
||||
return self.base_attributes.get("gender", "")
|
||||
|
||||
@property
|
||||
def sheet(self) -> str:
|
||||
sheet = self.base_attributes or {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
}
|
||||
|
||||
sheet_list = []
|
||||
|
||||
for key, value in sheet.items():
|
||||
sheet_list.append(f"{key}: {value}")
|
||||
|
||||
return "\n".join(sheet_list)
|
||||
|
||||
@property
|
||||
def random_dialogue_example(self):
|
||||
"""
|
||||
Get a random example dialogue line for this character.
|
||||
|
||||
Returns:
|
||||
str: The random example dialogue line.
|
||||
"""
|
||||
if not self.example_dialogue:
|
||||
return ""
|
||||
|
||||
return random.choice(self.example_dialogue)
|
||||
|
||||
def __str__(self):
|
||||
return f"Character: {self.name}"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def set_color(self, color: str | None = None):
|
||||
# if no color provided, chose a random color
|
||||
|
||||
if color is None:
|
||||
color = util.random_color()
|
||||
self.color = color
|
||||
|
||||
def set_cover_image(self, asset_id: str, initial_only: bool = False):
|
||||
if self.cover_image and initial_only:
|
||||
return
|
||||
|
||||
self.cover_image = asset_id
|
||||
|
||||
def sheet_filtered(self, *exclude):
|
||||
sheet = self.base_attributes or {
|
||||
"name": self.name,
|
||||
"gender": self.gender,
|
||||
"description": self.description,
|
||||
}
|
||||
|
||||
sheet_list = []
|
||||
|
||||
for key, value in sheet.items():
|
||||
if key not in exclude:
|
||||
sheet_list.append(f"{key}: {value}")
|
||||
|
||||
return "\n".join(sheet_list)
|
||||
|
||||
def random_dialogue_examples(
|
||||
self,
|
||||
scene: "Scene",
|
||||
num: int = 3,
|
||||
strip_name: bool = False,
|
||||
max_backlog: int = 250,
|
||||
max_length: int = 192,
|
||||
history_threshold: int = 15,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get multiple random example dialogue lines for this character.
|
||||
|
||||
Will return up to `num` examples and not have any duplicates.
|
||||
"""
|
||||
|
||||
if len(scene.history) < history_threshold and self.example_dialogue:
|
||||
# when history is too short, we just use from the prepared
|
||||
# examples
|
||||
return self._random_dialogue_examples(num, strip_name)
|
||||
|
||||
history_examples = self._random_dialogue_examples_from_history(
|
||||
scene, num, max_backlog
|
||||
)
|
||||
|
||||
if len(history_examples) < num:
|
||||
random_examples = self._random_dialogue_examples(
|
||||
num - len(history_examples), strip_name
|
||||
)
|
||||
|
||||
for example in random_examples:
|
||||
history_examples.append(example)
|
||||
|
||||
# ensure sane example lengths
|
||||
|
||||
history_examples = [
|
||||
util.strip_partial_sentences(example[:max_length])
|
||||
for example in history_examples
|
||||
]
|
||||
|
||||
log.debug("random_dialogue_examples", history_examples=history_examples)
|
||||
return history_examples
|
||||
|
||||
def _random_dialogue_examples_from_history(
|
||||
self, scene: "Scene", num: int = 3, max_backlog: int = 250
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get multiple random example dialogue lines for this character from the scene's history.
|
||||
|
||||
Will checks the last `max_backlog` messages in the scene's history and returns up to `num` examples.
|
||||
"""
|
||||
|
||||
history = scene.history[-max_backlog:]
|
||||
|
||||
examples = []
|
||||
|
||||
for message in history:
|
||||
if not isinstance(message, scene_message.CharacterMessage):
|
||||
continue
|
||||
|
||||
if message.character_name != self.name:
|
||||
continue
|
||||
|
||||
examples.append(message.without_name.strip())
|
||||
|
||||
if not examples:
|
||||
return []
|
||||
|
||||
return random.sample(examples, min(num, len(examples)))
|
||||
|
||||
def _random_dialogue_examples(
|
||||
self, num: int = 3, strip_name: bool = False
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get multiple random example dialogue lines for this character.
|
||||
|
||||
Will return up to `num` examples and not have any duplicates.
|
||||
"""
|
||||
|
||||
if not self.example_dialogue:
|
||||
return []
|
||||
|
||||
# create copy of example_dialogue so we dont modify the original
|
||||
|
||||
examples = self.example_dialogue.copy()
|
||||
|
||||
# shuffle the examples so we get a random order
|
||||
|
||||
random.shuffle(examples)
|
||||
|
||||
# now pop examples until we have `num` examples or we run out of examples
|
||||
|
||||
if strip_name:
|
||||
examples = [example.split(":", 1)[1].strip() for example in examples]
|
||||
|
||||
return [examples.pop() for _ in range(min(num, len(examples)))]
|
||||
|
||||
def filtered_sheet(self, attributes: list[str]):
|
||||
"""
|
||||
Same as sheet but only returns the attributes in the given list
|
||||
|
||||
Attributes that dont exist will be ignored
|
||||
"""
|
||||
|
||||
sheet_list = []
|
||||
|
||||
for key, value in self.base_attributes.items():
|
||||
if key.lower() not in attributes:
|
||||
continue
|
||||
sheet_list.append(f"{key}: {value}")
|
||||
|
||||
return "\n".join(sheet_list)
|
||||
|
||||
def rename(self, new_name: str):
|
||||
"""
|
||||
Rename the character.
|
||||
|
||||
Args:
|
||||
new_name (str): The new name of the character.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
orig_name = self.name
|
||||
self.name = new_name
|
||||
|
||||
if orig_name.lower() == "you":
|
||||
# we dont want to replace "you" in the description
|
||||
# or anywhere else so we can just return here
|
||||
return
|
||||
|
||||
if self.description:
|
||||
self.description = self.description.replace(f"{orig_name}", self.name)
|
||||
for k, v in self.base_attributes.items():
|
||||
if isinstance(v, str):
|
||||
self.base_attributes[k] = v.replace(f"{orig_name}", self.name)
|
||||
for i, v in list(self.details.items()):
|
||||
if isinstance(v, str):
|
||||
self.details[i] = v.replace(f"{orig_name}", self.name)
|
||||
self.memory_dirty = True
|
||||
|
||||
def introduce_main_character(self, character: "Character"):
|
||||
"""
|
||||
Makes this character aware of the main character's name in the scene.
|
||||
|
||||
This will replace all occurrences of {{user}} (case-insensitive) in all of the character's properties
|
||||
with the main character's name.
|
||||
"""
|
||||
|
||||
properties = ["description", "greeting_text"]
|
||||
|
||||
pattern = re.compile(re.escape("{{user}}"), re.IGNORECASE)
|
||||
|
||||
for prop in properties:
|
||||
prop_value = getattr(self, prop)
|
||||
|
||||
try:
|
||||
updated_prop_value = pattern.sub(character.name, prop_value)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"introduce_main_character",
|
||||
error=e,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
updated_prop_value = prop_value
|
||||
setattr(self, prop, updated_prop_value)
|
||||
|
||||
# also replace in all example dialogue
|
||||
|
||||
for i, dialogue in enumerate(self.example_dialogue):
|
||||
self.example_dialogue[i] = pattern.sub(character.name, dialogue)
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""
|
||||
Update character properties with given key-value pairs.
|
||||
"""
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
self.memory_dirty = True
|
||||
|
||||
async def purge_from_memory(self):
|
||||
"""
|
||||
Purges this character's details from memory.
|
||||
"""
|
||||
memory_agent = instance.get_agent("memory")
|
||||
await memory_agent.delete({"character": self.name})
|
||||
log.info("purged character from memory", character=self.name)
|
||||
|
||||
async def commit_to_memory(self, memory_agent):
|
||||
"""
|
||||
Commits this character's details to the memory agent. (vectordb)
|
||||
"""
|
||||
|
||||
items = []
|
||||
|
||||
if not self.base_attributes or "description" not in self.base_attributes:
|
||||
if not self.description:
|
||||
self.description = ""
|
||||
description_chunks = [
|
||||
chunk.strip() for chunk in self.description.split("\n") if chunk.strip()
|
||||
]
|
||||
|
||||
for idx in range(len(description_chunks)):
|
||||
chunk = description_chunks[idx]
|
||||
|
||||
items.append(
|
||||
{
|
||||
"text": f"{self.name}: {chunk}",
|
||||
"id": f"{self.name}.description.{idx}",
|
||||
"meta": {
|
||||
"character": self.name,
|
||||
"attr": "description",
|
||||
"typ": "base_attribute",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
seen_attributes = set()
|
||||
|
||||
for attr, value in self.base_attributes.items():
|
||||
if attr.startswith("_"):
|
||||
continue
|
||||
|
||||
if attr.lower() in ["name", "scenario_context", "_prompt", "_template"]:
|
||||
continue
|
||||
|
||||
seen_attributes.add(attr)
|
||||
|
||||
items.append(
|
||||
{
|
||||
"text": f"{self.name}'s {attr}: {value}",
|
||||
"id": f"{self.name}.{attr}",
|
||||
"meta": {
|
||||
"character": self.name,
|
||||
"attr": attr,
|
||||
"typ": "base_attribute",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
for key, detail in self.details.items():
|
||||
# if colliding with attribute name, prefix with detail_
|
||||
if key in seen_attributes:
|
||||
key = f"detail_{key}"
|
||||
|
||||
items.append(
|
||||
{
|
||||
"text": f"{self.name} - {key}: {detail}",
|
||||
"id": f"{self.name}.{key}",
|
||||
"meta": {
|
||||
"character": self.name,
|
||||
"typ": "details",
|
||||
"detail": key,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if items:
|
||||
await memory_agent.add_many(items)
|
||||
|
||||
self.memory_dirty = False
|
||||
|
||||
async def commit_single_attribute_to_memory(
|
||||
self, memory_agent, attribute: str, value: str
|
||||
):
|
||||
"""
|
||||
Commits a single attribute to memory
|
||||
"""
|
||||
|
||||
items = []
|
||||
|
||||
# remove old attribute if it exists
|
||||
|
||||
await memory_agent.delete(
|
||||
{"character": self.name, "typ": "base_attribute", "attr": attribute}
|
||||
)
|
||||
|
||||
self.base_attributes[attribute] = value
|
||||
|
||||
items.append(
|
||||
{
|
||||
"text": f"{self.name}'s {attribute}: {self.base_attributes[attribute]}",
|
||||
"id": f"{self.name}.{attribute}",
|
||||
"meta": {
|
||||
"character": self.name,
|
||||
"attr": attribute,
|
||||
"typ": "base_attribute",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
log.debug("commit_single_attribute_to_memory", items=items)
|
||||
|
||||
await memory_agent.add_many(items)
|
||||
|
||||
async def commit_single_detail_to_memory(
|
||||
self, memory_agent, detail: str, value: str
|
||||
):
|
||||
"""
|
||||
Commits a single detail to memory
|
||||
"""
|
||||
|
||||
items = []
|
||||
|
||||
# remove old detail if it exists
|
||||
|
||||
await memory_agent.delete(
|
||||
{"character": self.name, "typ": "details", "detail": detail}
|
||||
)
|
||||
|
||||
self.details[detail] = value
|
||||
|
||||
items.append(
|
||||
{
|
||||
"text": f"{self.name} - {detail}: {value}",
|
||||
"id": f"{self.name}.{detail}",
|
||||
"meta": {
|
||||
"character": self.name,
|
||||
"typ": "details",
|
||||
"detail": detail,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
log.debug("commit_single_detail_to_memory", items=items)
|
||||
|
||||
await memory_agent.add_many(items)
|
||||
|
||||
async def set_detail(self, name: str, value):
|
||||
memory_agent = instance.get_agent("memory")
|
||||
if not value:
|
||||
try:
|
||||
del self.details[name]
|
||||
await memory_agent.delete(
|
||||
{"character": self.name, "typ": "details", "detail": name}
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.details[name] = value
|
||||
await self.commit_single_detail_to_memory(memory_agent, name, value)
|
||||
|
||||
def set_detail_defer(self, name: str, value):
|
||||
self.details[name] = value
|
||||
self.memory_dirty = True
|
||||
|
||||
def get_detail(self, name: str):
|
||||
return self.details.get(name)
|
||||
|
||||
async def set_base_attribute(self, name: str, value):
|
||||
memory_agent = instance.get_agent("memory")
|
||||
|
||||
if not value:
|
||||
try:
|
||||
del self.base_attributes[name]
|
||||
await memory_agent.delete(
|
||||
{"character": self.name, "typ": "base_attribute", "attr": name}
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.base_attributes[name] = value
|
||||
await self.commit_single_attribute_to_memory(memory_agent, name, value)
|
||||
|
||||
def set_base_attribute_defer(self, name: str, value):
|
||||
self.base_attributes[name] = value
|
||||
self.memory_dirty = True
|
||||
|
||||
def get_base_attribute(self, name: str):
|
||||
return self.base_attributes.get(name)
|
||||
|
||||
async def set_description(self, description: str):
|
||||
memory_agent = instance.get_agent("memory")
|
||||
self.description = description
|
||||
|
||||
items = []
|
||||
|
||||
await memory_agent.delete(
|
||||
{"character": self.name, "typ": "base_attribute", "attr": "description"}
|
||||
)
|
||||
|
||||
description_chunks = [
|
||||
chunk.strip() for chunk in self.description.split("\n") if chunk.strip()
|
||||
]
|
||||
|
||||
for idx in range(len(description_chunks)):
|
||||
chunk = description_chunks[idx]
|
||||
|
||||
items.append(
|
||||
{
|
||||
"text": f"{self.name}: {chunk}",
|
||||
"id": f"{self.name}.description.{idx}",
|
||||
"meta": {
|
||||
"character": self.name,
|
||||
"attr": "description",
|
||||
"typ": "base_attribute",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await memory_agent.add_many(items)
|
||||
|
||||
|
||||
class VoiceChangedEvent(pydantic.BaseModel):
|
||||
character: "Character"
|
||||
voice: Voice | None
|
||||
auto: bool = False
|
||||
|
||||
|
||||
async def deactivate_character(scene: "Scene", character: Union[str, "Character"]):
|
||||
"""
|
||||
@@ -51,9 +574,18 @@ async def activate_character(scene: "Scene", character: Union[str, "Character"])
|
||||
return False
|
||||
|
||||
if not character.is_player:
|
||||
actor = scene.Actor(character, get_agent("conversation"))
|
||||
actor = scene.Actor(character, instance.get_agent("conversation"))
|
||||
else:
|
||||
actor = scene.Player(character, None)
|
||||
|
||||
await scene.add_actor(actor)
|
||||
del scene.inactive_characters[character.name]
|
||||
|
||||
|
||||
async def set_voice(character: "Character", voice: Voice | None, auto: bool = False):
|
||||
character.voice = voice
|
||||
emission: VoiceChangedEvent = VoiceChangedEvent(
|
||||
character=character, voice=voice, auto=auto
|
||||
)
|
||||
await async_signals.get("character.voice_changed").send(emission)
|
||||
return emission
|
||||
|
||||
@@ -9,10 +9,8 @@ from talemate.client.remote import (
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config import load_config
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"AnthropicClient",
|
||||
@@ -33,10 +31,13 @@ SUPPORTED_MODELS = [
|
||||
"claude-opus-4-20250514",
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "claude-3-5-sonnet-latest"
|
||||
MIN_THINKING_TOKENS = 1024
|
||||
|
||||
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "claude-3-5-sonnet-latest"
|
||||
model: str = DEFAULT_MODEL
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
@@ -52,7 +53,6 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
client_type = "anthropic"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
config_cls = ClientConfig
|
||||
@@ -66,22 +66,13 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="claude-3-5-sonnet-latest", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return True
|
||||
return not self.reason_enabled
|
||||
|
||||
@property
|
||||
def anthropic_api_key(self):
|
||||
return self.config.get("anthropic", {}).get("api_key")
|
||||
return self.config.anthropic.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -92,17 +83,25 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
@property
|
||||
def min_reason_tokens(self) -> int:
|
||||
return MIN_THINKING_TOKENS
|
||||
|
||||
@property
|
||||
def requires_reasoning_pattern(self) -> bool:
|
||||
return False
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message: str | None = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.anthropic_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -115,7 +114,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
@@ -124,73 +123,18 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
"double_coercion": self.double_coercion,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if (
|
||||
not self.anthropic_api_key
|
||||
and not self.endpoint_override_base_url_configured
|
||||
):
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "claude-3-opus-20240229"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"anthropic set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.output_tokens
|
||||
|
||||
@@ -200,13 +144,6 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
"""
|
||||
Anthropic handles the prompt template internally, so we just
|
||||
give the prompt as is.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -218,17 +155,35 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
):
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
if self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [{"role": "user", "content": prompt.strip()}]
|
||||
|
||||
if coercion_prompt:
|
||||
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
|
||||
if self.reason_enabled:
|
||||
parameters["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": self.validated_reason_tokens,
|
||||
}
|
||||
# thinking doesn't support temperature, top_p, or top_k
|
||||
# and the API will error if they are set
|
||||
parameters.pop("temperature", None)
|
||||
parameters.pop("top_p", None)
|
||||
parameters.pop("top_k", None)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
model=self.model_name,
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
@@ -238,7 +193,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
prompt_tokens = 0
|
||||
|
||||
try:
|
||||
stream = await self.client.messages.create(
|
||||
stream = await client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_message,
|
||||
messages=messages,
|
||||
@@ -247,13 +202,25 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
)
|
||||
|
||||
response = ""
|
||||
reasoning = ""
|
||||
|
||||
async for event in stream:
|
||||
if event.type == "content_block_delta":
|
||||
if (
|
||||
event.type == "content_block_delta"
|
||||
and event.delta.type == "text_delta"
|
||||
):
|
||||
content = event.delta.text
|
||||
response += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
elif (
|
||||
event.type == "content_block_delta"
|
||||
and event.delta.type == "thinking_delta"
|
||||
):
|
||||
content = event.delta.thinking
|
||||
reasoning += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
elif event.type == "message_start":
|
||||
prompt_tokens = event.message.usage.input_tokens
|
||||
|
||||
@@ -262,8 +229,9 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
self._reasoning_response = reasoning
|
||||
|
||||
log.debug("generated response", response=response)
|
||||
log.debug("generated response", response=response, reasoning=reasoning)
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
|
||||
@@ -4,6 +4,7 @@ A unified client base, based on the openai API
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import re
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
@@ -14,20 +15,27 @@ import pydantic
|
||||
import dataclasses
|
||||
import structlog
|
||||
import urllib3
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
from openai import PermissionDeniedError
|
||||
|
||||
import talemate.client.presets as presets
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.agents.context import active_agent
|
||||
from talemate.client.context import client_context_attribute
|
||||
from talemate.client.model_prompts import model_prompt
|
||||
from talemate.client.model_prompts import model_prompt, DEFAULT_TEMPLATE
|
||||
from talemate.client.ratelimit import CounterRateLimiter
|
||||
from talemate.context import active_scene
|
||||
from talemate.prompts.base import Prompt
|
||||
from talemate.emit import emit
|
||||
from talemate.config import load_config, save_config, EmbeddingFunctionPreset
|
||||
from talemate.config import get_config, Config
|
||||
from talemate.config.schema import EmbeddingFunctionPreset, Client as ClientConfig
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.exceptions import SceneInactiveError, GenerationCancelled
|
||||
from talemate.exceptions import (
|
||||
SceneInactiveError,
|
||||
GenerationCancelled,
|
||||
GenerationProcessingError,
|
||||
ReasoningResponseError,
|
||||
)
|
||||
import talemate.ux.schema as ux_schema
|
||||
|
||||
from talemate.client.system_prompts import SystemPrompts
|
||||
@@ -43,6 +51,11 @@ STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||
REPLACE_SMART_QUOTES = True
|
||||
|
||||
|
||||
INDIRECT_COERCION_PROMPT = "\nStart your response with: "
|
||||
|
||||
DEFAULT_REASONING_PATTERN = r".*?</think>"
|
||||
|
||||
|
||||
class ClientDisabledError(OSError):
|
||||
def __init__(self, client: "ClientBase"):
|
||||
self.client = client
|
||||
@@ -63,6 +76,7 @@ class PromptData(pydantic.BaseModel):
|
||||
generation_parameters: dict = pydantic.Field(default_factory=dict)
|
||||
inference_preset: str = None
|
||||
preset_group: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
class ErrorAction(pydantic.BaseModel):
|
||||
@@ -76,6 +90,9 @@ class CommonDefaults(pydantic.BaseModel):
|
||||
rate_limit: int | None = None
|
||||
data_format: Literal["yaml", "json"] | None = None
|
||||
preset_group: str | None = None
|
||||
reason_enabled: bool = False
|
||||
reason_tokens: int = 0
|
||||
reason_response_pattern: str | None = None
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
@@ -99,6 +116,7 @@ class ExtraField(pydantic.BaseModel):
|
||||
description: str
|
||||
group: FieldGroup | None = None
|
||||
note: ux_schema.Note | None = None
|
||||
choices: list[str | int | float | bool] | None = None
|
||||
|
||||
|
||||
class ParameterReroute(pydantic.BaseModel):
|
||||
@@ -162,39 +180,36 @@ class RequestInformation(pydantic.BaseModel):
|
||||
class ClientEmbeddingsStatus:
|
||||
client: "ClientBase | None" = None
|
||||
embedding_name: str | None = None
|
||||
seen: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClientStatus:
|
||||
client: "ClientBase | None" = None
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
async_signals.register(
|
||||
"client.embeddings_available",
|
||||
"client.enabled",
|
||||
"client.disabled",
|
||||
)
|
||||
|
||||
|
||||
class ClientBase:
|
||||
api_url: str
|
||||
model_name: str
|
||||
api_key: str = None
|
||||
name: str = None
|
||||
enabled: bool = True
|
||||
name: str
|
||||
remote_model_name: str | None = None
|
||||
remote_model_locked: bool = False
|
||||
current_status: str = None
|
||||
max_token_length: int = 8192
|
||||
processing: bool = False
|
||||
connected: bool = False
|
||||
conversation_retries: int = 0
|
||||
auto_break_repetition_enabled: bool = True
|
||||
decensor_enabled: bool = True
|
||||
auto_determine_prompt_template: bool = False
|
||||
finalizers: list[str] = []
|
||||
double_coercion: Union[str, None] = None
|
||||
data_format: Literal["yaml", "json"] | None = None
|
||||
rate_limit: int | None = None
|
||||
client_type = "base"
|
||||
request_information: RequestInformation | None = None
|
||||
|
||||
status_request_timeout: int = 2
|
||||
|
||||
system_prompts: SystemPrompts = SystemPrompts()
|
||||
preset_group: str | None = ""
|
||||
|
||||
rate_limit_counter: CounterRateLimiter = None
|
||||
|
||||
class Meta(pydantic.BaseModel):
|
||||
@@ -207,27 +222,92 @@ class ClientBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str = None,
|
||||
name: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.api_url = api_url
|
||||
self.name = name or self.client_type
|
||||
self.remote_model_name = None
|
||||
self.auto_determine_prompt_template_attempt = None
|
||||
self.log = structlog.get_logger(f"client.{self.client_type}")
|
||||
self.double_coercion = kwargs.get("double_coercion", None)
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self.enabled = kwargs.get("enabled", True)
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
|
||||
)
|
||||
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
|
||||
|
||||
#####
|
||||
|
||||
# config getters
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return get_config()
|
||||
|
||||
@property
|
||||
def client_config(self) -> ClientConfig:
|
||||
try:
|
||||
return get_config().clients[self.name]
|
||||
except KeyError:
|
||||
return ClientConfig(type=self.client_type, name=self.name)
|
||||
|
||||
@property
|
||||
def model(self) -> str | None:
|
||||
return self.client_config.model
|
||||
|
||||
@property
|
||||
def model_name(self) -> str | None:
|
||||
if self.remote_model_locked:
|
||||
return self.remote_model_name
|
||||
return self.remote_model_name or self.model
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.client_config.api_key
|
||||
|
||||
@property
|
||||
def api_url(self) -> str | None:
|
||||
return self.client_config.api_url
|
||||
|
||||
@property
|
||||
def max_token_length(self) -> int:
|
||||
return self.client_config.max_token_length
|
||||
|
||||
@property
|
||||
def double_coercion(self) -> str | None:
|
||||
return self.client_config.double_coercion
|
||||
|
||||
@property
|
||||
def rate_limit(self) -> int | None:
|
||||
return self.client_config.rate_limit
|
||||
|
||||
@property
|
||||
def data_format(self) -> Literal["yaml", "json"]:
|
||||
return self.client_config.data_format
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.client_config.enabled
|
||||
|
||||
@property
|
||||
def system_prompts(self) -> SystemPrompts:
|
||||
return self.client_config.system_prompts
|
||||
|
||||
@property
|
||||
def preset_group(self) -> str | None:
|
||||
return self.client_config.preset_group
|
||||
|
||||
@property
|
||||
def reason_enabled(self) -> bool:
|
||||
return self.client_config.reason_enabled
|
||||
|
||||
@property
|
||||
def reason_tokens(self) -> int:
|
||||
return self.client_config.reason_tokens
|
||||
|
||||
@property
|
||||
def reason_response_pattern(self) -> str:
|
||||
return self.client_config.reason_response_pattern or DEFAULT_REASONING_PATTERN
|
||||
|
||||
#####
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
@@ -238,6 +318,9 @@ class ClientBase:
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
if self.reason_enabled:
|
||||
# We are not able to coerce via pre-filling if reasoning is enabled
|
||||
return False
|
||||
return self.Meta().requires_prompt_template
|
||||
|
||||
@property
|
||||
@@ -283,7 +366,49 @@ class ClientBase:
|
||||
def embeddings_identifier(self) -> str:
|
||||
return f"client-api/{self.name}/{self.embeddings_model_name}"
|
||||
|
||||
async def destroy(self, config: dict):
|
||||
@property
|
||||
def reasoning_response(self) -> str | None:
|
||||
return getattr(self, "_reasoning_response", None)
|
||||
|
||||
@property
|
||||
def min_reason_tokens(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def validated_reason_tokens(self) -> int:
|
||||
return max(self.reason_tokens, self.min_reason_tokens)
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return DEFAULT_TEMPLATE
|
||||
|
||||
@property
|
||||
def requires_reasoning_pattern(self) -> bool:
|
||||
return True
|
||||
|
||||
async def enable(self):
|
||||
self.client_config.enabled = True
|
||||
self.emit_status()
|
||||
|
||||
await self.config.set_dirty()
|
||||
await self.status()
|
||||
await async_signals.get("client.enabled").send(
|
||||
ClientStatus(client=self, enabled=True)
|
||||
)
|
||||
|
||||
async def disable(self):
|
||||
self.client_config.enabled = False
|
||||
self.emit_status()
|
||||
|
||||
if self.supports_embeddings:
|
||||
await self.reset_embeddings()
|
||||
await self.config.set_dirty()
|
||||
await self.status()
|
||||
await async_signals.get("client.disabled").send(
|
||||
ClientStatus(client=self, enabled=False)
|
||||
)
|
||||
|
||||
async def destroy(self):
|
||||
"""
|
||||
This is called before the client is removed from talemate.instance.clients
|
||||
|
||||
@@ -294,16 +419,13 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
if self.supports_embeddings:
|
||||
self.remove_embeddings(config)
|
||||
await self.remove_embeddings()
|
||||
|
||||
def reset_embeddings(self):
|
||||
async def reset_embeddings(self):
|
||||
self._embeddings_model_name = None
|
||||
self._embeddings_status = False
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||
|
||||
def set_embeddings(self):
|
||||
async def set_embeddings(self):
|
||||
log.debug(
|
||||
"setting embeddings",
|
||||
client=self.name,
|
||||
@@ -314,7 +436,7 @@ class ClientBase:
|
||||
if not self.supports_embeddings or not self.embeddings_status:
|
||||
return
|
||||
|
||||
config = load_config(as_model=True)
|
||||
config: Config = get_config()
|
||||
|
||||
key = self.embeddings_identifier
|
||||
|
||||
@@ -334,30 +456,25 @@ class ClientBase:
|
||||
custom=True,
|
||||
)
|
||||
|
||||
save_config(config)
|
||||
await config.set_dirty()
|
||||
|
||||
def remove_embeddings(self, config: dict | None = None):
|
||||
async def remove_embeddings(self):
|
||||
# remove all embeddings for this client
|
||||
for key, value in list(config["presets"]["embeddings"].items()):
|
||||
if value["client"] == self.name and value["embeddings"] == "client-api":
|
||||
config: Config = get_config()
|
||||
for key, value in list(config.presets.embeddings.items()):
|
||||
if value.client == self.name and value.embeddings == "client-api":
|
||||
log.warning("!!! removing embeddings", client=self.name, key=key)
|
||||
config["presets"]["embeddings"].pop(key)
|
||||
|
||||
def set_system_prompts(self, system_prompts: dict | SystemPrompts):
|
||||
if isinstance(system_prompts, dict):
|
||||
self.system_prompts = SystemPrompts(**system_prompts)
|
||||
elif not isinstance(system_prompts, SystemPrompts):
|
||||
raise ValueError(
|
||||
"system_prompts must be a `dict` or `SystemPrompts` instance"
|
||||
)
|
||||
else:
|
||||
self.system_prompts = system_prompts
|
||||
config.presets.embeddings.pop(key)
|
||||
await config.set_dirty()
|
||||
|
||||
def prompt_template(self, sys_msg: str, prompt: str):
|
||||
"""
|
||||
Applies the appropriate prompt template for the model.
|
||||
"""
|
||||
|
||||
if not self.Meta().requires_prompt_template:
|
||||
return prompt
|
||||
|
||||
if not self.model_name:
|
||||
self.log.warning("prompt template not applied", reason="no model loaded")
|
||||
return f"{sys_msg}\n{prompt}"
|
||||
@@ -372,13 +489,22 @@ class ClientBase:
|
||||
else:
|
||||
double_coercion = None
|
||||
|
||||
return model_prompt(self.model_name, sys_msg, prompt, double_coercion)[0]
|
||||
return model_prompt(
|
||||
self.model_name,
|
||||
sys_msg,
|
||||
prompt,
|
||||
double_coercion,
|
||||
default_template=self.default_prompt_template,
|
||||
)[0]
|
||||
|
||||
def prompt_template_example(self):
|
||||
if not getattr(self, "model_name", None):
|
||||
return None, None
|
||||
return model_prompt(
|
||||
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
|
||||
self.model_name,
|
||||
"{sysmsg}",
|
||||
"{prompt}<|BOT|>{LLM coercion}",
|
||||
default_template=self.default_prompt_template,
|
||||
)
|
||||
|
||||
def split_prompt_for_coercion(self, prompt: str) -> tuple[str, str]:
|
||||
@@ -386,59 +512,29 @@ class ClientBase:
|
||||
Splits the prompt and the prefill/coercion prompt.
|
||||
"""
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
prompt, coercion = prompt.split("<|BOT|>", 1)
|
||||
|
||||
if self.double_coercion:
|
||||
right = f"{self.double_coercion}\n\n{right}"
|
||||
coercion = f"{self.double_coercion}\n\n{coercion}"
|
||||
|
||||
return prompt, right
|
||||
return prompt, coercion
|
||||
return prompt, None
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
def rate_limit_update(self):
|
||||
"""
|
||||
Reconfigures the client.
|
||||
Updates the rate limit counter for the client.
|
||||
|
||||
Keyword Arguments:
|
||||
|
||||
- api_url: the API URL to use
|
||||
- max_token_length: the max token length to use
|
||||
- enabled: whether the client is enabled
|
||||
If the rate limit is set to 0, the rate limit counter is set to None.
|
||||
"""
|
||||
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
|
||||
if kwargs.get("max_token_length"):
|
||||
self.max_token_length = int(kwargs["max_token_length"])
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
if not self.enabled and self.supports_embeddings and self.embeddings_status:
|
||||
self.reset_embeddings()
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def _reconfigure_common_parameters(self, **kwargs):
|
||||
if "rate_limit" in kwargs:
|
||||
self.rate_limit = kwargs["rate_limit"]
|
||||
if self.rate_limit:
|
||||
if not self.rate_limit_counter:
|
||||
self.rate_limit_counter = CounterRateLimiter(
|
||||
rate_per_minute=self.rate_limit
|
||||
)
|
||||
else:
|
||||
self.rate_limit_counter.update_rate_limit(self.rate_limit)
|
||||
if self.rate_limit:
|
||||
if not self.rate_limit_counter:
|
||||
self.rate_limit_counter = CounterRateLimiter(
|
||||
rate_per_minute=self.rate_limit
|
||||
)
|
||||
else:
|
||||
self.rate_limit_counter = None
|
||||
|
||||
if "data_format" in kwargs:
|
||||
self.data_format = kwargs["data_format"]
|
||||
|
||||
if "preset_group" in kwargs:
|
||||
self.preset_group = kwargs["preset_group"]
|
||||
self.rate_limit_counter.update_rate_limit(self.rate_limit)
|
||||
else:
|
||||
self.rate_limit_counter = None
|
||||
|
||||
def host_is_remote(self, url: str) -> bool:
|
||||
"""
|
||||
@@ -491,43 +587,40 @@ class ClientBase:
|
||||
|
||||
- kind: the kind of generation
|
||||
"""
|
||||
|
||||
app_config_system_prompts = client_context_attribute(
|
||||
"app_config_system_prompts"
|
||||
)
|
||||
|
||||
if app_config_system_prompts:
|
||||
self.system_prompts.parent = SystemPrompts(**app_config_system_prompts)
|
||||
|
||||
return self.system_prompts.get(kind, self.decensor_enabled)
|
||||
config: Config = get_config()
|
||||
self.system_prompts.parent = config.system_prompts
|
||||
sys_prompt = self.system_prompts.get(kind, self.decensor_enabled)
|
||||
return sys_prompt
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
"""
|
||||
Sets and emits the client status.
|
||||
"""
|
||||
error_message: str | None = None
|
||||
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
model_name = "Disabled"
|
||||
error_message = "Disabled"
|
||||
elif not self.connected:
|
||||
status = "error"
|
||||
model_name = "Could not connect"
|
||||
error_message = "Could not connect"
|
||||
elif self.model_name:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
status = "warning"
|
||||
|
||||
status_change = status != self.current_status
|
||||
self.current_status = status
|
||||
|
||||
default_prompt_template = self.default_prompt_template
|
||||
|
||||
prompt_template_example, prompt_template_file = self.prompt_template_example()
|
||||
has_prompt_template = (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
prompt_template_file and prompt_template_file != default_prompt_template
|
||||
)
|
||||
|
||||
if not has_prompt_template and self.auto_determine_prompt_template:
|
||||
@@ -545,21 +638,28 @@ class ClientBase:
|
||||
self.prompt_template_example()
|
||||
)
|
||||
has_prompt_template = (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
prompt_template_file
|
||||
and prompt_template_file != default_prompt_template
|
||||
)
|
||||
|
||||
dedicated_default_template = default_prompt_template != DEFAULT_TEMPLATE
|
||||
|
||||
data = {
|
||||
"api_key": self.api_key,
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": has_prompt_template,
|
||||
"dedicated_default_template": dedicated_default_template,
|
||||
"template_file": prompt_template_file,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"error_action": None,
|
||||
"double_coercion": self.double_coercion,
|
||||
"enabled": self.enabled,
|
||||
"system_prompts": self.system_prompts.model_dump(),
|
||||
"error_message": error_message,
|
||||
}
|
||||
|
||||
if self.Meta().enable_api_auth:
|
||||
data["api_key"] = self.api_key
|
||||
|
||||
data.update(self._common_status_data())
|
||||
|
||||
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
|
||||
@@ -571,7 +671,7 @@ class ClientBase:
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status,
|
||||
data=data,
|
||||
)
|
||||
@@ -595,6 +695,11 @@ class ClientBase:
|
||||
"supports_embeddings": self.supports_embeddings,
|
||||
"embeddings_status": self.embeddings_status,
|
||||
"embeddings_model_name": self.embeddings_model_name,
|
||||
"reason_enabled": self.reason_enabled,
|
||||
"reason_tokens": self.reason_tokens,
|
||||
"min_reason_tokens": self.min_reason_tokens,
|
||||
"reason_response_pattern": self.reason_response_pattern,
|
||||
"requires_reasoning_pattern": self.requires_reasoning_pattern,
|
||||
"request_information": self.request_information.model_dump()
|
||||
if self.request_information
|
||||
else None,
|
||||
@@ -646,20 +751,16 @@ class ClientBase:
|
||||
return
|
||||
|
||||
try:
|
||||
self.model_name = await self.get_model_name()
|
||||
self.remote_model_name = await self.get_model_name()
|
||||
except Exception as e:
|
||||
self.log.warning("client status error", e=e, client=self.name)
|
||||
self.model_name = None
|
||||
self.remote_model_name = None
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
self.connected = True
|
||||
|
||||
if not self.model_name or self.model_name == "None":
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
self.emit_status()
|
||||
|
||||
def generate_prompt_parameters(self, kind: str):
|
||||
@@ -682,6 +783,15 @@ class ClientBase:
|
||||
parameters, kind, agent_context.action
|
||||
)
|
||||
|
||||
if self.reason_enabled and self.reason_tokens > 0:
|
||||
log.debug(
|
||||
"padding for reasoning",
|
||||
client=self.client_type,
|
||||
reason_tokens=self.reason_tokens,
|
||||
validated_reason_tokens=self.validated_reason_tokens,
|
||||
)
|
||||
parameters["max_tokens"] += self.validated_reason_tokens
|
||||
|
||||
if client_context_attribute(
|
||||
"nuke_repetition"
|
||||
) > 0.0 and self.jiggle_enabled_for(kind):
|
||||
@@ -838,12 +948,89 @@ class ClientBase:
|
||||
else:
|
||||
self.request_information.tokens += tokens
|
||||
|
||||
def strip_coercion_prompt(self, response: str, coercion_prompt: str = None) -> str:
|
||||
"""
|
||||
Strips the coercion prompt from the response if it is present.
|
||||
"""
|
||||
if not coercion_prompt or not response.startswith(coercion_prompt):
|
||||
return response
|
||||
|
||||
return response.replace(coercion_prompt, "").lstrip()
|
||||
|
||||
def strip_reasoning(self, response: str) -> tuple[str, str]:
|
||||
"""
|
||||
Strips the reasoning from the response if the model is reasoning.
|
||||
"""
|
||||
|
||||
if not self.reason_enabled:
|
||||
return response, None
|
||||
|
||||
if not self.requires_reasoning_pattern:
|
||||
# reasoning handled automatically during streaming
|
||||
return response, None
|
||||
|
||||
pattern = self.reason_response_pattern
|
||||
if not pattern:
|
||||
pattern = DEFAULT_REASONING_PATTERN
|
||||
|
||||
log.debug("reasoning pattern", pattern=pattern)
|
||||
|
||||
extract_reason = re.search(pattern, response, re.DOTALL)
|
||||
|
||||
if extract_reason:
|
||||
reasoning_response = extract_reason.group(0)
|
||||
return response.replace(reasoning_response, ""), reasoning_response
|
||||
|
||||
raise ReasoningResponseError()
|
||||
|
||||
def attach_response_length_instruction(
|
||||
self, prompt: str, response_length: int | None
|
||||
) -> str:
|
||||
"""
|
||||
Attaches the response length instruction to the prompt.
|
||||
"""
|
||||
|
||||
if not response_length or response_length < 0:
|
||||
log.warning("response length instruction", response_length=response_length)
|
||||
return prompt
|
||||
|
||||
instructions_prompt = Prompt.get(
|
||||
"common.response-length",
|
||||
vars={
|
||||
"response_length": response_length,
|
||||
"attach_response_length_instruction": True,
|
||||
},
|
||||
)
|
||||
|
||||
instructions_prompt = instructions_prompt.render()
|
||||
|
||||
if instructions_prompt.strip() in prompt:
|
||||
log.debug(
|
||||
"response length instruction already in prompt",
|
||||
instructions_prompt=instructions_prompt,
|
||||
)
|
||||
return prompt
|
||||
|
||||
log.debug(
|
||||
"response length instruction", instructions_prompt=instructions_prompt
|
||||
)
|
||||
|
||||
if "<|RESPONSE_LENGTH_INSTRUCTIONS|>" in prompt:
|
||||
return prompt.replace(
|
||||
"<|RESPONSE_LENGTH_INSTRUCTIONS|>", instructions_prompt
|
||||
)
|
||||
elif "<|BOT|>" in prompt:
|
||||
return prompt.replace("<|BOT|>", f"{instructions_prompt}<|BOT|>")
|
||||
else:
|
||||
return f"{prompt}{instructions_prompt}"
|
||||
|
||||
async def send_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
kind: str = "conversation",
|
||||
finalize: Callable = lambda x: x,
|
||||
retries: int = 2,
|
||||
data_expected: bool | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the AI and return its response.
|
||||
@@ -852,7 +1039,9 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
try:
|
||||
return await self._send_prompt(prompt, kind, finalize, retries)
|
||||
return await self._send_prompt(
|
||||
prompt, kind, finalize, retries, data_expected
|
||||
)
|
||||
except GenerationCancelled:
|
||||
await self.abort_generation()
|
||||
raise
|
||||
@@ -863,6 +1052,7 @@ class ClientBase:
|
||||
kind: str = "conversation",
|
||||
finalize: Callable = lambda x: x,
|
||||
retries: int = 2,
|
||||
data_expected: bool | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the AI and return its response.
|
||||
@@ -871,6 +1061,7 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
try:
|
||||
self.rate_limit_update()
|
||||
if self.rate_limit_counter:
|
||||
aborted: bool = False
|
||||
while not self.rate_limit_counter.increment():
|
||||
@@ -927,12 +1118,27 @@ class ClientBase:
|
||||
try:
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
self._reasoning_response = None
|
||||
|
||||
self.emit_status(processing=True)
|
||||
await self.status()
|
||||
|
||||
prompt_param = self.generate_prompt_parameters(kind)
|
||||
|
||||
if self.reason_enabled and not data_expected:
|
||||
prompt = self.attach_response_length_instruction(
|
||||
prompt,
|
||||
(prompt_param.get(self.max_tokens_param_name) or 0)
|
||||
- self.reason_tokens,
|
||||
)
|
||||
|
||||
if not self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
if coercion_prompt:
|
||||
prompt += f"{INDIRECT_COERCION_PROMPT}{coercion_prompt}"
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
finalized_prompt = self.prompt_template(
|
||||
self.get_system_message(kind), prompt
|
||||
).strip(" ")
|
||||
@@ -954,11 +1160,26 @@ class ClientBase:
|
||||
max_token_length=self.max_token_length,
|
||||
parameters=prompt_param,
|
||||
)
|
||||
prompt_sent = self.repetition_adjustment(finalized_prompt)
|
||||
|
||||
if "<|RESPONSE_LENGTH_INSTRUCTIONS|>" in finalized_prompt:
|
||||
finalized_prompt = finalized_prompt.replace(
|
||||
"\n<|RESPONSE_LENGTH_INSTRUCTIONS|>", ""
|
||||
)
|
||||
|
||||
self.new_request()
|
||||
|
||||
response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
|
||||
response = await self._cancelable_generate(
|
||||
finalized_prompt, prompt_param, kind
|
||||
)
|
||||
|
||||
response, reasoning_response = self.strip_reasoning(response)
|
||||
if reasoning_response:
|
||||
self._reasoning_response = reasoning_response
|
||||
|
||||
if coercion_prompt:
|
||||
response = self.process_response_for_indirect_coercion(
|
||||
finalized_prompt, response, coercion_prompt
|
||||
)
|
||||
|
||||
self.end_request()
|
||||
|
||||
@@ -966,14 +1187,13 @@ class ClientBase:
|
||||
# generation was cancelled
|
||||
raise response
|
||||
|
||||
# response = await self.generate(prompt_sent, prompt_param, kind)
|
||||
|
||||
response, finalized_prompt = await self.auto_break_repetition(
|
||||
finalized_prompt, prompt_param, response, kind, retries
|
||||
)
|
||||
|
||||
if REPLACE_SMART_QUOTES:
|
||||
response = response.replace("“", '"').replace("”", '"')
|
||||
response = (
|
||||
response.replace("“", '"')
|
||||
.replace("”", '"')
|
||||
.replace("‘", "'")
|
||||
.replace("’", "'")
|
||||
)
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
@@ -991,7 +1211,7 @@ class ClientBase:
|
||||
"prompt_sent",
|
||||
data=PromptData(
|
||||
kind=kind,
|
||||
prompt=prompt_sent,
|
||||
prompt=finalized_prompt,
|
||||
response=response,
|
||||
prompt_tokens=self._returned_prompt_tokens or token_length,
|
||||
response_tokens=self._returned_response_tokens
|
||||
@@ -1003,12 +1223,17 @@ class ClientBase:
|
||||
generation_parameters=prompt_param,
|
||||
inference_preset=client_context_attribute("inference_preset"),
|
||||
preset_group=self.preset_group,
|
||||
reasoning=self._reasoning_response,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
except GenerationCancelled:
|
||||
raise
|
||||
except GenerationProcessingError as e:
|
||||
self.log.error("send_prompt error", e=e)
|
||||
emit("status", message=str(e), status="error")
|
||||
return ""
|
||||
except Exception:
|
||||
self.log.error("send_prompt error", e=traceback.format_exc())
|
||||
emit(
|
||||
@@ -1023,130 +1248,6 @@ class ClientBase:
|
||||
if self.rate_limit_counter:
|
||||
self.rate_limit_counter.increment()
|
||||
|
||||
async def auto_break_repetition(
|
||||
self,
|
||||
finalized_prompt: str,
|
||||
prompt_param: dict,
|
||||
response: str,
|
||||
kind: str,
|
||||
retries: int,
|
||||
pad_max_tokens: int = 32,
|
||||
) -> str:
|
||||
"""
|
||||
If repetition breaking is enabled, this will retry the prompt if its
|
||||
response is too similar to other messages in the prompt
|
||||
|
||||
This requires the agent to have the allow_repetition_break method
|
||||
and the jiggle_enabled_for method and the client to have the
|
||||
auto_break_repetition_enabled attribute set to True
|
||||
|
||||
Arguments:
|
||||
|
||||
- finalized_prompt: the prompt that was sent
|
||||
- prompt_param: the parameters that were used
|
||||
- response: the response that was received
|
||||
- kind: the kind of generation
|
||||
- retries: the number of retries left
|
||||
- pad_max_tokens: increase response max_tokens by this amount per iteration
|
||||
|
||||
Returns:
|
||||
|
||||
- the response
|
||||
"""
|
||||
|
||||
if not self.auto_break_repetition_enabled or not response.strip():
|
||||
return response, finalized_prompt
|
||||
|
||||
agent_context = active_agent.get()
|
||||
if self.jiggle_enabled_for(kind, auto=True):
|
||||
# check if the response is a repetition
|
||||
# using the default similarity threshold of 98, meaning it needs
|
||||
# to be really similar to be considered a repetition
|
||||
|
||||
is_repetition, similarity_score, matched_line = util.similarity_score(
|
||||
response, finalized_prompt.split("\n"), similarity_threshold=80
|
||||
)
|
||||
|
||||
if not is_repetition:
|
||||
# not a repetition, return the response
|
||||
|
||||
self.log.debug(
|
||||
"send_prompt no similarity", similarity_score=similarity_score
|
||||
)
|
||||
finalized_prompt = self.repetition_adjustment(
|
||||
finalized_prompt, is_repetitive=False
|
||||
)
|
||||
return response, finalized_prompt
|
||||
|
||||
while is_repetition and retries > 0:
|
||||
# it's a repetition, retry the prompt with adjusted parameters
|
||||
|
||||
self.log.warn(
|
||||
"send_prompt similarity retry",
|
||||
agent=agent_context.agent.agent_type,
|
||||
similarity_score=similarity_score,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
# first we apply the client's randomness jiggle which will adjust
|
||||
# parameters like temperature and repetition_penalty, depending
|
||||
# on the client
|
||||
#
|
||||
# this is a cumulative adjustment, so it will add to the previous
|
||||
# iteration's adjustment, this also means retries should be kept low
|
||||
# otherwise it will get out of hand and start generating nonsense
|
||||
|
||||
self.jiggle_randomness(prompt_param, offset=0.5)
|
||||
|
||||
# then we pad the max_tokens by the pad_max_tokens amount
|
||||
|
||||
prompt_param[self.max_tokens_param_name] += pad_max_tokens
|
||||
|
||||
# send the prompt again
|
||||
# we use the repetition_adjustment method to further encourage
|
||||
# the AI to break the repetition on its own as well.
|
||||
|
||||
finalized_prompt = self.repetition_adjustment(
|
||||
finalized_prompt, is_repetitive=True
|
||||
)
|
||||
|
||||
response = retried_response = await self.generate(
|
||||
finalized_prompt, prompt_param, kind
|
||||
)
|
||||
|
||||
self.log.debug(
|
||||
"send_prompt dedupe sentences",
|
||||
response=response,
|
||||
matched_line=matched_line,
|
||||
)
|
||||
|
||||
# a lot of the times the response will now contain the repetition + something new
|
||||
# so we dedupe the response to remove the repetition on sentences level
|
||||
|
||||
response = util.dedupe_sentences(
|
||||
response, matched_line, similarity_threshold=85, debug=True
|
||||
)
|
||||
self.log.debug(
|
||||
"send_prompt dedupe sentences (after)", response=response
|
||||
)
|
||||
|
||||
# deduping may have removed the entire response, so we check for that
|
||||
|
||||
if not util.strip_partial_sentences(response).strip():
|
||||
# if the response is empty, we set the response to the original
|
||||
# and try again next loop
|
||||
|
||||
response = retried_response
|
||||
|
||||
# check if the response is a repetition again
|
||||
|
||||
is_repetition, similarity_score, matched_line = util.similarity_score(
|
||||
response, finalized_prompt.split("\n"), similarity_threshold=80
|
||||
)
|
||||
retries -= 1
|
||||
|
||||
return response, finalized_prompt
|
||||
|
||||
def count_tokens(self, content: str):
|
||||
return util.count_tokens(content)
|
||||
|
||||
@@ -1169,31 +1270,9 @@ class ClientBase:
|
||||
|
||||
return agent.allow_repetition_break(kind, agent_context.action, auto=auto)
|
||||
|
||||
def repetition_adjustment(self, prompt: str, is_repetitive: bool = False):
|
||||
"""
|
||||
Breaks the prompt into lines and checkse each line for a match with
|
||||
[$REPETITION|{repetition_adjustment}].
|
||||
|
||||
On match and if is_repetitive is True, the line is removed from the prompt and
|
||||
replaced with the repetition_adjustment.
|
||||
|
||||
On match and if is_repetitive is False, the line is removed from the prompt.
|
||||
"""
|
||||
|
||||
lines = prompt.split("\n")
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
if line.startswith("[$REPETITION|"):
|
||||
if is_repetitive:
|
||||
new_lines.append(line.split("|")[1][:-1])
|
||||
else:
|
||||
new_lines.append("")
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def process_response_for_indirect_coercion(self, prompt: str, response: str) -> str:
|
||||
def process_response_for_indirect_coercion(
|
||||
self, prompt: str, response: str, coercion_prompt: str
|
||||
) -> str:
|
||||
"""
|
||||
A lot of remote APIs don't let us control the prompt template and we cannot directly
|
||||
append the beginning of the desired response to the prompt.
|
||||
@@ -1202,13 +1281,19 @@ class ClientBase:
|
||||
and then hopefully it will adhere to it and we can strip it off the actual response.
|
||||
"""
|
||||
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if coercion_prompt and coercion_prompt.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
log.debug(
|
||||
"process_response_for_indirect_coercion",
|
||||
response=f"|{response[:100]}...|",
|
||||
coercion_prompt=f"|{coercion_prompt}|",
|
||||
)
|
||||
|
||||
if coercion_prompt and response.startswith(coercion_prompt):
|
||||
response = response[len(coercion_prompt) :].strip()
|
||||
elif coercion_prompt and response.lstrip().startswith(coercion_prompt):
|
||||
response = response.lstrip()[len(coercion_prompt) :].strip()
|
||||
|
||||
return response
|
||||
|
||||
@@ -15,9 +15,8 @@ from talemate.client.remote import (
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig, load_config
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
@@ -54,7 +53,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
client_type = "cohere"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
@@ -67,18 +65,9 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="command-r-plus", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def cohere_api_key(self):
|
||||
return self.config.get("cohere", {}).get("api_key")
|
||||
return self.config.cohere.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -96,15 +85,15 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.cohere_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -117,7 +106,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
@@ -125,67 +114,18 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = AsyncClientV2("sk-1111")
|
||||
log.error("No cohere API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "command-r-plus"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncClientV2(self.api_key, base_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"cohere set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response)
|
||||
|
||||
@@ -195,16 +135,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
@@ -228,13 +158,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No cohere API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
client = AsyncClientV2(self.api_key, base_url=self.base_url)
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
@@ -263,7 +187,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
# manager, so attempting to use `async with` raises a `TypeError` as seen
|
||||
# in issue logs. We therefore iterate over the generator directly.
|
||||
|
||||
stream = self.client.chat_stream(
|
||||
stream = client.chat_stream(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
@@ -283,13 +207,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
log.debug("generated response", response=response)
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
# except PermissionDeniedError as e:
|
||||
# self.log.error("generate error", e=e)
|
||||
|
||||
@@ -4,9 +4,7 @@ from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
@@ -40,7 +38,6 @@ class DeepSeekClient(ClientBase):
|
||||
|
||||
client_type = "deepseek"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
@@ -52,17 +49,9 @@ class DeepSeekClient(ClientBase):
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="deepseek-chat", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def deepseek_api_key(self):
|
||||
return self.config.get("deepseek", {}).get("api_key")
|
||||
return self.config.deepseek.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -75,15 +64,15 @@ class DeepSeekClient(ClientBase):
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.deepseek_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -96,7 +85,7 @@ class DeepSeekClient(ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
@@ -104,66 +93,18 @@ class DeepSeekClient(ClientBase):
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.deepseek_api_key:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111", base_url=BASE_URL)
|
||||
log.error("No DeepSeek API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "deepseek-chat"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.deepseek_api_key, base_url=BASE_URL)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"deepseek set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def count_tokens(self, content: str):
|
||||
if not self.model_name:
|
||||
return 0
|
||||
@@ -172,18 +113,6 @@ class DeepSeekClient(ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
# only gpt-4-1106-preview supports json_object response coersion
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
# Count tokens in a response string using the util.count_tokens helper
|
||||
return self.count_tokens(response)
|
||||
@@ -200,20 +129,7 @@ class DeepSeekClient(ClientBase):
|
||||
if not self.deepseek_api_key:
|
||||
raise Exception("No DeepSeek API key set")
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
supports_json_object = (
|
||||
self.model_name.startswith("gpt-4-")
|
||||
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
)
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
client = AsyncOpenAI(api_key=self.deepseek_api_key, base_url=BASE_URL)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
@@ -227,7 +143,7 @@ class DeepSeekClient(ClientBase):
|
||||
|
||||
try:
|
||||
# Use streaming so we can update_Request_tokens incrementally
|
||||
stream = await self.client.chat.completions.create(
|
||||
stream = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[system_message, human_message],
|
||||
stream=True,
|
||||
@@ -251,20 +167,6 @@ class DeepSeekClient(ClientBase):
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
|
||||
@@ -21,10 +21,8 @@ from talemate.client.remote import (
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config import load_config
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
@@ -41,10 +39,11 @@ SUPPORTED_MODELS = [
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-flash-lite-preview-06-17",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-pro-preview-06-05",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
|
||||
@@ -59,6 +58,9 @@ class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
|
||||
MIN_THINKING_TOKENS = 0
|
||||
|
||||
|
||||
@register()
|
||||
class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
"""
|
||||
@@ -67,7 +69,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
client_type = "google"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
@@ -90,21 +91,23 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
extra_fields.update(endpoint_override_extra_fields())
|
||||
|
||||
def __init__(self, model="gemini-2.0-flash", **kwargs):
|
||||
self.model_name = model
|
||||
self.setup_status = None
|
||||
self.model_instance = None
|
||||
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
|
||||
self.google_credentials_read = False
|
||||
self.google_project_id = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
@property
|
||||
def disable_safety_settings(self):
|
||||
return self.client_config.disable_safety_settings
|
||||
|
||||
@property
|
||||
def min_reason_tokens(self) -> int:
|
||||
return MIN_THINKING_TOKENS
|
||||
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return True
|
||||
return not self.reason_enabled
|
||||
|
||||
@property
|
||||
def google_credentials(self):
|
||||
@@ -116,15 +119,15 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
@property
|
||||
def google_credentials_path(self):
|
||||
return self.config.get("google").get("gcloud_credentials_path")
|
||||
return self.config.google.gcloud_credentials_path
|
||||
|
||||
@property
|
||||
def google_location(self):
|
||||
return self.config.get("google").get("gcloud_location")
|
||||
return self.config.google.gcloud_location
|
||||
|
||||
@property
|
||||
def google_api_key(self):
|
||||
return self.config.get("google").get("api_key")
|
||||
return self.config.google.api_key
|
||||
|
||||
@property
|
||||
def vertexai_ready(self) -> bool:
|
||||
@@ -197,6 +200,16 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
return genai_types.HttpOptions(base_url=self.base_url)
|
||||
|
||||
@property
|
||||
def thinking_config(self) -> genai_types.ThinkingConfig | None:
|
||||
if not self.reason_enabled:
|
||||
return None
|
||||
|
||||
return genai_types.ThinkingConfig(
|
||||
thinking_budget=self.validated_reason_tokens,
|
||||
include_thoughts=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
@@ -211,6 +224,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def requires_reasoning_pattern(self) -> bool:
|
||||
return False
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
@@ -269,46 +286,20 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
"Error setting client base URL", error=e, client=self.client_type
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None, **kwargs):
|
||||
if not self.ready:
|
||||
log.error("Google cloud setup incomplete")
|
||||
if self.setup_status:
|
||||
self.setup_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gemini-2.0-flash"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
def make_client(self) -> genai.Client:
|
||||
if self.google_credentials_path:
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_credentials_path
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if self.vertexai_ready and not self.developer_api_ready:
|
||||
self.client = genai.Client(
|
||||
return genai.Client(
|
||||
vertexai=True,
|
||||
project=self.google_project_id,
|
||||
location=self.google_location,
|
||||
)
|
||||
else:
|
||||
self.client = genai.Client(
|
||||
return genai.Client(
|
||||
api_key=self.api_key or None, http_options=self.http_options
|
||||
)
|
||||
|
||||
log.info(
|
||||
"google set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
"""Return token count for a response which may be a string or SDK object."""
|
||||
return count_tokens(response)
|
||||
@@ -316,22 +307,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "disable_safety_settings" in kwargs:
|
||||
self.disable_safety_settings = kwargs["disable_safety_settings"]
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
@@ -339,13 +314,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
if "top_k" in parameters and parameters["top_k"] == 0:
|
||||
del parameters["top_k"]
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
"""
|
||||
Google handles the prompt template internally, so we just
|
||||
give the prompt as is.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -354,7 +322,12 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
if not self.ready:
|
||||
raise Exception("Google setup incomplete")
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
client = self.make_client()
|
||||
|
||||
if self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
@@ -371,6 +344,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
|
||||
contents.append(
|
||||
genai_types.Content(
|
||||
role="model",
|
||||
@@ -384,49 +358,57 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
model=self.model_name,
|
||||
base_url=self.base_url,
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
disable_safety_settings=self.disable_safety_settings,
|
||||
safety_settings=self.safety_settings,
|
||||
thinking_config=self.thinking_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# Use streaming so we can update_Request_tokens incrementally
|
||||
# stream = await chat.send_message_async(
|
||||
# human_message,
|
||||
# safety_settings=self.safety_settings,
|
||||
# generation_config=parameters,
|
||||
# stream=True
|
||||
# )
|
||||
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
stream = await client.aio.models.generate_content_stream(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=genai_types.GenerateContentConfig(
|
||||
safety_settings=self.safety_settings,
|
||||
http_options=self.http_options,
|
||||
thinking_config=self.thinking_config,
|
||||
**parameters,
|
||||
),
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
reasoning = ""
|
||||
# https://ai.google.dev/gemini-api/docs/thinking#summaries
|
||||
async for chunk in stream:
|
||||
# For each streamed chunk, append content and update token counts
|
||||
content_piece = getattr(chunk, "text", None)
|
||||
if not content_piece:
|
||||
# Some SDK versions wrap text under candidates[0].text
|
||||
try:
|
||||
content_piece = chunk.candidates[0].text # type: ignore
|
||||
except Exception:
|
||||
content_piece = None
|
||||
try:
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
if content_piece:
|
||||
response += content_piece
|
||||
# Incrementally update token usage
|
||||
self.update_request_tokens(count_tokens(content_piece))
|
||||
if not chunk.candidates:
|
||||
continue
|
||||
|
||||
if not chunk.candidates[0].content.parts:
|
||||
continue
|
||||
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if not part.text:
|
||||
continue
|
||||
if part.thought:
|
||||
reasoning += part.text
|
||||
else:
|
||||
response += part.text
|
||||
self.update_request_tokens(count_tokens(part.text))
|
||||
except Exception as e:
|
||||
log.error("error processing chunk", e=e, chunk=chunk)
|
||||
continue
|
||||
|
||||
if reasoning:
|
||||
self._reasoning_response = reasoning
|
||||
|
||||
# Store total token accounting for prompt/response
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
|
||||
@@ -4,9 +4,8 @@ from groq import AsyncGroq, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
@@ -23,6 +22,10 @@ SUPPORTED_MODELS = [
|
||||
"mixtral-8x7b-32768",
|
||||
"llama3-8b-8192",
|
||||
"llama3-70b-8192",
|
||||
"llama-3.3-70b-versatile",
|
||||
"qwen/qwen3-32b",
|
||||
"moonshotai/kimi-k2-instruct",
|
||||
"deepseek-r1-distill-llama-70b",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = []
|
||||
@@ -30,7 +33,11 @@ JSON_OBJECT_RESPONSE_MODELS = []
|
||||
|
||||
class Defaults(EndpointOverride, pydantic.BaseModel):
|
||||
max_token_length: int = 8192
|
||||
model: str = "llama3-70b-8192"
|
||||
model: str = "moonshotai/kimi-k2-instruct"
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register()
|
||||
@@ -41,9 +48,9 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
client_type = "groq"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Groq"
|
||||
@@ -54,19 +61,13 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="llama3-70b-8192", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
# Apply any endpoint override parameters provided via kwargs before creating client
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return not self.reason_enabled
|
||||
|
||||
@property
|
||||
def groq_api_key(self):
|
||||
return self.config.get("groq", {}).get("api_key")
|
||||
return self.config.groq.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -83,15 +84,15 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.groq_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -104,7 +105,7 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
@@ -112,6 +113,7 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
# Include shared/common status data (rate limit, etc.)
|
||||
data.update(self._common_status_data())
|
||||
@@ -120,66 +122,11 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
# Determine if we should use the globally configured API key or the override key
|
||||
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
|
||||
# No API key and no endpoint override – cannot initialize client correctly
|
||||
self.client = AsyncGroq(api_key="sk-1111")
|
||||
log.error("No groq.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "llama3-70b-8192"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
# Use the override values (if any) when constructing the Groq client
|
||||
self.client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"groq.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
# Allow dynamic reconfiguration of endpoint override parameters
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
# Reconfigure any common parameters (rate limit, data format, etc.)
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
@@ -189,16 +136,6 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -207,16 +144,12 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No groq.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
if self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
@@ -225,6 +158,10 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
@@ -233,27 +170,25 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
stream = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
response = ""
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
# Iterate over streamed chunks
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if delta and getattr(delta, "content", None):
|
||||
content_piece = delta.content
|
||||
response += content_piece
|
||||
# Incrementally track token usage
|
||||
self.update_request_tokens(self.count_tokens(content_piece))
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
|
||||
@@ -75,6 +75,7 @@ class KoboldEmbeddingFunction(EmbeddingFunction):
|
||||
class KoboldCppClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "koboldcpp"
|
||||
remote_model_locked: bool = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "KoboldCpp"
|
||||
@@ -188,6 +189,10 @@ class KoboldCppClient(ClientBase):
|
||||
def embeddings_function(self):
|
||||
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
|
||||
|
||||
@property
|
||||
def default_prompt_template(self) -> str:
|
||||
return "KoboldAI.jinja2"
|
||||
|
||||
def api_endpoint_specified(self, url: str) -> bool:
|
||||
return "/v1" in self.api_url
|
||||
|
||||
@@ -200,14 +205,9 @@ class KoboldCppClient(ClientBase):
|
||||
self.api_url += "/"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.api_key = kwargs.pop("api_key", "")
|
||||
super().__init__(**kwargs)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
async def get_embeddings_model_name(self):
|
||||
# if self._embeddings_model_name is set, return it
|
||||
if self.embeddings_model_name:
|
||||
@@ -245,15 +245,21 @@ class KoboldCppClient(ClientBase):
|
||||
model_name=self.embeddings_model_name,
|
||||
)
|
||||
|
||||
self.set_embeddings()
|
||||
await self.set_embeddings()
|
||||
|
||||
await async_signals.get("client.embeddings_available").send(
|
||||
ClientEmbeddingsStatus(
|
||||
client=self,
|
||||
embedding_name=self.embeddings_model_name,
|
||||
)
|
||||
emission = ClientEmbeddingsStatus(
|
||||
client=self,
|
||||
embedding_name=self.embeddings_model_name,
|
||||
)
|
||||
|
||||
await async_signals.get("client.embeddings_available").send(emission)
|
||||
|
||||
if not emission.seen:
|
||||
# the suggestion has not been seen by the memory agent
|
||||
# yet, so we unset the embeddings model name so it will
|
||||
# get suggested again
|
||||
self._embeddings_model_name = None
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
@@ -437,12 +443,6 @@ class KoboldCppClient(ClientBase):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs.pop("api_key")
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
async def visual_automatic1111_setup(self, visual_agent: "VisualBase") -> bool:
|
||||
"""
|
||||
Automatically configure the visual agent for automatic1111
|
||||
|
||||
@@ -14,6 +14,7 @@ class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
class LMStudioClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "lmstudio"
|
||||
remote_model_locked: bool = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "LMStudio"
|
||||
@@ -32,17 +33,16 @@ class LMStudioClient(ClientBase):
|
||||
),
|
||||
]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
if self.client and self.client.base_url != self.api_url:
|
||||
self.set_client()
|
||||
def make_client(self):
|
||||
return AsyncOpenAI(base_url=self.api_url + "/v1", api_key=self.api_key)
|
||||
|
||||
async def get_model_name(self):
|
||||
model_name = await super().get_model_name()
|
||||
client = self.make_client()
|
||||
models = await client.models.list(timeout=self.status_request_timeout)
|
||||
try:
|
||||
model_name = models.data[0].id
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
# model name comes back as a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
@@ -65,9 +65,11 @@ class LMStudioClient(ClientBase):
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
client = self.make_client()
|
||||
|
||||
try:
|
||||
# Send the request in streaming mode so we can update token counts
|
||||
stream = await self.client.completions.create(
|
||||
stream = await client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
|
||||
@@ -15,9 +15,8 @@ from talemate.client.remote import (
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig, load_config
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"MistralAIClient",
|
||||
@@ -33,14 +32,7 @@ SUPPORTED_MODELS = [
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"open-mixtral-8x22b",
|
||||
"open-mistral-nemo",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
"magistral-medium-2506",
|
||||
]
|
||||
|
||||
|
||||
@@ -61,7 +53,6 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
client_type = "mistral"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
@@ -75,17 +66,13 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x22b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return not self.reason_enabled
|
||||
|
||||
@property
|
||||
def mistral_api_key(self):
|
||||
return self.config.get("mistralai", {}).get("api_key")
|
||||
return self.config.mistralai.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -97,15 +84,15 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.mistral_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -118,74 +105,25 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.mistral_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = Mistral(api_key="sk-1111")
|
||||
log.error("No mistral.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "open-mixtral-8x22b"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = Mistral(api_key=self.api_key, server_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"mistral.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
@@ -195,16 +133,6 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
super().clean_prompt_parameters(parameters)
|
||||
# clamp temperature to 0.1 and 1.0
|
||||
@@ -220,16 +148,12 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
if not self.mistral_api_key:
|
||||
raise Exception("No mistral.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
client = Mistral(api_key=self.api_key, server_url=self.base_url)
|
||||
|
||||
if self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
@@ -238,6 +162,16 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
{"role": "user", "content": prompt.strip()},
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": coercion_prompt.strip(),
|
||||
"prefix": True,
|
||||
}
|
||||
)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
base_url=self.base_url,
|
||||
@@ -247,7 +181,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
)
|
||||
|
||||
try:
|
||||
event_stream = await self.client.chat.stream_async(
|
||||
event_stream = await client.chat.stream_async(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
@@ -271,22 +205,6 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
# response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except SDKError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
|
||||
@@ -27,6 +27,8 @@ TALEMATE_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "talemate")
|
||||
# user overrides
|
||||
USER_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "user")
|
||||
|
||||
DEFAULT_TEMPLATE = "default.jinja2"
|
||||
|
||||
TEMPLATE_IDENTIFIERS = []
|
||||
|
||||
|
||||
@@ -73,10 +75,11 @@ class ModelPrompt:
|
||||
system_message: str,
|
||||
prompt: str,
|
||||
double_coercion: str = None,
|
||||
default_template: str = DEFAULT_TEMPLATE,
|
||||
):
|
||||
template, template_file = self.get_template(model_name)
|
||||
if not template:
|
||||
template_file = "default.jinja2"
|
||||
template_file = default_template
|
||||
template = self.env.get_template(template_file)
|
||||
|
||||
if not double_coercion:
|
||||
|
||||
@@ -12,7 +12,7 @@ from talemate.client.base import (
|
||||
ExtraField,
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.ollama")
|
||||
|
||||
@@ -24,12 +24,10 @@ class OllamaClientDefaults(CommonDefaults):
|
||||
api_url: str = "http://localhost:11434" # Default Ollama URL
|
||||
model: str = "" # Allow empty default, will fetch from Ollama
|
||||
api_handles_prompt_template: bool = False
|
||||
allow_thinking: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
allow_thinking: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
@@ -58,13 +56,6 @@ class OllamaClient(ClientBase):
|
||||
required=False,
|
||||
description="Let Ollama handle the prompt template. Only do this if you don't know which prompt template to use. Letting talemate handle the prompt template will generally lead to improved responses.",
|
||||
),
|
||||
"allow_thinking": ExtraField(
|
||||
name="allow_thinking",
|
||||
type="bool",
|
||||
label="Allow thinking",
|
||||
required=False,
|
||||
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -90,51 +81,25 @@ class OllamaClient(ClientBase):
|
||||
"extra_stopping_strings",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
self._available_models = []
|
||||
self._models_last_fetched = 0
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def can_be_coerced(self):
|
||||
"""
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
return not self.api_handles_prompt_template
|
||||
return not self.api_handles_prompt_template and not self.reason_enabled
|
||||
|
||||
@property
|
||||
def can_think(self) -> bool:
|
||||
"""
|
||||
Allow reasoning models to think before responding.
|
||||
"""
|
||||
return self.allow_thinking
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
api_handles_prompt_template=False,
|
||||
allow_thinking=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
self.allow_thinking = allow_thinking
|
||||
self._available_models = []
|
||||
self._models_last_fetched = 0
|
||||
self.client = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
"""
|
||||
Initialize the Ollama client with the API URL.
|
||||
"""
|
||||
# Update model if provided
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
|
||||
# Create async client with the configured API URL
|
||||
# Ollama's AsyncClient expects just the base URL without any path
|
||||
self.client = ollama.AsyncClient(host=self.api_url)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
self.allow_thinking = kwargs.get("allow_thinking", self.allow_thinking)
|
||||
def api_handles_prompt_template(self) -> bool:
|
||||
return self.client_config.api_handles_prompt_template
|
||||
|
||||
async def status(self):
|
||||
"""
|
||||
@@ -177,7 +142,9 @@ class OllamaClient(ClientBase):
|
||||
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
|
||||
return self._available_models
|
||||
|
||||
response = await self.client.list()
|
||||
client = ollama.AsyncClient(host=self.api_url)
|
||||
|
||||
response = await client.list()
|
||||
models = response.get("models", [])
|
||||
model_names = [model.model for model in models]
|
||||
self._available_models = sorted(model_names)
|
||||
@@ -192,19 +159,11 @@ class OllamaClient(ClientBase):
|
||||
return data
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
return self.model
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
@@ -251,6 +210,8 @@ class OllamaClient(ClientBase):
|
||||
if not self.model_name:
|
||||
raise Exception("No model specified or available in Ollama")
|
||||
|
||||
client = ollama.AsyncClient(host=self.api_url)
|
||||
|
||||
# Prepare options for Ollama
|
||||
options = parameters
|
||||
|
||||
@@ -258,12 +219,11 @@ class OllamaClient(ClientBase):
|
||||
|
||||
try:
|
||||
# Use generate endpoint for completion
|
||||
stream = await self.client.generate(
|
||||
stream = await client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt.strip(),
|
||||
options=options,
|
||||
raw=self.can_be_coerced,
|
||||
think=self.can_think,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -306,20 +266,3 @@ class OllamaClient(ClientBase):
|
||||
prompt_config["repetition_penalty"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
"""
|
||||
Reconfigure the client with new settings.
|
||||
"""
|
||||
# Handle model update
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
# Re-initialize client if API URL changed or model changed
|
||||
if "api_url" in kwargs or "model" in kwargs:
|
||||
self.set_client(**kwargs)
|
||||
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
|
||||
@@ -12,9 +12,8 @@ from talemate.client.remote import (
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig, load_config
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"OpenAIClient",
|
||||
@@ -44,22 +43,15 @@ SUPPORTED_MODELS = [
|
||||
"o1-preview",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
# any model starting with gpt-4- is assumed to support 'json_object'
|
||||
# for others we need to explicitly state the model name
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4o-2024-11-20",
|
||||
"gpt-4o-realtime-preview",
|
||||
"gpt-4o-mini-realtime-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
]
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
|
||||
# TODO this whole function probably needs to be rewritten at this point
|
||||
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -83,7 +75,7 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model or "o1" in model or "o3" in model:
|
||||
elif "gpt-4" in model or "o1" in model or "o3" in model or "gpt-5" in model:
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -104,9 +96,13 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
|
||||
return num_tokens
|
||||
|
||||
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gpt-4o"
|
||||
model: str = DEFAULT_MODEL
|
||||
reason_tokens: int = 1024
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
@@ -121,7 +117,6 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
client_type = "openai"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
config_cls = ClientConfig
|
||||
@@ -135,18 +130,9 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="gpt-4o", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
return self.config.openai.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -157,17 +143,35 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
@property
|
||||
def requires_reasoning_pattern(self) -> bool:
|
||||
return False
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
# Auto-toggle reasoning based on selected model (OpenAI-specific)
|
||||
# o1/o3/gpt-5 families are reasoning models
|
||||
try:
|
||||
if self.model_name:
|
||||
is_reasoning_model = (
|
||||
"o1" in self.model_name
|
||||
or "o3" in self.model_name
|
||||
or "gpt-5" in self.model_name
|
||||
)
|
||||
if self.client_config.reason_enabled != is_reasoning_model:
|
||||
self.client_config.reason_enabled = is_reasoning_model
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.openai_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -180,7 +184,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
@@ -188,6 +192,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
|
||||
@@ -195,74 +200,11 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111")
|
||||
log.error("No OpenAI API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gpt-3.5-turbo-16k"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
if model == "gpt-3.5-turbo":
|
||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||
elif model == "gpt-4":
|
||||
self.max_token_length = min(max_token_length or 8192, 8192)
|
||||
elif model == "gpt-3.5-turbo-16k":
|
||||
self.max_token_length = min(max_token_length or 16384, 16384)
|
||||
elif model.startswith("gpt-4o") and model != "gpt-4o-2024-05-13":
|
||||
self.max_token_length = min(max_token_length or 16384, 16384)
|
||||
elif model == "gpt-4o-2024-05-13":
|
||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||
elif model == "gpt-4-1106-preview":
|
||||
self.max_token_length = min(max_token_length or 128000, 128000)
|
||||
else:
|
||||
self.max_token_length = max_token_length or 8192
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"openai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def count_tokens(self, content: str):
|
||||
if not self.model_name:
|
||||
return 0
|
||||
@@ -271,18 +213,6 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
# only gpt-4-1106-preview supports json_object response coersion
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -291,26 +221,17 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No OpenAI API key set")
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
supports_json_object = (
|
||||
self.model_name.startswith("gpt-4-")
|
||||
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
)
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
# o1 and o3 models don't support system_message
|
||||
if "o1" in self.model_name or "o3" in self.model_name:
|
||||
if (
|
||||
"o1" in self.model_name
|
||||
or "o3" in self.model_name
|
||||
or "gpt-5" in self.model_name
|
||||
):
|
||||
messages = [human_message]
|
||||
# paramters need to be munged
|
||||
# `max_tokens` becomes `max_completion_tokens`
|
||||
@@ -339,13 +260,20 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
model=self.model_name,
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
# GPT-5 models do not allow streaming for non-verified orgs; use non-streaming path
|
||||
if "gpt-5" in self.model_name:
|
||||
return await self._generate_non_streaming_completion(
|
||||
client, messages, parameters
|
||||
)
|
||||
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(
|
||||
stream = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
@@ -365,23 +293,6 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
# Incrementally track token usage
|
||||
self.update_request_tokens(self.count_tokens(content_piece))
|
||||
|
||||
# self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
# self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
@@ -389,3 +300,36 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
return ""
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
async def _generate_non_streaming_completion(
|
||||
self, client: AsyncOpenAI, messages: list[dict], parameters: dict
|
||||
) -> str:
|
||||
"""Perform a non-streaming chat completion request and return the content.
|
||||
|
||||
This is used for GPT-5 models which disallow streaming for non-verified orgs.
|
||||
"""
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
# No stream flag -> non-streaming
|
||||
**parameters,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
return ""
|
||||
|
||||
message = response.choices[0].message
|
||||
content = getattr(message, "content", "") or ""
|
||||
|
||||
if content:
|
||||
# Update token usage based on the full content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
return content
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate (non-streaming) error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
|
||||
log = structlog.get_logger("talemate.client.openai_compat")
|
||||
@@ -51,13 +51,9 @@ class OpenAICompatibleClient(ClientBase):
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_key = api_key
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
super().__init__(**kwargs)
|
||||
@property
|
||||
def api_handles_prompt_template(self) -> bool:
|
||||
return self.client_config.api_handles_prompt_template
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
@@ -69,7 +65,7 @@ class OpenAICompatibleClient(ClientBase):
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
return not self.api_handles_prompt_template
|
||||
return not self.reason_enabled
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -80,43 +76,21 @@ class OpenAICompatibleClient(ClientBase):
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
url = self.api_url
|
||||
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
)
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
)
|
||||
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
return self.model
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
client = AsyncOpenAI(base_url=self.api_url, api_key=self.api_key)
|
||||
|
||||
try:
|
||||
if self.api_handles_prompt_template:
|
||||
# OpenAI API handles prompt template
|
||||
@@ -126,15 +100,37 @@ class OpenAICompatibleClient(ClientBase):
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
)
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
response = await self.client.chat.completions.create(
|
||||
|
||||
if self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.get_system_message(kind)},
|
||||
{"role": "user", "content": prompt.strip()},
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
log.debug(
|
||||
"Adding coercion pre-fill", coercion_prompt=coercion_prompt
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": coercion_prompt.strip(),
|
||||
"prefix": True,
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[human_message],
|
||||
messages=messages,
|
||||
stream=False,
|
||||
**parameters,
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
return self.process_response_for_indirect_coercion(prompt, response)
|
||||
return response
|
||||
else:
|
||||
# Talemate handles prompt template
|
||||
# Use the completions endpoint
|
||||
@@ -144,7 +140,7 @@ class OpenAICompatibleClient(ClientBase):
|
||||
parameters=parameters,
|
||||
)
|
||||
parameters["prompt"] = prompt
|
||||
response = await self.client.completions.create(
|
||||
response = await client.completions.create(
|
||||
model=self.model_name, stream=False, **parameters
|
||||
)
|
||||
return response.choices[0].text
|
||||
@@ -159,34 +155,6 @@ class OpenAICompatibleClient(ClientBase):
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
|
||||
)
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs["api_key"]
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
# TODO: why isn't this calling super()?
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
if "rate_limit" in kwargs:
|
||||
self.rate_limit = kwargs["rate_limit"]
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and presence penalty
|
||||
|
||||
@@ -4,9 +4,17 @@ import httpx
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
|
||||
from talemate.client.base import (
|
||||
ClientBase,
|
||||
ErrorAction,
|
||||
CommonDefaults,
|
||||
ExtraField,
|
||||
FieldGroup,
|
||||
)
|
||||
from talemate.config.schema import Client as BaseClientConfig
|
||||
from talemate.config import get_config
|
||||
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
@@ -18,6 +26,91 @@ log = structlog.get_logger("talemate.client.openrouter")
|
||||
|
||||
# Available models will be populated when first client with API key is initialized
|
||||
AVAILABLE_MODELS = []
|
||||
|
||||
# Static list of providers that are supported by OpenRouter
|
||||
# https://openrouter.ai/docs/features/provider-routing#json-schema-for-provider-preferences
|
||||
|
||||
|
||||
AVAILABLE_PROVIDERS = [
|
||||
"AnyScale",
|
||||
"Cent-ML",
|
||||
"HuggingFace",
|
||||
"Hyperbolic 2",
|
||||
"Lepton",
|
||||
"Lynn 2",
|
||||
"Lynn",
|
||||
"Mancer",
|
||||
"Modal",
|
||||
"OctoAI",
|
||||
"Recursal",
|
||||
"Reflection",
|
||||
"Replicate",
|
||||
"SambaNova 2",
|
||||
"SF Compute",
|
||||
"Together 2",
|
||||
"01.AI",
|
||||
"AI21",
|
||||
"AionLabs",
|
||||
"Alibaba",
|
||||
"Amazon Bedrock",
|
||||
"Anthropic",
|
||||
"AtlasCloud",
|
||||
"Atoma",
|
||||
"Avian",
|
||||
"Azure",
|
||||
"BaseTen",
|
||||
"Cerebras",
|
||||
"Chutes",
|
||||
"Cloudflare",
|
||||
"Cohere",
|
||||
"CrofAI",
|
||||
"Crusoe",
|
||||
"DeepInfra",
|
||||
"DeepSeek",
|
||||
"Enfer",
|
||||
"Featherless",
|
||||
"Fireworks",
|
||||
"Friendli",
|
||||
"GMICloud",
|
||||
"Google",
|
||||
"Google AI Studio",
|
||||
"Groq",
|
||||
"Hyperbolic",
|
||||
"Inception",
|
||||
"InferenceNet",
|
||||
"Infermatic",
|
||||
"Inflection",
|
||||
"InoCloud",
|
||||
"Kluster",
|
||||
"Lambda",
|
||||
"Liquid",
|
||||
"Mancer 2",
|
||||
"Meta",
|
||||
"Minimax",
|
||||
"Mistral",
|
||||
"Moonshot AI",
|
||||
"Morph",
|
||||
"NCompass",
|
||||
"Nebius",
|
||||
"NextBit",
|
||||
"Nineteen",
|
||||
"Novita",
|
||||
"OpenAI",
|
||||
"OpenInference",
|
||||
"Parasail",
|
||||
"Perplexity",
|
||||
"Phala",
|
||||
"SambaNova",
|
||||
"Stealth",
|
||||
"Switchpoint",
|
||||
"Targon",
|
||||
"Together",
|
||||
"Ubicloud",
|
||||
"Venice",
|
||||
"xAI",
|
||||
]
|
||||
AVAILABLE_PROVIDERS.sort()
|
||||
|
||||
DEFAULT_MODEL = ""
|
||||
MODELS_FETCHED = False
|
||||
|
||||
@@ -25,7 +118,6 @@ MODELS_FETCHED = False
|
||||
async def fetch_available_models(api_key: str = None):
|
||||
"""Fetch available models from OpenRouter API"""
|
||||
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
|
||||
|
||||
if not api_key:
|
||||
return []
|
||||
|
||||
@@ -37,6 +129,7 @@ async def fetch_available_models(api_key: str = None):
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
try:
|
||||
log.debug("Fetching models from OpenRouter")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://openrouter.ai/api/v1/models", timeout=10.0
|
||||
@@ -61,19 +154,36 @@ async def fetch_available_models(api_key: str = None):
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
|
||||
def fetch_models_sync(event):
|
||||
api_key = event.data.get("openrouter", {}).get("api_key")
|
||||
def fetch_models_sync(api_key: str):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(fetch_available_models(api_key))
|
||||
|
||||
|
||||
handlers["config_saved"].connect(fetch_models_sync)
|
||||
handlers["talemate_started"].connect(fetch_models_sync)
|
||||
def on_talemate_started(event):
|
||||
fetch_models_sync(get_config().openrouter.api_key)
|
||||
|
||||
|
||||
handlers["talemate_started"].connect(on_talemate_started)
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = DEFAULT_MODEL
|
||||
provider_only: list[str] = pydantic.Field(default_factory=list)
|
||||
provider_ignore: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
provider_only: list[str] = pydantic.Field(default_factory=list)
|
||||
provider_ignore: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
PROVIDER_FIELD_GROUP = FieldGroup(
|
||||
name="provider",
|
||||
label="Provider",
|
||||
description="Configure OpenRouter provider routing.",
|
||||
icon="mdi-server-network",
|
||||
)
|
||||
|
||||
|
||||
@register()
|
||||
@@ -84,9 +194,9 @@ class OpenRouterClient(ClientBase):
|
||||
|
||||
client_type = "openrouter"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "OpenRouter"
|
||||
@@ -97,23 +207,46 @@ class OpenRouterClient(ClientBase):
|
||||
)
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"provider_only": ExtraField(
|
||||
name="provider_only",
|
||||
type="flags",
|
||||
label="Only use these providers",
|
||||
choices=AVAILABLE_PROVIDERS,
|
||||
description="Manually limit the providers to use for the selected model. This will override the default provider selection for this model.",
|
||||
group=PROVIDER_FIELD_GROUP,
|
||||
required=False,
|
||||
),
|
||||
"provider_ignore": ExtraField(
|
||||
name="provider_ignore",
|
||||
type="flags",
|
||||
label="Ignore these providers",
|
||||
choices=AVAILABLE_PROVIDERS,
|
||||
description="Ignore these providers for the selected model. This will override the default provider selection for this model.",
|
||||
group=PROVIDER_FIELD_GROUP,
|
||||
required=False,
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
self.model_name = model or DEFAULT_MODEL
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
def __init__(self, **kwargs):
|
||||
self._models_fetched = False
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
@property
|
||||
def provider_only(self) -> list[str]:
|
||||
return self.client_config.provider_only
|
||||
|
||||
@property
|
||||
def provider_ignore(self) -> list[str]:
|
||||
return self.client_config.provider_ignore
|
||||
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return True
|
||||
return not self.reason_enabled
|
||||
|
||||
@property
|
||||
def openrouter_api_key(self):
|
||||
return self.config.get("openrouter", {}).get("api_key")
|
||||
return self.config.openrouter.api_key
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -130,15 +263,15 @@ class OpenRouterClient(ClientBase):
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
error_message = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.openrouter_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_message = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
@@ -151,7 +284,7 @@ class OpenRouterClient(ClientBase):
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
error_message = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
@@ -159,6 +292,7 @@ class OpenRouterClient(ClientBase):
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
"error_message": error_message,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
|
||||
@@ -166,60 +300,11 @@ class OpenRouterClient(ClientBase):
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=self.model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
# Unlike other clients, we don't need to set up a client instance
|
||||
# We'll use httpx directly in the generate method
|
||||
|
||||
if not self.openrouter_api_key:
|
||||
log.error("No OpenRouter API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = DEFAULT_MODEL
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
# Set max token length (default to 16k if not specified)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"openrouter set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
async def status(self):
|
||||
# Fetch models if we have an API key and haven't fetched yet
|
||||
if self.openrouter_api_key and not self._models_fetched:
|
||||
@@ -229,13 +314,6 @@ class OpenRouterClient(ClientBase):
|
||||
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
"""
|
||||
Open-router handles the prompt template internally, so we just
|
||||
give the prompt as is.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters using OpenRouter API.
|
||||
@@ -244,7 +322,10 @@ class OpenRouterClient(ClientBase):
|
||||
if not self.openrouter_api_key:
|
||||
raise Exception("No OpenRouter API key set")
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
if self.can_be_coerced:
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
else:
|
||||
coercion_prompt = None
|
||||
|
||||
# Prepare messages for chat completion
|
||||
messages = [
|
||||
@@ -253,7 +334,23 @@ class OpenRouterClient(ClientBase):
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": coercion_prompt.strip(),
|
||||
"prefix": True,
|
||||
}
|
||||
)
|
||||
|
||||
provider = {}
|
||||
if self.provider_only:
|
||||
provider["only"] = self.provider_only
|
||||
if self.provider_ignore:
|
||||
provider["ignore"] = self.provider_ignore
|
||||
|
||||
if provider:
|
||||
parameters["provider"] = provider
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
@@ -320,7 +417,7 @@ class OpenRouterClient(ClientBase):
|
||||
self.count_tokens(content)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
# Extract the response content
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
|
||||
from talemate.client.context import set_client_context_attribute
|
||||
from talemate.config import InferencePresets, InferencePresetGroup, load_config
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.client.base import ClientBase
|
||||
@@ -20,42 +19,19 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.client.presets")
|
||||
|
||||
config = load_config(as_model=True)
|
||||
|
||||
|
||||
# Load the config
|
||||
CONFIG = {
|
||||
"inference": config.presets.inference,
|
||||
"inference_groups": config.presets.inference_groups,
|
||||
}
|
||||
|
||||
|
||||
# Sync the config when it is saved
|
||||
def sync_config(event):
|
||||
CONFIG["inference"] = InferencePresets(
|
||||
**event.data.get("presets", {}).get("inference", {})
|
||||
)
|
||||
CONFIG["inference_groups"] = {
|
||||
group: InferencePresetGroup(**data)
|
||||
for group, data in event.data.get("presets", {})
|
||||
.get("inference_groups", {})
|
||||
.items()
|
||||
}
|
||||
|
||||
|
||||
handlers["config_saved"].connect(sync_config)
|
||||
|
||||
|
||||
def get_inference_parameters(preset_name: str, group: str | None = None) -> dict:
|
||||
"""
|
||||
Returns the inference parameters for the given preset name.
|
||||
"""
|
||||
|
||||
presets = CONFIG["inference"].model_dump()
|
||||
config = get_config()
|
||||
|
||||
presets = config.presets.inference.model_dump()
|
||||
|
||||
if group:
|
||||
try:
|
||||
group_presets = CONFIG["inference_groups"].get(group).model_dump()
|
||||
group_presets = config.presets.inference_groups.get(group).model_dump()
|
||||
presets.update(group_presets["presets"])
|
||||
except AttributeError:
|
||||
log.warning(
|
||||
@@ -74,6 +50,7 @@ def configure(parameters: dict, kind: str, total_budget: int, client: "ClientBas
|
||||
"""
|
||||
set_preset(parameters, kind, client)
|
||||
set_max_tokens(parameters, kind, total_budget)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
@@ -141,7 +118,7 @@ def preset_for_kind(kind: str, client: "ClientBase") -> dict:
|
||||
if not preset_name:
|
||||
log.warning(
|
||||
f"No preset found for kind {kind}, defaulting to 'scene_direction'",
|
||||
presets=CONFIG["inference"],
|
||||
presets=get_config().presets.inference,
|
||||
)
|
||||
preset_name = "scene_direction"
|
||||
|
||||
|
||||
@@ -69,17 +69,13 @@ class EndpointOverrideAPIKeyField(EndpointOverrideField):
|
||||
|
||||
|
||||
class EndpointOverrideMixin:
|
||||
override_base_url: str | None = None
|
||||
override_api_key: str | None = None
|
||||
@property
|
||||
def override_base_url(self) -> str | None:
|
||||
return self.client_config.override_base_url
|
||||
|
||||
def set_client_api_key(self, api_key: str | None):
|
||||
if getattr(self, "client", None):
|
||||
try:
|
||||
self.client.api_key = api_key
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Error setting client API key", error=e, client=self.client_type
|
||||
)
|
||||
@property
|
||||
def override_api_key(self) -> str | None:
|
||||
return self.client_config.override_api_key
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
@@ -108,41 +104,7 @@ class EndpointOverrideMixin:
|
||||
and self.endpoint_override_api_key_configured
|
||||
)
|
||||
|
||||
def _reconfigure_endpoint_override(self, **kwargs):
|
||||
if "override_base_url" in kwargs:
|
||||
orig = getattr(self, "override_base_url", None)
|
||||
self.override_base_url = kwargs["override_base_url"]
|
||||
if getattr(self, "client", None) and orig != self.override_base_url:
|
||||
log.info("Reconfiguring client base URL", new=self.override_base_url)
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "override_api_key" in kwargs:
|
||||
self.override_api_key = kwargs["override_api_key"]
|
||||
self.set_client_api_key(self.override_api_key)
|
||||
|
||||
|
||||
class RemoteServiceMixin:
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
@@ -9,7 +9,7 @@ import dotenv
|
||||
import runpod
|
||||
import structlog
|
||||
|
||||
from talemate.config import load_config
|
||||
from talemate.config import get_config
|
||||
|
||||
from .bootstrap import ClientBootstrap, ClientType, register_list
|
||||
|
||||
@@ -17,7 +17,6 @@ log = structlog.get_logger("talemate.client.runpod")
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
runpod.api_key = load_config().get("runpod", {}).get("api_key", "")
|
||||
|
||||
TEXTGEN_IDENTIFIERS = ["textgen", "thebloke llms", "text-generation-webui"]
|
||||
|
||||
@@ -35,6 +34,7 @@ async def _async_get_pods():
|
||||
"""
|
||||
asyncio wrapper around get_pods.
|
||||
"""
|
||||
runpod.api_key = get_config().runpod.api_key
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, runpod.get_pods)
|
||||
@@ -44,6 +44,7 @@ async def get_textgen_pods():
|
||||
"""
|
||||
Return a list of text generation pods.
|
||||
"""
|
||||
runpod.api_key = get_config().runpod.api_key
|
||||
|
||||
if not runpod.api_key:
|
||||
return
|
||||
@@ -60,6 +61,8 @@ async def get_automatic1111_pods():
|
||||
Return a list of automatic1111 pods.
|
||||
"""
|
||||
|
||||
runpod.api_key = get_config().runpod.api_key
|
||||
|
||||
if not runpod.api_key:
|
||||
return
|
||||
|
||||
|
||||