* separate other tts apis and improve chunking

* move old tts config to voice agent config and implement config widget ux elements for table editing

* elevenlabs updated to use their client and expose model selection

* linting

* separate character class into character.pt and start on voice routing

* linting

* tts hot swapping and chunking improvements

* linting

* add support for piper-tts

* update gitignore

* linting

* support google tts
fix issue where quick_toggle agent config didnt work on standard config items

* linting

* only show agent quick toggles if the agent is enabled

* change elevenlabs to use a locally maintained voice list

* tts generate before / after events

* voice library refactor

* linting

* update openai model and voices

* tweak configs

* voice library ux

* linting

* add support for kokoro tts

* fix add / remove voice

* voice library tags

* linting

* linting

* tts api status

* api infos and add more kokoro voices

* allow voice testing before saving a new voice

* tweaks to voice library ux and some api info text

* linting

* voice mixer

* polish

* voice files go into /tts instead of templates/voice

* change default narrator voice

* xtts confirmation note

* character voice select

* koboldai format template

* polish

* skip empty chunks

* change default voice

* replace em-dash with normal dash

* adjust limit

* replace libebreaks

* chunk cleanup for whitespace

* info updated

* remove invalid endif tag

* sort voices by ready api

* Character hashable type

* clarify set_simulated_environment use to avoid unwanted character deactivated

* allow manual generation of tts and fix assorted issues with tts

* tts websocket handler router renamed

* voice mixer: when there are only 2 voices auto adjust the other weight as needed

* separate persist character functions into own mixin

* auto assign voices

* fix chara load and auto assign voice during chara load

* smart speaker separation

* tts speaker separation config

* generate tts for intro text

* fix prompting issues with anthropic, google and openrouter clients

* decensor flag off again

* only to ai assisted voice markup on narrator messages

* openrouter provider configuration

* linting

* improved sound controls

* add support for chatterbox

* fix info

* chatterbox dependencies

* remove piper and xtts2

* linting

* voice params

* linting

* tts model overrides and move tts info to tab

* reorg toolbar

* allow overriding of test text

* more tts fixes, apply intensity, chatterbox voices

* confirm voice delete

* lintinG

* groq updates

* reorg decorators

* tts fixes

* cancelable audio queue

* voice library uploads

* scene voice library

* Config refactor (#13)

* config refactor progres

* config nuke continues

* fix system prompts

* linting

* client fun

* client config refactor

* fix kcpp auto embedding selection

* linting

* fix proxy config

* remove cruft

* fix remaining client bugs from config refactor
always use get_config(), dont keep an instance reference

* support for reasoning models

* more reasoning tweaks

* only allow one frontend to connect at a time

* fix tests

* relock

* relock

* more client adjustments

* pattern prefill

* some tts agent fixes

* fix ai assist cond

* tts nodes

* fix config retrieval

* assign voice node and fixes

* sim suite char gen assign voice

* fix voice assign template to consider used voices

* get rid of auto break repetition which wasn't working right for a while anyhow

* linting

* generate tts node
as string node

* linting

* voice change on character event

* tweak chatterbox max length

* koboldai default template

* linting

* fix saving of existing voice

* relock

* adjust params of eva default voice

* f5tts support

* f5tts samples

* f5tts support

* f5tts tweaks

* chunk size per tts api and reorg defaul f5tts voices

* chatterbox default voice reog to match f5-tts default voices

* voice library ux polish pass

* cleanup

* f5-tts tweaks

* missing samples

* get rid of old save cmd

* add chatterbox and f5tts

* housekeeping

* fix some issues with world entry editing

* remove cruft

* replace exclamation marks

* fix save immutable check

* fix replace_exclamation_marks

* better error handling in websocket plugins and fix issue with saves

* agent config save on dialog close

* ctrl click to disable / enable agents

* fix quick config

* allow modifying response size of focal requests

* sim suite set goal always sets story intent, encourage calling of set goal during simulation start

* allow setting of model

* voice param tweaks

* tts tweaks

* fix character card load

* fix note_on_value

* add mixed speaker_separation mode

* indicate which message the audio is for and provide way to stop audio from the message

* fix issue with some tts generation failing

* linting

* fix speaker separate modes

* bad idea

* linting

* refactor speaker separation prompt

* add kimi think pattern

* fix issue with unwanted cover image replacemenT

* no scene analysis for visual promp generation (for now)

* linting

* tts for context investigation messages

* prompt tweaks

* tweak intro

* fix intro text tts not auto playing sometimes

* consider narrator voice when assigning voice tro a character

* allow director log messages to go only into the director console

* linting

* startup performance fixes

* init time

* linting

* only show audio control for messagews taht can have it

* always create story intent and dont override existing saves during character card load

* fix history check in dynamic story line node
add HasHistory node

* linting

* fix intro message not having speaker separation

* voice library character manager

* sequantial and cancelable auto assign all

* linting

* fix generation cancel handling

* tooltips

* fix auto assign voice from scene voices

* polish

* kokoro does not like lazy import

* update info text

* complete scene export / import

* linting

* wording

* remove cruft

* fix story intent generation during character card import

* fix generation cancelled emit status inf loop

* prompt tweak

* reasoning quick toggle, reasoning token slider, tooltips

* improved reasoning pattern handling

* fix indirect coercion response parsing

* fix streaming issue

* response length instructions

* more robust streaming

* adjust default

* adjust formatting

* litning

* remove debug output

* director console log function calls

* install cuda script updated

* linting

* add another step

* adjust default

* update dialogue examples

* fix voice selection issues

* what's happening here

* third time's the charm?

* Vite migration (#207)

* add vite config

* replace babel, webpack, vue-cli deps with vite, switch to esm modules, separate eslint config

* change process.env to import.meta.env

* update index.html for vite and move to root

* update docs for vite

* remove vue cli config

* update example env with vite

* bump frontend deps after rebase to 32.0

---------

Co-authored-by: pax-co <Pax_801@proton.me>

* properly referencer data type

* what's new

* better indication of dialogue example supporting multiple lines, improve dialogue example display

* fix potential issue with cached scene anlysis being reused when it shouldn't

* fix character creation issues with player character toggle

* fix issue where editing a message would sometimes lose parts of the message

* fix slider ux thumb labels (vuetify update)

* relock

* narrative conversation format

* remove planning step

* linting

* tweaks

* don't overthink

* update dialogue examples and intro

* dont dictate response length instructions when data structures are expected

* prompt tweaks

* prompt tweaks

* linting

* fix edit message not handling : well

* prompt tweaks

* fix tests

* fix manual revision when character message was generated in new narrative mode

* fix issue with message editing

* Docker packages relese (#204)

* add CI workflow for Docker image build and MkDocs deployment

* rename CI workflow from 'ci' to 'package'

* refactor CI workflow: consolidate container build and documentation deployment into a single file

* fix: correct indentation for permissions in CI workflow

* fix: correct indentation for steps in deploy-docs job in CI workflow

* build both cpu and cuda image

* docs

* docs

* expose writing style during state reinforcement

* prompt tweaks

* test container build

* test container  image

* update docker compose

* docs

* test-container-build

* test container build

* test container build

* update docker build workflows

* fix guidance prompt prefix not being dropped

* mount tts dir

* add gpt-5

* remove debug output

* docs

* openai auto toggle reasoning based on model selection

* linting

---------

Co-authored-by: pax-co <123330830+pax-co@users.noreply.github.com>
Co-authored-by: pax-co <Pax_801@proton.me>
Co-authored-by: Luis Alexandre Deschamps Brandão <brandao_luis@yahoo.com>
This commit is contained in:
veguAI
2025-08-08 13:56:29 +03:00
committed by GitHub
parent 685ca994f9
commit ce4c302d73
223 changed files with 16882 additions and 16488 deletions

View File

@@ -1,30 +1,57 @@
name: ci
name: ci
on:
push:
branches:
- master
- main
- prep-0.26.0
- master
release:
types: [published]
permissions:
contents: write
packages: write
jobs:
deploy:
container-build:
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
- name: Log in to GHCR
uses: docker/login-action@v3
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build & push
uses: docker/build-push-action@v5
with:
context: .
file: Dockerfile
push: true
tags: |
ghcr.io/${{ github.repository }}:latest
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
deploy-docs:
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git credentials
run: |
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
- uses: actions/setup-python@v5
with: { python-version: '3.x' }
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
restore-keys: mkdocs-material-
- run: pip install mkdocs-material mkdocs-awesome-pages-plugin mkdocs-glightbox
- run: mkdocs gh-deploy --force

View File

@@ -0,0 +1,32 @@
name: test-container-build
on:
push:
branches: [ 'prep-*' ]
permissions:
contents: read
packages: write
jobs:
container-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build & push
uses: docker/build-push-action@v5
with:
context: .
file: Dockerfile
push: true
# Tag with prep suffix to avoid conflicts with production
tags: |
ghcr.io/${{ github.repository }}:${{ github.ref_name }}

12
.gitignore vendored
View File

@@ -8,11 +8,20 @@
talemate_env
chroma
config.yaml
.cursor
.claude
# uv
.venv/
templates/llm-prompt/user/*.jinja2
templates/world-state/*.yaml
tts/voice/piper/*.onnx
tts/voice/piper/*.json
tts/voice/kokoro/*.pt
tts/voice/xtts2/*.wav
tts/voice/chatterbox/*.wav
tts/voice/f5tts/*.wav
tts/voice/voice-library.json
scenes/
!scenes/infinity-quest-dynamic-scenario/
!scenes/infinity-quest-dynamic-scenario/assets/
@@ -21,4 +30,5 @@ scenes/
!scenes/infinity-quest/assets/
!scenes/infinity-quest/infinity-quest.json
tts_voice_samples/*.wav
third-party-docs/
third-party-docs/
legacy-state-reinforcements.yaml

View File

@@ -35,18 +35,9 @@ COPY pyproject.toml uv.lock /app/
# Copy the Python source code (needed for editable install)
COPY ./src /app/src
# Create virtual environment and install dependencies
# Create virtual environment and install dependencies (includes CUDA support via pyproject.toml)
RUN uv sync
# Conditional PyTorch+CUDA install
ARG CUDA_AVAILABLE=false
RUN . /app/.venv/bin/activate && \
if [ "$CUDA_AVAILABLE" = "true" ]; then \
echo "Installing PyTorch with CUDA support..." && \
uv pip uninstall torch torchaudio && \
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
fi
# Stage 3: Final image
FROM python:3.11-slim

20
docker-compose.manual.yml Normal file
View File

@@ -0,0 +1,20 @@
version: '3.8'
services:
talemate:
build:
context: .
dockerfile: Dockerfile
ports:
- "${FRONTEND_PORT:-8080}:8080"
- "${BACKEND_PORT:-5050}:5050"
volumes:
- ./config.yaml:/app/config.yaml
- ./scenes:/app/scenes
- ./templates:/app/templates
- ./chroma:/app/chroma
- ./tts:/app/tts
environment:
- PYTHONUNBUFFERED=1
- PYTHONPATH=/app/src:$PYTHONPATH
command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]

View File

@@ -2,11 +2,7 @@ version: '3.8'
services:
talemate:
build:
context: .
dockerfile: Dockerfile
args:
- CUDA_AVAILABLE=${CUDA_AVAILABLE:-false}
image: ghcr.io/vegu-ai/talemate:latest
ports:
- "${FRONTEND_PORT:-8080}:8080"
- "${BACKEND_PORT:-5050}:5050"
@@ -15,6 +11,7 @@ services:
- ./scenes:/app/scenes
- ./templates:/app/templates
- ./chroma:/app/chroma
- ./tts:/app/tts
environment:
- PYTHONUNBUFFERED=1
- PYTHONPATH=/app/src:$PYTHONPATH

View File

@@ -27,10 +27,10 @@ uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
### Letting the frontend know about the new host and port
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL`.
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VITE_TALEMATE_BACKEND_WEBSOCKET_URL`.
```env
VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
VITE_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
```
Next rebuild the frontend.

View File

@@ -1,22 +1,15 @@
!!! example "Experimental"
Talemate through docker has not received a lot of testing from me, so please let me know if you encounter any issues.
You can do so by creating an issue on the [:material-github: GitHub repository](https://github.com/vegu-ai/talemate)
## Quick install instructions
1. `git clone https://github.com/vegu-ai/talemate.git`
1. `cd talemate`
1. copy config file
1. linux: `cp config.example.yaml config.yaml`
1. windows: `copy config.example.yaml config.yaml`
1. If your host has a CUDA compatible Nvidia GPU
1. Windows (via PowerShell): `$env:CUDA_AVAILABLE="true"; docker compose up`
1. Linux: `CUDA_AVAILABLE=true docker compose up`
1. If your host does **NOT** have a CUDA compatible Nvidia GPU
1. Windows: `docker compose up`
1. Linux: `docker compose up`
1. windows: `copy config.example.yaml config.yaml` (or just copy the file and rename it via the file explorer)
1. `docker compose up`
1. Navigate your browser to http://localhost:8080
!!! info "Pre-built Images"
The default setup uses pre-built images from GitHub Container Registry that include CUDA support by default. To manually build the container instead, use `docker compose -f docker-compose.manual.yml up --build`.
!!! note
When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

View File

@@ -0,0 +1,58 @@
# Chatterbox
Local zero shot voice cloning from .wav files.
![Chatterbox API settings](/talemate/img/0.32.0/chatterbox-api-settings.png)
##### Device
Auto-detects best available option
##### Model
Default Chatterbox model optimized for speed
##### Chunk size
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
## Adding Chatterbox Voices
### Voice Requirements
Chatterbox voices require:
- Reference audio file (.wav format, 5-15 seconds optimal)
- Clear speech with minimal background noise
- Single speaker throughout the sample
### Creating a Voice
1. Open the Voice Library
2. Click **:material-plus: New**
3. Select "Chatterbox" as the provider
4. Configure the voice:
![Add Chatterbox voice](/talemate/img/0.32.0/add-chatterbox-voice.png)
**Label:** Descriptive name (e.g., "Marcus - Deep Male")
**Voice ID / Upload File** Upload a .wav file containing the voice sample. The uploaded reference audio will also be the voice ID.
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
**Tags:** Add descriptive tags for organization
**Extra voice parameters**
There exist some optional parameters that can be set here on a per voice level.
![Chatterbox extra voice parameters](/talemate/img/0.32.0/chatterbox-parameters.png)
##### Exaggeration Level
Exaggeration (Neutral = 0.5, extreme values can be unstable). Higher exaggeration tends to speed up speech; reducing cfg helps compensate with slower, more deliberate pacing.
##### CFG / Pace
If the reference speaker has a fast speaking style, lowering cfg to around 0.3 can improve pacing.

View File

@@ -1,7 +1,41 @@
# ElevenLabs
If you have not configured the ElevenLabs TTS API, the voice agent will show that the API key is missing.
Professional voice synthesis with voice cloning capabilities using ElevenLabs API.
![Elevenlaps api key missing](/talemate/img/0.26.0/voice-agent-missing-api-key.png)
![ElevenLabs API settings](/talemate/img/0.32.0/elevenlabs-api-settings.png)
See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on how to set up the API key.
## API Setup
ElevenLabs requires an API key. See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on obtaining and setting an API key.
## Configuration
**Model:** Select from available ElevenLabs models
!!! warning "Voice Limits"
Your ElevenLabs subscription allows you to maintain a set number of voices (10 for the cheapest plan). Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button. It is recommended to test voices via their voice library instead.
## Adding ElevenLabs Voices
### Getting Voice IDs
1. Go to [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab) to view your voices
2. Find or create the voice you want to use
3. Click "More Actions" -> "Copy Voice ID" for the desired voice
![Copy Voice ID](/talemate/img/0.32.0/elevenlabs-copy-voice-id.png)
### Creating a Voice in Talemate
![Add ElevenLabs voice](/talemate/img/0.32.0/add-elevenlabs-voice.png)
1. Open the Voice Library
2. Click "Add Voice"
3. Select "ElevenLabs" as the provider
4. Configure the voice:
**Label:** Descriptive name for the voice
**Provider ID:** Paste the ElevenLabs voice ID you copied
**Tags:** Add descriptive tags for organization

View File

@@ -0,0 +1,78 @@
# F5-TTS
Local zero shot voice cloning from .wav files.
![F5-TTS configuration](/talemate/img/0.32.0/f5tts-api-settings.png)
##### Device
Auto-detects best available option (GPU preferred)
##### Model
- F5TTS_v1_Base (default, most recent model)
- F5TTS_Base
- E2TTS_Base
##### NFE Step
Number of steps to generate the voice. Higher values result in more detailed voices.
##### Chunk size
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
##### Replace exclamation marks
If checked, exclamation marks will be replaced with periods. This is recommended for `F5TTS_v1_Base` since it seems to over exaggerate exclamation marks.
## Adding F5-TTS Voices
### Voice Requirements
F5-TTS voices require:
- Reference audio file (.wav format, 10-30 seconds)
- Clear speech with minimal background noise
- Single speaker throughout the sample
- Reference text (optional but recommended)
### Creating a Voice
1. Open the Voice Library
2. Click "Add Voice"
3. Select "F5-TTS" as the provider
4. Configure the voice:
![Add F5-TTS voice](/talemate/img/0.32.0/add-f5tts-voice.png)
**Label:** Descriptive name (e.g., "Emma - Calm Female")
**Voice ID / Upload File** Upload a .wav file containing the **reference audio** voice sample. The uploaded reference audio will also be the voice ID.
- Use 6-10 second samples (longer doesn't improve quality)
- Ensure clear speech with minimal background noise
- Record at natural speaking pace
**Reference Text:** Enter the exact text spoken in the reference audio for improved quality
- Enter exactly what is spoken in the reference audio
- Include proper punctuation and capitalization
- Improves voice cloning accuracy significantly
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
**Tags:** Add descriptive tags (gender, age, style) for organization
**Extra voice parameters**
There exist some optional parameters that can be set here on a per voice level.
![F5-TTS extra voice parameters](/talemate/img/0.32.0/f5tts-parameters.png)
##### Speed
Allows you to adjust the speed of the voice.
##### CFG Strength
A higher CFG strength generally leads to more faithful reproduction of the input text, while a lower CFG strength can result in more varied or creative speech output, potentially at the cost of text-to-speech accuracy.

View File

@@ -0,0 +1,15 @@
# Google Gemini-TTS
Google Gemini-TTS provides access to Google's text-to-speech service.
## API Setup
Google Gemini-TTS requires a Google Cloud API key.
See the [Google Cloud API setup](/talemate/user-guide/apis/google/) for instructions on obtaining an API key.
## Configuration
![Google TTS settings](/talemate/img/0.32.0/google-tts-api-settings.png)
**Model:** Select from available Google TTS models

View File

@@ -1,6 +1,26 @@
# Overview
Talemate supports Text-to-Speech (TTS) functionality, allowing users to convert text into spoken audio. This document outlines the steps required to configure TTS for Talemate using different providers, including ElevenLabs and a local TTS API.
In 0.32.0 Talemate's TTS (Text-to-Speech) agent has been completely refactored to provide advanced voice capabilities including per-character voice assignment, speaker separation, and support for multiple local and remote APIs. The voice system now includes a comprehensive voice library for managing and organizing voices across all supported providers.
## Key Features
- **Per-character voice assignment** - Each character can have their own unique voice
- **Speaker separation** - Automatic detection and separation of dialogue from narration
- **Voice library management** - Centralized management of all voices across providers
- **Multiple API support** - Support for both local and remote TTS providers
- **Director integration** - Automatic voice assignment for new characters
## Supported APIs
### Local APIs
- **Kokoro** - Fastest generation with predefined voice models and mixing
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
- **Chatterbox** - High-quality voice cloning (slower generation)
### Remote APIs
- **ElevenLabs** - Professional voice synthesis with voice cloning
- **Google Gemini-TTS** - Google's text-to-speech service
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
## Enable the Voice agent
@@ -12,28 +32,30 @@ If your voice agent is disabled - indicated by the grey dot next to the agent -
![Agent disabled](/talemate/img/0.26.0/agent-disabled.png) ![Agent enabled](/talemate/img/0.26.0/agent-enabled.png)
!!! note "Ctrl click to toggle agent"
You can use Ctrl click to toggle the agent on and off.
!!! abstract "Next: Connect to a TTS api"
Next you need to decide which service / api to use for audio generation and configure the voice agent accordingly.
## Voice Library Management
- [OpenAI](openai.md)
- [ElevenLabs](elevenlabs.md)
- [Local TTS](local_tts.md)
Voices are managed through the Voice Library, accessible from the main application bar. The Voice Library allows you to:
You can also find more information about the various settings [here](settings.md).
- Add and organize voices from all supported providers
- Assign voices to specific characters
- Create mixed voices (Kokoro)
- Manage both global and scene-specific voice libraries
## Select a voice
See the [Voice Library Guide](voice-library.md) for detailed instructions.
![Elevenlaps voice missing](/talemate/img/0.26.0/voice-agent-no-voice-selected.png)
## Character Voice Assignment
Click on the agent to open the agent settings.
![Character voice assignment](/talemate/img/0.32.0/character-voice-assignment.png)
Then click on the `Narrator Voice` dropdown and select a voice.
Characters can have individual voices assigned through the Voice Library. When a character has a voice assigned:
![Elevenlaps voice selected](/talemate/img/0.26.0/voice-agent-select-voice.png)
1. Their dialogue will use their specific voice
2. The narrator voice is used for exposition in their messages (with speaker separation enabled)
3. If their assigned voice's API is not available, it falls back to the narrator voice
The selection is saved automatically, click anywhere outside the agent window to close it.
The Voice agent status will show all assigned character voices and their current status.
The Voice agent should now show that the voice is selected and be ready to use.
![Elevenlabs ready](/talemate/img/0.26.0/elevenlabs-ready.png)
![Voice agent status with characters](/talemate/img/0.32.0/voice-agent-status-characters.png)

View File

@@ -0,0 +1,55 @@
# Kokoro
Kokoro provides predefined voice models and voice mixing capabilities for creating custom voices.
## Using Predefined Voices
Kokoro comes with built-in voice models that are ready to use immediately
Available predefined voices include various male and female voices with different characteristics.
## Creating Mixed Voices
Kokor allows you to mix voices together to create a new voice.
### Voice Mixing Interface
To create a mixed voice:
1. Open the Voice Library
2. Click ":material-plus: New"
3. Select "Kokoro" as the provider
4. Choose ":material-tune:Mixer" option
5. Configure the mixed voice:
![Voice mixing interface](/talemate/img/0.32.0/kokoro-mixer.png)
**Label:** Descriptive name for the mixed voice
**Base Voices:** Select 2-4 existing Kokoro voices to combine
**Weights:** Set the influence of each voice (0.1 to 1.0)
**Tags:** Descriptive tags for organization
### Weight Configuration
Each selected voice can have its weight adjusted:
- Higher weights make that voice more prominent in the mix
- Lower weights make that voice more subtle
- Total weights need to sum to 1.0
- Experiment with different combinations to achieve desired results
### Saving Mixed Voices
Once configured click "Add Voice", mixed voices are saved to your voice library and can be:
- Assigned to characters
- Used as narrator voices
just like any other voice.
Saving a mixed cvoice may take a moment to complete.

View File

@@ -1,53 +0,0 @@
# Local TTS
!!! warning
This has not been tested in a while and may not work as expected. It will likely be replaced with something different in the future. If this approach is currently broken its likely to remain so until it is replaced.
For running a local TTS API, Talemate requires specific dependencies to be installed.
### Windows Installation
Run `install-local-tts.bat` to install the necessary requirements.
### Linux Installation
Execute the following command:
```bash
pip install TTS
```
### Model and Device Configuration
1. Choose a TTS model from the [Coqui TTS model list](https://github.com/coqui-ai/TTS).
2. Decide whether to use `cuda` or `cpu` for the device setting.
3. The first time you run TTS through the local API, it will download the specified model. Please note that this may take some time, and the download progress will be visible in the Talemate backend output.
Example configuration snippet:
```yaml
tts:
device: cuda # or 'cpu'
model: tts_models/multilingual/multi-dataset/xtts_v2
```
### Voice Samples Configuration
Configure voice samples by setting the `value` field to the path of a .wav file voice sample. Official samples can be downloaded from [Coqui XTTS-v2 samples](https://huggingface.co/coqui/XTTS-v2/tree/main/samples).
Example configuration snippet:
```yaml
tts:
voices:
- label: English Male
value: path/to/english_male.wav
- label: English Female
value: path/to/english_female.wav
```
## Saving the Configuration
After configuring the `config.yaml` file, save your changes. Talemate will use the updated settings the next time it starts.
For more detailed information on configuring Talemate, refer to the `config.py` file in the Talemate source code and the `config.example.yaml` file for a barebone configuration example.

View File

@@ -8,16 +8,12 @@ See the [OpenAI API setup](/apis/openai.md) for instructions on how to set up th
## Settings
![Voice agent openai settings](/talemate/img/0.26.0/voice-agent-openai-settings.png)
![Voice agent openai settings](/talemate/img/0.32.0/openai-tts-api-settings.png)
##### Model
Which model to use for generation.
- GPT-4o Mini TTS
- TTS-1
- TTS-1 HD
!!! quote "OpenAI API documentation on quality"
For real-time applications, the standard tts-1 model provides the lowest latency but at a lower quality than the tts-1-hd model. Due to the way the audio is generated, tts-1 is likely to generate content that has more static in certain situations than tts-1-hd. In some cases, the audio may not have noticeable differences depending on your listening device and the individual person.
Generally i have found that HD is fast enough for talemate, so this is the default.
- TTS-1 HD

View File

@@ -1,36 +1,65 @@
# Settings
![Voice agent settings](/talemate/img/0.26.0/voice-agent-settings.png)
![Voice agent settings](/talemate/img/0.32.0/voice-agent-settings.png)
##### API
##### Enabled APIs
The TTS API to use for voice generation.
Select which TTS APIs to enable. You can enable multiple APIs simultaneously:
- OpenAI
- ElevenLabs
- Local TTS
- **Kokoro** - Fastest generation with predefined voice models and mixing
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
- **Chatterbox** - High-quality voice cloning (slower generation)
- **ElevenLabs** - Professional voice synthesis with voice cloning
- **Google Gemini-TTS** - Google's text-to-speech service
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
!!! note "Multi-API Support"
You can enable multiple APIs and assign different voices from different providers to different characters. The system will automatically route voice generation to the appropriate API based on the voice assignment.
##### Narrator Voice
The voice to use for narration. Each API will come with its own set of voices.
The default voice used for narration and as a fallback for characters without assigned voices.
![Narrator voice](/talemate/img/0.26.0/voice-agent-select-voice.png)
The dropdown shows all available voices from all enabled APIs, with the format: "Voice Name (Provider)"
!!! note "Local TTS"
For local TTS, you will have to provide voice samples yourself. See [Local TTS Instructions](local_tts.md) for more information.
!!! info "Voice Management"
Voices are managed through the Voice Library, accessible from the main application bar. Adding, removing, or modifying voices should be done through the Voice Library interface.
##### Generate for player
##### Speaker Separation
Whether to generate voice for the player. If enabled, whenever the player speaks, the voice agent will generate audio for them.
Controls how dialogue is separated from exposition in messages:
##### Generate for NPCs
- **No separation** - Character messages use character voice entirely, narrator messages use narrator voice
- **Simple** - Basic separation of dialogue from exposition using punctuation analysis, with exposition being read by the narrator voice
- **Mixed** - Enables AI assisted separation for narrator messages and simple separation for character messages
- **AI assisted** - AI assisted separation for both narrator and character messages
Whether to generate voice for NPCs. If enabled, whenever a non player character speaks, the voice agent will generate audio for them.
!!! warning "AI Assisted Performance"
AI-assisted speaker separation sends additional prompts to your LLM, which may impact response time and API costs.
##### Generate for narration
##### Auto-generate for player
Whether to generate voice for narration. If enabled, whenever the narrator speaks, the voice agent will generate audio for them.
Generate voice automatically for player messages
##### Split generation
##### Auto-generate for AI characters
If enabled, the voice agent will generate audio in chunks, allowing for faster generation. This does however cause it lose context between chunks, and inflection may not be as good.
Generate voice automatically for NPC/AI character messages
##### Auto-generate for narration
Generate voice automatically for narrator messages
##### Auto-generate for context investigation
Generate voice automatically for context investigation messages
## Advanced Settings
Advanced settings are configured per-API and can be found in the respective API configuration sections:
- **Chunk size** - Maximum text length per generation request
- **Model selection** - Choose specific models for each API
- **Voice parameters** - Provider-specific voice settings
!!! tip "Performance Optimization"
Each API has different optimal chunk sizes and parameters. The system automatically handles chunking and queuing for optimal performance across all enabled APIs.

View File

@@ -0,0 +1,156 @@
# Voice Library
The Voice Library is the central hub for managing all voices across all TTS providers in Talemate. It provides a unified interface for organizing, creating, and assigning voices to characters.
## Accessing the Voice Library
The Voice Library can be accessed from the main application bar at the top of the Talemate interface.
![Voice Library access](/talemate/img/0.32.0/voice-library-access.png)
Click the voice icon to open the Voice Library dialog.
!!! note "Voice agent needs to be enabled"
The Voice agent needs to be enabled for the voice library to be available.
## Voice Library Interface
![Voice Library interface](/talemate/img/0.32.0/voice-library-interface.png)
The Voice Library interface consists of:
### Scope Tabs
- **Global** - Voices available across all scenes
- **Scene** - Voices specific to the current scene (only visible when a scene is loaded)
- **Characters** - Character voice assignments for the current scene (only visible when a scene is loaded)
### API Status
The toolbar shows the status of all TTS APIs:
- **Green** - API is enabled and ready
- **Orange** - API is enabled but not configured
- **Red** - API has configuration issues
- **Gray** - API is disabled
![API status](/talemate/img/0.32.0/voice-library-api-status.png)
## Managing Voices
### Global Voice Library
The global voice library contains voices that are available across all scenes. These include:
- Default voices provided by each TTS provider
- Custom voices you've added
#### Adding New Voices
To add a new voice:
1. Click the "+ New" button
2. Select the TTS provider
3. Configure the voice parameters:
- **Label** - Display name for the voice
- **Provider ID** - Provider-specific identifier
- **Tags** - Free-form descriptive tags you define (gender, age, style, etc.)
- **Parameters** - Provider-specific settings
Check the provider specific documentation for more information on how to configure the voice.
#### Voice Types by Provider
**F5-TTS & Chatterbox:**
- Upload .wav reference files for voice cloning
- Specify reference text for better quality
- Adjust speed and other parameters
**Kokoro:**
- Select from predefined voice models
- Create mixed voices by combining multiple models
- Adjust voice mixing weights
**ElevenLabs:**
- Select from available ElevenLabs voices
- Configure voice settings and stability
- Use custom cloned voices from your ElevenLabs account
**OpenAI:**
- Choose from available OpenAI voice models
- Configure model (GPT-4o Mini TTS, TTS-1, TTS-1-HD)
**Google Gemini-TTS:**
- Select from Google's voice models
- Configure language and gender settings
### Scene Voice Library
Scene-specific voices are only available within the current scene. This is useful for:
- Scene-specific characters
- Temporary voice experiments
- Custom voices for specific scenarios
Scene voices are saved with the scene and will be available when the scene is loaded.
## Character Voice Assignment
### Automatic Assignment
The Director agent can automatically assign voices to new characters based on:
- Character tags and attributes
- Voice tags matching character personality
- Available voices in the voice library
This feature can be enabled in the Director agent settings.
### Manual Assignment
![Character voice assignment](/talemate/img/0.32.0/character-voice-assignment.png)
To manually assign a voice to a character:
1. Go to the "Characters" tab in the Voice Library
2. Find the character in the list
3. Click the voice dropdown for that character
4. Select a voice from the available options
5. The assignment is saved automatically
### Character Voice Status
The character list shows:
- **Character name**
- **Currently assigned voice** (if any)
- **Voice status** - whether the voice's API is available
- **Quick assignment controls**
## Voice Tags and Organization
### Tagging System
Voices can be tagged with any descriptive attributes you choose. Tags are completely free-form and user-defined. Common examples include:
- **Gender**: male, female, neutral
- **Age**: young, mature, elderly
- **Style**: calm, energetic, dramatic, mysterious
- **Quality**: deep, high, raspy, smooth
- **Character types**: narrator, villain, hero, comic relief
- **Custom tags**: You can create any tags that help you organize your voices
### Filtering and Search
Use the search bar to filter voices by:
- Voice label/name
- Provider
- Tags
- Character assignments
This makes it easy to find the right voice for specific characters or situations.

View File

@@ -0,0 +1,82 @@
# Reasoning Model Support
Talemate supports reasoning models that can perform step-by-step thinking before generating their final response. This feature allows models to work through complex problems internally before providing an answer.
## Enabling Reasoning Support
To enable reasoning support for a client:
1. Open the **Clients** dialog from the main toolbar
2. Select the client you want to configure
3. Navigate to the **Reasoning** tab in the client configuration
![Client reasoning configuration](/talemate/img/0.32.0/client-reasoning-2.png)
4. Check the **Enable Reasoning** checkbox
## Configuring Reasoning Tokens
Once reasoning is enabled, you can configure the **Reasoning Tokens** setting using the slider:
![Reasoning tokens configuration](/talemate/img/0.32.0/client-reasoning.png)
### Recommended Token Amounts
**For local reasoning models:** Use a high token allocation (recommended: 4096 tokens) to give the model sufficient space for complex reasoning.
**For remote APIs:** Start with lower amounts (512-1024 tokens) and adjust based on your needs and token costs.
### Token Allocation Behavior
The behavior of the reasoning tokens setting depends on your API provider:
**For APIs that support direct reasoning token specification:**
- The specified tokens will be allocated specifically for reasoning
- The model will use these tokens for internal thinking before generating the response
**For APIs that do NOT support reasoning token specification:**
- The tokens are added as extra allowance to the response token limit for ALL requests
- This may lead to more verbose responses than usual since Talemate normally uses response token limits to control verbosity
!!! warning "Increased Verbosity"
For providers without direct reasoning token support, enabling reasoning may result in more verbose responses since the extra tokens are added to all requests.
## Response Pattern Configuration
When reasoning is enabled, you may need to configure a **Pattern to strip from the response** to remove the thinking process from the final output.
### Default Patterns
Talemate provides quick-access buttons for common reasoning patterns:
- **Default** - Uses the built-in pattern: `.*?</think>`
- **`.*?◁/think▷`** - For models using arrow-style thinking delimiters
- **`.*?</think>`** - For models using XML-style think tags
### Custom Patterns
You can also specify a custom regular expression pattern that matches your model's reasoning format. This pattern will be used to strip the thinking tokens from the response before displaying it to the user.
## Model Compatibility
Not all models support reasoning. This feature works best with:
- Models specifically trained for chain-of-thought reasoning
- Models that support structured thinking patterns
- APIs that provide reasoning token specification
## Important Notes
- **Coercion Disabled**: When reasoning is enabled, LLM coercion (pre-filling responses) is automatically disabled since reasoning models need to generate their complete thought process
- **Response Time**: Reasoning models may take longer to respond as they work through their thinking process
## Troubleshooting
### Pattern Not Working
If the reasoning pattern isn't properly stripping the thinking process:
1. Check your model's actual reasoning output format
2. Adjust the regular expression pattern to match your model's specific format
3. Test with the default pattern first to see if it works

View File

@@ -35,4 +35,19 @@ A unique name for the client that makes sense to you.
Which model to use. Currently defaults to `gpt-4o`.
!!! note "Talemate lags behind OpenAI"
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
##### Reasoning models (o1, o3, gpt-5)
!!! important "Enable reasoning and allocate tokens"
The `o1`, `o3`, and `gpt-5` families are reasoning models. They always perform internal thinking before producing the final answer. To use them effectively in Talemate:
- Enable the **Reasoning** option in the client configuration.
- Set **Reasoning Tokens** to a sufficiently high value to make room for the model's thinking process.
A good starting range is 5121024 tokens. Increase if your tasks are complex. Without enabling reasoning and allocating tokens, these models may return minimal or empty visible content because the token budget is consumed by internal reasoning.
See the detailed guide: [Reasoning Model Support](/talemate/user-guide/clients/reasoning/).
!!! tip "Getting empty responses?"
If these models return empty or very short answers, it usually means the reasoning budget was exhausted. Increase **Reasoning Tokens** and try again.

View File

@@ -4,4 +4,4 @@
uv pip uninstall torch torchaudio
# install torch and torchaudio with CUDA support
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
uv pip install torch~=2.7.1 torchaudio~=2.7.1 --index-url https://download.pytorch.org/whl/cu128

View File

@@ -1,11 +1,12 @@
[project]
name = "talemate"
version = "0.31.0"
version = "0.32.0"
description = "AI-backed roleplay and narrative tools"
authors = [{name = "VeguAITools"}]
license = {text = "GNU Affero General Public License v3.0"}
requires-python = ">=3.10,<3.14"
requires-python = ">=3.11,<3.14"
dependencies = [
"pip",
"astroid>=2.8",
"jedi>=0.18",
"black",
@@ -50,11 +51,20 @@ dependencies = [
# ChromaDB
"chromadb>=1.0.12",
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
"torch>=2.7.0",
"torchaudio>=2.7.0",
# locked for instructor embeddings
#sentence-transformers==2.2.2
"torch>=2.7.1",
"torchaudio>=2.7.1",
"sentence_transformers>=2.7.0",
# TTS
"elevenlabs>=2.7.1",
# Local TTS
# Chatterbox TTS
#"chatterbox-tts @ https://github.com/rsxdalv/chatterbox/archive/refs/heads/fast.zip",
"chatterbox-tts==0.1.2",
# kokoro TTS
"kokoro>=0.9.4",
"soundfile>=0.13.1",
# F5-TTS
"f5-tts>=1.1.7",
]
[project.optional-dependencies]
@@ -105,3 +115,25 @@ force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
[tool.uv]
override-dependencies = [
# chatterbox wants torch 2.6.0, but is confirmed working with 2.7.1
"torchaudio>=2.7.1",
"torch>=2.7.1",
"numpy>=2",
"pydantic>=2.11",
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
torchaudio = [
{ index = "pytorch-cu128" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

View File

@@ -1,6 +1,6 @@
{
"description": "Captain Elmer Farstield and his trusty first officer, Kaira, embark upon a daring mission into uncharted space. Their small but mighty exploration vessel, the Starlight Nomad, is equipped with state-of-the-art technology and crewed by an elite team of scientists, engineers, and pilots. Together they brave the vast cosmos seeking answers to humanity's most pressing questions about life beyond our solar system.",
"intro": "You awaken aboard your ship, the Starlight Nomad, surrounded by darkness. A soft hum resonates throughout the vessel indicating its systems are online. Your mind struggles to recall what brought you here - where 'here' actually is. You remember nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there was a bright light followed by this endless void.\n\nGingerly, you make your way through the dimly lit corridors of the ship. It seems smaller than you expected given the magnitude of the mission ahead. However, each room reveals intricate technology designed specifically for long-term space travel and exploration. There appears to be no other living soul besides yourself. An eerie silence fills every corner.",
"intro": "Elmer awoke aboard his ship, the Starlight Nomad, surrounded by darkness. A soft hum resonated throughout the vessel indicating its systems were online. His mind struggled to recall what had brought him here - where here actually was. He remembered nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there had been a bright light followed by this endless void.\n\nGingerly, he made his way through the dimly lit corridors of the ship. It seemed smaller than he had expected given the magnitude of the mission ahead. However, each room revealed intricate technology designed specifically for long-term space travel and exploration. There appeared to be no other living soul besides himself. An eerie silence filled every corner.",
"name": "Infinity Quest",
"history": [],
"environment": "scene",
@@ -90,11 +90,11 @@
"gender": "female",
"color": "red",
"example_dialogue": [
"Kaira: \"Yes Captain, I believe that is the best course of action\" She nods slightly, as if to punctuate her approval of the decision*",
"Kaira: \"This device appears to have multiple functions, Captain. Allow me to analyze its capabilities and determine if it could be useful in our exploration efforts.\"",
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" Excitement bubbles beneath her calm exterior as she shares the news",
"Kaira: \"Captain, I understand why you would want us to pursue this course of action based on our current data, but I cannot shake the feeling that there might be unforeseen consequences if we proceed without further investigation into potential hazards.\"",
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world... But then again, perhaps it was fate that led me here, onto this ship bound for destinations unknown...\""
"Kaira: \"Yes Captain, I believe that is the best course of action.\" Kaira glanced at the navigation display, then back at him with a slight nod. \"The numbers check out on my end too. If we adjust our trajectory by 2.7 degrees and increase thrust by fifteen percent, we should reach the nebula's outer edge within six hours.\" Her violet fingers moved efficiently across the controls as she pulled up the gravitational readings.",
"Kaira: The scanner hummed as it analyzed the alien artifact. Kaira knelt beside the strange object, frowning in concentration at the shifting symbols on its warm metal surface. \"This device appears to have multiple functions, Captain,\" she said, adjusting her scanner's settings. \"Give me a few minutes to run a full analysis and I'll know what we're dealing with. The material composition is fascinating - it's responding to our ship's systems in ways that shouldn't be possible.\"",
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" The excitement in her voice was unmistakable as Kaira looked up from her console. \"These readings are incredible - I've never seen anything like this before. There are structures beneath the surface that predate our oldest sites by millions of years.\" She paused, processing the implications. \"If these readings are accurate, we may have found something truly significant.\"",
"Kaira: Something felt off about the proposed course of action. Kaira moved from her station to stand beside the Captain, organizing her thoughts carefully. \"Captain, I understand why you would want us to pursue this based on our current data,\" she began respectfully, clasping her hands behind her back. \"But something feels wrong about this. The quantum signatures have subtle variations that remind me of Hegemony cloaking technology. Maybe we should run a few more scans before we commit to a full approach.\"",
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world,\" she said softly, not turning from the observation deck viewport as footsteps approached. The stars wheeled slowly past in their eternal dance. \"Sometimes I dream of Altrusia's crystal gardens, the way our twin suns would set over the mountains.\" Kaira finally turned, her expression thoughtful. \"Then again, I suppose I was always meant to end up out here somehow. Perhaps this journey is exactly where I'm supposed to be.\""
],
"history_events": [],
"is_player": false,

View File

@@ -494,34 +494,6 @@
"registry": "state/GetState",
"base_type": "core/Node"
},
"8a050403-5c69-46f7-abe2-f65db4553942": {
"title": "TRUE",
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
"properties": {
"value": true
},
"x": 460,
"y": 670,
"width": 210,
"height": 58,
"collapsed": true,
"inherited": false,
"registry": "core/MakeBool",
"base_type": "core/Node"
},
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
"title": "Create Character",
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
"properties": {},
"x": 580,
"y": 550,
"width": 245,
"height": 186,
"collapsed": false,
"inherited": false,
"registry": "agents/creator/CreateCharacter",
"base_type": "core/Graph"
},
"970f12d0-330e-41b3-b025-9a53bcf2fc6f": {
"title": "SET - created_character",
"id": "970f12d0-330e-41b3-b025-9a53bcf2fc6f",
@@ -628,6 +600,34 @@
"inherited": false,
"registry": "data/string/MakeText",
"base_type": "core/Node"
},
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
"title": "Create Character",
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
"properties": {},
"x": 580,
"y": 550,
"width": 245,
"height": 206,
"collapsed": false,
"inherited": false,
"registry": "agents/creator/CreateCharacter",
"base_type": "core/Graph"
},
"8a050403-5c69-46f7-abe2-f65db4553942": {
"title": "TRUE",
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
"properties": {
"value": true
},
"x": 400,
"y": 720,
"width": 210,
"height": 58,
"collapsed": true,
"inherited": false,
"registry": "core/MakeBool",
"base_type": "core/Node"
}
},
"edges": {
@@ -714,17 +714,6 @@
"fd4cd318-121b-47de-84a0-b1ab62c5601b.value": [
"41c0c2cd-d39b-4528-a5fe-459094393ba3.list"
],
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
],
"a2457116-35cb-4571-ba1a-cbf63851544e.value": [
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.state",
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.character",
@@ -738,6 +727,18 @@
],
"5041f12d-507f-4f5d-a26f-048625974602.value": [
"b50e23d4-b456-4385-b80e-c2b6884c7855.template"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
],
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.assign_voice"
]
},
"groups": [
@@ -808,13 +809,14 @@
"inputs": [],
"outputs": [
{
"id": "f78fa84b-8b6f-4c8a-83c0-754537bb9060",
"id": "92423e0a-89d8-4ad8-a42d-fdcba75b31d1",
"name": "fn",
"optional": false,
"group": null,
"socket_type": "function"
}
],
"module_properties": {},
"style": {
"title_color": "#573a2e",
"node_color": "#392f2c",

View File

@@ -116,8 +116,8 @@
"context_aware": true,
"history_aware": true
},
"x": 568,
"y": 861,
"x": 372,
"y": 857,
"width": 249,
"height": 406,
"collapsed": false,
@@ -130,7 +130,7 @@
"id": "a8110d74-0fb5-4601-b883-6c63ceaa9d31",
"properties": {},
"x": 24,
"y": 2084,
"y": 2088,
"width": 140,
"height": 106,
"collapsed": false,
@@ -145,7 +145,7 @@
"attribute": "title"
},
"x": 213,
"y": 2094,
"y": 2098,
"width": 210,
"height": 98,
"collapsed": false,
@@ -160,7 +160,7 @@
"value": "The Simulation Suite"
},
"x": 213,
"y": 2284,
"y": 2288,
"width": 210,
"height": 58,
"collapsed": false,
@@ -177,7 +177,7 @@
"case_sensitive": true
},
"x": 523,
"y": 2184,
"y": 2188,
"width": 210,
"height": 126,
"collapsed": false,
@@ -192,7 +192,7 @@
"pass_through": true
},
"x": 774,
"y": 2185,
"y": 2189,
"width": 210,
"height": 78,
"collapsed": false,
@@ -207,7 +207,7 @@
"attribute": "0"
},
"x": 1629,
"y": 2411,
"y": 2415,
"width": 210,
"height": 98,
"collapsed": false,
@@ -222,7 +222,7 @@
"attribute": "title"
},
"x": 2189,
"y": 2321,
"y": 2325,
"width": 210,
"height": 98,
"collapsed": false,
@@ -237,7 +237,7 @@
"stage": 5
},
"x": 2459,
"y": 2331,
"y": 2335,
"width": 210,
"height": 118,
"collapsed": true,
@@ -250,7 +250,7 @@
"id": "8208d05c-1822-4f4a-ba75-cfd18d2de8ca",
"properties": {},
"x": 1989,
"y": 2331,
"y": 2335,
"width": 140,
"height": 106,
"collapsed": true,
@@ -266,7 +266,7 @@
"chars": "\"'*"
},
"x": 1899,
"y": 2411,
"y": 2415,
"width": 210,
"height": 102,
"collapsed": false,
@@ -282,7 +282,7 @@
"max_splits": 1
},
"x": 1359,
"y": 2411,
"y": 2415,
"width": 210,
"height": 102,
"collapsed": false,
@@ -304,7 +304,7 @@
"history_aware": true
},
"x": 1014,
"y": 2185,
"y": 2189,
"width": 276,
"height": 406,
"collapsed": false,
@@ -376,8 +376,8 @@
"name": "arg_goal",
"scope": "local"
},
"x": 228,
"y": 1041,
"x": 32,
"y": 1037,
"width": 256,
"height": 122,
"collapsed": false,
@@ -393,7 +393,7 @@
"max_scene_types": 2
},
"x": 671,
"y": 1490,
"y": 1494,
"width": 210,
"height": 122,
"collapsed": false,
@@ -409,7 +409,7 @@
"scope": "local"
},
"x": 30,
"y": 1551,
"y": 1555,
"width": 256,
"height": 122,
"collapsed": false,
@@ -448,37 +448,6 @@
"registry": "state/GetState",
"base_type": "core/Node"
},
"59d31050-f61d-4798-9790-e22d34ecbd4b": {
"title": "GET local.auto_direct_enabled",
"id": "59d31050-f61d-4798-9790-e22d34ecbd4b",
"properties": {
"name": "auto_direct_enabled",
"scope": "local"
},
"x": 20,
"y": 850,
"width": 256,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/GetState",
"base_type": "core/Node"
},
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8": {
"title": "Switch",
"id": "e8f19a05-43fe-4e4a-9cc4-bec0a29779d8",
"properties": {
"pass_through": true
},
"x": 310,
"y": 870,
"width": 210,
"height": 78,
"collapsed": false,
"inherited": false,
"registry": "core/Switch",
"base_type": "core/Node"
},
"b03fa942-c48e-4c04-b9ae-a009a7e0f947": {
"title": "GET local.auto_direct_enabled",
"id": "b03fa942-c48e-4c04-b9ae-a009a7e0f947",
@@ -487,7 +456,7 @@
"scope": "local"
},
"x": 24,
"y": 1367,
"y": 1371,
"width": 256,
"height": 122,
"collapsed": false,
@@ -502,7 +471,7 @@
"pass_through": true
},
"x": 360,
"y": 1390,
"y": 1394,
"width": 210,
"height": 78,
"collapsed": false,
@@ -518,7 +487,7 @@
"scope": "local"
},
"x": 30,
"y": 1821,
"y": 1826,
"width": 256,
"height": 122,
"collapsed": false,
@@ -533,7 +502,7 @@
"stage": 4
},
"x": 1100,
"y": 1870,
"y": 1875,
"width": 210,
"height": 118,
"collapsed": true,
@@ -546,7 +515,7 @@
"id": "9db37d1e-3cf8-49bd-bdc5-8663494e5657",
"properties": {},
"x": 670,
"y": 1840,
"y": 1845,
"width": 226,
"height": 62,
"collapsed": false,
@@ -561,7 +530,7 @@
"stage": 3
},
"x": 1080,
"y": 1520,
"y": 1524,
"width": 210,
"height": 118,
"collapsed": true,
@@ -576,7 +545,7 @@
"pass_through": true
},
"x": 370,
"y": 1840,
"y": 1845,
"width": 210,
"height": 78,
"collapsed": false,
@@ -590,8 +559,8 @@
"properties": {
"stage": 2
},
"x": 1280,
"y": 890,
"x": 1084,
"y": 886,
"width": 210,
"height": 118,
"collapsed": true,
@@ -599,14 +568,29 @@
"registry": "core/Stage",
"base_type": "core/Node"
},
"5559196c-f6b1-4223-8e13-2bf64e3cfef0": {
"title": "true",
"id": "5559196c-f6b1-4223-8e13-2bf64e3cfef0",
"properties": {
"value": true
},
"x": 24,
"y": 876,
"width": 210,
"height": 58,
"collapsed": false,
"inherited": false,
"registry": "core/MakeBool",
"base_type": "core/Node"
},
"6ef94917-f9b1-4c18-af15-617430e50cfe": {
"title": "Set Scene Intent",
"id": "6ef94917-f9b1-4c18-af15-617430e50cfe",
"properties": {
"intent": ""
},
"x": 930,
"y": 860,
"x": 731,
"y": 853,
"width": 210,
"height": 78,
"collapsed": false,
@@ -693,14 +677,8 @@
"e4cd1391-daed-4951-a6c6-438d993c07a9.state"
],
"c66bdaeb-4166-4835-9415-943af547c926.value": [
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description",
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state"
],
"59d31050-f61d-4798-9790-e22d34ecbd4b.value": [
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.value"
],
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.yes": [
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state",
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description"
],
"b03fa942-c48e-4c04-b9ae-a009a7e0f947.value": [
"8ad7c42c-110e-46ae-b649-4a1e6d055e25.value"
@@ -717,6 +695,9 @@
"6a8762c4-16cf-4e8c-9d10-8af7597c4097.yes": [
"9db37d1e-3cf8-49bd-bdc5-8663494e5657.state"
],
"5559196c-f6b1-4223-8e13-2bf64e3cfef0.value": [
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
],
"6ef94917-f9b1-4c18-af15-617430e50cfe.state": [
"f4cd34d9-0628-4145-a3da-ec1215cd356c.state"
]
@@ -745,7 +726,7 @@
{
"title": "Generate Scene Types",
"x": -1,
"y": 1287,
"y": 1290,
"width": 1298,
"height": 408,
"color": "#8AA",
@@ -756,8 +737,8 @@
"title": "Set story intention",
"x": -1,
"y": 773,
"width": 1539,
"height": 512,
"width": 1320,
"height": 514,
"color": "#8AA",
"font_size": 24,
"inherited": false
@@ -765,7 +746,7 @@
{
"title": "Evaluate Scene Intent",
"x": -1,
"y": 1697,
"y": 1701,
"width": 1293,
"height": 302,
"color": "#8AA",
@@ -775,7 +756,7 @@
{
"title": "Set title",
"x": -1,
"y": 2003,
"y": 2006,
"width": 2618,
"height": 637,
"color": "#8AA",
@@ -794,7 +775,7 @@
{
"text": "Some times the AI will produce more text after the title, we only care about the title on the first line.",
"x": 1359,
"y": 2311,
"y": 2315,
"width": 471,
"inherited": false
}
@@ -804,13 +785,14 @@
"inputs": [],
"outputs": [
{
"id": "dede1a38-2107-4475-9db5-358c09cb0d12",
"id": "5c8dee64-5832-40ba-b1e2-2a411d913cc7",
"name": "fn",
"optional": false,
"group": null,
"socket_type": "function"
}
],
"module_properties": {},
"style": {
"title_color": "#573a2e",
"node_color": "#392f2c",

View File

@@ -666,23 +666,6 @@
"registry": "data/ListAppend",
"base_type": "core/Node"
},
"4eb36f21-1020-4609-85e5-d16b42019c66": {
"title": "AI Function Calling",
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
"properties": {
"template": "computer",
"max_calls": 5,
"retries": 1
},
"x": 1000,
"y": 2581,
"width": 212,
"height": 206,
"collapsed": false,
"inherited": false,
"registry": "focal/Focal",
"base_type": "core/Node"
},
"3da498c0-55de-4ec3-9943-6486279b9826": {
"title": "GET scene loop.user_message",
"id": "3da498c0-55de-4ec3-9943-6486279b9826",
@@ -1167,6 +1150,24 @@
"inherited": false,
"registry": "data/MakeList",
"base_type": "core/Node"
},
"4eb36f21-1020-4609-85e5-d16b42019c66": {
"title": "AI Function Calling",
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
"properties": {
"template": "computer",
"max_calls": 5,
"retries": 1,
"response_length": 1408
},
"x": 1000,
"y": 2581,
"width": 210,
"height": 230,
"collapsed": false,
"inherited": false,
"registry": "focal/Focal",
"base_type": "core/Node"
}
},
"edges": {
@@ -1296,9 +1297,6 @@
"d45f07ff-c3ce-4aac-bae6-b9c77089cb69.list": [
"4eb36f21-1020-4609-85e5-d16b42019c66.callbacks"
],
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
],
"3da498c0-55de-4ec3-9943-6486279b9826.value": [
"6d707f1c-af55-481a-970a-eb7a9f9c45dd.message"
],
@@ -1380,6 +1378,9 @@
],
"632d4a0e-3327-409b-aaa4-ed38b932286b.list": [
"06175a92-abbe-4483-9ab2-abaee2104728.value"
],
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
]
},
"groups": [

View File

@@ -19,7 +19,7 @@ You have access to the following functions you must call to fulfill the user's r
focal.callbacks.set_simulated_environment.render(
"Create or change the simulated environment. This means the location, time, specific conditions, or any other aspect of the simulation that is not directly related to the characters.",
instructions="Instructions on how to change the simulated environment. These will be given to the simulation computer to setup the new environment. REQUIRED.",
reset="If true, the environment should be reset and all simulated characters are removed. If false, the environment should be changed but the characters should remain. REQUIRED.",
reset="If true, the environment should be reset and ALL simulated characters are removed. IMPORTANT: If you set reset=true, this function MUST be the FIRST call in your stack; otherwise, set reset=false to avoid deactivating characters added earlier. REQUIRED.",
examples=[
{"instructions": "Change the location to a lush forest, with a river running through it.", "reset":true},
{"instructions": "The simulation suite flickers and changes to a bustling city street.", "reset":true},
@@ -123,7 +123,7 @@ You have access to the following functions you must call to fulfill the user's r
{{
focal.callbacks.set_simulation_goal.render(
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives.",
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives. You MUST call this on new simulations or if the user has requested a change in the simulation's goal.",
goal="The overall goal of the simulation. This should be a clear and concise statement that outlines the user's objective. REQUIRED.",
examples=[
{"goal": "The user is exploring a mysterious alien planet to uncover the secrets of an ancient civilization."},

View File

@@ -18,10 +18,11 @@ from talemate.agents.context import ActiveAgent, active_agent
from talemate.emit import emit
from talemate.events import GameLoopStartEvent
from talemate.context import active_scene
import talemate.config as config
from talemate.ux.schema import Column
from talemate.config import get_config, Config
import talemate.config.schema as config_schema
from talemate.client.context import (
ClientContext,
set_client_context_attribute,
)
__all__ = [
@@ -53,19 +54,20 @@ class AgentActionConfig(pydantic.BaseModel):
type: str
label: str
description: str = ""
value: Union[int, float, str, bool, list, None] = None
default_value: Union[int, float, str, bool] = None
max: Union[int, float, None] = None
min: Union[int, float, None] = None
step: Union[int, float, None] = None
value: int | float | str | bool | list | None = None
default_value: int | float | str | bool | None = None
max: int | float | None = None
min: int | float | None = None
step: int | float | None = None
scope: str = "global"
choices: Union[list[dict[str, str]], None] = None
choices: list[dict[str, str | int | float | bool]] | None = None
note: Union[str, None] = None
expensive: bool = False
quick_toggle: bool = False
condition: Union[AgentActionConditional, None] = None
title: Union[str, None] = None
value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True)
condition: AgentActionConditional | None = None
title: str | None = None
value_migration: Callable | None = pydantic.Field(default=None, exclude=True)
columns: list[Column] | None = None
note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict)
@@ -78,20 +80,21 @@ class AgentAction(pydantic.BaseModel):
label: str
description: str = ""
warning: str = ""
config: Union[dict[str, AgentActionConfig], None] = None
condition: Union[AgentActionConditional, None] = None
config: dict[str, AgentActionConfig] | None = None
condition: AgentActionConditional | None = None
container: bool = False
icon: Union[str, None] = None
icon: str | None = None
can_be_disabled: bool = False
quick_toggle: bool = False
experimental: bool = False
class AgentDetail(pydantic.BaseModel):
value: Union[str, None] = None
description: Union[str, None] = None
icon: Union[str, None] = None
value: str | None = None
description: str | None = None
icon: str | None = None
color: str = "grey"
hidden: bool = False
class DynamicInstruction(pydantic.BaseModel):
@@ -172,11 +175,6 @@ def set_processing(fn):
if scene:
scene.continue_actions()
if getattr(scene, "config", None):
set_client_context_attribute(
"app_config_system_prompts", scene.config.get("system_prompts", {})
)
with ActiveAgent(self, fn, args, kwargs) as active_agent_context:
try:
await self.emit_status(processing=True)
@@ -221,7 +219,6 @@ class Agent(ABC):
verbose_name = None
set_processing = set_processing
requires_llm_client = True
auto_break_repetition = False
websocket_handler = None
essential = True
ready_check_error = None
@@ -235,6 +232,10 @@ class Agent(ABC):
return actions
@property
def config(self) -> Config:
return get_config()
@property
def agent_details(self):
if hasattr(self, "client"):
@@ -244,6 +245,12 @@ class Agent(ABC):
@property
def ready(self):
if not self.requires_llm_client:
return True
if not hasattr(self, "client"):
return False
if not getattr(self.client, "enabled", True):
return False
@@ -326,7 +333,7 @@ class Agent(ABC):
# scene state
def context_fingerpint(self, extra: list[str] = []) -> str | None:
def context_fingerprint(self, extra: list[str] = None) -> str | None:
active_agent_context = active_agent.get()
if not active_agent_context:
@@ -337,8 +344,9 @@ class Agent(ABC):
else:
fingerprint = f"START-{active_agent_context.first.fingerprint}"
for extra_key in extra:
fingerprint += f"-{hash(extra_key)}"
if extra:
for extra_key in extra:
fingerprint += f"-{hash(extra_key)}"
return fingerprint
@@ -448,25 +456,26 @@ class Agent(ABC):
except AttributeError:
pass
async def save_config(self, app_config: config.Config | None = None):
async def save_config(self):
"""
Saves the agent config to the config file.
If no config object is provided, the config is loaded from the config file.
"""
if not app_config:
app_config: config.Config = config.load_config(as_model=True)
app_config: Config = get_config()
app_config.agents[self.agent_type] = config.Agent(
app_config.agents[self.agent_type] = config_schema.Agent(
name=self.agent_type,
client=self.client.name if self.client else None,
client=self.client.name if getattr(self, "client", None) else None,
enabled=self.enabled,
actions={
action_key: config.AgentAction(
action_key: config_schema.AgentAction(
enabled=action.enabled,
config={
config_key: config.AgentActionConfig(value=config_obj.value)
config_key: config_schema.AgentActionConfig(
value=config_obj.value
)
for config_key, config_obj in action.config.items()
},
)
@@ -478,7 +487,8 @@ class Agent(ABC):
agent=self.agent_type,
config=app_config.agents[self.agent_type],
)
config.save_config(app_config)
app_config.dirty = True
async def on_game_loop_start(self, event: GameLoopStartEvent):
"""
@@ -602,7 +612,7 @@ class Agent(ABC):
exclude_fn: Callable = None,
):
current_memory_context = []
memory_helper = self.scene.get_helper("memory")
memory_helper = instance.get_agent("memory")
if memory_helper:
history_messages = "\n".join(
self.scene.recent_history(memory_history_context_max)

View File

@@ -23,6 +23,7 @@ from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
AgentActionNote,
AgentDetail,
AgentEmission,
DynamicInstruction,
@@ -85,12 +86,22 @@ class ConversationAgent(MemoryRAGMixin, Agent):
"format": AgentActionConfig(
type="text",
label="Format",
description="The generation format of the scene context, as seen by the AI.",
description="The generation format of the scene progression, as seen by the AI. Has no direct effect on your view of the scene, but will affect the way the AI perceives the scene and its characters, leading to changes in the response, for better or worse.",
choices=[
{"label": "Screenplay", "value": "movie_script"},
{"label": "Chat (legacy)", "value": "chat"},
{
"label": "Narrative (NEW, experimental)",
"value": "narrative",
},
],
value="movie_script",
note_on_value={
"narrative": AgentActionNote(
type="primary",
text="Will attempt to generate flowing, novel-like prose with scene intent awareness and character goal consideration. A reasoning model is STRONGLY recommended. Experimental and more prone to generate out of turn character actions and dialogue.",
)
},
),
"length": AgentActionConfig(
type="number",
@@ -133,12 +144,6 @@ class ConversationAgent(MemoryRAGMixin, Agent):
),
},
),
"auto_break_repetition": AgentAction(
enabled=True,
can_be_disabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"content": AgentAction(
enabled=True,
can_be_disabled=False,
@@ -161,7 +166,7 @@ class ConversationAgent(MemoryRAGMixin, Agent):
def __init__(
self,
client: client.TaleMateClient,
client: client.ClientBase | None = None,
kind: Optional[str] = "pygmalion",
logging_enabled: Optional[bool] = True,
**kwargs,
@@ -453,21 +458,31 @@ class ConversationAgent(MemoryRAGMixin, Agent):
if total_result.startswith(":\n") or total_result.startswith(": "):
total_result = total_result[2:]
# movie script format
# {uppercase character name}
# {dialogue}
total_result = total_result.replace(f"{character.name.upper()}\n", "")
conversation_format = self.conversation_format
# chat format
# {character name}: {dialogue}
total_result = total_result.replace(f"{character.name}:", "")
if conversation_format == "narrative":
# For narrative format, the LLM generates pure prose without character name prefixes
# We need to store it internally in the standard {name}: {text} format
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Only add character name if it's not already there
if not total_result.startswith(character.name + ":"):
total_result = f"{character.name}: {total_result}"
else:
# movie script format
# {uppercase character name}
# {dialogue}
total_result = total_result.replace(f"{character.name.upper()}\n", "")
# Removes partial sentence at the end
total_result = util.clean_dialogue(total_result, main_name=character.name)
# chat format
# {character name}: {dialogue}
total_result = total_result.replace(f"{character.name}:", "")
# Check if total_result starts with character name, if not, prepend it
if not total_result.startswith(character.name + ":"):
total_result = f"{character.name}: {total_result}"
# Removes partial sentence at the end
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Check if total_result starts with character name, if not, prepend it
if not total_result.startswith(character.name + ":"):
total_result = f"{character.name}: {total_result}"
total_result = total_result.strip()
@@ -499,9 +514,6 @@ class ConversationAgent(MemoryRAGMixin, Agent):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return agent_function_name == "converse"
def inject_prompt_paramters(

View File

@@ -39,7 +39,7 @@ class CreatorAgent(
def __init__(
self,
client: client.ClientBase,
client: client.ClientBase | None = None,
**kwargs,
):
self.client = client

View File

@@ -21,7 +21,7 @@
"num": 3
},
"x": 39,
"y": 507,
"y": 236,
"width": 210,
"height": 154,
"collapsed": false,
@@ -40,7 +40,7 @@
"num": 1
},
"x": 38,
"y": 107,
"y": -164,
"width": 210,
"height": 154,
"collapsed": false,
@@ -59,7 +59,7 @@
"num": 0
},
"x": 38,
"y": -93,
"y": -364,
"width": 210,
"height": 154,
"collapsed": false,
@@ -76,7 +76,7 @@
"num": 0
},
"x": 288,
"y": -63,
"y": -334,
"width": 210,
"height": 106,
"collapsed": true,
@@ -89,7 +89,7 @@
"id": "553125be-2c2b-4404-98b5-d6333a4f9655",
"properties": {},
"x": 348,
"y": 507,
"y": 236,
"width": 140,
"height": 26,
"collapsed": false,
@@ -102,7 +102,7 @@
"id": "207e357e-5e83-4d40-a331-d0041b9dfa49",
"properties": {},
"x": 348,
"y": 307,
"y": 36,
"width": 140,
"height": 26,
"collapsed": false,
@@ -115,7 +115,7 @@
"id": "bad7d2ed-f6fa-452b-8b47-d753ae7a45e0",
"properties": {},
"x": 348,
"y": 107,
"y": -164,
"width": 140,
"height": 26,
"collapsed": false,
@@ -131,7 +131,7 @@
"scope": "local"
},
"x": 598,
"y": 107,
"y": -164,
"width": 210,
"height": 122,
"collapsed": false,
@@ -147,7 +147,7 @@
"scope": "local"
},
"x": 598,
"y": 307,
"y": 36,
"width": 210,
"height": 122,
"collapsed": false,
@@ -163,7 +163,7 @@
"scope": "local"
},
"x": 598,
"y": 507,
"y": 236,
"width": 210,
"height": 122,
"collapsed": false,
@@ -176,7 +176,7 @@
"id": "1765a51c-82ac-4d96-9c60-d0adc7faaa68",
"properties": {},
"x": 498,
"y": 707,
"y": 436,
"width": 171,
"height": 26,
"collapsed": false,
@@ -192,7 +192,7 @@
"scope": "local"
},
"x": 798,
"y": 706,
"y": 435,
"width": 244,
"height": 122,
"collapsed": false,
@@ -205,7 +205,7 @@
"id": "f3dc37df-1748-4235-b575-60e55b8bec73",
"properties": {},
"x": 498,
"y": 906,
"y": 635,
"width": 171,
"height": 26,
"collapsed": false,
@@ -220,7 +220,7 @@
"stage": 0
},
"x": 1208,
"y": 366,
"y": 95,
"width": 210,
"height": 118,
"collapsed": true,
@@ -238,7 +238,7 @@
"icon": "F1719"
},
"x": 1178,
"y": -144,
"y": -415,
"width": 210,
"height": 130,
"collapsed": false,
@@ -253,7 +253,7 @@
"default": true
},
"x": 298,
"y": 736,
"y": 465,
"width": 210,
"height": 58,
"collapsed": true,
@@ -268,7 +268,7 @@
"default": false
},
"x": 298,
"y": 936,
"y": 665,
"width": 210,
"height": 58,
"collapsed": true,
@@ -287,7 +287,7 @@
"num": 2
},
"x": 38,
"y": 306,
"y": 35,
"width": 210,
"height": 154,
"collapsed": false,
@@ -303,7 +303,7 @@
"scope": "local"
},
"x": 798,
"y": 906,
"y": 635,
"width": 244,
"height": 122,
"collapsed": false,
@@ -322,7 +322,7 @@
"num": 5
},
"x": 38,
"y": 706,
"y": 435,
"width": 237,
"height": 154,
"collapsed": false,
@@ -341,7 +341,7 @@
"num": 7
},
"x": 38,
"y": 906,
"y": 635,
"width": 210,
"height": 154,
"collapsed": false,
@@ -360,7 +360,7 @@
"num": 4
},
"x": 48,
"y": 1326,
"y": 1055,
"width": 237,
"height": 154,
"collapsed": false,
@@ -375,7 +375,7 @@
"default": true
},
"x": 318,
"y": 1356,
"y": 1085,
"width": 210,
"height": 58,
"collapsed": true,
@@ -388,7 +388,7 @@
"id": "6d387c67-6b32-4435-b984-4760f0f1f8d2",
"properties": {},
"x": 498,
"y": 1336,
"y": 1065,
"width": 171,
"height": 26,
"collapsed": false,
@@ -404,7 +404,7 @@
"scope": "local"
},
"x": 798,
"y": 1316,
"y": 1045,
"width": 244,
"height": 122,
"collapsed": false,
@@ -455,7 +455,7 @@
"num": 6
},
"x": 38,
"y": 1106,
"y": 835,
"width": 252,
"height": 154,
"collapsed": false,
@@ -471,7 +471,7 @@
"apply_on_unresolved": true
},
"x": 518,
"y": 1136,
"y": 865,
"width": 210,
"height": 102,
"collapsed": true,
@@ -1101,7 +1101,7 @@
"num": 8
},
"x": 49,
"y": 1546,
"y": 1275,
"width": 210,
"height": 154,
"collapsed": false,
@@ -1116,7 +1116,7 @@
"default": false
},
"x": 329,
"y": 1586,
"y": 1315,
"width": 210,
"height": 58,
"collapsed": true,
@@ -1129,7 +1129,7 @@
"id": "8acfe789-fbb5-4e29-8fd8-2217b987c086",
"properties": {},
"x": 499,
"y": 1566,
"y": 1295,
"width": 171,
"height": 26,
"collapsed": false,
@@ -1145,7 +1145,7 @@
"scope": "local"
},
"x": 799,
"y": 1526,
"y": 1255,
"width": 244,
"height": 122,
"collapsed": false,
@@ -1258,8 +1258,8 @@
"output_name": "character",
"num": 0
},
"x": 331,
"y": 5739,
"x": 332,
"y": 6178,
"width": 210,
"height": 106,
"collapsed": false,
@@ -1274,8 +1274,8 @@
"name": "character",
"scope": "local"
},
"x": 51,
"y": 5729,
"x": 52,
"y": 6168,
"width": 210,
"height": 122,
"collapsed": false,
@@ -1291,8 +1291,8 @@
"output_name": "actor",
"num": 0
},
"x": 331,
"y": 5949,
"x": 332,
"y": 6388,
"width": 210,
"height": 106,
"collapsed": false,
@@ -1307,8 +1307,8 @@
"name": "actor",
"scope": "local"
},
"x": 51,
"y": 5949,
"x": 52,
"y": 6388,
"width": 210,
"height": 122,
"collapsed": false,
@@ -1323,7 +1323,7 @@
"stage": 0
},
"x": 1320,
"y": 1160,
"y": 889,
"width": 210,
"height": 118,
"collapsed": true,
@@ -1414,7 +1414,7 @@
"writing_style": null
},
"x": 320,
"y": 1200,
"y": 929,
"width": 270,
"height": 122,
"collapsed": true,
@@ -1430,7 +1430,7 @@
"scope": "local"
},
"x": 790,
"y": 1110,
"y": 839,
"width": 244,
"height": 122,
"collapsed": false,
@@ -1475,6 +1475,159 @@
"inherited": false,
"registry": "agents/creator/ContextualGenerate",
"base_type": "core/Node"
},
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1": {
"title": "assign_voice",
"id": "d61de1ad-6f2a-447f-918a-dce7e76ea3a1",
"properties": {},
"x": 509,
"y": 1544,
"width": 171,
"height": 26,
"collapsed": false,
"inherited": false,
"registry": "core/Watch",
"base_type": "core/Node"
},
"3d655827-b66b-4355-910d-96097e7f2f13": {
"title": "SET local.assign_voice",
"id": "3d655827-b66b-4355-910d-96097e7f2f13",
"properties": {
"name": "assign_voice",
"scope": "local"
},
"x": 810,
"y": 1506,
"width": 244,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/SetState",
"base_type": "core/Node"
},
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1": {
"title": "Stage 0",
"id": "6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1",
"properties": {
"stage": 0
},
"x": 1170,
"y": 1546,
"width": 210,
"height": 118,
"collapsed": true,
"inherited": false,
"registry": "core/Stage",
"base_type": "core/Node"
},
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f": {
"title": "IN assign_voice",
"id": "9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f",
"properties": {
"input_type": "bool",
"input_name": "assign_voice",
"input_optional": true,
"input_group": "",
"num": 9
},
"x": 60,
"y": 1527,
"width": 210,
"height": 154,
"collapsed": false,
"inherited": false,
"registry": "core/Input",
"base_type": "core/Node"
},
"de11206d-13db-44a5-befd-de559fb68d09": {
"title": "GET local.assign_voice",
"id": "de11206d-13db-44a5-befd-de559fb68d09",
"properties": {
"name": "assign_voice",
"scope": "local"
},
"x": 25,
"y": 5720,
"width": 240,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/GetState",
"base_type": "core/Node"
},
"97b196a3-e7a6-4cfa-905e-686a744890b7": {
"title": "Switch",
"id": "97b196a3-e7a6-4cfa-905e-686a744890b7",
"properties": {
"pass_through": true
},
"x": 355,
"y": 5740,
"width": 210,
"height": 78,
"collapsed": false,
"inherited": false,
"registry": "core/Switch",
"base_type": "core/Node"
},
"3956711a-a3df-4213-b739-104cbe704964": {
"title": "GET local.character",
"id": "3956711a-a3df-4213-b739-104cbe704964",
"properties": {
"name": "character",
"scope": "local"
},
"x": 25,
"y": 5930,
"width": 210,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/GetState",
"base_type": "core/Node"
},
"47b2b492-4178-4254-b468-3877a5341f66": {
"title": "Assign Voice",
"id": "47b2b492-4178-4254-b468-3877a5341f66",
"properties": {},
"x": 665,
"y": 5830,
"width": 161,
"height": 66,
"collapsed": false,
"inherited": false,
"registry": "agents/director/AssignVoice",
"base_type": "core/Node"
},
"44d78795-2d41-468b-94d5-399b9b655888": {
"title": "Stage 6",
"id": "44d78795-2d41-468b-94d5-399b9b655888",
"properties": {
"stage": 6
},
"x": 885,
"y": 5860,
"width": 210,
"height": 118,
"collapsed": true,
"inherited": false,
"registry": "core/Stage",
"base_type": "core/Node"
},
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85": {
"title": "As Bool",
"id": "f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85",
"properties": {
"default": true
},
"x": 330,
"y": 1560,
"width": 210,
"height": 58,
"collapsed": true,
"inherited": false,
"registry": "core/AsBool",
"base_type": "core/Node"
}
},
"edges": {
@@ -1726,15 +1879,39 @@
],
"4acb67ea-68ee-43ae-a8c6-98a2b0e0f053.text": [
"d41f0d98-14d5-49dd-8e57-7812fb9fee94.value"
],
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value": [
"3d655827-b66b-4355-910d-96097e7f2f13.value"
],
"3d655827-b66b-4355-910d-96097e7f2f13.value": [
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1.state"
],
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f.value": [
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value"
],
"de11206d-13db-44a5-befd-de559fb68d09.value": [
"97b196a3-e7a6-4cfa-905e-686a744890b7.value"
],
"97b196a3-e7a6-4cfa-905e-686a744890b7.yes": [
"47b2b492-4178-4254-b468-3877a5341f66.state"
],
"3956711a-a3df-4213-b739-104cbe704964.value": [
"47b2b492-4178-4254-b468-3877a5341f66.character"
],
"47b2b492-4178-4254-b468-3877a5341f66.state": [
"44d78795-2d41-468b-94d5-399b9b655888.state"
],
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value": [
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value"
]
},
"groups": [
{
"title": "Process Arguments - Stage 0",
"x": 1,
"y": -218,
"width": 1446,
"height": 1948,
"y": -490,
"width": 1432,
"height": 2216,
"color": "#3f789e",
"font_size": 24,
"inherited": false
@@ -1792,12 +1969,22 @@
{
"title": "Outputs",
"x": 0,
"y": 5642,
"y": 6080,
"width": 595,
"height": 472,
"color": "#8A8",
"font_size": 24,
"inherited": false
},
{
"title": "Assign Voice - Stage 6",
"x": 0,
"y": 5640,
"width": 1120,
"height": 437,
"color": "#3f789e",
"font_size": 24,
"inherited": false
}
],
"comments": [],
@@ -1805,6 +1992,7 @@
"base_type": "core/Graph",
"inputs": [],
"outputs": [],
"module_properties": {},
"style": {
"title_color": "#572e44",
"node_color": "#392c34",

View File

@@ -1,33 +1,24 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
import structlog
import traceback
import talemate.instance as instance
from talemate.emit import emit
from talemate.scene_message import DirectorMessage
from talemate.util import random_color
from talemate.character import deactivate_character
from talemate.status import LoadingStatus
from talemate.exceptions import GenerationCancelled
from talemate.scene_message import DirectorMessage, Flags
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
from talemate.agents.base import Agent, AgentAction, AgentActionConfig
from talemate.agents.registry import register
from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.client import ClientBase
from talemate.game.focal.schema import Call
from .guide import GuideSceneMixin
from .generate_choices import GenerateChoicesMixin
from .legacy_scene_instructions import LegacySceneInstructionsMixin
from .auto_direct import AutoDirectMixin
from .websocket_handler import DirectorWebsocketHandler
from .character_management import CharacterManagementMixin
import talemate.agents.director.nodes # noqa: F401
if TYPE_CHECKING:
from talemate import Character, Scene
log = structlog.get_logger("talemate.agent.director")
@@ -38,6 +29,7 @@ class DirectorAgent(
GenerateChoicesMixin,
AutoDirectMixin,
LegacySceneInstructionsMixin,
CharacterManagementMixin,
Agent,
):
agent_type = "director"
@@ -76,9 +68,10 @@ class DirectorAgent(
GenerateChoicesMixin.add_actions(actions)
GuideSceneMixin.add_actions(actions)
AutoDirectMixin.add_actions(actions)
CharacterManagementMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.is_enabled = True
self.client = client
self.next_direct_character = {}
@@ -101,178 +94,24 @@ class DirectorAgent(
def actor_direction_mode(self):
return self.actions["direct"].config["actor_direction_mode"].value
@set_processing
async def persist_characters_from_worldstate(
self, exclude: list[str] = None
) -> List[Character]:
created_characters = []
for character_name in self.scene.world_state.characters.keys():
if exclude and character_name.lower() in exclude:
continue
if character_name in self.scene.character_names:
continue
character = await self.persist_character(name=character_name)
created_characters.append(character)
self.scene.emit_status()
return created_characters
@set_processing
async def persist_character(
self,
name: str,
content: str = None,
attributes: str = None,
determine_name: bool = True,
templates: list[str] = None,
active: bool = True,
narrate_entry: bool = False,
narrate_entry_direction: str = "",
augment_attributes: str = "",
generate_attributes: bool = True,
description: str = "",
) -> Character:
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
narrator = instance.get_agent("narrator")
memory = instance.get_agent("memory")
scene: "Scene" = self.scene
any_attribute_templates = False
loading_status = LoadingStatus(max_steps=None, cancellable=True)
# Start of character creation
log.debug("persist_character", name=name)
# Determine the character's name (or clarify if it's already set)
if determine_name:
loading_status("Determining character name")
name = await creator.determine_character_name(name, instructions=content)
log.debug("persist_character", adjusted_name=name)
# Create the blank character
character: Character = self.scene.Character(name=name)
# Add the character to the scene
character.color = random_color()
actor = self.scene.Actor(
character=character, agent=instance.get_agent("conversation")
async def log_function_call(self, call: Call):
log.debug("director.log_function_call", call=call)
message = DirectorMessage(
message=f"Called {call.name}",
action=call.name,
flags=Flags.HIDDEN,
subtype="function_call",
)
await self.scene.add_actor(actor)
emit("director", message, data={"function_call": call.model_dump()})
try:
# Apply any character generation templates
if templates:
loading_status("Applying character generation templates")
templates = scene.world_state_manager.template_collection.collect_all(
templates
)
log.debug("persist_character", applying_templates=templates)
await scene.world_state_manager.apply_templates(
templates.values(),
character_name=character.name,
information=content,
)
# if any of the templates are attribute templates, then we no longer need to
# generate a character sheet
any_attribute_templates = any(
template.template_type == "character_attribute"
for template in templates.values()
)
log.debug(
"persist_character", any_attribute_templates=any_attribute_templates
)
if (
any_attribute_templates
and augment_attributes
and generate_attributes
):
log.debug(
"persist_character", augmenting_attributes=augment_attributes
)
loading_status("Augmenting character attributes")
additional_attributes = await world_state.extract_character_sheet(
name=name,
text=content,
augmentation_instructions=augment_attributes,
)
character.base_attributes.update(additional_attributes)
# Generate a character sheet if there are no attribute templates
if not any_attribute_templates and generate_attributes:
loading_status("Generating character sheet")
log.debug("persist_character", extracting_character_sheet=True)
if not attributes:
attributes = await world_state.extract_character_sheet(
name=name, text=content
)
else:
attributes = world_state._parse_character_sheet(attributes)
log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
# Generate a description for the character
if not description:
loading_status("Generating character description")
description = await creator.determine_character_description(
character, information=content
)
character.description = description
log.debug("persist_character", description=description)
# Generate a dialogue instructions for the character
loading_status("Generating acting instructions")
dialogue_instructions = (
await creator.determine_character_dialogue_instructions(
character, information=content
)
)
character.dialogue_instructions = dialogue_instructions
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
# Narrate the character's entry if the option is selected
if active and narrate_entry:
loading_status("Narrating character entry")
is_present = await world_state.is_character_present(name)
if not is_present:
await narrator.action_to_narration(
"narrate_character_entry",
emit_message=True,
character=character,
narrative_direction=narrate_entry_direction,
)
# Deactivate the character if not active
if not active:
await deactivate_character(scene, character)
# Commit the character's details to long term memory
await character.commit_to_memory(memory)
self.scene.emit_status()
self.scene.world_state.emit()
loading_status.done(
message=f"{character.name} added to scene", status="success"
)
return character
except GenerationCancelled:
loading_status.done(message="Character creation cancelled", status="idle")
await scene.remove_actor(actor)
except Exception:
loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor)
log.error("Error persisting character", error=traceback.format_exc())
async def log_action(self, action: str, action_description: str):
message = DirectorMessage(message=action_description, action=action)
async def log_action(
self, action: str, action_description: str, console_only: bool = False
):
message = DirectorMessage(
message=action_description,
action=action,
flags=Flags.HIDDEN if console_only else Flags.NONE,
)
self.scene.push_history(message)
emit("director", message)

View File

@@ -0,0 +1,333 @@
from typing import TYPE_CHECKING
import traceback
import structlog
import talemate.instance as instance
import talemate.agents.tts.voice_library as voice_library
from talemate.agents.tts.schema import Voice
from talemate.util import random_color
from talemate.character import deactivate_character, set_voice
from talemate.status import LoadingStatus
from talemate.exceptions import GenerationCancelled
from talemate.agents.base import AgentAction, AgentActionConfig, set_processing
import talemate.game.focal as focal
__all__ = [
"CharacterManagementMixin",
]
log = structlog.get_logger()
if TYPE_CHECKING:
from talemate import Character, Scene
from talemate.agents.tts import TTSAgent
class VoiceCandidate(Voice):
used: bool = False
class CharacterManagementMixin:
"""
Director agent mixin that provides functionality for automatically guiding
the actors or the narrator during the scene progression.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["character_management"] = AgentAction(
enabled=True,
container=True,
can_be_disabled=False,
label="Character Management",
icon="mdi-account",
description="Configure how the director manages characters.",
config={
"assign_voice": AgentActionConfig(
type="bool",
label="Assign Voice (TTS)",
description="If enabled, the director is allowed to assign a text-to-speech voice when persisting a character.",
value=True,
title="Persisting Characters",
),
},
)
# config property helpers
@property
def cm_assign_voice(self) -> bool:
return self.actions["character_management"].config["assign_voice"].value
@property
def cm_should_assign_voice(self) -> bool:
if not self.cm_assign_voice:
return False
tts_agent: "TTSAgent" = instance.get_agent("tts")
if not tts_agent.enabled:
return False
if not tts_agent.ready_apis:
return False
return True
# actions
@set_processing
async def persist_characters_from_worldstate(
self, exclude: list[str] = None
) -> list["Character"]:
created_characters = []
for character_name in self.scene.world_state.characters.keys():
if exclude and character_name.lower() in exclude:
continue
if character_name in self.scene.character_names:
continue
character = await self.persist_character(name=character_name)
created_characters.append(character)
self.scene.emit_status()
return created_characters
@set_processing
async def persist_character(
self,
name: str,
content: str = None,
attributes: str = None,
determine_name: bool = True,
templates: list[str] = None,
active: bool = True,
narrate_entry: bool = False,
narrate_entry_direction: str = "",
augment_attributes: str = "",
generate_attributes: bool = True,
description: str = "",
assign_voice: bool = True,
is_player: bool = False,
) -> "Character":
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
narrator = instance.get_agent("narrator")
memory = instance.get_agent("memory")
scene: "Scene" = self.scene
any_attribute_templates = False
loading_status = LoadingStatus(max_steps=None, cancellable=True)
# Start of character creation
log.debug("persist_character", name=name)
# Determine the character's name (or clarify if it's already set)
if determine_name:
loading_status("Determining character name")
name = await creator.determine_character_name(name, instructions=content)
log.debug("persist_character", adjusted_name=name)
# Create the blank character
character: "Character" = self.scene.Character(name=name, is_player=is_player)
# Add the character to the scene
character.color = random_color()
if is_player:
actor = self.scene.Player(
character=character, agent=instance.get_agent("conversation")
)
else:
actor = self.scene.Actor(
character=character, agent=instance.get_agent("conversation")
)
await self.scene.add_actor(actor)
try:
# Apply any character generation templates
if templates:
loading_status("Applying character generation templates")
templates = scene.world_state_manager.template_collection.collect_all(
templates
)
log.debug("persist_character", applying_templates=templates)
await scene.world_state_manager.apply_templates(
templates.values(),
character_name=character.name,
information=content,
)
# if any of the templates are attribute templates, then we no longer need to
# generate a character sheet
any_attribute_templates = any(
template.template_type == "character_attribute"
for template in templates.values()
)
log.debug(
"persist_character", any_attribute_templates=any_attribute_templates
)
if (
any_attribute_templates
and augment_attributes
and generate_attributes
):
log.debug(
"persist_character", augmenting_attributes=augment_attributes
)
loading_status("Augmenting character attributes")
additional_attributes = await world_state.extract_character_sheet(
name=name,
text=content,
augmentation_instructions=augment_attributes,
)
character.base_attributes.update(additional_attributes)
# Generate a character sheet if there are no attribute templates
if not any_attribute_templates and generate_attributes:
loading_status("Generating character sheet")
log.debug("persist_character", extracting_character_sheet=True)
if not attributes:
attributes = await world_state.extract_character_sheet(
name=name, text=content
)
else:
attributes = world_state._parse_character_sheet(attributes)
log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
# Generate a description for the character
if not description:
loading_status("Generating character description")
description = await creator.determine_character_description(
character, information=content
)
character.description = description
log.debug("persist_character", description=description)
# Generate a dialogue instructions for the character
loading_status("Generating acting instructions")
dialogue_instructions = (
await creator.determine_character_dialogue_instructions(
character, information=content
)
)
character.dialogue_instructions = dialogue_instructions
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
# Narrate the character's entry if the option is selected
if active and narrate_entry:
loading_status("Narrating character entry")
is_present = await world_state.is_character_present(name)
if not is_present:
await narrator.action_to_narration(
"narrate_character_entry",
emit_message=True,
character=character,
narrative_direction=narrate_entry_direction,
)
if assign_voice:
await self.assign_voice_to_character(character)
# Deactivate the character if not active
if not active:
await deactivate_character(scene, character)
# Commit the character's details to long term memory
await character.commit_to_memory(memory)
self.scene.emit_status()
self.scene.world_state.emit()
loading_status.done(
message=f"{character.name} added to scene", status="success"
)
return character
except GenerationCancelled:
loading_status.done(message="Character creation cancelled", status="idle")
await scene.remove_actor(actor)
except Exception:
loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor)
log.error("Error persisting character", error=traceback.format_exc())
@set_processing
async def assign_voice_to_character(
self, character: "Character"
) -> list[focal.Call]:
tts_agent: "TTSAgent" = instance.get_agent("tts")
if not self.cm_should_assign_voice:
log.debug("assign_voice_to_character", skip=True, reason="not enabled")
return
vl: voice_library.VoiceLibrary = voice_library.get_instance()
ready_tts_apis = tts_agent.ready_apis
voices_global = voice_library.voices_for_apis(ready_tts_apis, vl)
voices_scene = voice_library.voices_for_apis(
ready_tts_apis, self.scene.voice_library
)
voices = voices_global + voices_scene
if not voices:
log.debug(
"assign_voice_to_character", skip=True, reason="no voices available"
)
return
voice_candidates = {
voice.id: VoiceCandidate(**voice.model_dump()) for voice in voices
}
for scene_character in self.scene.all_characters:
if scene_character.voice:
voice_candidates[scene_character.voice.id].used = True
async def assign_voice(voice_id: str):
voice = vl.get_voice(voice_id) or self.scene.voice_library.get_voice(
voice_id
)
if not voice:
log.error(
"assign_voice_to_character",
skip=True,
reason="voice not found",
voice_id=voice_id,
)
return
await set_voice(character, voice, auto=True)
await self.log_action(
f"Assigned voice `{voice.label}` to `{character.name}`",
"Assigned voice",
console_only=True,
)
focal_handler = focal.Focal(
self.client,
callbacks=[
focal.Callback(
name="assign_voice",
arguments=[focal.Argument(name="voice_id", type="str")],
fn=assign_voice,
),
],
max_calls=1,
character=character,
voices=list(voice_candidates.values()),
scene=self.scene,
narrator_voice=tts_agent.narrator_voice,
)
await focal_handler.request("director.cm-assign-voice")
log.debug("assign_voice_to_character", calls=focal_handler.state.calls)
return focal_handler.state.calls

View File

@@ -231,7 +231,9 @@ class GuideSceneMixin:
if cached_guidance:
if not analysis:
return cached_guidance.get("guidance")
elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]):
elif cached_guidance.get("fp") == self.context_fingerprint(
extra=[analysis]
):
return cached_guidance.get("guidance")
return None
@@ -250,7 +252,7 @@ class GuideSceneMixin:
self.set_scene_states(
**{
key: {
"fp": self.context_fingerpint(extra=[analysis]),
"fp": self.context_fingerprint(extra=[analysis]),
"guidance": guidance,
"analysis_type": analysis_type,
"character": character.name if character else None,

View File

@@ -7,7 +7,7 @@ from talemate.game.engine.nodes.core import (
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
from talemate.character import Character
TYPE_CHOICES.extend(
[
@@ -77,3 +77,90 @@ class PersistCharacter(AgentNode):
)
self.set_output_values({"state": state, "character": character})
@register("agents/director/AssignVoice")
class AssignVoice(AgentNode):
"""
Assigns a voice to a character.
"""
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Assign Voice", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_output("state")
self.add_output("character", socket_type="character")
self.add_output("voice", socket_type="tts/voice")
async def run(self, state: GraphState):
character: "Character" = self.require_input("character")
await self.agent.assign_voice_to_character(character)
voice = character.voice
self.set_output_values({"state": state, "character": character, "voice": voice})
@register("agents/director/LogAction")
class LogAction(AgentNode):
"""
Logs an action to the console.
"""
_agent_name: ClassVar[str] = "director"
class Fields:
action = PropertyField(
name="action",
type="str",
description="The action to log",
default="",
)
action_description = PropertyField(
name="action_description",
type="str",
description="The description of the action",
default="",
)
console_only = PropertyField(
name="console_only",
type="bool",
description="Whether to log the action to the console only",
default=False,
)
def __init__(self, title="Log Director Action", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("action", socket_type="str")
self.add_input("action_description", socket_type="str")
self.add_input("console_only", socket_type="bool", optional=True)
self.set_property("action", "")
self.set_property("action_description", "")
self.set_property("console_only", False)
self.add_output("state")
async def run(self, state: GraphState):
state = self.require_input("state")
action = self.require_input("action")
action_description = self.require_input("action_description")
console_only = self.normalized_input_value("console_only") or False
await self.agent.log_action(
action=action,
action_description=action_description,
console_only=console_only,
)
self.set_output_values({"state": state})

View File

@@ -5,8 +5,9 @@ from typing import TYPE_CHECKING
from talemate.instance import get_agent
from talemate.server.websocket_plugin import Plugin
from talemate.context import interaction
from talemate.context import interaction, handle_generation_cancelled
from talemate.status import set_loading
from talemate.exceptions import GenerationCancelled
if TYPE_CHECKING:
from talemate.tale_mate import Scene
@@ -45,6 +46,12 @@ class PersistCharacterPayload(pydantic.BaseModel):
content: str = ""
description: str = ""
is_player: bool = False
class AssignVoiceToCharacterPayload(pydantic.BaseModel):
character_name: str
class DirectorWebsocketHandler(Plugin):
"""
@@ -105,7 +112,13 @@ class DirectorWebsocketHandler(Plugin):
async def handle_task_done(task):
if task.exception():
log.error("Error persisting character", error=task.exception())
exc = task.exception()
log.error("Error persisting character", error=exc)
# Handle GenerationCancelled properly to reset cancel_requested flag
if isinstance(exc, GenerationCancelled):
handle_generation_cancelled(exc)
await self.signal_operation_failed("Error persisting character")
else:
self.websocket_handler.queue_put(
@@ -118,3 +131,63 @@ class DirectorWebsocketHandler(Plugin):
await self.signal_operation_done()
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
async def handle_assign_voice_to_character(self, data: dict):
"""
Assign a voice to a character using the director agent
"""
try:
payload = AssignVoiceToCharacterPayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
scene: "Scene" = self.scene
if not scene:
await self.signal_operation_failed("No scene active")
return
character = scene.get_character(payload.character_name)
if not character:
await self.signal_operation_failed(
f"Character '{payload.character_name}' not found"
)
return
character.voice = None
# Add as asyncio task
task = asyncio.create_task(self.director.assign_voice_to_character(character))
async def handle_task_done(task):
if task.exception():
exc = task.exception()
log.error("Error assigning voice to character", error=exc)
# Handle GenerationCancelled properly to reset cancel_requested flag
if isinstance(exc, GenerationCancelled):
handle_generation_cancelled(exc)
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "assign_voice_to_character_failed",
"character_name": payload.character_name,
"error": str(exc),
}
)
await self.signal_operation_failed(
f"Error assigning voice to character: {exc}"
)
else:
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "assign_voice_to_character_done",
"character_name": payload.character_name,
}
)
await self.signal_operation_done()
self.scene.emit_status()
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))

View File

@@ -6,6 +6,7 @@ import structlog
import talemate.emit.async_signals
import talemate.util as util
from talemate.client import ClientBase
from talemate.prompts import Prompt
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
@@ -86,7 +87,7 @@ class EditorAgent(
RevisionMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.client = client
self.is_enabled = True
self.actions = EditorAgent.init_actions()

View File

@@ -60,5 +60,8 @@ class EditorWebsocketHandler(Plugin):
character=character,
)
revised = await editor.revision_revise(info)
if isinstance(message, CharacterMessage):
if not revised.startswith(character.name + ":"):
revised = f"{character.name}: {revised}"
scene.edit_message(message.id, revised)

View File

@@ -23,10 +23,9 @@ from talemate.agents.base import (
AgentDetail,
set_processing,
)
from talemate.config import load_config
from talemate.config.schema import EmbeddingFunctionPreset
from talemate.context import scene_is_loading, active_scene
from talemate.emit import emit
from talemate.emit.signals import handlers
import talemate.emit.async_signals as async_signals
from talemate.agents.memory.context import memory_request, MemoryRequest
from talemate.agents.memory.exceptions import (
@@ -107,14 +106,13 @@ class MemoryAgent(Agent):
}
return actions
def __init__(self, scene, **kwargs):
def __init__(self, **kwargs):
self.db = None
self.scene = scene
self.memory_tracker = {}
self.config = load_config()
self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved)
async_signals.get("config.changed").connect(self.on_config_changed)
async_signals.get("client.embeddings_available").connect(
self.on_client_embeddings_available
)
@@ -136,28 +134,29 @@ class MemoryAgent(Agent):
@property
def get_presets(self):
def _label(embedding: dict):
prefix = (
embedding["client"] if embedding["client"] else embedding["embeddings"]
)
if embedding["model"]:
return f"{prefix}: {embedding['model']}"
def _label(embedding: EmbeddingFunctionPreset):
prefix = embedding.client if embedding.client else embedding.embeddings
if embedding.model:
return f"{prefix}: {embedding.model}"
else:
return f"{prefix}"
return [
{"value": k, "label": _label(v)}
for k, v in self.config.get("presets", {}).get("embeddings", {}).items()
for k, v in self.config.presets.embeddings.items()
]
@property
def embeddings_config(self):
_embeddings = self.actions["_config"].config["embeddings"].value
return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {})
return self.config.presets.embeddings.get(_embeddings)
@property
def embeddings(self):
return self.embeddings_config.get("embeddings", "sentence-transformer")
try:
return self.embeddings_config.embeddings
except AttributeError:
return None
@property
def using_openai_embeddings(self):
@@ -181,22 +180,31 @@ class MemoryAgent(Agent):
@property
def embeddings_client(self):
return self.embeddings_config.get("client")
try:
return self.embeddings_config.client
except AttributeError:
return None
@property
def max_distance(self) -> float:
distance = float(self.embeddings_config.get("distance", 1.0))
distance_mod = float(self.embeddings_config.get("distance_mod", 1.0))
distance = float(self.embeddings_config.distance)
distance_mod = float(self.embeddings_config.distance_mod)
return distance * distance_mod
@property
def model(self):
return self.embeddings_config.get("model")
try:
return self.embeddings_config.model
except AttributeError:
return None
@property
def distance_function(self):
return self.embeddings_config.get("distance_function", "l2")
try:
return self.embeddings_config.distance_function
except AttributeError:
return None
@property
def device(self) -> str:
@@ -204,7 +212,10 @@ class MemoryAgent(Agent):
@property
def trust_remote_code(self) -> bool:
return self.embeddings_config.get("trust_remote_code", False)
try:
return self.embeddings_config.trust_remote_code
except AttributeError:
return False
@property
def fingerprint(self) -> str:
@@ -241,7 +252,7 @@ class MemoryAgent(Agent):
if self.using_sentence_transformer_embeddings and not self.model:
self.actions["_config"].config["embeddings"].value = "default"
if not scene or not scene.get_helper("memory"):
if not scene or not scene.active:
return
self.close_db(scene)
@@ -255,7 +266,14 @@ class MemoryAgent(Agent):
self.actions["_config"].config["embeddings"].choices = self.get_presets
return self.actions["_config"].config["embeddings"].choices
def on_config_saved(self, event):
async def fix_broken_embeddings(self):
if not self.embeddings_config:
self.actions["_config"].config["embeddings"].value = "default"
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
async def on_config_changed(self, event):
loop = asyncio.get_running_loop()
openai_key = self.openai_api_key
@@ -263,7 +281,6 @@ class MemoryAgent(Agent):
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
self.config = load_config()
new_presets = self.sync_presets()
if fingerprint != self.fingerprint:
log.warning(
@@ -285,10 +302,13 @@ class MemoryAgent(Agent):
if emit_status:
loop.run_until_complete(self.emit_status())
await self.fix_broken_embeddings()
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
current_embeddings = self.actions["_config"].config["embeddings"].value
if current_embeddings == event.client.embeddings_identifier:
event.seen = True
return
if not self.using_client_api_embeddings or not self.ready:
@@ -304,6 +324,7 @@ class MemoryAgent(Agent):
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
event.seen = True
@set_processing
async def set_db(self):
@@ -837,7 +858,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
return self.config.openai.api_key
@property
def embedding_function(self) -> Callable:

View File

@@ -138,11 +138,6 @@ class NarratorAgent(MemoryRAGMixin, Agent):
),
},
),
"auto_break_repetition": AgentAction(
enabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"content": AgentAction(
enabled=True,
can_be_disabled=False,
@@ -210,7 +205,7 @@ class NarratorAgent(MemoryRAGMixin, Agent):
def __init__(
self,
client: client.TaleMateClient,
client: client.ClientBase | None = None,
**kwargs,
):
self.client = client
@@ -753,9 +748,6 @@ class NarratorAgent(MemoryRAGMixin, Agent):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return True
def set_generation_overrides(self, prompt_param: dict):

View File

@@ -16,7 +16,7 @@ from talemate.scene_message import (
ReinforcementMessage,
)
from talemate.world_state.templates import GenerationOptions
from talemate.client import ClientBase
from talemate.agents.base import (
Agent,
AgentAction,
@@ -34,6 +34,7 @@ from talemate.history import ArchiveEntry
from .analyze_scene import SceneAnalyzationMixin
from .context_investigation import ContextInvestigationMixin
from .layered_history import LayeredHistoryMixin
from .tts_utils import TTSUtilsMixin
if TYPE_CHECKING:
from talemate.tale_mate import Character
@@ -71,6 +72,7 @@ class SummarizeAgent(
ContextInvestigationMixin,
# Needs to be after ContextInvestigationMixin so signals are connected in the right order
SceneAnalyzationMixin,
TTSUtilsMixin,
Agent,
):
"""
@@ -129,7 +131,7 @@ class SummarizeAgent(
ContextInvestigationMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.client = client
self.actions = SummarizeAgent.init_actions()

View File

@@ -16,12 +16,29 @@ from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission
from talemate.agents.context import active_agent
from talemate.agents.base import RagBuildSubInstructionEmission
from contextvars import ContextVar
if TYPE_CHECKING:
from talemate.tale_mate import Character
log = structlog.get_logger()
## CONTEXT
scene_analysis_disabled_context = ContextVar("scene_analysis_disabled", default=False)
class SceneAnalysisDisabled:
"""
Context manager to disable scene analysis during specific agent actions.
"""
def __enter__(self):
self.token = scene_analysis_disabled_context.set(True)
def __exit__(self, _exc_type, _exc_value, _traceback):
scene_analysis_disabled_context.reset(self.token)
talemate.emit.async_signals.register(
"agent.summarization.scene_analysis.before",
@@ -184,6 +201,16 @@ class SceneAnalyzationMixin:
if not self.analyze_scene:
return
try:
if scene_analysis_disabled_context.get():
log.debug(
"on_inject_instructions: scene analysis disabled through context",
emission=emission,
)
return
except LookupError:
pass
analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}")
if not analyze_scene_for_type:
@@ -259,7 +286,14 @@ class SceneAnalyzationMixin:
if not cached_analysis:
return None
fingerprint = self.context_fingerpint()
fingerprint = self.context_fingerprint()
log.debug(
"get_cached_analysis",
fingerprint=fingerprint,
cached_analysis_fp=cached_analysis.get("fp"),
match=cached_analysis.get("fp") == fingerprint,
)
if cached_analysis.get("fp") == fingerprint:
return cached_analysis["guidance"]
@@ -271,7 +305,7 @@ class SceneAnalyzationMixin:
Sets the cached analysis for the given type.
"""
fingerprint = self.context_fingerpint()
fingerprint = self.context_fingerprint()
self.set_scene_states(
**{

View File

@@ -0,0 +1,59 @@
import structlog
from talemate.agents.base import (
set_processing,
)
from talemate.prompts import Prompt
from talemate.status import set_loading
from talemate.util.dialogue import separate_dialogue_from_exposition
log = structlog.get_logger("talemate.agents.summarize.tts_utils")
class TTSUtilsMixin:
"""
Summarizer Mixin for text-to-speech utilities.
"""
@set_loading("Preparing TTS context")
@set_processing
async def markup_context_for_tts(self, text: str) -> str:
"""
Markup the context for text-to-speech.
"""
original_text = text
log.debug("Markup context for TTS", text=text)
# if there are no quotes in the text, there is nothing to separate
if '"' not in text:
return original_text
# here we separate dialogue from exposition because into
# obvious segments. It seems to have a positive effect on some
# LLMs returning the complete text.
separate_chunks = separate_dialogue_from_exposition(text)
numbered_chunks = []
for i, chunk in enumerate(separate_chunks):
numbered_chunks.append(f"[{i + 1}] {chunk.text.strip()}")
text = "\n".join(numbered_chunks)
response = await Prompt.request(
"summarizer.markup-context-for-tts",
self.client,
"investigate_1024",
vars={
"text": text,
"max_tokens": self.client.max_token_length,
"scene": self.scene,
},
)
try:
response = response.split("<MARKUP>")[1].split("</MARKUP>")[0].strip()
return response
except IndexError:
log.error("Failed to extract markup from response", response=response)
return original_text

View File

@@ -1,670 +0,0 @@
from __future__ import annotations
import asyncio
import base64
import functools
import io
import os
import tempfile
import time
import uuid
from typing import Union
import httpx
import nltk
import pydantic
import structlog
from nltk.tokenize import sent_tokenize
from openai import AsyncOpenAI
import talemate.config as config
import talemate.emit.async_signals
import talemate.instance as instance
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import CharacterMessage, NarratorMessage
from .base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .registry import register
try:
from TTS.api import TTS
except ImportError:
TTS = None
log = structlog.get_logger("talemate.agents.tts") #
if not TTS:
# TTS installation is massive and requires a lot of dependencies
# so we don't want to require it unless the user wants to use it
log.info(
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
)
def parse_chunks(text: str) -> list[str]:
"""
Takes a string and splits it into chunks based on punctuation.
In case of an error it will return the original text as a single chunk and
the error will be logged.
"""
try:
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
chunk = chunk.replace("*", "")
if not chunk:
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
except Exception as e:
log.error("chunking error", error=e, text=text)
return [text.replace("__ellipsis__", "...").replace("*", "")]
def clean_quotes(chunk: str):
# if there is an uneven number of quotes, remove the last one if its
# at the end of the chunk. If its in the middle, add a quote to the end
if chunk.count('"') % 2 == 1:
if chunk.endswith('"'):
chunk = chunk[:-1]
else:
chunk += '"'
return chunk
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(clean_quotes(current_chunk))
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(clean_quotes(current_chunk))
return joined_chunks
class Voice(pydantic.BaseModel):
value: str
label: str
class VoiceLibrary(pydantic.BaseModel):
api: str
voices: list[Voice] = pydantic.Field(default_factory=list)
last_synced: float = None
@register()
class TTSAgent(Agent):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Voice"
requires_llm_client = False
essential = False
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if agent:
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
voice.model_dump() for voice in agent.list_voices_sync()
]
return config_options
def __init__(self, **kwargs):
self.is_enabled = False #
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
try:
nltk.download("punkt", quiet=True)
except Exception as e:
log.error("nltk download error", error=e)
except Exception as e:
log.error("nltk find error", error=e)
self.voices = {
"elevenlabs": VoiceLibrary(api="elevenlabs"),
"tts": VoiceLibrary(api="tts"),
"openai": VoiceLibrary(api="openai"),
}
self.config = config.load_config()
self.playback_done_event = asyncio.Event()
self.preselect_voice = None
self.actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"api": AgentActionConfig(
type="text",
choices=[
{"value": "tts", "label": "TTS (Local)"},
{"value": "elevenlabs", "label": "Eleven Labs"},
{"value": "openai", "label": "OpenAI"},
],
value="tts",
label="API",
description="Which TTS API to use",
onchange="emit",
),
"voice_id": AgentActionConfig(
type="text",
value="default",
label="Narrator Voice",
description="Voice ID/Name to use for TTS",
choices=[],
),
"generate_for_player": AgentActionConfig(
type="bool",
value=False,
label="Generate for player",
description="Generate audio for player messages",
),
"generate_for_npc": AgentActionConfig(
type="bool",
value=True,
label="Generate for NPCs",
description="Generate audio for NPC messages",
),
"generate_for_narration": AgentActionConfig(
type="bool",
value=True,
label="Generate for narration",
description="Generate audio for narration messages",
),
"generate_chunks": AgentActionConfig(
type="bool",
value=False,
label="Split generation",
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
),
},
),
"openai": AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
condition=AgentActionConditional(
attribute="_config.config.api", value="openai"
),
label="OpenAI",
config={
"model": AgentActionConfig(
type="text",
value="tts-1",
choices=[
{"value": "tts-1", "label": "TTS 1"},
{"value": "tts-1-hd", "label": "TTS 1 HD"},
],
label="Model",
description="TTS model to use",
),
},
),
}
self.actions["_config"].model_dump()
handlers["config_saved"].connect(self.on_config_saved)
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def not_ready_reason(self) -> str:
"""
Returns a string explaining why the agent is not ready
"""
if self.ready:
return ""
if self.api == "tts":
if not TTS:
return "TTS not installed"
elif self.requires_token and not self.token:
return "No API token"
elif not self.default_voice_id:
return "No voice selected"
@property
def agent_details(self):
details = {
"api": AgentDetail(
icon="mdi-server-outline",
value=self.api_label,
description="The backend to use for TTS",
).model_dump(),
}
if self.ready and self.enabled:
details["voice"] = AgentDetail(
icon="mdi-account-voice",
value=self.voice_id_to_label(self.default_voice_id) or "",
description="The voice to use for TTS",
color="info",
).model_dump()
elif self.enabled:
details["error"] = AgentDetail(
icon="mdi-alert",
value=self.not_ready_reason,
description=self.not_ready_reason,
color="error",
).model_dump()
return details
@property
def api(self):
return self.actions["_config"].config["api"].value
@property
def api_label(self):
choices = self.actions["_config"].config["api"].choices
api = self.api
for choice in choices:
if choice["value"] == api:
return choice["label"]
return api
@property
def token(self):
api = self.api
return self.config.get(api, {}).get("api_key")
@property
def default_voice_id(self):
return self.actions["_config"].config["voice_id"].value
@property
def requires_token(self):
return self.api != "tts"
@property
def ready(self):
if self.api == "tts":
if not TTS:
return False
return True
return (not self.requires_token or self.token) and self.default_voice_id
@property
def status(self):
if not self.enabled:
return "disabled"
if self.ready:
if getattr(self, "processing_bg", 0) > 0:
return "busy_bg" if not getattr(self, "processing", False) else "busy"
return "active" if not getattr(self, "processing", False) else "busy"
if self.requires_token and not self.token:
return "error"
if self.api == "tts":
if not TTS:
return "error"
return "uninitialized"
@property
def max_generation_length(self):
if self.api == "elevenlabs":
return 1024
elif self.api == "coqui":
return 250
return 250
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
async def apply_config(self, *args, **kwargs):
try:
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
except KeyError:
api = self.api
api_changed = api != self.api
# log.debug(
# "apply_config",
# api=api,
# api_changed=api != self.api,
# current_api=self.api,
# args=args,
# kwargs=kwargs,
# )
try:
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
"value"
]
except KeyError:
self.preselect_voice = self.default_voice_id
await super().apply_config(*args, **kwargs)
if api_changed:
try:
self.actions["_config"].config["voice_id"].value = (
self.voices[api].voices[0].value
)
except IndexError:
self.actions["_config"].config["voice_id"].value = ""
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_new_message").connect(
self.on_game_loop_new_message
)
def on_config_saved(self, event):
config = event.data
self.config = config
instance.emit_agent_status(self.__class__, self)
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if not self.enabled or not self.ready:
return
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
return
if (
isinstance(emission.message, NarratorMessage)
and not self.actions["_config"].config["generate_for_narration"].value
):
return
if isinstance(emission.message, CharacterMessage):
if (
emission.message.source == "player"
and not self.actions["_config"].config["generate_for_player"].value
):
return
elif (
emission.message.source == "ai"
and not self.actions["_config"].config["generate_for_npc"].value
):
return
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
else:
character_prefix = ""
log.info(
"reactive tts", message=emission.message, character_prefix=character_prefix
)
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
def voice(self, voice_id: str) -> Union[Voice, None]:
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice
return None
def voice_id_to_label(self, voice_id: str):
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice.label
return None
def list_voices_sync(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.list_voices())
async def list_voices(self):
if self.requires_token and not self.token:
return []
library = self.voices[self.api]
# TODO: allow re-syncing voices
if library.last_synced:
return library.voices
list_fn = getattr(self, f"_list_voices_{self.api}")
log.info("Listing voices", api=self.api)
library.voices = await list_fn()
library.last_synced = time.time()
if self.preselect_voice:
if self.voice(self.preselect_voice):
self.actions["_config"].config["voice_id"].value = self.preselect_voice
self.preselect_voice = None
# if the current voice cannot be found, reset it
if not self.voice(self.default_voice_id):
self.actions["_config"].config["voice_id"].value = ""
# set loading to false
return library.voices
@set_processing
async def generate(self, text: str):
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
generate_fn = getattr(self, f"_generate_{self.api}")
if self.actions["_config"].config["generate_chunks"].value:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks)
else:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks, chunk_size=self.max_generation_length)
# Start generating audio chunks in the background
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
await self.set_background_processing(generation_task)
# Wait for both tasks to complete
# await asyncio.gather(generation_task)
async def generate_chunks(self, generate_fn, chunks):
for chunk in chunks:
chunk = chunk.replace("*", "").strip()
log.info("Generating audio", api=self.api, chunk=chunk)
audio_data = await generate_fn(chunk)
self.play_audio(audio_data)
def play_audio(self, audio_data):
# play audio through the python audio player
# play(audio_data)
emit(
"audio_queue",
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
)
self.playback_done_event.set() # Signal that playback is finished
# LOCAL
async def _generate_tts(self, text: str) -> Union[bytes, None]:
if not TTS:
return
tts_config = self.config.get("tts", {})
model = tts_config.get("model")
device = tts_config.get("device", "cpu")
log.debug("tts local", model=model, device=device)
if not hasattr(self, "tts_instance"):
self.tts_instance = TTS(model).to(device)
tts = self.tts_instance
loop = asyncio.get_event_loop()
voice = self.voice(self.default_voice_id)
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(
None,
functools.partial(
tts.tts_to_file,
text=text,
speaker_wav=voice.value,
language="en",
file_path=file_path,
),
)
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
with open(file_path, "rb") as f:
return f.read()
async def _list_voices_tts(self) -> dict[str, str]:
return [
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
]
# ELEVENLABS
async def _generate_elevenlabs(
self, text: str, chunk_size: int = 1024
) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
async with httpx.AsyncClient() as client:
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.default_voice_id}"
headers = {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": api_key,
}
data = {
"text": text,
"model_id": self.config.get("elevenlabs", {}).get("model"),
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
}
response = await client.post(url, json=data, headers=headers, timeout=300)
if response.status_code == 200:
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()
else:
log.error(f"Error generating audio: {response.text}")
async def _list_voices_elevenlabs(self) -> dict[str, str]:
url_voices = "https://api.elevenlabs.io/v1/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Accept": "application/json",
"xi-api-key": self.token,
}
response = await client.get(
url_voices, headers=headers, params={"per_page": 1000}
)
speakers = response.json()["voices"]
voices.extend(
[
Voice(value=speaker["voice_id"], label=speaker["name"])
for speaker in speakers
]
)
# sort by name
voices.sort(key=lambda x: x.label)
return voices
# OPENAI
async def _generate_openai(self, text: str, chunk_size: int = 1024):
client = AsyncOpenAI(api_key=self.openai_api_key)
model = self.actions["openai"].config["model"].value
response = await client.audio.speech.create(
model=model, voice=self.default_voice_id, input=text
)
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()
async def _list_voices_openai(self) -> dict[str, str]:
return [
Voice(value="alloy", label="Alloy"),
Voice(value="echo", label="Echo"),
Voice(value="fable", label="Fable"),
Voice(value="onyx", label="Onyx"),
Voice(value="nova", label="Nova"),
Voice(value="shimmer", label="Shimmer"),
]

View File

@@ -0,0 +1,995 @@
from __future__ import annotations
import asyncio
import base64
import re
import traceback
from typing import TYPE_CHECKING
import uuid
from collections import deque
import structlog
from nltk.tokenize import sent_tokenize
import talemate.util.dialogue as dialogue_utils
import talemate.emit.async_signals as async_signals
import talemate.instance as instance
from talemate.ux.schema import Note
from talemate.emit import emit
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import (
CharacterMessage,
NarratorMessage,
ContextInvestigationMessage,
)
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
AgentDetail,
AgentActionNote,
set_processing,
)
from talemate.agents.registry import register
from .schema import (
APIStatus,
Voice,
VoiceLibrary,
GenerationContext,
Chunk,
VoiceGenerationEmission,
)
from .providers import provider
import talemate.agents.tts.voice_library as voice_library
from .elevenlabs import ElevenLabsMixin
from .openai import OpenAIMixin
from .google import GoogleMixin
from .kokoro import KokoroMixin
from .chatterbox import ChatterboxMixin
from .websocket_handler import TTSWebsocketHandler
from .f5tts import F5TTSMixin
import talemate.agents.tts.nodes as tts_nodes # noqa: F401
if TYPE_CHECKING:
from talemate.character import Character, VoiceChangedEvent
from talemate.agents.summarize import SummarizeAgent
from talemate.game.engine.nodes.scene import SceneLoopEvent
log = structlog.get_logger("talemate.agents.tts")
HOT_SWAP_NOTIFICATION_TIME = 60
VOICE_LIBRARY_NOTE = "Voices are not managed here, but in the voice library which can be accessed through the Talemate application bar at the top. When disabling/enabling APIS, close and open this window to refresh the choices."
async_signals.register(
"agent.tts.prepare.before",
"agent.tts.prepare.after",
"agent.tts.generate.before",
"agent.tts.generate.after",
)
def parse_chunks(text: str) -> list[str]:
"""
Takes a string and splits it into chunks based on punctuation.
In case of an error it will return the original text as a single chunk and
the error will be logged.
"""
try:
text = text.replace("*", "")
# ensure sentence terminators are before quotes
# otherwise the beginning of dialog will bleed into narration
text = re.sub(r'([^.?!]+) "', r'\1. "', text)
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
if not chunk.strip():
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
except Exception as e:
log.error("chunking error", error=e, text=text)
return [text.replace("__ellipsis__", "...").replace("*", "")]
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(current_chunk)
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(current_chunk)
return joined_chunks
@register()
class TTSAgent(
ElevenLabsMixin,
OpenAIMixin,
GoogleMixin,
KokoroMixin,
ChatterboxMixin,
F5TTSMixin,
Agent,
):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Voice"
requires_llm_client = False
essential = False
# websocket handler for frontend voice library management
websocket_handler = TTSWebsocketHandler
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if not agent:
return config_options
narrator_voice_id = config_options["actions"]["_config"]["config"][
"narrator_voice_id"
]
narrator_voice_id["choices"] = cls.narrator_voice_id_choices(agent)
return config_options
@classmethod
def narrator_voice_id_choices(cls, agent: "TTSAgent") -> list[dict[str, str]]:
choices = voice_library.voices_for_apis(agent.ready_apis, agent.voice_library)
choices.sort(key=lambda x: x.label)
return [
{
"label": f"{voice.label} ({voice.provider})",
"value": voice.id,
}
for voice in choices
]
@classmethod
def init_actions(cls) -> dict[str, AgentAction]:
actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"apis": AgentActionConfig(
type="flags",
value=[
"kokoro",
],
label="Enabled APIs",
description="APIs to use for TTS",
choices=[],
),
"narrator_voice_id": AgentActionConfig(
type="autocomplete",
value="kokoro:am_adam",
label="Narrator Voice",
description="Voice to use for narration",
choices=[],
note=VOICE_LIBRARY_NOTE,
),
"speaker_separation": AgentActionConfig(
type="text",
value="simple",
label="Speaker separation",
description="How to separate speaker dialogue from exposition",
choices=[
{"label": "No separation", "value": "none"},
{"label": "Simple", "value": "simple"},
{"label": "Mixed", "value": "mixed"},
{"label": "AI assisted", "value": "ai_assisted"},
],
note_on_value={
"none": AgentActionNote(
type="primary",
text="Character messages will be voiced entirely by the character's voice with a fallback to the narrator voice if the character has no voice selecte. Narrator messages will be voiced exclusively by the narrator voice.",
),
"simple": AgentActionNote(
type="primary",
text="Exposition and dialogue will be separated in character messages. Narrator messages will be voiced exclusively by the narrator voice. This means",
),
"mixed": AgentActionNote(
type="primary",
text="A mix of `simple` and `ai_assisted`. Character messages are separated into narrator and the character's voice. Narrator messages that have dialogue are analyzed by the Summarizer agent to determine the appropriate speaker(s).",
),
"ai_assisted": AgentActionNote(
type="primary",
text="Appropriate speaker separation will be attempted based on the content of the message with help from the Summarizer agent. This sends an extra prompt to the LLM to determine the appropriate speaker(s).",
),
},
),
"generate_for_player": AgentActionConfig(
type="bool",
value=False,
label="Auto-generate for player",
description="Generate audio for player messages",
),
"generate_for_npc": AgentActionConfig(
type="bool",
value=True,
label="Auto-generate for AI characters",
description="Generate audio for NPC messages",
),
"generate_for_narration": AgentActionConfig(
type="bool",
value=True,
label="Auto-generate for narration",
description="Generate audio for narration messages",
),
"generate_for_context_investigation": AgentActionConfig(
type="bool",
value=True,
label="Auto-generate for context investigation",
description="Generate audio for context investigation messages",
),
},
),
}
KokoroMixin.add_actions(actions)
ChatterboxMixin.add_actions(actions)
GoogleMixin.add_actions(actions)
ElevenLabsMixin.add_actions(actions)
OpenAIMixin.add_actions(actions)
F5TTSMixin.add_actions(actions)
return actions
def __init__(self, **kwargs):
self.is_enabled = False # tts agent is disabled by default
self.actions = TTSAgent.init_actions()
self.playback_done_event = asyncio.Event()
# Queue management for voice generation
# Each queue instance gets a unique id so it can later be referenced
# (e.g. for cancellation of all remaining items).
# Only one queue can be active at a time. New generation requests that
# arrive while a queue is processing will be appended to the same
# queue. Once the queue is fully processed it is discarded and a new
# one will be created for subsequent generation requests.
# Queue now holds individual (context, chunk) pairs so interruption can
# happen between chunks even when a single context produced many.
self._generation_queue: deque[tuple[GenerationContext, Chunk]] = deque()
self._queue_id: str | None = None
self._queue_task: asyncio.Task | None = None
self._queue_lock = asyncio.Lock()
# general helpers
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def voice_library(self) -> VoiceLibrary:
return voice_library.get_instance()
# config helpers
@property
def narrator_voice_id(self) -> str:
return self.actions["_config"].config["narrator_voice_id"].value
@property
def generate_for_player(self) -> bool:
return self.actions["_config"].config["generate_for_player"].value
@property
def generate_for_npc(self) -> bool:
return self.actions["_config"].config["generate_for_npc"].value
@property
def generate_for_narration(self) -> bool:
return self.actions["_config"].config["generate_for_narration"].value
@property
def generate_for_context_investigation(self) -> bool:
return (
self.actions["_config"].config["generate_for_context_investigation"].value
)
@property
def speaker_separation(self) -> str:
return self.actions["_config"].config["speaker_separation"].value
@property
def apis(self) -> list[str]:
return self.actions["_config"].config["apis"].value
@property
def all_apis(self) -> list[str]:
return [api["value"] for api in self.actions["_config"].config["apis"].choices]
@property
def agent_details(self):
details = {}
self.actions["_config"].config[
"narrator_voice_id"
].choices = self.narrator_voice_id_choices(self)
if not self.enabled:
return details
used_apis: set[str] = set()
used_disabled_apis: set[str] = set()
if self.narrator_voice:
#
label = self.narrator_voice.label
color = "primary"
used_apis.add(self.narrator_voice.provider)
if not self.api_enabled(self.narrator_voice.provider):
used_disabled_apis.add(self.narrator_voice.provider)
if not self.api_ready(self.narrator_voice.provider):
color = "error"
details["narrator_voice"] = AgentDetail(
icon="mdi-script-text",
value=label,
description="Default voice",
color=color,
).model_dump()
scene = getattr(self, "scene", None)
if scene:
for character in scene.characters:
if character.voice:
label = character.voice.label
color = "primary"
used_apis.add(character.voice.provider)
if not self.api_enabled(character.voice.provider):
used_disabled_apis.add(character.voice.provider)
if not self.api_ready(character.voice.provider):
color = "error"
details[f"{character.name}_voice"] = AgentDetail(
icon="mdi-account-voice",
value=f"{character.name}",
description=f"{character.name}'s voice: {label} ({character.voice.provider})",
color=color,
).model_dump()
for api in used_disabled_apis:
details[f"{api}_disabled"] = AgentDetail(
icon="mdi-alert-circle",
value=f"{api} disabled",
description=f"{api} disabled - at least one voice is attempting to use this api but is not enabled",
color="error",
).model_dump()
for api in used_apis:
fn = getattr(self, f"{api}_agent_details", None)
if fn:
details.update(fn)
return details
@property
def status(self):
if not self.enabled:
return "disabled"
if self.ready:
if getattr(self, "processing_bg", 0) > 0:
return "busy_bg" if not getattr(self, "processing", False) else "busy"
return "idle" if not getattr(self, "processing", False) else "busy"
return "uninitialized"
@property
def narrator_voice(self) -> Voice | None:
return self.voice_library.get_voice(self.narrator_voice_id)
@property
def api_status(self) -> list[APIStatus]:
api_status: list[APIStatus] = []
for api in self.all_apis:
not_configured_reason = getattr(self, f"{api}_not_configured_reason", None)
not_configured_action = getattr(self, f"{api}_not_configured_action", None)
api_info: str | None = getattr(self, f"{api}_info", None)
messages: list[Note] = []
if not_configured_reason:
messages.append(
Note(
text=not_configured_reason,
color="error",
icon="mdi-alert-circle-outline",
actions=[not_configured_action]
if not_configured_action
else None,
)
)
if api_info:
messages.append(
Note(
text=api_info.strip(),
color="muted",
icon="mdi-information-outline",
)
)
_status = APIStatus(
api=api,
enabled=self.api_enabled(api),
ready=self.api_ready(api),
configured=self.api_configured(api),
messages=messages,
supports_mixing=getattr(self, f"{api}_supports_mixing", False),
provider=provider(api),
default_model=getattr(self, f"{api}_model", None),
model_choices=getattr(self, f"{api}_model_choices", []),
)
api_status.append(_status)
# order by api
api_status.sort(key=lambda x: x.api)
return api_status
# events
def connect(self, scene):
super().connect(scene)
async_signals.get("game_loop_new_message").connect(
self.on_game_loop_new_message
)
async_signals.get("voice_library.update.after").connect(
self.on_voice_library_update
)
async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init)
async_signals.get("character.voice_changed").connect(
self.on_character_voice_changed
)
async def on_scene_loop_init(self, event: "SceneLoopEvent"):
if not self.enabled or not self.ready or not self.generate_for_narration:
return
if self.scene.environment == "creative":
return
content_messages = self.scene.last_message_of_type(
["character", "narrator", "context_investigation"]
)
if content_messages:
# we already have a history, so we don't need to generate TTS for the intro
return
await self.generate(self.scene.get_intro(), character=None)
async def on_voice_library_update(self, voice_library: VoiceLibrary):
log.debug("Voice library updated - refreshing narrator voice choices")
self.actions["_config"].config[
"narrator_voice_id"
].choices = self.narrator_voice_id_choices(self)
await self.emit_status()
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if self.scene.environment == "creative":
return
character: Character | None = None
if not self.enabled or not self.ready:
return
if not isinstance(
emission.message,
(CharacterMessage, NarratorMessage, ContextInvestigationMessage),
):
return
if (
isinstance(emission.message, NarratorMessage)
and not self.generate_for_narration
):
return
if (
isinstance(emission.message, ContextInvestigationMessage)
and not self.generate_for_context_investigation
):
return
if isinstance(emission.message, CharacterMessage):
if emission.message.source == "player" and not self.generate_for_player:
return
elif emission.message.source == "ai" and not self.generate_for_npc:
return
character = self.scene.get_character(emission.message.character_name)
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
text_to_generate = str(emission.message).replace(
character_prefix + ": ", ""
)
elif isinstance(emission.message, ContextInvestigationMessage):
character_prefix = ""
text_to_generate = (
emission.message.message
) # Use just the message content, not the title prefix
else:
character_prefix = ""
text_to_generate = str(emission.message)
log.info(
"reactive tts", message=emission.message, character_prefix=character_prefix
)
await self.generate(
text_to_generate,
character=character,
message=emission.message,
)
async def on_character_voice_changed(self, event: "VoiceChangedEvent"):
log.debug(
"Character voice changed", character=event.character, voice=event.voice
)
await self.emit_status()
# voice helpers
@property
def ready_apis(self) -> list[str]:
"""
Returns a list of apis that are ready
"""
return [api for api in self.apis if self.api_ready(api)]
@property
def used_apis(self) -> list[str]:
"""
Returns a list of apis that are in use
The api is in use if it is the narrator voice or if any of the active characters in the scene use a voice from the api.
"""
return [api for api in self.apis if self.api_used(api)]
def api_enabled(self, api: str) -> bool:
"""
Returns whether the api is currently in the .apis list, which means it is enabled.
"""
return api in self.apis
def api_ready(self, api: str) -> bool:
"""
Returns whether the api is ready.
The api must be enabled and configured.
"""
if not self.api_enabled(api):
return False
return self.api_configured(api)
def api_configured(self, api: str) -> bool:
return getattr(self, f"{api}_configured", True)
def api_used(self, api: str) -> bool:
"""
Returns whether the narrator or any of the active characters in the scene
use a voice from the given api
Args:
api (str): The api to check
Returns:
bool: Whether the api is in use
"""
if self.narrator_voice and self.narrator_voice.provider == api:
return True
if not getattr(self, "scene", None):
return False
for character in self.scene.characters:
if not character.voice:
continue
voice = self.voice_library.get_voice(character.voice.id)
if voice and voice.provider == api:
return True
return False
def use_ai_assisted_speaker_separation(
self,
text: str,
message: CharacterMessage
| NarratorMessage
| ContextInvestigationMessage
| None,
) -> bool:
"""
Returns whether the ai assisted speaker separation should be used for the given text.
"""
try:
if not message and '"' not in text:
return False
if not message and '"' in text:
return self.speaker_separation in ["ai_assisted", "mixed"]
if message.source == "player":
return False
if self.speaker_separation == "ai_assisted":
return True
if (
isinstance(message, NarratorMessage)
and self.speaker_separation == "mixed"
):
return True
return False
except Exception as e:
log.error(
"Error using ai assisted speaker separation",
error=e,
traceback=traceback.format_exc(),
)
return False
# tts markup cache
async def get_tts_markup_cache(self, text: str) -> str | None:
"""
Returns the cached tts markup for the given text.
"""
fp = hash(text)
cached_markup = self.get_scene_state("tts_markup_cache")
if cached_markup and cached_markup.get("fp") == fp:
return cached_markup.get("markup")
return None
async def set_tts_markup_cache(self, text: str, markup: str):
fp = hash(text)
self.set_scene_states(
tts_markup_cache={
"fp": fp,
"markup": markup,
}
)
# generation
@set_processing
async def generate(
self,
text: str,
character: Character | None = None,
force_voice: Voice | None = None,
message: CharacterMessage | NarratorMessage | None = None,
):
"""
Public entry-point for voice generation.
The actual audio generation happens sequentially inside a single
background queue. If a queue is currently active, we simply append the
new request to it; if not, we create a new queue (with its own unique
id) and start processing.
"""
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
summarizer: "SummarizeAgent" = instance.get_agent("summarizer")
context = GenerationContext(voice_id=self.narrator_voice_id)
character_voice: Voice = force_voice or self.narrator_voice
if character and character.voice:
voice = character.voice
if voice and self.api_ready(voice.provider):
character_voice = voice
else:
log.warning(
"Character voice not available",
character=character.name,
voice=character.voice,
)
log.debug("Voice routing", character=character, voice=character_voice)
# initial chunking by separating dialogue from exposition
chunks: list[Chunk] = []
if self.speaker_separation != "none":
if self.use_ai_assisted_speaker_separation(text, message):
markup = await self.get_tts_markup_cache(text)
if not markup:
log.debug("No markup cache found, generating markup")
markup = await summarizer.markup_context_for_tts(text)
await self.set_tts_markup_cache(text, markup)
else:
log.debug("Using markup cache")
# Use the new markup parser for AI-assisted format
dlg_chunks = dialogue_utils.parse_tts_markup(markup)
else:
# Use the original parser for non-AI-assisted format
dlg_chunks = dialogue_utils.separate_dialogue_from_exposition(text)
for _dlg_chunk in dlg_chunks:
_voice = (
character_voice
if _dlg_chunk.type == "dialogue"
else self.narrator_voice
)
if _dlg_chunk.speaker is not None:
# speaker name has been identified
_character = self.scene.get_character(_dlg_chunk.speaker)
log.debug(
"Identified speaker",
speaker=_dlg_chunk.speaker,
character=_character,
)
if (
_character
and _character.voice
and self.api_ready(_character.voice.provider)
):
log.debug(
"Using character voice",
character=_character.name,
voice=_character.voice,
)
_voice = _character.voice
_api: str = _voice.provider if _voice else self.api
chunk = Chunk(
api=_api,
voice=Voice(**_voice.model_dump()),
model=_voice.provider_model,
generate_fn=getattr(self, f"{_api}_generate"),
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
character_name=character.name if character else None,
text=[_dlg_chunk.text],
type=_dlg_chunk.type,
message_id=message.id if message else None,
)
chunks.append(chunk)
else:
_voice = character_voice if character else self.narrator_voice
_api: str = _voice.provider if _voice else self.api
chunks = [
Chunk(
api=_api,
voice=Voice(**_voice.model_dump()),
model=_voice.provider_model,
generate_fn=getattr(self, f"{_api}_generate"),
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
character_name=character.name if character else None,
text=[text],
type="dialogue" if character else "exposition",
message_id=message.id if message else None,
)
]
# second chunking by splitting into chunks of max_generation_length
for chunk in chunks:
api_chunk_size = getattr(self, f"{chunk.api}_chunk_size", 0)
log.debug("chunking", api=chunk.api, api_chunk_size=api_chunk_size)
_text = []
max_generation_length = getattr(self, f"{chunk.api}_max_generation_length")
if api_chunk_size > 0:
max_generation_length = min(max_generation_length, api_chunk_size)
for _chunk_text in chunk.text:
if len(_chunk_text) <= max_generation_length:
_text.append(_chunk_text)
continue
_parsed = parse_chunks(_chunk_text)
_joined = rejoin_chunks(_parsed, chunk_size=max_generation_length)
_text.extend(_joined)
log.debug("chunked for size", before=chunk.text, after=_text)
chunk.text = _text
context.chunks = chunks
# Enqueue each chunk individually for fine-grained interruptibility
async with self._queue_lock:
if self._queue_id is None:
self._queue_id = str(uuid.uuid4())
for chunk in context.chunks:
self._generation_queue.append((context, chunk))
# Start processing task if needed
if self._queue_task is None or self._queue_task.done():
self._queue_task = asyncio.create_task(
self._process_queue(self._queue_id)
)
log.debug(
"tts queue enqueue",
queue_id=self._queue_id,
total_items=len(self._generation_queue),
)
# The caller doesn't need to wait for the queue to finish; it runs in
# the background. We still register the task with Talemate's
# background-processing tracking so that UI can reflect activity.
await self.set_background_processing(self._queue_task)
# ---------------------------------------------------------------------
# Queue helpers
# ---------------------------------------------------------------------
async def _process_queue(self, queue_id: str):
"""Sequentially processes all GenerationContext objects in the queue.
Once the last context has been processed the queue state is reset so a
future generation call will create a new queue (and therefore a new
id). The *queue_id* argument allows us to later add cancellation logic
that can target a specific queue instance.
"""
try:
while True:
async with self._queue_lock:
if not self._generation_queue:
break
context, chunk = self._generation_queue.popleft()
log.debug(
"tts queue dequeue",
queue_id=queue_id,
total_items=len(self._generation_queue),
chunk_type=chunk.type,
)
# Process outside lock so other coroutines can enqueue
await self._generate_chunk(chunk, context)
except Exception as e:
log.error(
"Error processing queue", error=e, traceback=traceback.format_exc()
)
finally:
# Clean up queue state after finishing (or on cancellation)
async with self._queue_lock:
if queue_id == self._queue_id:
self._queue_id = None
self._queue_task = None
self._generation_queue.clear()
# Public helper so external code (e.g. later cancellation UI) can find the current queue id
def current_queue_id(self) -> str | None:
return self._queue_id
async def _generate_chunk(self, chunk: Chunk, context: GenerationContext):
"""Generate audio for a single chunk (all its sub-chunks)."""
for _chunk in chunk.sub_chunks:
if not _chunk.cleaned_text.strip():
continue
emission: VoiceGenerationEmission = VoiceGenerationEmission(
chunk=_chunk, context=context
)
if _chunk.prepare_fn:
await async_signals.get("agent.tts.prepare.before").send(emission)
await _chunk.prepare_fn(_chunk)
await async_signals.get("agent.tts.prepare.after").send(emission)
log.info(
"Generating audio",
api=chunk.api,
text=_chunk.cleaned_text,
parameters=_chunk.voice.parameters,
prepare_fn=_chunk.prepare_fn,
)
await async_signals.get("agent.tts.generate.before").send(emission)
try:
emission.wav_bytes = await _chunk.generate_fn(_chunk, context)
except Exception as e:
log.error("Error generating audio", error=e, chunk=_chunk)
continue
await async_signals.get("agent.tts.generate.after").send(emission)
self.play_audio(emission.wav_bytes, chunk.message_id)
await asyncio.sleep(0.1)
# Deprecated: kept for backward compatibility but no longer used.
async def generate_chunks(self, context: GenerationContext):
for chunk in context.chunks:
await self._generate_chunk(chunk, context)
def play_audio(self, audio_data, message_id: int | None = None):
# play audio through the websocket (browser)
audio_data_encoded: str = base64.b64encode(audio_data).decode("utf-8")
emit(
"audio_queue",
data={"audio_data": audio_data_encoded, "message_id": message_id},
)
self.playback_done_event.set() # Signal that playback is finished
async def stop_and_clear_queue(self):
"""Cancel any ongoing generation and clear the pending queue.
This is triggered by UI actions that request immediate stop of TTS
synthesis and playback. It cancels the background task (if still
running) and clears all queued items in a thread-safe manner.
"""
async with self._queue_lock:
# Clear all queued items
self._generation_queue.clear()
# Cancel the background task if it is still running
if self._queue_task and not self._queue_task.done():
self._queue_task.cancel()
# Reset queue identifiers/state
self._queue_id = None
self._queue_task = None
# Ensure downstream components know playback is finished
self.playback_done_event.set()

View File

@@ -0,0 +1,317 @@
import os
import functools
import tempfile
import uuid
import asyncio
import structlog
import pydantic
import torch
# Lazy imports for heavy dependencies
def _import_heavy_deps():
global ta, ChatterboxTTS
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
CUDA_AVAILABLE = torch.cuda.is_available()
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
AgentDetail,
)
from talemate.ux.schema import Field
from .schema import Voice, Chunk, GenerationContext, VoiceProvider, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
from .providers import register, provider
from .util import voice_is_talemate_asset
log = structlog.get_logger("talemate.agents.tts.chatterbox")
add_default_voices(
[
Voice(
label="Eva",
provider="chatterbox",
provider_id="tts/voice/chatterbox/eva.wav",
tags=["female", "calm", "mature", "thoughtful"],
),
Voice(
label="Lisa",
provider="chatterbox",
provider_id="tts/voice/chatterbox/lisa.wav",
tags=["female", "energetic", "young"],
),
Voice(
label="Adam",
provider="chatterbox",
provider_id="tts/voice/chatterbox/adam.wav",
tags=["male", "calm", "mature", "thoughtful", "deep"],
),
Voice(
label="Bradford",
provider="chatterbox",
provider_id="tts/voice/chatterbox/bradford.wav",
tags=["male", "calm", "mature", "thoughtful", "deep"],
),
Voice(
label="Julia",
provider="chatterbox",
provider_id="tts/voice/chatterbox/julia.wav",
tags=["female", "calm", "mature"],
),
Voice(
label="Zoe",
provider="chatterbox",
provider_id="tts/voice/chatterbox/zoe.wav",
tags=["female"],
),
Voice(
label="William",
provider="chatterbox",
provider_id="tts/voice/chatterbox/william.wav",
tags=["male", "young"],
),
]
)
CHATTERBOX_INFO = """
Chatterbox is a local text to speech model.
The voice id is the path to the .wav file for the voice.
The path can be relative to the talemate root directory, and you can put new *.wav samples
in the `tts/voice/chatterbox` directory. It is also ok if you want to load the files from somewhere else as long as the filepath is available to the talemate backend.
First generation will download the models (2.13GB + 1.06GB).
Uses about 4GB of VRAM.
"""
@register()
class ChatterboxProvider(VoiceProvider):
name: str = "chatterbox"
allow_model_override: bool = False
allow_file_upload: bool = True
upload_file_types: list[str] = ["audio/wav"]
voice_parameters: list[Field] = [
Field(
name="exaggeration",
type="number",
label="Exaggeration level",
value=0.5,
min=0.25,
max=2.0,
step=0.05,
),
Field(
name="cfg_weight",
type="number",
label="CFG/Pace",
value=0.5,
min=0.2,
max=1.0,
step=0.1,
),
Field(
name="temperature",
type="number",
label="Temperature",
value=0.8,
min=0.05,
max=5.0,
step=0.05,
),
]
class ChatterboxInstance(pydantic.BaseModel):
model: "ChatterboxTTS"
device: str
class Config:
arbitrary_types_allowed = True
class ChatterboxMixin:
"""
Chatterbox agent mixin for local text to speech.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "chatterbox",
"label": "Chatterbox (Local)",
"help": "Chatterbox is a local text to speech model.",
}
)
actions["chatterbox"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="Chatterbox",
description="Chatterbox is a local text to speech model.",
config={
"device": AgentActionConfig(
type="text",
value="cuda" if CUDA_AVAILABLE else "cpu",
label="Device",
choices=[
{"value": "cpu", "label": "CPU"},
{"value": "cuda", "label": "CUDA"},
],
description="Device to use for TTS",
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=256,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@property
def chatterbox_configured(self) -> bool:
return True
@property
def chatterbox_max_generation_length(self) -> int:
return 512
@property
def chatterbox_device(self) -> str:
return self.actions["chatterbox"].config["device"].value
@property
def chatterbox_chunk_size(self) -> int:
return self.actions["chatterbox"].config["chunk_size"].value
@property
def chatterbox_info(self) -> str:
return CHATTERBOX_INFO
@property
def chatterbox_agent_details(self) -> dict:
if not self.chatterbox_configured:
return {}
details = {}
details["chatterbox_device"] = AgentDetail(
icon="mdi-memory",
value=f"Chatterbox: {self.chatterbox_device}",
description="The device to use for Chatterbox",
).model_dump()
return details
def chatterbox_delete_voice(self, voice: Voice):
"""
Remove the voice from the file system.
Only do this if the path is within TALEMATE_ROOT.
"""
is_talemate_asset, resolved = voice_is_talemate_asset(
voice, provider(voice.provider)
)
log.debug(
"chatterbox_delete_voice",
voice_id=voice.provider_id,
is_talemate_asset=is_talemate_asset,
resolved=resolved,
)
if not is_talemate_asset:
return
try:
if resolved.exists() and resolved.is_file():
resolved.unlink()
log.debug("Deleted chatterbox voice file", path=str(resolved))
except Exception as e:
log.error(
"Failed to delete chatterbox voice file", error=e, path=str(resolved)
)
def _chatterbox_generate_file(
self,
model: "ChatterboxTTS",
text: str,
audio_prompt_path: str,
output_path: str,
**kwargs,
):
wav = model.generate(text=text, audio_prompt_path=audio_prompt_path, **kwargs)
ta.save(output_path, wav, model.sr)
return output_path
async def chatterbox_generate(
self, chunk: Chunk, context: GenerationContext
) -> bytes | None:
chatterbox_instance: ChatterboxInstance | None = getattr(
self, "chatterbox_instance", None
)
reload: bool = False
if not chatterbox_instance:
reload = True
elif chatterbox_instance.device != self.chatterbox_device:
reload = True
if reload:
log.debug(
"chatterbox - reinitializing tts instance",
device=self.chatterbox_device,
)
# Lazy import heavy dependencies only when needed
_import_heavy_deps()
self.chatterbox_instance = ChatterboxInstance(
model=ChatterboxTTS.from_pretrained(device=self.chatterbox_device),
device=self.chatterbox_device,
)
model: "ChatterboxTTS" = self.chatterbox_instance.model
loop = asyncio.get_event_loop()
voice = chunk.voice
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(
None,
functools.partial(
self._chatterbox_generate_file,
model=model,
text=chunk.cleaned_text,
audio_prompt_path=voice.provider_id,
output_path=file_path,
**voice.parameters,
),
)
with open(file_path, "rb") as f:
return f.read()
async def chatterbox_prepare_chunk(self, chunk: Chunk):
voice = chunk.voice
P = provider(voice.provider)
exaggeration = P.voice_parameter(voice, "exaggeration")
voice.parameters["exaggeration"] = exaggeration

View File

@@ -0,0 +1,248 @@
import io
from typing import Union
import structlog
# Lazy imports for heavy dependencies
def _import_heavy_deps():
global AsyncElevenLabs, ApiError
from elevenlabs.client import AsyncElevenLabs
# Added explicit ApiError import for clearer error handling
from elevenlabs.core.api_error import ApiError
from talemate.ux.schema import Action
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
AgentDetail,
)
from .schema import Voice, VoiceLibrary, GenerationContext, Chunk, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
# emit helper to propagate status messages to the UX
from talemate.emit import emit
log = structlog.get_logger("talemate.agents.tts.elevenlabs")
add_default_voices(
[
Voice(
label="Adam",
provider="elevenlabs",
provider_id="wBXNqKUATyqu0RtYt25i",
tags=["male", "deep"],
),
Voice(
label="Amy",
provider="elevenlabs",
provider_id="oGn4Ha2pe2vSJkmIJgLQ",
tags=["female"],
),
]
)
ELEVENLABS_INFO = """
ElevenLabs is a cloud-based text to speech API.
To add new voices, head to their voice library at [https://elevenlabs.io/app/voice-library](https://elevenlabs.io/app/voice-library) and note the voice id of the voice you want to use. (Click 'More Actions -> Copy Voice ID')
**About elevenlabs voices**
Your elevenlabs subscription allows you to maintain a set number of voices (10 for cheapest plan).
Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button above. It is recommend testing via their voice library instead.
"""
class ElevenLabsMixin:
"""
ElevenLabs TTS agent mixin for cloud-based text to speech.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "elevenlabs",
"label": "ElevenLabs",
"help": "ElevenLabs is a cloud-based text to speech model that uses the ElevenLabs API. (API key required)",
}
)
actions["elevenlabs"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="ElevenLabs",
description="ElevenLabs is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> ElevenLabs)",
config={
"model": AgentActionConfig(
type="text",
value="eleven_flash_v2_5",
label="Model",
description="Model to use for TTS",
choices=[
{
"value": "eleven_multilingual_v2",
"label": "Eleven Multilingual V2",
},
{"value": "eleven_flash_v2_5", "label": "Eleven Flash V2.5"},
{"value": "eleven_turbo_v2_5", "label": "Eleven Turbo V2.5"},
],
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=0,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@classmethod
def add_voices(cls, voices: dict[str, VoiceLibrary]):
voices["elevenlabs"] = VoiceLibrary(api="elevenlabs", local=True)
@property
def elevenlabs_chunk_size(self) -> int:
return self.actions["elevenlabs"].config["chunk_size"].value
@property
def elevenlabs_configured(self) -> bool:
api_key_set = bool(self.elevenlabs_api_key)
model_set = bool(self.elevenlabs_model)
return api_key_set and model_set
@property
def elevenlabs_not_configured_reason(self) -> str | None:
if not self.elevenlabs_api_key:
return "ElevenLabs API key not set"
if not self.elevenlabs_model:
return "ElevenLabs model not set"
return None
@property
def elevenlabs_not_configured_action(self) -> Action | None:
if not self.elevenlabs_api_key:
return Action(
action_name="openAppConfig",
arguments=["application", "elevenlabs_api"],
label="Set API Key",
icon="mdi-key",
)
if not self.elevenlabs_model:
return Action(
action_name="openAgentSettings",
arguments=["tts", "elevenlabs"],
label="Set Model",
icon="mdi-brain",
)
return None
@property
def elevenlabs_max_generation_length(self) -> int:
return 1024
@property
def elevenlabs_model(self) -> str:
return self.actions["elevenlabs"].config["model"].value
@property
def elevenlabs_model_choices(self) -> list[str]:
return [
{"label": choice["label"], "value": choice["value"]}
for choice in self.actions["elevenlabs"].config["model"].choices
]
@property
def elevenlabs_info(self) -> str:
return ELEVENLABS_INFO
@property
def elevenlabs_agent_details(self) -> dict:
details = {}
if not self.elevenlabs_configured:
details["elevenlabs_api_key"] = AgentDetail(
icon="mdi-key",
value="ElevenLabs API key not set",
description="ElevenLabs API key not set. You can set it in the Talemate Settings -> Application -> ElevenLabs",
color="error",
).model_dump()
else:
details["elevenlabs_model"] = AgentDetail(
icon="mdi-brain",
value=self.elevenlabs_model,
description="The model to use for ElevenLabs",
).model_dump()
return details
@property
def elevenlabs_api_key(self) -> str:
return self.config.elevenlabs.api_key
async def elevenlabs_generate(
self, chunk: Chunk, context: GenerationContext, chunk_size: int = 1024
) -> Union[bytes, None]:
api_key = self.elevenlabs_api_key
if not api_key:
return
# Lazy import heavy dependencies only when needed
_import_heavy_deps()
client = AsyncElevenLabs(api_key=api_key)
try:
response_async_iter = client.text_to_speech.convert(
text=chunk.cleaned_text,
voice_id=chunk.voice.provider_id,
model_id=chunk.model or self.elevenlabs_model,
)
bytes_io = io.BytesIO()
async for _chunk_bytes in response_async_iter:
if _chunk_bytes:
bytes_io.write(_chunk_bytes)
return bytes_io.getvalue()
except ApiError as e:
# Emit detailed status message to the frontend UI
error_message = "ElevenLabs API Error"
try:
# The ElevenLabs ApiError often contains a JSON body with details
detail = e.body.get("detail", {}) if hasattr(e, "body") else {}
error_message = detail.get("message", str(e)) or str(e)
except Exception:
error_message = str(e)
log.error("ElevenLabs API error", error=str(e))
emit(
"status",
message=f"ElevenLabs TTS: {error_message}",
status="error",
)
raise e
except Exception as e:
# Catch-all to ensure the app does not crash on unexpected errors
log.error("ElevenLabs TTS generation error", error=str(e))
emit(
"status",
message=f"ElevenLabs TTS: Unexpected error {str(e)}",
status="error",
)
raise e

View File

@@ -0,0 +1,436 @@
import os
import functools
import tempfile
import uuid
import asyncio
import structlog
import pydantic
import re
import torch
# Lazy imports for heavy dependencies
def _import_heavy_deps():
global F5TTS
from f5_tts.api import F5TTS
CUDA_AVAILABLE = torch.cuda.is_available()
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
AgentDetail,
)
from talemate.ux.schema import Field
from .schema import Voice, Chunk, GenerationContext, VoiceProvider, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
from .providers import register, provider
from .util import voice_is_talemate_asset
log = structlog.get_logger("talemate.agents.tts.f5tts")
REF_TEXT = "You awaken aboard your ship, the Starlight Nomad. A soft hum resonates throughout the vessel indicating its systems are online."
add_default_voices(
[
Voice(
label="Adam",
provider="f5tts",
provider_id="tts/voice/f5tts/adam.wav",
tags=["male", "calm", "mature", "deep", "thoughtful"],
parameters={
"speed": 1.05,
"ref_text": REF_TEXT,
},
),
Voice(
label="Bradford",
provider="f5tts",
provider_id="tts/voice/f5tts/bradford.wav",
tags=["male", "calm", "mature"],
parameters={
"speed": 1,
"ref_text": REF_TEXT,
},
),
Voice(
label="Julia",
provider="f5tts",
provider_id="tts/voice/f5tts/julia.wav",
tags=["female", "calm", "mature"],
parameters={
"speed": 1.1,
"ref_text": REF_TEXT,
},
),
Voice(
label="Lisa",
provider="f5tts",
provider_id="tts/voice/f5tts/lisa.wav",
tags=["female", "young", "energetic"],
parameters={
"speed": 1.2,
"ref_text": REF_TEXT,
},
),
Voice(
label="Eva",
provider="f5tts",
provider_id="tts/voice/f5tts/eva.wav",
tags=["female", "mature", "thoughtful"],
parameters={
"speed": 1.15,
"ref_text": REF_TEXT,
},
),
Voice(
label="Zoe",
provider="f5tts",
provider_id="tts/voice/f5tts/zoe.wav",
tags=["female"],
parameters={
"speed": 1.15,
"ref_text": REF_TEXT,
},
),
Voice(
label="William",
provider="f5tts",
provider_id="tts/voice/f5tts/william.wav",
tags=["male", "young"],
parameters={
"speed": 1.15,
"ref_text": REF_TEXT,
},
),
]
)
F5TTS_INFO = """
F5-TTS is a local text-to-speech model.
The voice id is the path to the reference *.wav* file that contains a short
voice sample (≈3-5 s). You can place new samples in the
`tts/voice/f5tts` directory of your Talemate workspace or supply an absolute
path that is accessible to the backend.
The first generation will download the model weights (~1.3 GB) if they are not
cached yet.
"""
@register()
class F5TTSProvider(VoiceProvider):
"""Metadata for the F5-TTS provider."""
name: str = "f5tts"
allow_model_override: bool = False
allow_file_upload: bool = True
upload_file_types: list[str] = ["audio/wav"]
# Provider-specific tunable parameters that can be stored per-voice
voice_parameters: list[Field] = [
Field(
name="speed",
type="number",
label="Speed",
value=1.0,
min=0.25,
max=2.0,
step=0.05,
description="If the speech is too fast or slow, adjust this value. 1.0 is normal speed.",
),
Field(
name="ref_text",
type="text",
label="Reference text",
value="",
description="Text that matches the reference audio sample (improves synthesis quality).",
required=True,
),
Field(
name="cfg_strength",
type="number",
label="CFG Strength",
value=2.0,
min=0.1,
step=0.1,
max=10.0,
description="CFG strength for the model.",
),
]
class F5TTSInstance(pydantic.BaseModel):
"""Holds a single F5-TTS model instance (lazy-initialised)."""
model: "F5TTS" # Forward reference for lazy loading
model_name: str
class Config:
arbitrary_types_allowed = True
class F5TTSMixin:
"""F5-TTS agent mixin for local text-to-speech generation."""
# ---------------------------------------------------------------------
# UI integration / configuration helpers
# ---------------------------------------------------------------------
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
"""Expose the F5-TTS backend in the global TTS agent settings."""
actions["_config"].config["apis"].choices.append(
{
"value": "f5tts",
"label": "F5-TTS (Local)",
"help": "F5-TTS is a local text-to-speech model.",
}
)
actions["f5tts"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="F5-TTS",
description="F5-TTS is a local text-to-speech model.",
config={
"device": AgentActionConfig(
type="text",
value="cuda" if CUDA_AVAILABLE else "cpu",
label="Device",
choices=[
{"value": "cpu", "label": "CPU"},
{"value": "cuda", "label": "CUDA"},
],
description="Device to use for TTS",
),
"model_name": AgentActionConfig(
type="text",
value="F5TTS_v1_Base",
label="Model",
description="Model will be downloaded on first use.",
choices=[
{"value": "E2TTS_Base", "label": "E2TTS_Base"},
{"value": "F5TTS_Base", "label": "F5TTS_Base"},
{"value": "F5TTS_v1_Base", "label": "F5TTS_v1_Base"},
],
),
"nfe_step": AgentActionConfig(
type="number",
label="NFE Step",
value=32,
min=32,
step=16,
max=64,
description="Number of diffusion steps.",
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=32,
max=1024,
value=64,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
"replace_exclamation_marks": AgentActionConfig(
type="bool",
value=True,
label="Replace exclamation marks",
description="Some models tend to over-emphasise exclamation marks, so this is a workaround to make the speech more natural.",
),
},
)
# No additional per-API settings (model/device) required for F5-TTS.
return actions
# ------------------------------------------------------------------
# Convenience properties consumed by the core TTS agent
# ------------------------------------------------------------------
@property
def f5tts_configured(self) -> bool:
# Local backend always available once the model weights are present.
return True
@property
def f5tts_device(self) -> str:
return self.actions["f5tts"].config["device"].value
@property
def f5tts_chunk_size(self) -> int:
return self.actions["f5tts"].config["chunk_size"].value
@property
def f5tts_replace_exclamation_marks(self) -> bool:
return self.actions["f5tts"].config["replace_exclamation_marks"].value
@property
def f5tts_model_name(self) -> str:
return self.actions["f5tts"].config["model_name"].value
@property
def f5tts_nfe_step(self) -> int:
return self.actions["f5tts"].config["nfe_step"].value
@property
def f5tts_max_generation_length(self) -> int:
return 1024
@property
def f5tts_info(self) -> str:
return F5TTS_INFO
@property
def f5tts_agent_details(self) -> dict:
if not self.f5tts_configured:
return {}
details = {}
device = self.f5tts_device
model_name = self.f5tts_model_name
details["f5tts_device"] = AgentDetail(
icon="mdi-memory",
value=f"{model_name}@{device}",
description="The model and device to use for F5-TTS",
).model_dump()
return details
# ------------------------------------------------------------------
# Voice housekeeping helpers
# ------------------------------------------------------------------
def f5tts_delete_voice(self, voice: Voice):
"""Delete *voice* reference file if it is inside the Talemate workspace."""
is_talemate_asset, resolved = voice_is_talemate_asset(
voice, provider(voice.provider)
)
log.debug(
"f5tts_delete_voice",
voice_id=voice.provider_id,
is_talemate_asset=is_talemate_asset,
resolved=resolved,
)
if not is_talemate_asset:
return
try:
if resolved.exists() and resolved.is_file():
resolved.unlink()
log.debug("Deleted F5-TTS voice file", path=str(resolved))
except Exception as e:
log.error("Failed to delete F5-TTS voice file", error=e, path=str(resolved))
# ------------------------------------------------------------------
# Generation helpers
# ------------------------------------------------------------------
def _f5tts_generate_file(
self,
model: "F5TTS",
chunk: Chunk,
voice: Voice,
output_path: str,
) -> str:
"""Blocking generation helper executed in a thread-pool."""
wav, sr, _ = model.infer(
ref_file=voice.provider_id,
ref_text=voice.parameters.get("ref_text", ""),
gen_text=chunk.cleaned_text,
file_wave=output_path,
speed=voice.parameters.get("speed", 1.0),
cfg_strength=voice.parameters.get("cfg_strength", 2.0),
nfe_step=self.f5tts_nfe_step,
)
# Some versions of F5-TTS dont write *file_wave*. Drop-in save as fallback.
# if not os.path.exists(output_path):
# ta.save(output_path, wav, sr)
return output_path
async def f5tts_generate(
self, chunk: Chunk, context: GenerationContext
) -> bytes | None:
"""Asynchronously synthesise *chunk* using F5-TTS."""
# Lazy initialisation & caching across invocations
f5tts_instance: "F5TTSInstance | None" = getattr(self, "f5tts_instance", None)
device = self.f5tts_device
model_name: str = self.f5tts_model_name
reload_model = (
f5tts_instance is None
or f5tts_instance.model.device != device
or f5tts_instance.model_name != model_name
)
if reload_model:
if f5tts_instance is not None:
log.debug(
"Reloading F5-TTS backend", device=device, model_name=model_name
)
else:
log.debug(
"Initialising F5-TTS backend", device=device, model_name=model_name
)
# Lazy import heavy dependencies only when needed
_import_heavy_deps()
f5tts_instance = F5TTSInstance(
model=F5TTS(device=device, model=model_name),
model_name=model_name,
)
self.f5tts_instance = f5tts_instance
model: "F5TTS" = f5tts_instance.model
loop = asyncio.get_event_loop()
voice = chunk.voice
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
# Delegate blocking work to the default ThreadPoolExecutor
await loop.run_in_executor(
None,
functools.partial(
self._f5tts_generate_file, model, chunk, voice, file_path
),
)
# Read the generated WAV and return bytes for websocket playback
with open(file_path, "rb") as f:
return f.read()
async def f5tts_prepare_chunk(self, chunk: Chunk):
text = chunk.text[0]
# f5-tts seems to have issues with ellipses
text = text.replace("", "...").replace("...", ".")
# hyphanated words also seem to be a problem
text = re.sub(r"(\w)-(\w)", r"\1 \2", text)
if self.f5tts_replace_exclamation_marks:
text = text.replace("!", ".")
chunk.text[0] = text
return chunk

View File

@@ -0,0 +1,319 @@
import io
import wave
from typing import Union, Optional
import structlog
from google import genai
from google.genai import types
from talemate.ux.schema import Action
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
AgentDetail,
)
from .schema import Voice, VoiceLibrary, Chunk, GenerationContext, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
log = structlog.get_logger("talemate.agents.tts.google")
GOOGLE_INFO = """
Google Gemini TTS is a cloud-based text to speech model.
A list of available voices can be found at [https://ai.google.dev/gemini-api/docs/speech-generation](https://ai.google.dev/gemini-api/docs/speech-generation).
"""
add_default_voices(
[
Voice(label="Zephyr", provider="google", provider_id="Zephyr", tags=["female"]),
Voice(label="Puck", provider="google", provider_id="Puck", tags=["male"]),
Voice(label="Charon", provider="google", provider_id="Charon", tags=["male"]),
Voice(label="Kore", provider="google", provider_id="Kore", tags=["female"]),
Voice(label="Fenrir", provider="google", provider_id="Fenrir", tags=["male"]),
Voice(label="Leda", provider="google", provider_id="Leda", tags=["female"]),
Voice(label="Orus", provider="google", provider_id="Orus", tags=["male"]),
Voice(label="Aoede", provider="google", provider_id="Aoede", tags=["female"]),
Voice(
label="Callirrhoe",
provider="google",
provider_id="Callirrhoe",
tags=["female"],
),
Voice(
label="Autonoe", provider="google", provider_id="Autonoe", tags=["female"]
),
Voice(
label="Enceladus",
provider="google",
provider_id="Enceladus",
tags=["male", "deep"],
),
Voice(label="Iapetus", provider="google", provider_id="Iapetus", tags=["male"]),
Voice(label="Umbriel", provider="google", provider_id="Umbriel", tags=["male"]),
Voice(
label="Algieba",
provider="google",
provider_id="Algieba",
tags=["male", "deep"],
),
Voice(
label="Despina",
provider="google",
provider_id="Despina",
tags=["female", "young"],
),
Voice(
label="Erinome", provider="google", provider_id="Erinome", tags=["female"]
),
Voice(label="Algenib", provider="google", provider_id="Algenib", tags=["male"]),
Voice(
label="Rasalgethi",
provider="google",
provider_id="Rasalgethi",
tags=["male", "neutral"],
),
Voice(
label="Laomedeia",
provider="google",
provider_id="Laomedeia",
tags=["female"],
),
Voice(
label="Achernar",
provider="google",
provider_id="Achernar",
tags=["female", "young"],
),
Voice(label="Alnilam", provider="google", provider_id="Alnilam", tags=["male"]),
Voice(label="Schedar", provider="google", provider_id="Schedar", tags=["male"]),
Voice(
label="Gacrux",
provider="google",
provider_id="Gacrux",
tags=["female", "mature"],
),
Voice(
label="Pulcherrima",
provider="google",
provider_id="Pulcherrima",
tags=["female", "mature"],
),
Voice(
label="Achird",
provider="google",
provider_id="Achird",
tags=["male", "energetic"],
),
Voice(
label="Zubenelgenubi",
provider="google",
provider_id="Zubenelgenubi",
tags=["male"],
),
Voice(
label="Vindemiatrix",
provider="google",
provider_id="Vindemiatrix",
tags=["female", "mature"],
),
Voice(
label="Sadachbia", provider="google", provider_id="Sadachbia", tags=["male"]
),
Voice(
label="Sadaltager",
provider="google",
provider_id="Sadaltager",
tags=["male"],
),
Voice(
label="Sulafat",
provider="google",
provider_id="Sulafat",
tags=["female", "young"],
),
]
)
class GoogleMixin:
"""Google Gemini TTS mixin (Flash/Pro preview models)."""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "google",
"label": "Google Gemini",
"help": "Google Gemini is a cloud-based text to speech model that uses the Google Gemini API. (API key required)",
}
)
actions["google"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="Google Gemini",
description="Google Gemini is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> Google)",
config={
"model": AgentActionConfig(
type="text",
value="gemini-2.5-flash-preview-tts",
choices=[
{
"value": "gemini-2.5-flash-preview-tts",
"label": "Gemini 2.5 Flash TTS (Preview)",
},
{
"value": "gemini-2.5-pro-preview-tts",
"label": "Gemini 2.5 Pro TTS (Preview)",
},
],
label="Model",
description="Google TTS model to use",
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=0,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@classmethod
def add_voices(cls, voices: dict[str, VoiceLibrary]):
voices["google"] = VoiceLibrary(api="google")
@property
def google_configured(self) -> bool:
return bool(self.google_api_key) and bool(self.google_model)
@property
def google_chunk_size(self) -> int:
return self.actions["google"].config["chunk_size"].value
@property
def google_not_configured_reason(self) -> str | None:
if not self.google_api_key:
return "Google API key not set"
if not self.google_model:
return "Google model not set"
return None
@property
def google_not_configured_action(self) -> Action | None:
if not self.google_api_key:
return Action(
action_name="openAppConfig",
arguments=["application", "google_api"],
label="Set API Key",
icon="mdi-key",
)
if not self.google_model:
return Action(
action_name="openAgentSettings",
arguments=["tts", "google"],
label="Set Model",
icon="mdi-brain",
)
return None
@property
def google_info(self) -> str:
return GOOGLE_INFO
@property
def google_max_generation_length(self) -> int:
return 1024
@property
def google_model(self) -> str:
return self.actions["google"].config["model"].value
@property
def google_model_choices(self) -> list[str]:
return [
{"label": choice["label"], "value": choice["value"]}
for choice in self.actions["google"].config["model"].choices
]
@property
def google_api_key(self) -> Optional[str]:
return self.config.google.api_key
@property
def google_agent_details(self) -> dict:
details = {}
if not self.google_configured:
details["google_api_key"] = AgentDetail(
icon="mdi-key",
value="Google API key not set",
description="Google API key not set. You can set it in the Talemate Settings -> Application -> Google",
color="error",
).model_dump()
else:
details["google_model"] = AgentDetail(
icon="mdi-brain",
value=self.google_model,
description="The model to use for Google",
).model_dump()
return details
def _make_google_client(self) -> genai.Client:
"""Return a fresh genai.Client so updated creds propagate immediately."""
return genai.Client(api_key=self.google_api_key or None)
async def google_generate(
self,
chunk: Chunk,
context: GenerationContext,
chunk_size: int = 1024, # kept for signature parity
) -> Union[bytes, None]:
"""Generate audio and wrap raw PCM into a playable WAV container."""
voice_name = chunk.voice.provider_id
client = self._make_google_client()
try:
response = await client.aio.models.generate_content(
model=chunk.model or self.google_model,
contents=chunk.cleaned_text,
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
),
),
)
# Extract raw 24kHz 16bit PCM (mono) bytes from first candidate
part = response.candidates[0].content.parts[0].inline_data
if not part or not part.data:
return None
pcm_bytes: bytes = part.data
# Wrap into a WAV container that browsers can decode
wav_io = io.BytesIO()
with wave.open(wav_io, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16bit
wf.setframerate(24000) # Hz
wf.writeframes(pcm_bytes)
return wav_io.getvalue()
except Exception as e:
import traceback
traceback.print_exc()
log.error("google_generate failed", error=str(e))
return None

View File

@@ -0,0 +1,324 @@
import os
import functools
import tempfile
import uuid
import asyncio
import structlog
import pydantic
import traceback
from pathlib import Path
import torch
import soundfile as sf
from kokoro import KPipeline
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
)
from .schema import (
Voice,
Chunk,
GenerationContext,
VoiceMixer,
VoiceProvider,
INFO_CHUNK_SIZE,
)
from .providers import register
from .voice_library import add_default_voices
log = structlog.get_logger("talemate.agents.tts.kokoro")
CUSTOM_VOICE_STORAGE = (
Path(__file__).parent.parent.parent.parent.parent / "tts" / "voice" / "kokoro"
)
add_default_voices(
[
Voice(
label="Alloy", provider="kokoro", provider_id="af_alloy", tags=["female"]
),
Voice(
label="Aoede", provider="kokoro", provider_id="af_aoede", tags=["female"]
),
Voice(
label="Bella", provider="kokoro", provider_id="af_bella", tags=["female"]
),
Voice(
label="Heart", provider="kokoro", provider_id="af_heart", tags=["female"]
),
Voice(
label="Jessica",
provider="kokoro",
provider_id="af_jessica",
tags=["female"],
),
Voice(label="Kore", provider="kokoro", provider_id="af_kore", tags=["female"]),
Voice(
label="Nicole", provider="kokoro", provider_id="af_nicole", tags=["female"]
),
Voice(label="Nova", provider="kokoro", provider_id="af_nova", tags=["female"]),
Voice(
label="River", provider="kokoro", provider_id="af_river", tags=["female"]
),
Voice(
label="Sarah", provider="kokoro", provider_id="af_sarah", tags=["female"]
),
Voice(label="Sky", provider="kokoro", provider_id="af_sky", tags=["female"]),
Voice(label="Adam", provider="kokoro", provider_id="am_adam", tags=["male"]),
Voice(label="Echo", provider="kokoro", provider_id="am_echo", tags=["male"]),
Voice(label="Eric", provider="kokoro", provider_id="am_eric", tags=["male"]),
Voice(
label="Fenrir", provider="kokoro", provider_id="am_fenrir", tags=["male"]
),
Voice(label="Liam", provider="kokoro", provider_id="am_liam", tags=["male"]),
Voice(
label="Michael", provider="kokoro", provider_id="am_michael", tags=["male"]
),
Voice(label="Onyx", provider="kokoro", provider_id="am_onyx", tags=["male"]),
Voice(label="Puck", provider="kokoro", provider_id="am_puck", tags=["male"]),
Voice(label="Santa", provider="kokoro", provider_id="am_santa", tags=["male"]),
Voice(
label="Alice", provider="kokoro", provider_id="bf_alice", tags=["female"]
),
Voice(label="Emma", provider="kokoro", provider_id="bf_emma", tags=["female"]),
Voice(
label="Isabella",
provider="kokoro",
provider_id="bf_isabella",
tags=["female"],
),
Voice(label="Lily", provider="kokoro", provider_id="bf_lily", tags=["female"]),
Voice(
label="Daniel", provider="kokoro", provider_id="bm_daniel", tags=["male"]
),
Voice(label="Fable", provider="kokoro", provider_id="bm_fable", tags=["male"]),
Voice(
label="George", provider="kokoro", provider_id="bm_george", tags=["male"]
),
Voice(label="Lewis", provider="kokoro", provider_id="bm_lewis", tags=["male"]),
]
)
KOKORO_INFO = """
Kokoro is a local text to speech model.
**WILL DOWNLOAD**: Voices will be downloaded on first use, so the first generation will take longer to complete.
"""
@register()
class KokoroProvider(VoiceProvider):
name: str = "kokoro"
allow_model_override: bool = False
class KokoroInstance(pydantic.BaseModel):
pipeline: "KPipeline" # Forward reference for lazy loading
class Config:
arbitrary_types_allowed = True
class KokoroMixin:
"""
Kokoro agent mixin for local text to speech.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "kokoro",
"label": "Kokoro (Local)",
"help": "Kokoro is a local text to speech model.",
}
)
actions["kokoro"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="Kokoro",
description="Kokoro is a local text to speech model.",
config={
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=512,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@property
def kokoro_configured(self) -> bool:
return True
@property
def kokoro_chunk_size(self) -> int:
return self.actions["kokoro"].config["chunk_size"].value
@property
def kokoro_max_generation_length(self) -> int:
return 256
@property
def kokoro_agent_details(self) -> dict:
return {}
@property
def kokoro_supports_mixing(self) -> bool:
return True
@property
def kokoro_info(self) -> str:
return KOKORO_INFO
def kokoro_delete_voice(self, voice_id: str) -> None:
"""
If the voice_id is a file in the CUSTOM_VOICE_STORAGE directory, delete it.
"""
# if voice id is a deletable file it'll be a relative or absolute path
# to a file in the CUSTOM_VOICE_STORAGE directory
# we must verify that it is in the CUSTOM_VOICE_STORAGE directory
voice_path = Path(voice_id).resolve()
log.debug(
"Kokoro - Checking if voice id is deletable",
voice_id=voice_id,
exists=voice_path.exists(),
parent=voice_path.parent,
is_custom_voice_storage=voice_path.parent == CUSTOM_VOICE_STORAGE,
)
if voice_path.exists() and voice_path.parent == CUSTOM_VOICE_STORAGE:
log.debug("Kokoro - Deleting voice file", voice_id=voice_id)
try:
voice_path.unlink()
except FileNotFoundError:
pass
def _kokoro_mix(self, mixer: VoiceMixer) -> "torch.Tensor":
pipeline = KPipeline(lang_code="a")
packs = [
{
"voice_tensor": pipeline.load_single_voice(voice.id),
"weight": voice.weight,
}
for voice in mixer.voices
]
mixed_voice = None
for pack in packs:
if mixed_voice is None:
mixed_voice = pack["voice_tensor"] * pack["weight"]
else:
mixed_voice += pack["voice_tensor"] * pack["weight"]
# TODO: ensure weights sum to 1
return mixed_voice
async def kokoro_test_mix(self, mixer: VoiceMixer):
"""Test a mixed voice by generating a sample."""
mixed_voice_tensor = self._kokoro_mix(mixer)
loop = asyncio.get_event_loop()
pipeline = KPipeline(lang_code="a")
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(
None,
functools.partial(
self._kokoro_generate,
pipeline,
"This is a test of the mixed voice.",
mixed_voice_tensor,
file_path,
),
)
# Read and play the audio
with open(file_path, "rb") as f:
audio_data = f.read()
self.play_audio(audio_data)
async def kokoro_save_mix(self, voice_id: str, mixer: VoiceMixer) -> Path:
"""Save a voice tensor to disk."""
# Ensure the directory exists
CUSTOM_VOICE_STORAGE.mkdir(parents=True, exist_ok=True)
save_to_path = CUSTOM_VOICE_STORAGE / f"{voice_id}.pt"
voice_tensor = self._kokoro_mix(mixer)
torch.save(voice_tensor, save_to_path)
return save_to_path
def _kokoro_generate(
self,
pipeline: "KPipeline",
text: str,
voice: "str | torch.Tensor",
file_path: str,
) -> None:
"""Generate audio from text using the given voice."""
try:
generator = pipeline(text, voice=voice)
for i, (gs, ps, audio) in enumerate(generator):
sf.write(file_path, audio, 24000)
except Exception as e:
traceback.print_exc()
raise e
async def kokoro_generate(
self, chunk: Chunk, context: GenerationContext
) -> bytes | None:
kokoro_instance = getattr(self, "kokoro_instance", None)
reload: bool = False
if not kokoro_instance:
reload = True
if reload:
log.debug(
"kokoro - reinitializing tts instance",
)
# Lazy import heavy dependencies only when needed
self.kokoro_instance = KokoroInstance(
# a= American English
# TODO: allow config of language???
pipeline=KPipeline(lang_code="a")
)
pipeline = self.kokoro_instance.pipeline
loop = asyncio.get_event_loop()
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(
None,
functools.partial(
self._kokoro_generate,
pipeline,
chunk.cleaned_text,
chunk.voice.provider_id,
file_path,
),
)
with open(file_path, "rb") as f:
return f.read()

View File

@@ -0,0 +1,165 @@
import structlog
from typing import ClassVar
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
TYPE_CHOICES,
UNRESOLVED,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
from talemate.agents.tts.schema import Voice, VoiceLibrary
TYPE_CHOICES.extend(
[
"tts/voice",
]
)
log = structlog.get_logger("talemate.game.engine.nodes.agents.tts")
@register("agents/tts/Settings")
class TTSAgentSettings(AgentSettingsNode):
"""
Base node to render TTS agent settings.
"""
_agent_name: ClassVar[str] = "tts"
def __init__(self, title="TTS Agent Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/tts/GetVoice")
class GetVoice(AgentNode):
"""
Gets a voice from the TTS agent.
"""
_agent_name: ClassVar[str] = "tts"
class Fields:
voice_id = PropertyField(
name="voice_id",
type="str",
description="The ID of the voice to get",
default=UNRESOLVED,
)
def __init__(self, title="Get Voice", **kwargs):
super().__init__(title=title, **kwargs)
@property
def voice_library(self) -> VoiceLibrary:
return self.agent.voice_library
def setup(self):
self.add_input("voice_id", socket_type="str", optional=True)
self.set_property("voice_id", UNRESOLVED)
self.add_output("voice", socket_type="tts/voice")
async def run(self, state: GraphState):
voice_id = self.require_input("voice_id")
voice = self.voice_library.get_voice(voice_id)
self.set_output_values({"voice": voice})
@register("agents/tts/GetNarratorVoice")
class GetNarratorVoice(AgentNode):
"""
Gets the narrator voice from the TTS agent.
"""
_agent_name: ClassVar[str] = "tts"
def __init__(self, title="Get Narrator Voice", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_output("voice", socket_type="tts/voice")
async def run(self, state: GraphState):
voice = self.agent.narrator_voice
self.set_output_values({"voice": voice})
@register("agents/tts/UnpackVoice")
class UnpackVoice(AgentNode):
"""
Unpacks a voice from the TTS agent.
"""
_agent_name: ClassVar[str] = "tts"
def __init__(self, title="Unpack Voice", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("voice", socket_type="tts/voice")
self.add_output("voice", socket_type="tts/voice")
self.add_output("label", socket_type="str")
self.add_output("provider", socket_type="str")
self.add_output("provider_id", socket_type="str")
self.add_output("provider_model", socket_type="str")
self.add_output("tags", socket_type="list")
self.add_output("parameters", socket_type="dict")
self.add_output("is_scene_asset", socket_type="bool")
async def run(self, state: GraphState):
voice: Voice = self.require_input("voice")
self.set_output_values(
{
"voice": voice,
**voice.model_dump(),
}
)
@register("agents/tts/Generate")
class Generate(AgentNode):
"""
Generates a voice from the TTS agent.
"""
_agent_name: ClassVar[str] = "tts"
class Fields:
text = PropertyField(
name="text",
type="text",
description="The text to generate",
default=UNRESOLVED,
)
def __init__(self, title="Generate TTS", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("text", socket_type="text", optional=True)
self.add_input("voice", socket_type="tts/voice", optional=True)
self.add_input("character", socket_type="character", optional=True)
self.set_property("text", UNRESOLVED)
self.add_output("state")
async def run(self, state: GraphState):
text = self.require_input("text")
voice = self.normalized_input_value("voice")
character = self.normalized_input_value("character")
if not voice and not character:
raise ValueError("Either voice or character must be provided")
await self.agent.generate(
text=text,
character=character,
force_voice=voice,
)
self.set_output_values({"state": state})

View File

@@ -0,0 +1,230 @@
import io
from typing import Union
import structlog
from openai import AsyncOpenAI
from talemate.ux.schema import Action
from talemate.agents.base import AgentAction, AgentActionConfig, AgentDetail
from .schema import Voice, VoiceLibrary, Chunk, GenerationContext, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
log = structlog.get_logger("talemate.agents.tts.openai")
OPENAI_INFO = """
OpenAI TTS is a cloud-based text to speech model.
A list of available voices can be found at [https://platform.openai.com/docs/guides/text-to-speech#voice-options](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
"""
add_default_voices(
[
Voice(
label="Alloy",
provider="openai",
provider_id="alloy",
tags=["neutral", "female"],
),
Voice(
label="Ash",
provider="openai",
provider_id="ash",
tags=["male"],
),
Voice(
label="Ballad",
provider="openai",
provider_id="ballad",
tags=["male", "energetic"],
),
Voice(
label="Coral",
provider="openai",
provider_id="coral",
tags=["female", "energetic"],
),
Voice(
label="Echo",
provider="openai",
provider_id="echo",
tags=["male", "neutral"],
),
Voice(
label="Fable",
provider="openai",
provider_id="fable",
tags=["neutral", "feminine"],
),
Voice(
label="Onyx",
provider="openai",
provider_id="onyx",
tags=["male"],
),
Voice(
label="Nova",
provider="openai",
provider_id="nova",
tags=["female"],
),
Voice(
label="Sage",
provider="openai",
provider_id="sage",
tags=["female"],
),
Voice(
label="Shimmer",
provider="openai",
provider_id="shimmer",
tags=["female"],
),
]
)
class OpenAIMixin:
"""
OpenAI TTS agent mixin for cloud-based text to speech.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "openai",
"label": "OpenAI",
"help": "OpenAI is a cloud-based text to speech model that uses the OpenAI API. (API key required)",
}
)
actions["openai"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="OpenAI",
description="OpenAI TTS is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> OpenAI)",
config={
"model": AgentActionConfig(
type="text",
value="gpt-4o-mini-tts",
choices=[
{"value": "gpt-4o-mini-tts", "label": "GPT-4o Mini TTS"},
{"value": "tts-1", "label": "TTS 1"},
{"value": "tts-1-hd", "label": "TTS 1 HD"},
],
label="Model",
description="TTS model to use",
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=512,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@classmethod
def add_voices(cls, voices: dict[str, VoiceLibrary]):
voices["openai"] = VoiceLibrary(api="openai")
@property
def openai_chunk_size(self) -> int:
return self.actions["openai"].config["chunk_size"].value
@property
def openai_max_generation_length(self) -> int:
return 1024
@property
def openai_model(self) -> str:
return self.actions["openai"].config["model"].value
@property
def openai_model_choices(self) -> list[str]:
return [
{"label": choice["label"], "value": choice["value"]}
for choice in self.actions["openai"].config["model"].choices
]
@property
def openai_api_key(self) -> str:
return self.config.openai.api_key
@property
def openai_configured(self) -> bool:
return bool(self.openai_api_key) and bool(self.openai_model)
@property
def openai_info(self) -> str:
return OPENAI_INFO
@property
def openai_not_configured_reason(self) -> str | None:
if not self.openai_api_key:
return "OpenAI API key not set"
if not self.openai_model:
return "OpenAI model not set"
return None
@property
def openai_not_configured_action(self) -> Action | None:
if not self.openai_api_key:
return Action(
action_name="openAppConfig",
arguments=["application", "openai_api"],
label="Set API Key",
icon="mdi-key",
)
if not self.openai_model:
return Action(
action_name="openAgentSettings",
arguments=["tts", "openai"],
label="Set Model",
icon="mdi-brain",
)
return None
@property
def openai_agent_details(self) -> dict:
details = {}
if not self.openai_configured:
details["openai_api_key"] = AgentDetail(
icon="mdi-key",
value="OpenAI API key not set",
description="OpenAI API key not set. You can set it in the Talemate Settings -> Application -> OpenAI",
color="error",
).model_dump()
else:
details["openai_model"] = AgentDetail(
icon="mdi-brain",
value=self.openai_model,
description="The model to use for OpenAI",
).model_dump()
return details
async def openai_generate(
self, chunk: Chunk, context: GenerationContext, chunk_size: int = 1024
) -> Union[bytes, None]:
client = AsyncOpenAI(api_key=self.openai_api_key)
model = chunk.model or self.openai_model
response = await client.audio.speech.create(
model=model, voice=chunk.voice.provider_id, input=chunk.cleaned_text
)
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()

View File

@@ -0,0 +1,24 @@
from .schema import VoiceProvider
from typing import Generator
__all__ = ["register", "provider", "providers"]
PROVIDERS = {}
class register:
def __call__(self, cls: type[VoiceProvider]):
PROVIDERS[cls().name] = cls
return cls
def provider(name: str) -> VoiceProvider:
cls = PROVIDERS.get(name)
if not cls:
return VoiceProvider(name=name)
return cls()
def providers() -> Generator[VoiceProvider, None, None]:
for cls in PROVIDERS.values():
yield cls()

View File

@@ -0,0 +1,201 @@
import pydantic
from pathlib import Path
import re
from typing import Callable, Literal
from talemate.ux.schema import Note, Field
from talemate.path import TALEMATE_ROOT
__all__ = [
"APIStatus",
"Chunk",
"GenerationContext",
"VoiceProvider",
"Voice",
"VoiceLibrary",
"VoiceWeight",
"VoiceMixer",
"VoiceGenerationEmission",
"INFO_CHUNK_SIZE",
]
MAX_TAG_LENGTH: int = 64 # Maximum number of characters per tag (configurable)
MAX_TAGS_PER_VOICE: int = 10 # Maximum number of tags per voice (configurable)
DEFAULT_VOICE_DIR = TALEMATE_ROOT / "tts" / "voice"
INFO_CHUNK_SIZE = "Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking."
class VoiceProvider(pydantic.BaseModel):
name: str
voice_parameters: list[Field] = pydantic.Field(default_factory=list)
allow_model_override: bool = True
allow_file_upload: bool = False
upload_file_types: list[str] | None = None
@property
def default_parameters(self) -> dict[str, str | float | int | bool]:
return {param.name: param.value for param in self.voice_parameters}
@property
def default_voice_dir(self) -> Path:
return DEFAULT_VOICE_DIR / self.name
def voice_parameter(
self, voice: "Voice", name: str
) -> str | float | int | bool | None:
"""
Get a parameter from the voice.
If the parameter is not set, return the default parameter from the provider.
"""
if name in voice.parameters:
return voice.parameters[name]
return self.default_parameters.get(name)
class VoiceWeight(pydantic.BaseModel):
id: str
weight: float
class VoiceMixer(pydantic.BaseModel):
voices: list[VoiceWeight]
class Voice(pydantic.BaseModel):
# arbitrary voice label to allow a human to easily identify the voice
label: str
# voice provider, this would be the TTS api in the voice
provider: str
# voice id as known to the voice provider
provider_id: str
# allows to also override to a specific model
provider_model: str | None = None
# free-form tags for categorizing the voice (e.g. "male", "energetic")
tags: list[str] = pydantic.Field(default_factory=list)
# provider specific parameters for the voice
parameters: dict[str, str | float | int | bool] = pydantic.Field(
default_factory=dict
)
is_scene_asset: bool = False
@pydantic.field_validator("tags")
@classmethod
def _validate_tags(cls, v: list[str]):
"""Validate tag list length and individual tag length."""
if len(v) > MAX_TAGS_PER_VOICE:
raise ValueError(
f"Too many tags maximum {MAX_TAGS_PER_VOICE} tags are allowed per voice"
)
for tag in v:
if len(tag) > MAX_TAG_LENGTH:
raise ValueError(
f"Tag '{tag}' exceeds maximum length of {MAX_TAG_LENGTH} characters"
)
return v
model_config = pydantic.ConfigDict(validate_assignment=True, exclude_none=True)
@pydantic.computed_field(description="The unique identifier for the voice")
@property
def id(self) -> str:
return f"{self.provider}:{self.provider_id}"
class VoiceLibrary(pydantic.BaseModel):
version: int = 1
voices: dict[str, Voice] = pydantic.Field(default_factory=dict)
def get_voice(self, voice_id: str) -> Voice | None:
return self.voices.get(voice_id)
class Chunk(pydantic.BaseModel):
text: list[str] = pydantic.Field(default_factory=list)
type: Literal["dialogue", "exposition"]
character_name: str | None = None
api: str | None = None
voice: Voice | None = None
model: str | None = None
generate_fn: Callable | None = None
prepare_fn: Callable | None = None
message_id: int | None = None
@property
def cleaned_text(self) -> str:
cleaned: str = self.text[0].replace("*", "").replace('"', "").replace("`", "")
# troublemakers
cleaned = cleaned.replace("", " - ").replace("", "...").replace(";", ",")
# replace any grouped up whitespace with a single space
cleaned = re.sub(r"\s+", " ", cleaned)
# replace full uppercase word with lowercase
# e.g. "HELLO" -> "hello"
cleaned = re.sub(r"[A-Z]{2,}", lambda m: m.group(0).lower(), cleaned)
cleaned = cleaned.strip(",").strip()
# If there is no commong sentence ending punctuation, add a period
if len(cleaned) > 0 and cleaned[-1] not in [".", "!", "?"]:
cleaned += "."
return cleaned.strip().strip(",").strip()
@property
def sub_chunks(self) -> list["Chunk"]:
if len(self.text) == 1:
return [self]
return [
Chunk(
text=[text],
type=self.type,
character_name=self.character_name,
api=self.api,
voice=Voice(**self.voice.model_dump()),
model=self.model,
generate_fn=self.generate_fn,
prepare_fn=self.prepare_fn,
)
for text in self.text
]
class GenerationContext(pydantic.BaseModel):
chunks: list[Chunk] = pydantic.Field(default_factory=list)
class VoiceGenerationEmission(pydantic.BaseModel):
chunk: Chunk
context: GenerationContext
wav_bytes: bytes | None = None
class ModelChoice(pydantic.BaseModel):
label: str
value: str
class APIStatus(pydantic.BaseModel):
"""Status of an API."""
api: str
enabled: bool
ready: bool
configured: bool
provider: VoiceProvider
messages: list[Note] = pydantic.Field(default_factory=list)
supports_mixing: bool = False
default_model: str | None = None
model_choices: list[ModelChoice] = pydantic.Field(default_factory=list)

View File

@@ -0,0 +1,111 @@
from pathlib import Path
from typing import TYPE_CHECKING
import structlog
from .schema import TALEMATE_ROOT, Voice, VoiceProvider
from .voice_library import get_instance
if TYPE_CHECKING:
from talemate.tale_mate import Scene
log = structlog.get_logger("talemate.agents.tts.util")
__all__ = [
"voice_parameter",
"voice_is_talemate_asset",
"voice_is_scene_asset",
"get_voice",
]
def voice_parameter(
voice: Voice, provider: VoiceProvider, name: str
) -> str | float | int | bool | None:
"""
Get a parameter from the voice.
"""
if name in voice.parameters:
return voice.parameters[name]
return provider.default_parameters.get(name)
def voice_is_talemate_asset(
voice: Voice, provider: VoiceProvider
) -> tuple[bool, Path | None]:
"""
Check if the voice is a Talemate asset.
"""
if not provider.allow_file_upload:
return False, None
path = Path(voice.provider_id)
if not path.is_absolute():
path = TALEMATE_ROOT / path
try:
resolved = path.resolve(strict=False)
except Exception as e:
log.error(
"voice_is_talemate_asset - invalid path",
error=e,
voice_id=voice.provider_id,
)
return False, None
root = TALEMATE_ROOT.resolve()
log.debug(
"voice_is_talemate_asset - resolved", resolved=str(resolved), root=str(root)
)
if not str(resolved).startswith(str(root)):
return False, None
return True, resolved
def voice_is_scene_asset(voice: Voice, provider: VoiceProvider) -> bool:
"""
Check if the voice is a scene asset.
Scene assets are stored in the the scene's assets directory.
This function does NOT check .is_scene_asset but does path resolution to
determine if the voice is a scene asset.
"""
is_talemate_asset, resolved = voice_is_talemate_asset(voice, provider)
if not is_talemate_asset:
return False
SCENES_DIR = TALEMATE_ROOT / "scenes"
if str(resolved).startswith(str(SCENES_DIR.resolve())):
return True
return False
def get_voice(scene: "Scene", voice_id: str) -> Voice | None:
"""Return a Voice by *voice_id* preferring the scene's library (if any).
Args:
scene: Scene instance or ``None``.
voice_id: The fully-qualified voice identifier (``provider:provider_id``).
The function first checks *scene.voice_library* (if present) and falls back
to the global voice library instance.
"""
try:
if scene and getattr(scene, "voice_library", None):
voice = scene.voice_library.get_voice(voice_id)
if voice:
return voice
except Exception as e:
log.error("get_voice - scene lookup failed", error=e)
try:
return get_instance().get_voice(voice_id)
except Exception as e:
log.error("get_voice - global lookup failed", error=e)
return None

View File

@@ -0,0 +1,169 @@
import structlog
from pathlib import Path #
import pydantic
import talemate.emit.async_signals as async_signals
from .schema import VoiceLibrary, Voice
from typing import TYPE_CHECKING, Callable, Literal
if TYPE_CHECKING:
from talemate.tale_mate import Scene
__all__ = [
"load_voice_library",
"save_voice_library",
"get_instance",
"add_default_voices",
"DEFAULT_VOICES",
"VOICE_LIBRARY_PATH",
"require_instance",
"load_scene_voice_library",
"save_scene_voice_library",
"scoped_voice_library",
]
log = structlog.get_logger("talemate.agents.tts.voice_library")
async_signals.register(
"voice_library.update.before",
"voice_library.update.after",
)
VOICE_LIBRARY_PATH = (
Path(__file__).parent.parent.parent.parent.parent
/ "tts"
/ "voice"
/ "voice-library.json"
)
DEFAULT_VOICES = {}
# TODO: does this need to be made thread safe?
VOICE_LIBRARY = None
class ScopedVoiceLibrary(pydantic.BaseModel):
voice_library: VoiceLibrary
fn_save: Callable[[VoiceLibrary], None]
async def save(self):
await self.fn_save(self.voice_library)
def scoped_voice_library(
scope: Literal["global", "scene"], scene: "Scene | None" = None
) -> ScopedVoiceLibrary:
if scope == "global":
return ScopedVoiceLibrary(
voice_library=get_instance(), fn_save=save_voice_library
)
else:
if not scene:
raise ValueError("Scene is required for scoped voice library")
async def _save(library: VoiceLibrary):
await save_scene_voice_library(scene, library)
return ScopedVoiceLibrary(voice_library=scene.voice_library, fn_save=_save)
async def require_instance():
global VOICE_LIBRARY
if not VOICE_LIBRARY:
VOICE_LIBRARY = await load_voice_library()
return VOICE_LIBRARY
async def load_voice_library() -> VoiceLibrary:
"""
Load the voice library from the file.
"""
try:
with open(VOICE_LIBRARY_PATH, "r") as f:
return VoiceLibrary.model_validate_json(f.read())
except FileNotFoundError:
library = VoiceLibrary(voices=DEFAULT_VOICES)
await save_voice_library(library)
return library
finally:
log.debug("loaded voice library", path=str(VOICE_LIBRARY_PATH))
async def save_voice_library(voice_library: VoiceLibrary):
"""
Save the voice library to the file.
"""
await async_signals.get("voice_library.update.before").send(voice_library)
with open(VOICE_LIBRARY_PATH, "w") as f:
f.write(voice_library.model_dump_json(indent=2))
await async_signals.get("voice_library.update.after").send(voice_library)
def get_instance() -> VoiceLibrary:
"""
Get the shared voice library instance.
"""
if not VOICE_LIBRARY:
raise RuntimeError("Voice library not loaded yet.")
return VOICE_LIBRARY
def add_default_voices(voices: list[Voice]):
"""
Add default voices to the voice library.
"""
global DEFAULT_VOICES
for voice in voices:
DEFAULT_VOICES[voice.id] = voice
def voices_for_apis(apis: list[str], voice_library: VoiceLibrary) -> list[Voice]:
"""
Get the voices for the given apis.
"""
return [voice for voice in voice_library.voices.values() if voice.provider in apis]
def _scene_library_path(scene: "Scene") -> Path:
"""Return the path to the *scene* voice-library.json file."""
return Path(scene.info_dir) / "voice-library.json"
async def load_scene_voice_library(scene: "Scene") -> VoiceLibrary:
"""Load and return the voice library for *scene*.
If the file does not exist an empty ``VoiceLibrary`` instance is returned.
The returned instance is *not* stored on the scene caller decides.
"""
path = _scene_library_path(scene)
try:
if path.exists():
with open(path, "r") as f:
library = VoiceLibrary.model_validate_json(f.read())
else:
library = VoiceLibrary()
except Exception as e:
log.error("load_scene_voice_library", error=e, path=str(path))
library = VoiceLibrary()
return library
async def save_scene_voice_library(scene: "Scene", library: VoiceLibrary):
"""Persist *library* to the scene's ``voice-library.json``.
The directory ``scene/{name}/info`` is created if necessary.
"""
path = _scene_library_path(scene)
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
f.write(library.model_dump_json(indent=2))
except Exception as e:
log.error("save_scene_voice_library", error=e, path=str(path))

View File

@@ -0,0 +1,674 @@
from __future__ import annotations
import asyncio
from pathlib import Path
import pydantic
import structlog
from typing import TYPE_CHECKING, Literal
import base64
import os
import re
import talemate.emit.async_signals as async_signals
from talemate.instance import get_agent
from talemate.server.websocket_plugin import Plugin
import talemate.scene_message as scene_message
from .voice_library import (
get_instance as get_voice_library,
save_voice_library,
scoped_voice_library,
ScopedVoiceLibrary,
)
from .schema import (
Voice,
GenerationContext,
Chunk,
APIStatus,
VoiceMixer,
VoiceWeight,
TALEMATE_ROOT,
VoiceLibrary,
)
from .util import voice_is_scene_asset
from .providers import provider
if TYPE_CHECKING:
from talemate.agents.tts import TTSAgent
from talemate.tale_mate import Scene
from talemate.character import Character
__all__ = [
"TTSWebsocketHandler",
]
log = structlog.get_logger("talemate.server.voice_library")
class EditVoicePayload(pydantic.BaseModel):
"""Payload for editing an existing voice. Only specified fields are updated."""
voice_id: str
scope: Literal["global", "scene"]
label: str
provider: str
provider_id: str
provider_model: str | None = None
tags: list[str] = pydantic.Field(default_factory=list)
parameters: dict[str, int | float | str | bool] = pydantic.Field(
default_factory=dict
)
class VoiceRefPayload(pydantic.BaseModel):
"""Payload referencing an existing voice by its id (used for remove / test)."""
voice_id: str
scope: Literal["global", "scene"]
class TestVoicePayload(pydantic.BaseModel):
"""Payload for testing a voice."""
provider: str
provider_id: str
provider_model: str | None = None
text: str | None = None
parameters: dict[str, int | float | str | bool] = pydantic.Field(
default_factory=dict
)
class TestCharacterVoicePayload(pydantic.BaseModel):
"""Payload for testing a character voice."""
character_name: str
text: str | None = None
class AddVoicePayload(Voice):
"""Explicit payload for adding a new voice - identical fields to Voice."""
scope: Literal["global", "scene"]
class TestMixedVoicePayload(pydantic.BaseModel):
"""Payload for testing a mixed voice."""
provider: str
voices: list[VoiceWeight]
class SaveMixedVoicePayload(pydantic.BaseModel):
"""Payload for saving a mixed voice."""
provider: str
label: str
voices: list[VoiceWeight]
tags: list[str] = pydantic.Field(default_factory=list)
class UploadVoiceFilePayload(pydantic.BaseModel):
"""Payload for uploading a new voice file for providers that support it."""
provider: str
label: str
content: str # Base64 data URL (e.g. data:audio/wav;base64,AAAB...)
as_scene_asset: bool = False
@pydantic.field_validator("content")
@classmethod
def _validate_content(cls, v: str):
if not v.startswith("data:") or ";base64," not in v:
raise ValueError("Content must be a base64 data URL")
return v
class GenerateForSceneMessagePayload(pydantic.BaseModel):
"""Payload for generating a voice for a scene message."""
message_id: int | Literal["intro"]
class TTSWebsocketHandler(Plugin):
"""Websocket plugin to manage the TTS voice library."""
router = "tts"
def __init__(self, websocket_handler):
super().__init__(websocket_handler)
# Immediately send current voice list to the frontend
asyncio.create_task(self._send_voice_list())
# ---------------------------------------------------------------------
# Events
# ---------------------------------------------------------------------
def connect(self):
# needs to be after config is saved so the TTS agent has already
# refreshed to the latest config
async_signals.get("config.changed.follow").connect(
self.on_app_config_change_followup
)
async def on_app_config_change_followup(self, event):
self._send_api_status()
# ---------------------------------------------------------------------
# Helper methods
# ---------------------------------------------------------------------
async def _send_voice_list(self, select_voice_id: str | None = None):
# global voice library
voice_library = get_voice_library()
voices = [v.model_dump() for v in voice_library.voices.values()]
voices.sort(key=lambda x: x["label"])
# scene voice library
if self.scene:
scene_voice_library = self.scene.voice_library
scene_voices = [v.model_dump() for v in scene_voice_library.voices.values()]
scene_voices.sort(key=lambda x: x["label"])
else:
scene_voices = []
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "voices",
"voices": voices,
"scene_voices": scene_voices,
"select_voice_id": select_voice_id,
}
)
def _voice_exists(self, voice_library: VoiceLibrary, voice_id: str) -> bool:
return voice_id in voice_library.voices
def _broadcast_update(self, select_voice_id: str | None = None):
# After any mutation we broadcast the full list for simplicity
asyncio.create_task(self._send_voice_list(select_voice_id))
def _send_api_status(self):
tts_agent: "TTSAgent" = get_agent("tts")
api_status: list[APIStatus] = tts_agent.api_status
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "api_status",
"api_status": [s.model_dump() for s in api_status],
}
)
# ---------------------------------------------------------------------
# Handlers
# ---------------------------------------------------------------------
async def handle_list(self, data: dict):
await self._send_voice_list()
async def handle_api_status(self, data: dict):
self._send_api_status()
async def handle_add(self, data: dict):
try:
voice = AddVoicePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
if voice.scope == "scene" and not self.scene:
await self.signal_operation_failed("No scene active")
return
scoped: ScopedVoiceLibrary = scoped_voice_library(voice.scope, self.scene)
voice.is_scene_asset = voice.scope == "scene"
if self._voice_exists(scoped.voice_library, voice.id):
await self.signal_operation_failed("Voice already exists")
return
scoped.voice_library.voices[voice.id] = voice
await scoped.save()
self._broadcast_update()
await self.signal_operation_done()
async def handle_remove(self, data: dict):
try:
payload = VoiceRefPayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
tts_agent: "TTSAgent" = get_agent("tts")
scoped: ScopedVoiceLibrary = scoped_voice_library(payload.scope, self.scene)
log.debug("Removing voice", voice_id=payload.voice_id, scope=payload.scope)
try:
voice = scoped.voice_library.voices.pop(payload.voice_id)
except KeyError:
await self.signal_operation_failed("Voice not found (1)")
return
provider = voice.provider
# check if porivder has a delete method
delete_method = getattr(tts_agent, f"{provider}_delete_voice", None)
if delete_method:
delete_method(voice)
await scoped.save()
self._broadcast_update()
await self.signal_operation_done()
async def handle_edit(self, data: dict):
try:
payload = EditVoicePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
scoped: ScopedVoiceLibrary = scoped_voice_library(payload.scope, self.scene)
voice = scoped.voice_library.voices.get(payload.voice_id)
if not voice:
await self.signal_operation_failed("Voice not found")
return
# all fields are always provided
voice.label = payload.label
voice.provider = payload.provider
voice.provider_id = payload.provider_id
voice.provider_model = payload.provider_model
voice.tags = payload.tags
voice.parameters = payload.parameters
voice.is_scene_asset = voice_is_scene_asset(voice, provider(voice.provider))
# If provider or provider_id changed, id changes -> reinsert
new_id = voice.id
if new_id != payload.voice_id:
# Remove old key, insert new
del scoped.voice_library.voices[payload.voice_id]
scoped.voice_library.voices[new_id] = voice
await scoped.save()
self._broadcast_update()
await self.signal_operation_done()
async def handle_test(self, data: dict):
"""Handle a request to test a voice.
Supports two payload formats:
1. Existing voice - identified by ``voice_id`` (legacy behaviour)
2. Unsaved voice - identified by at least ``provider`` and ``provider_id``.
"""
tts_agent: "TTSAgent" = get_agent("tts")
try:
payload = TestVoicePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
voice = Voice(
label=f"{payload.provider_id} (test)",
provider=payload.provider,
provider_id=payload.provider_id,
provider_model=payload.provider_model,
parameters=payload.parameters,
)
if not tts_agent or not tts_agent.api_ready(voice.provider):
await self.signal_operation_failed(f"API '{voice.provider}' not ready")
return
generate_fn = getattr(tts_agent, f"{voice.provider}_generate", None)
if not generate_fn:
await self.signal_operation_failed("Provider not supported by TTS agent")
return
prepare_fn = getattr(tts_agent, f"{voice.provider}_prepare_chunk", None)
# Use provided text or default
test_text = payload.text or "This is a test of the selected voice."
# Build minimal generation context
context = GenerationContext()
chunk = Chunk(
text=[test_text],
type="dialogue",
api=voice.provider,
voice=voice,
model=voice.provider_model,
generate_fn=generate_fn,
prepare_fn=prepare_fn,
character_name=None,
)
context.chunks.append(chunk)
# Run generation in background so we don't block the event loop
async def _run_test():
try:
await tts_agent.generate_chunks(context)
finally:
await self.signal_operation_done(signal_only=True)
asyncio.create_task(_run_test())
async def handle_test_character_voice(self, data: dict):
"""Handle a request to test a character voice."""
try:
payload = TestCharacterVoicePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
character = self.scene.get_character(payload.character_name)
if not character:
await self.signal_operation_failed("Character not found")
return
if not character.voice:
await self.signal_operation_failed("Character has no voice")
return
text: str = payload.text or "This is a test of the selected voice."
await self.handle_test(
{
"provider": character.voice.provider,
"provider_id": character.voice.provider_id,
"provider_model": character.voice.provider_model,
"parameters": character.voice.parameters,
"text": text,
}
)
async def handle_test_mixed(self, data: dict):
"""Handle a request to test a mixed voice."""
tts_agent: "TTSAgent" = get_agent("tts")
try:
payload = TestMixedVoicePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
# Validate that weights sum to 1.0
total_weight = sum(v.weight for v in payload.voices)
if abs(total_weight - 1.0) > 0.001:
await self.signal_operation_failed(
f"Weights must sum to 1.0, got {total_weight}"
)
return
if not tts_agent or not tts_agent.api_ready(payload.provider):
await self.signal_operation_failed(f"{payload.provider} API not ready")
return
# Build mixer
mixer = VoiceMixer(voices=payload.voices)
# Run test in background using the appropriate provider's test method
test_method = getattr(tts_agent, f"{payload.provider}_test_mix", None)
if not test_method:
await self.signal_operation_failed(
f"{payload.provider} does not implement voice mixing"
)
return
async def _run_test():
try:
await test_method(mixer)
finally:
await self.signal_operation_done(signal_only=True)
asyncio.create_task(_run_test())
async def handle_save_mixed(self, data: dict):
"""Handle a request to save a mixed voice."""
tts_agent: "TTSAgent" = get_agent("tts")
try:
payload = SaveMixedVoicePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
# Validate that weights sum to 1.0
total_weight = sum(v.weight for v in payload.voices)
if abs(total_weight - 1.0) > 0.001:
await self.signal_operation_failed(
f"Weights must sum to 1.0, got {total_weight}"
)
return
if not tts_agent or not tts_agent.api_ready(payload.provider):
await self.signal_operation_failed(f"{payload.provider} API not ready")
return
# Build mixer
mixer = VoiceMixer(voices=payload.voices)
# Create a unique voice id for the mixed voice
voice_id = f"{payload.label.lower().replace(' ', '-')}"
# Mix and save the voice using the appropriate provider's methods
save_method = getattr(tts_agent, f"{payload.provider}_save_mix", None)
if not save_method:
await self.signal_operation_failed(
f"{payload.provider} does not implement voice mixing"
)
return
try:
saved_path = await save_method(voice_id, mixer)
# voice id is Path relative to talemate root
voice_id = str(saved_path.relative_to(TALEMATE_ROOT))
# Add the voice to the library
new_voice = Voice(
label=payload.label,
provider=payload.provider,
provider_id=voice_id,
tags=payload.tags,
mix=mixer,
)
voice_library = get_voice_library()
voice_library.voices[new_voice.id] = new_voice
await save_voice_library(voice_library)
self._broadcast_update(new_voice.id)
await self.signal_operation_done()
except Exception as e:
log.error("Failed to save mixed voice", error=e)
await self.signal_operation_failed(f"Failed to save mixed voice: {str(e)}")
async def handle_generate_for_scene_message(self, data: dict):
"""Handle a request to generate a voice for a scene message."""
tts_agent: "TTSAgent" = get_agent("tts")
scene: "Scene" = self.scene
log.debug("Generating TTS for scene message", data=data)
try:
payload = GenerateForSceneMessagePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
log.debug("Payload", payload=payload)
character: "Character | None" = None
text: str = ""
message: scene_message.SceneMessage | None = None
if payload.message_id == "intro":
text = scene.get_intro()
else:
message = scene.get_message(payload.message_id)
if not message:
await self.signal_operation_failed("Message not found")
return
if message.typ not in ["character", "narrator", "context_investigation"]:
await self.signal_operation_failed(
"Message is not a character, narrator, or context investigation message"
)
return
log.debug("Message type", message_type=message.typ)
if isinstance(message, scene_message.CharacterMessage):
character = scene.get_character(message.character_name)
if not character:
await self.signal_operation_failed("Character not found")
return
text = message.without_name
elif isinstance(message, scene_message.ContextInvestigationMessage):
text = message.message
else:
text = message.message
if not text:
await self.signal_operation_failed("No text to generate speech for.")
return
await tts_agent.generate(text, character, message=message)
await self.signal_operation_done()
async def handle_stop_and_clear(self, data: dict):
"""Handle a request from the frontend to stop and clear the current TTS queue."""
tts_agent: "TTSAgent" = get_agent("tts")
if not tts_agent:
await self.signal_operation_failed("TTS agent not available")
return
try:
await tts_agent.stop_and_clear_queue()
await self.signal_operation_done()
except Exception as e:
log.error("Failed to stop and clear TTS queue", error=e)
await self.signal_operation_failed(str(e))
async def handle_upload_voice_file(self, data: dict):
"""Handle uploading a new audio file for a voice.
The *provider* defines which MIME types it accepts via
``VoiceProvider.upload_file_types``. This method therefore:
1. Parses the data-URL to obtain the raw bytes **and** MIME type.
2. Verifies the MIME type against the provider's allowed list
(if the provider restricts uploads).
3. Stores the file under
``tts/voice/<provider>/<slug(label)>.<extension>``
where *extension* is derived from the MIME type (e.g. ``audio/wav`` → ``wav``).
4. Returns the relative path ("provider_id") back to the frontend so
it can populate the voice's ``provider_id`` field.
"""
try:
payload = UploadVoiceFilePayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
# Check provider allows file uploads
from .providers import provider as get_provider
P = get_provider(payload.provider)
if not P.allow_file_upload:
await self.signal_operation_failed(
f"Provider '{payload.provider}' does not support file uploads"
)
return
# Build filename from label
def slugify(text: str) -> str:
text = text.lower().strip()
text = re.sub(r"[^a-z0-9]+", "-", text)
return text.strip("-")
filename_no_ext = slugify(payload.label or "voice") or "voice"
# Determine media type and validate against provider
try:
header, b64data = payload.content.split(",", 1)
media_type = header.split(":", 1)[1].split(";", 1)[0]
except Exception:
await self.signal_operation_failed("Invalid data URL format")
return
if P.upload_file_types and media_type not in P.upload_file_types:
await self.signal_operation_failed(
f"File type '{media_type}' not allowed for provider '{payload.provider}'"
)
return
extension = media_type.split("/")[1]
filename = f"{filename_no_ext}.{extension}"
# Determine target directory and path
if not payload.as_scene_asset:
target_dir = P.default_voice_dir
else:
target_dir = Path(self.scene.assets.asset_directory) / "tts"
os.makedirs(target_dir, exist_ok=True)
target_path = target_dir / filename
log.debug(
"Target path",
target_path=target_path,
as_scene_asset=payload.as_scene_asset,
)
# Decode base64 data URL
try:
file_bytes = base64.b64decode(b64data)
except Exception as e:
await self.signal_operation_failed(f"Invalid base64 data: {e}")
return
try:
with open(target_path, "wb") as f:
f.write(file_bytes)
except Exception as e:
await self.signal_operation_failed(f"Failed to save file: {e}")
return
provider_id = str(target_path.relative_to(TALEMATE_ROOT))
# Send response back to frontend so it can set provider_id
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "voice_file_uploaded",
"provider_id": provider_id,
}
)
await self.signal_operation_done(signal_only=True)

View File

@@ -14,10 +14,9 @@ from talemate.agents.base import (
)
from talemate.agents.registry import register
from talemate.agents.editor.revision import RevisionDisabled
from talemate.agents.summarize.analyze_scene import SceneAnalysisDisabled
from talemate.client.base import ClientBase
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers as signal_handlers
from talemate.prompts.base import Prompt
from .commands import * # noqa
@@ -152,16 +151,13 @@ class VisualBase(Agent):
return actions
def __init__(self, client: ClientBase, *kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.client = client
self.is_enabled = False
self.backend_ready = False
self.initialized = False
self.config = load_config()
self.actions = VisualBase.init_actions()
signal_handlers["config_saved"].connect(self.on_config_saved)
@property
def enabled(self):
return self.is_enabled
@@ -231,6 +227,10 @@ class VisualBase(Agent):
or f"{self.backend_name} is not ready for processing",
).model_dump()
backend_detail_fn = getattr(self, f"{self.backend.lower()}_agent_details", None)
if backend_detail_fn:
details.update(backend_detail_fn())
return details
@property
@@ -241,11 +241,6 @@ class VisualBase(Agent):
def allow_automatic_generation(self):
return self.actions["automatic_generation"].enabled
def on_config_saved(self, event):
config = event.data
self.config = config
asyncio.create_task(self.emit_status())
async def on_ready_check_success(self):
prev_ready = self.backend_ready
self.backend_ready = True
@@ -406,7 +401,11 @@ class VisualBase(Agent):
f"data:image/png;base64,{image}"
)
character.cover_image = asset.id
self.scene.assets.cover_image = asset.id
# Only set scene cover image if scene doesn't already have one
if not self.scene.assets.cover_image:
self.scene.assets.cover_image = asset.id
self.scene.emit_status()
async def emit_image(self, image: str):
@@ -538,7 +537,7 @@ class VisualBase(Agent):
@set_processing
async def generate_environment_prompt(self, instructions: str = None):
with RevisionDisabled():
with RevisionDisabled(), SceneAnalysisDisabled():
response = await Prompt.request(
"visual.generate-environment-prompt",
self.client,
@@ -557,7 +556,7 @@ class VisualBase(Agent):
):
character = self.scene.get_character(character_name)
with RevisionDisabled():
with RevisionDisabled(), SceneAnalysisDisabled():
response = await Prompt.request(
"visual.generate-character-prompt",
self.client,

View File

@@ -10,7 +10,12 @@ import httpx
import pydantic
import structlog
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
from talemate.agents.base import (
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
)
from .handlers import register
from .schema import RenderSettings, Resolution
@@ -164,11 +169,16 @@ class ComfyUIMixin:
label="Checkpoint",
choices=[],
description="The main checkpoint to use.",
note="If the agent is enabled and connected, but the checkpoint list is empty, try closing this window and opening it again.",
),
},
)
}
@property
def comfyui_checkpoint(self):
return self.actions["comfyui"].config["checkpoint"].value
@property
def comfyui_workflow_filename(self):
base_name = self.actions["comfyui"].config["workflow"].value
@@ -219,10 +229,27 @@ class ComfyUIMixin:
async def comfyui_checkpoints(self):
loader_node = (await self.comfyui_object_info)["CheckpointLoaderSimple"]
_checkpoints = loader_node["input"]["required"]["ckpt_name"][0]
log.debug("comfyui_checkpoints", _checkpoints=_checkpoints)
return [
{"label": checkpoint, "value": checkpoint} for checkpoint in _checkpoints
]
def comfyui_agent_details(self):
checkpoint: str = self.comfyui_checkpoint
if not checkpoint:
return {}
# remove .safetensors
checkpoint = checkpoint.replace(".safetensors", "")
return {
"checkpoint": AgentDetail(
icon="mdi-brain",
value=checkpoint,
description="The checkpoint to use for comfyui",
).model_dump()
}
async def comfyui_get_image(self, filename: str, subfolder: str, folder_type: str):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)

View File

@@ -55,7 +55,7 @@ class OpenAIImageMixin:
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
return self.config.openai.api_key
@property
def openai_model_type(self):

View File

@@ -91,7 +91,7 @@ class VisualWebsocketHandler(Plugin):
await visual.generate_character_portrait(
payload.context.character_name,
payload.context.instructions,
replace=True,
replace=payload.context.replace,
prompt_only=payload.context.prompt_only,
)

View File

@@ -13,6 +13,7 @@ import talemate.util as util
from talemate.emit import emit
from talemate.events import GameLoopEvent
from talemate.instance import get_agent
from talemate.client import ClientBase
from talemate.prompts import Prompt
from talemate.scene_message import (
ReinforcementMessage,
@@ -125,7 +126,7 @@ class WorldStateAgent(CharacterProgressionMixin, Agent):
CharacterProgressionMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.client = client
self.is_enabled = True
self.next_update = 0

View File

@@ -1,16 +1,539 @@
from typing import TYPE_CHECKING, Union
import pydantic
import structlog
import random
import re
import traceback
from talemate.instance import get_agent
import talemate.util as util
import talemate.instance as instance
import talemate.scene_message as scene_message
import talemate.agents.base as agent_base
from talemate.agents.tts.schema import Voice
import talemate.emit.async_signals as async_signals
if TYPE_CHECKING:
from talemate.tale_mate import Character, Scene
from talemate.tale_mate import Scene, Actor
__all__ = [
"Character",
"VoiceChangedEvent",
"deactivate_character",
"activate_character",
"set_voice",
]
log = structlog.get_logger("talemate.character")
async_signals.register("character.voice_changed")
class Character(pydantic.BaseModel):
# core character information
name: str
description: str = ""
greeting_text: str = ""
color: str = "#fff"
is_player: bool = False
memory_dirty: bool = False
cover_image: str | None = None
voice: Voice | None = None
# dialogue instructions and examples
dialogue_instructions: str | None = None
example_dialogue: list[str] = pydantic.Field(default_factory=list)
# attribute and detail storage
base_attributes: dict[str, str | int | float | bool] = pydantic.Field(
default_factory=dict
)
details: dict[str, str] = pydantic.Field(default_factory=dict)
# helpful references
agent: agent_base.Agent | None = pydantic.Field(default=None, exclude=True)
actor: "Actor | None" = pydantic.Field(default=None, exclude=True)
class Config:
arbitrary_types_allowed = True
@property
def gender(self) -> str:
return self.base_attributes.get("gender", "")
@property
def sheet(self) -> str:
sheet = self.base_attributes or {
"name": self.name,
"description": self.description,
}
sheet_list = []
for key, value in sheet.items():
sheet_list.append(f"{key}: {value}")
return "\n".join(sheet_list)
@property
def random_dialogue_example(self):
"""
Get a random example dialogue line for this character.
Returns:
str: The random example dialogue line.
"""
if not self.example_dialogue:
return ""
return random.choice(self.example_dialogue)
def __str__(self):
return f"Character: {self.name}"
def __repr__(self):
return str(self)
def __hash__(self):
return hash(self.name)
def set_color(self, color: str | None = None):
# if no color provided, chose a random color
if color is None:
color = util.random_color()
self.color = color
def set_cover_image(self, asset_id: str, initial_only: bool = False):
if self.cover_image and initial_only:
return
self.cover_image = asset_id
def sheet_filtered(self, *exclude):
sheet = self.base_attributes or {
"name": self.name,
"gender": self.gender,
"description": self.description,
}
sheet_list = []
for key, value in sheet.items():
if key not in exclude:
sheet_list.append(f"{key}: {value}")
return "\n".join(sheet_list)
def random_dialogue_examples(
self,
scene: "Scene",
num: int = 3,
strip_name: bool = False,
max_backlog: int = 250,
max_length: int = 192,
history_threshold: int = 15,
) -> list[str]:
"""
Get multiple random example dialogue lines for this character.
Will return up to `num` examples and not have any duplicates.
"""
if len(scene.history) < history_threshold and self.example_dialogue:
# when history is too short, we just use from the prepared
# examples
return self._random_dialogue_examples(num, strip_name)
history_examples = self._random_dialogue_examples_from_history(
scene, num, max_backlog
)
if len(history_examples) < num:
random_examples = self._random_dialogue_examples(
num - len(history_examples), strip_name
)
for example in random_examples:
history_examples.append(example)
# ensure sane example lengths
history_examples = [
util.strip_partial_sentences(example[:max_length])
for example in history_examples
]
log.debug("random_dialogue_examples", history_examples=history_examples)
return history_examples
def _random_dialogue_examples_from_history(
self, scene: "Scene", num: int = 3, max_backlog: int = 250
) -> list[str]:
"""
Get multiple random example dialogue lines for this character from the scene's history.
Will checks the last `max_backlog` messages in the scene's history and returns up to `num` examples.
"""
history = scene.history[-max_backlog:]
examples = []
for message in history:
if not isinstance(message, scene_message.CharacterMessage):
continue
if message.character_name != self.name:
continue
examples.append(message.without_name.strip())
if not examples:
return []
return random.sample(examples, min(num, len(examples)))
def _random_dialogue_examples(
self, num: int = 3, strip_name: bool = False
) -> list[str]:
"""
Get multiple random example dialogue lines for this character.
Will return up to `num` examples and not have any duplicates.
"""
if not self.example_dialogue:
return []
# create copy of example_dialogue so we dont modify the original
examples = self.example_dialogue.copy()
# shuffle the examples so we get a random order
random.shuffle(examples)
# now pop examples until we have `num` examples or we run out of examples
if strip_name:
examples = [example.split(":", 1)[1].strip() for example in examples]
return [examples.pop() for _ in range(min(num, len(examples)))]
def filtered_sheet(self, attributes: list[str]):
"""
Same as sheet but only returns the attributes in the given list
Attributes that dont exist will be ignored
"""
sheet_list = []
for key, value in self.base_attributes.items():
if key.lower() not in attributes:
continue
sheet_list.append(f"{key}: {value}")
return "\n".join(sheet_list)
def rename(self, new_name: str):
"""
Rename the character.
Args:
new_name (str): The new name of the character.
Returns:
None
"""
orig_name = self.name
self.name = new_name
if orig_name.lower() == "you":
# we dont want to replace "you" in the description
# or anywhere else so we can just return here
return
if self.description:
self.description = self.description.replace(f"{orig_name}", self.name)
for k, v in self.base_attributes.items():
if isinstance(v, str):
self.base_attributes[k] = v.replace(f"{orig_name}", self.name)
for i, v in list(self.details.items()):
if isinstance(v, str):
self.details[i] = v.replace(f"{orig_name}", self.name)
self.memory_dirty = True
def introduce_main_character(self, character: "Character"):
"""
Makes this character aware of the main character's name in the scene.
This will replace all occurrences of {{user}} (case-insensitive) in all of the character's properties
with the main character's name.
"""
properties = ["description", "greeting_text"]
pattern = re.compile(re.escape("{{user}}"), re.IGNORECASE)
for prop in properties:
prop_value = getattr(self, prop)
try:
updated_prop_value = pattern.sub(character.name, prop_value)
except Exception as e:
log.error(
"introduce_main_character",
error=e,
traceback=traceback.format_exc(),
)
updated_prop_value = prop_value
setattr(self, prop, updated_prop_value)
# also replace in all example dialogue
for i, dialogue in enumerate(self.example_dialogue):
self.example_dialogue[i] = pattern.sub(character.name, dialogue)
def update(self, **kwargs):
"""
Update character properties with given key-value pairs.
"""
for key, value in kwargs.items():
setattr(self, key, value)
self.memory_dirty = True
async def purge_from_memory(self):
"""
Purges this character's details from memory.
"""
memory_agent = instance.get_agent("memory")
await memory_agent.delete({"character": self.name})
log.info("purged character from memory", character=self.name)
async def commit_to_memory(self, memory_agent):
"""
Commits this character's details to the memory agent. (vectordb)
"""
items = []
if not self.base_attributes or "description" not in self.base_attributes:
if not self.description:
self.description = ""
description_chunks = [
chunk.strip() for chunk in self.description.split("\n") if chunk.strip()
]
for idx in range(len(description_chunks)):
chunk = description_chunks[idx]
items.append(
{
"text": f"{self.name}: {chunk}",
"id": f"{self.name}.description.{idx}",
"meta": {
"character": self.name,
"attr": "description",
"typ": "base_attribute",
},
}
)
seen_attributes = set()
for attr, value in self.base_attributes.items():
if attr.startswith("_"):
continue
if attr.lower() in ["name", "scenario_context", "_prompt", "_template"]:
continue
seen_attributes.add(attr)
items.append(
{
"text": f"{self.name}'s {attr}: {value}",
"id": f"{self.name}.{attr}",
"meta": {
"character": self.name,
"attr": attr,
"typ": "base_attribute",
},
}
)
for key, detail in self.details.items():
# if colliding with attribute name, prefix with detail_
if key in seen_attributes:
key = f"detail_{key}"
items.append(
{
"text": f"{self.name} - {key}: {detail}",
"id": f"{self.name}.{key}",
"meta": {
"character": self.name,
"typ": "details",
"detail": key,
},
}
)
if items:
await memory_agent.add_many(items)
self.memory_dirty = False
async def commit_single_attribute_to_memory(
self, memory_agent, attribute: str, value: str
):
"""
Commits a single attribute to memory
"""
items = []
# remove old attribute if it exists
await memory_agent.delete(
{"character": self.name, "typ": "base_attribute", "attr": attribute}
)
self.base_attributes[attribute] = value
items.append(
{
"text": f"{self.name}'s {attribute}: {self.base_attributes[attribute]}",
"id": f"{self.name}.{attribute}",
"meta": {
"character": self.name,
"attr": attribute,
"typ": "base_attribute",
},
}
)
log.debug("commit_single_attribute_to_memory", items=items)
await memory_agent.add_many(items)
async def commit_single_detail_to_memory(
self, memory_agent, detail: str, value: str
):
"""
Commits a single detail to memory
"""
items = []
# remove old detail if it exists
await memory_agent.delete(
{"character": self.name, "typ": "details", "detail": detail}
)
self.details[detail] = value
items.append(
{
"text": f"{self.name} - {detail}: {value}",
"id": f"{self.name}.{detail}",
"meta": {
"character": self.name,
"typ": "details",
"detail": detail,
},
}
)
log.debug("commit_single_detail_to_memory", items=items)
await memory_agent.add_many(items)
async def set_detail(self, name: str, value):
memory_agent = instance.get_agent("memory")
if not value:
try:
del self.details[name]
await memory_agent.delete(
{"character": self.name, "typ": "details", "detail": name}
)
except KeyError:
pass
else:
self.details[name] = value
await self.commit_single_detail_to_memory(memory_agent, name, value)
def set_detail_defer(self, name: str, value):
self.details[name] = value
self.memory_dirty = True
def get_detail(self, name: str):
return self.details.get(name)
async def set_base_attribute(self, name: str, value):
memory_agent = instance.get_agent("memory")
if not value:
try:
del self.base_attributes[name]
await memory_agent.delete(
{"character": self.name, "typ": "base_attribute", "attr": name}
)
except KeyError:
pass
else:
self.base_attributes[name] = value
await self.commit_single_attribute_to_memory(memory_agent, name, value)
def set_base_attribute_defer(self, name: str, value):
self.base_attributes[name] = value
self.memory_dirty = True
def get_base_attribute(self, name: str):
return self.base_attributes.get(name)
async def set_description(self, description: str):
memory_agent = instance.get_agent("memory")
self.description = description
items = []
await memory_agent.delete(
{"character": self.name, "typ": "base_attribute", "attr": "description"}
)
description_chunks = [
chunk.strip() for chunk in self.description.split("\n") if chunk.strip()
]
for idx in range(len(description_chunks)):
chunk = description_chunks[idx]
items.append(
{
"text": f"{self.name}: {chunk}",
"id": f"{self.name}.description.{idx}",
"meta": {
"character": self.name,
"attr": "description",
"typ": "base_attribute",
},
}
)
await memory_agent.add_many(items)
class VoiceChangedEvent(pydantic.BaseModel):
character: "Character"
voice: Voice | None
auto: bool = False
async def deactivate_character(scene: "Scene", character: Union[str, "Character"]):
"""
@@ -51,9 +574,18 @@ async def activate_character(scene: "Scene", character: Union[str, "Character"])
return False
if not character.is_player:
actor = scene.Actor(character, get_agent("conversation"))
actor = scene.Actor(character, instance.get_agent("conversation"))
else:
actor = scene.Player(character, None)
await scene.add_actor(actor)
del scene.inactive_characters[character.name]
async def set_voice(character: "Character", voice: Voice | None, auto: bool = False):
character.voice = voice
emission: VoiceChangedEvent = VoiceChangedEvent(
character=character, voice=voice, auto=auto
)
await async_signals.get("character.voice_changed").send(emission)
return emission

View File

@@ -9,10 +9,8 @@ from talemate.client.remote import (
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig
from talemate.config import load_config
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"AnthropicClient",
@@ -33,10 +31,13 @@ SUPPORTED_MODELS = [
"claude-opus-4-20250514",
]
DEFAULT_MODEL = "claude-3-5-sonnet-latest"
MIN_THINKING_TOKENS = 1024
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "claude-3-5-sonnet-latest"
model: str = DEFAULT_MODEL
double_coercion: str = None
@@ -52,7 +53,6 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
client_type = "anthropic"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
config_cls = ClientConfig
@@ -66,22 +66,13 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="claude-3-5-sonnet-latest", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
return not self.reason_enabled
@property
def anthropic_api_key(self):
return self.config.get("anthropic", {}).get("api_key")
return self.config.anthropic.api_key
@property
def supported_parameters(self):
@@ -92,17 +83,25 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
"max_tokens",
]
@property
def min_reason_tokens(self) -> int:
return MIN_THINKING_TOKENS
@property
def requires_reasoning_pattern(self) -> bool:
return False
def emit_status(self, processing: bool = None):
error_action = None
error_message: str | None = None
if processing is not None:
self.processing = processing
if self.anthropic_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -115,7 +114,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
@@ -124,73 +123,18 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
"double_coercion": self.double_coercion,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
if (
not self.anthropic_api_key
and not self.endpoint_override_base_url_configured
):
self.client = AsyncAnthropic(api_key="sk-1111")
log.error("No anthropic API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "claude-3-opus-20240229"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"anthropic set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return response.usage.output_tokens
@@ -200,13 +144,6 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
"""
Anthropic handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -218,17 +155,35 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
):
raise Exception("No anthropic API key set")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
if self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
else:
coercion_prompt = None
system_message = self.get_system_message(kind)
messages = [{"role": "user", "content": prompt.strip()}]
if coercion_prompt:
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
if self.reason_enabled:
parameters["thinking"] = {
"type": "enabled",
"budget_tokens": self.validated_reason_tokens,
}
# thinking doesn't support temperature, top_p, or top_k
# and the API will error if they are set
parameters.pop("temperature", None)
parameters.pop("top_p", None)
parameters.pop("top_k", None)
self.log.debug(
"generate",
model=self.model_name,
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
@@ -238,7 +193,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
prompt_tokens = 0
try:
stream = await self.client.messages.create(
stream = await client.messages.create(
model=self.model_name,
system=system_message,
messages=messages,
@@ -247,13 +202,25 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
)
response = ""
reasoning = ""
async for event in stream:
if event.type == "content_block_delta":
if (
event.type == "content_block_delta"
and event.delta.type == "text_delta"
):
content = event.delta.text
response += content
self.update_request_tokens(self.count_tokens(content))
elif (
event.type == "content_block_delta"
and event.delta.type == "thinking_delta"
):
content = event.delta.thinking
reasoning += content
self.update_request_tokens(self.count_tokens(content))
elif event.type == "message_start":
prompt_tokens = event.message.usage.input_tokens
@@ -262,8 +229,9 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
self._reasoning_response = reasoning
log.debug("generated response", response=response)
log.debug("generated response", response=response, reasoning=reasoning)
return response
except PermissionDeniedError as e:

View File

@@ -4,6 +4,7 @@ A unified client base, based on the openai API
import ipaddress
import logging
import re
import random
import time
import traceback
@@ -14,20 +15,27 @@ import pydantic
import dataclasses
import structlog
import urllib3
from openai import AsyncOpenAI, PermissionDeniedError
from openai import PermissionDeniedError
import talemate.client.presets as presets
import talemate.instance as instance
import talemate.util as util
from talemate.agents.context import active_agent
from talemate.client.context import client_context_attribute
from talemate.client.model_prompts import model_prompt
from talemate.client.model_prompts import model_prompt, DEFAULT_TEMPLATE
from talemate.client.ratelimit import CounterRateLimiter
from talemate.context import active_scene
from talemate.prompts.base import Prompt
from talemate.emit import emit
from talemate.config import load_config, save_config, EmbeddingFunctionPreset
from talemate.config import get_config, Config
from talemate.config.schema import EmbeddingFunctionPreset, Client as ClientConfig
import talemate.emit.async_signals as async_signals
from talemate.exceptions import SceneInactiveError, GenerationCancelled
from talemate.exceptions import (
SceneInactiveError,
GenerationCancelled,
GenerationProcessingError,
ReasoningResponseError,
)
import talemate.ux.schema as ux_schema
from talemate.client.system_prompts import SystemPrompts
@@ -43,6 +51,11 @@ STOPPING_STRINGS = ["<|im_end|>", "</s>"]
REPLACE_SMART_QUOTES = True
INDIRECT_COERCION_PROMPT = "\nStart your response with: "
DEFAULT_REASONING_PATTERN = r".*?</think>"
class ClientDisabledError(OSError):
def __init__(self, client: "ClientBase"):
self.client = client
@@ -63,6 +76,7 @@ class PromptData(pydantic.BaseModel):
generation_parameters: dict = pydantic.Field(default_factory=dict)
inference_preset: str = None
preset_group: str | None = None
reasoning: str | None = None
class ErrorAction(pydantic.BaseModel):
@@ -76,6 +90,9 @@ class CommonDefaults(pydantic.BaseModel):
rate_limit: int | None = None
data_format: Literal["yaml", "json"] | None = None
preset_group: str | None = None
reason_enabled: bool = False
reason_tokens: int = 0
reason_response_pattern: str | None = None
class Defaults(CommonDefaults, pydantic.BaseModel):
@@ -99,6 +116,7 @@ class ExtraField(pydantic.BaseModel):
description: str
group: FieldGroup | None = None
note: ux_schema.Note | None = None
choices: list[str | int | float | bool] | None = None
class ParameterReroute(pydantic.BaseModel):
@@ -162,39 +180,36 @@ class RequestInformation(pydantic.BaseModel):
class ClientEmbeddingsStatus:
client: "ClientBase | None" = None
embedding_name: str | None = None
seen: bool = False
@dataclasses.dataclass
class ClientStatus:
client: "ClientBase | None" = None
enabled: bool = False
async_signals.register(
"client.embeddings_available",
"client.enabled",
"client.disabled",
)
class ClientBase:
api_url: str
model_name: str
api_key: str = None
name: str = None
enabled: bool = True
name: str
remote_model_name: str | None = None
remote_model_locked: bool = False
current_status: str = None
max_token_length: int = 8192
processing: bool = False
connected: bool = False
conversation_retries: int = 0
auto_break_repetition_enabled: bool = True
decensor_enabled: bool = True
auto_determine_prompt_template: bool = False
finalizers: list[str] = []
double_coercion: Union[str, None] = None
data_format: Literal["yaml", "json"] | None = None
rate_limit: int | None = None
client_type = "base"
request_information: RequestInformation | None = None
status_request_timeout: int = 2
system_prompts: SystemPrompts = SystemPrompts()
preset_group: str | None = ""
rate_limit_counter: CounterRateLimiter = None
class Meta(pydantic.BaseModel):
@@ -207,27 +222,92 @@ class ClientBase:
def __init__(
self,
api_url: str = None,
name: str = None,
**kwargs,
):
self.api_url = api_url
self.name = name or self.client_type
self.remote_model_name = None
self.auto_determine_prompt_template_attempt = None
self.log = structlog.get_logger(f"client.{self.client_type}")
self.double_coercion = kwargs.get("double_coercion", None)
self._reconfigure_common_parameters(**kwargs)
self.enabled = kwargs.get("enabled", True)
if "max_token_length" in kwargs:
self.max_token_length = (
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
)
self.set_client(max_token_length=self.max_token_length)
def __str__(self):
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
#####
# config getters
@property
def config(self) -> Config:
return get_config()
@property
def client_config(self) -> ClientConfig:
try:
return get_config().clients[self.name]
except KeyError:
return ClientConfig(type=self.client_type, name=self.name)
@property
def model(self) -> str | None:
return self.client_config.model
@property
def model_name(self) -> str | None:
if self.remote_model_locked:
return self.remote_model_name
return self.remote_model_name or self.model
@property
def api_key(self) -> str | None:
return self.client_config.api_key
@property
def api_url(self) -> str | None:
return self.client_config.api_url
@property
def max_token_length(self) -> int:
return self.client_config.max_token_length
@property
def double_coercion(self) -> str | None:
return self.client_config.double_coercion
@property
def rate_limit(self) -> int | None:
return self.client_config.rate_limit
@property
def data_format(self) -> Literal["yaml", "json"]:
return self.client_config.data_format
@property
def enabled(self) -> bool:
return self.client_config.enabled
@property
def system_prompts(self) -> SystemPrompts:
return self.client_config.system_prompts
@property
def preset_group(self) -> str | None:
return self.client_config.preset_group
@property
def reason_enabled(self) -> bool:
return self.client_config.reason_enabled
@property
def reason_tokens(self) -> int:
return self.client_config.reason_tokens
@property
def reason_response_pattern(self) -> str:
return self.client_config.reason_response_pattern or DEFAULT_REASONING_PATTERN
#####
@property
def experimental(self):
return False
@@ -238,6 +318,9 @@ class ClientBase:
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
if self.reason_enabled:
# We are not able to coerce via pre-filling if reasoning is enabled
return False
return self.Meta().requires_prompt_template
@property
@@ -283,7 +366,49 @@ class ClientBase:
def embeddings_identifier(self) -> str:
return f"client-api/{self.name}/{self.embeddings_model_name}"
async def destroy(self, config: dict):
@property
def reasoning_response(self) -> str | None:
return getattr(self, "_reasoning_response", None)
@property
def min_reason_tokens(self) -> int:
return 0
@property
def validated_reason_tokens(self) -> int:
return max(self.reason_tokens, self.min_reason_tokens)
@property
def default_prompt_template(self) -> str:
return DEFAULT_TEMPLATE
@property
def requires_reasoning_pattern(self) -> bool:
return True
async def enable(self):
self.client_config.enabled = True
self.emit_status()
await self.config.set_dirty()
await self.status()
await async_signals.get("client.enabled").send(
ClientStatus(client=self, enabled=True)
)
async def disable(self):
self.client_config.enabled = False
self.emit_status()
if self.supports_embeddings:
await self.reset_embeddings()
await self.config.set_dirty()
await self.status()
await async_signals.get("client.disabled").send(
ClientStatus(client=self, enabled=False)
)
async def destroy(self):
"""
This is called before the client is removed from talemate.instance.clients
@@ -294,16 +419,13 @@ class ClientBase:
"""
if self.supports_embeddings:
self.remove_embeddings(config)
await self.remove_embeddings()
def reset_embeddings(self):
async def reset_embeddings(self):
self._embeddings_model_name = None
self._embeddings_status = False
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def set_embeddings(self):
async def set_embeddings(self):
log.debug(
"setting embeddings",
client=self.name,
@@ -314,7 +436,7 @@ class ClientBase:
if not self.supports_embeddings or not self.embeddings_status:
return
config = load_config(as_model=True)
config: Config = get_config()
key = self.embeddings_identifier
@@ -334,30 +456,25 @@ class ClientBase:
custom=True,
)
save_config(config)
await config.set_dirty()
def remove_embeddings(self, config: dict | None = None):
async def remove_embeddings(self):
# remove all embeddings for this client
for key, value in list(config["presets"]["embeddings"].items()):
if value["client"] == self.name and value["embeddings"] == "client-api":
config: Config = get_config()
for key, value in list(config.presets.embeddings.items()):
if value.client == self.name and value.embeddings == "client-api":
log.warning("!!! removing embeddings", client=self.name, key=key)
config["presets"]["embeddings"].pop(key)
def set_system_prompts(self, system_prompts: dict | SystemPrompts):
if isinstance(system_prompts, dict):
self.system_prompts = SystemPrompts(**system_prompts)
elif not isinstance(system_prompts, SystemPrompts):
raise ValueError(
"system_prompts must be a `dict` or `SystemPrompts` instance"
)
else:
self.system_prompts = system_prompts
config.presets.embeddings.pop(key)
await config.set_dirty()
def prompt_template(self, sys_msg: str, prompt: str):
"""
Applies the appropriate prompt template for the model.
"""
if not self.Meta().requires_prompt_template:
return prompt
if not self.model_name:
self.log.warning("prompt template not applied", reason="no model loaded")
return f"{sys_msg}\n{prompt}"
@@ -372,13 +489,22 @@ class ClientBase:
else:
double_coercion = None
return model_prompt(self.model_name, sys_msg, prompt, double_coercion)[0]
return model_prompt(
self.model_name,
sys_msg,
prompt,
double_coercion,
default_template=self.default_prompt_template,
)[0]
def prompt_template_example(self):
if not getattr(self, "model_name", None):
return None, None
return model_prompt(
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
self.model_name,
"{sysmsg}",
"{prompt}<|BOT|>{LLM coercion}",
default_template=self.default_prompt_template,
)
def split_prompt_for_coercion(self, prompt: str) -> tuple[str, str]:
@@ -386,59 +512,29 @@ class ClientBase:
Splits the prompt and the prefill/coercion prompt.
"""
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
prompt, coercion = prompt.split("<|BOT|>", 1)
if self.double_coercion:
right = f"{self.double_coercion}\n\n{right}"
coercion = f"{self.double_coercion}\n\n{coercion}"
return prompt, right
return prompt, coercion
return prompt, None
def reconfigure(self, **kwargs):
def rate_limit_update(self):
"""
Reconfigures the client.
Updates the rate limit counter for the client.
Keyword Arguments:
- api_url: the API URL to use
- max_token_length: the max token length to use
- enabled: whether the client is enabled
If the rate limit is set to 0, the rate limit counter is set to None.
"""
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if kwargs.get("max_token_length"):
self.max_token_length = int(kwargs["max_token_length"])
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if not self.enabled and self.supports_embeddings and self.embeddings_status:
self.reset_embeddings()
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
def _reconfigure_common_parameters(self, **kwargs):
if "rate_limit" in kwargs:
self.rate_limit = kwargs["rate_limit"]
if self.rate_limit:
if not self.rate_limit_counter:
self.rate_limit_counter = CounterRateLimiter(
rate_per_minute=self.rate_limit
)
else:
self.rate_limit_counter.update_rate_limit(self.rate_limit)
if self.rate_limit:
if not self.rate_limit_counter:
self.rate_limit_counter = CounterRateLimiter(
rate_per_minute=self.rate_limit
)
else:
self.rate_limit_counter = None
if "data_format" in kwargs:
self.data_format = kwargs["data_format"]
if "preset_group" in kwargs:
self.preset_group = kwargs["preset_group"]
self.rate_limit_counter.update_rate_limit(self.rate_limit)
else:
self.rate_limit_counter = None
def host_is_remote(self, url: str) -> bool:
"""
@@ -491,43 +587,40 @@ class ClientBase:
- kind: the kind of generation
"""
app_config_system_prompts = client_context_attribute(
"app_config_system_prompts"
)
if app_config_system_prompts:
self.system_prompts.parent = SystemPrompts(**app_config_system_prompts)
return self.system_prompts.get(kind, self.decensor_enabled)
config: Config = get_config()
self.system_prompts.parent = config.system_prompts
sys_prompt = self.system_prompts.get(kind, self.decensor_enabled)
return sys_prompt
def emit_status(self, processing: bool = None):
"""
Sets and emits the client status.
"""
error_message: str | None = None
if processing is not None:
self.processing = processing
if not self.enabled:
status = "disabled"
model_name = "Disabled"
error_message = "Disabled"
elif not self.connected:
status = "error"
model_name = "Could not connect"
error_message = "Could not connect"
elif self.model_name:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
model_name = "No model loaded"
error_message = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
default_prompt_template = self.default_prompt_template
prompt_template_example, prompt_template_file = self.prompt_template_example()
has_prompt_template = (
prompt_template_file and prompt_template_file != "default.jinja2"
prompt_template_file and prompt_template_file != default_prompt_template
)
if not has_prompt_template and self.auto_determine_prompt_template:
@@ -545,21 +638,28 @@ class ClientBase:
self.prompt_template_example()
)
has_prompt_template = (
prompt_template_file and prompt_template_file != "default.jinja2"
prompt_template_file
and prompt_template_file != default_prompt_template
)
dedicated_default_template = default_prompt_template != DEFAULT_TEMPLATE
data = {
"api_key": self.api_key,
"prompt_template_example": prompt_template_example,
"has_prompt_template": has_prompt_template,
"dedicated_default_template": dedicated_default_template,
"template_file": prompt_template_file,
"meta": self.Meta().model_dump(),
"error_action": None,
"double_coercion": self.double_coercion,
"enabled": self.enabled,
"system_prompts": self.system_prompts.model_dump(),
"error_message": error_message,
}
if self.Meta().enable_api_auth:
data["api_key"] = self.api_key
data.update(self._common_status_data())
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
@@ -571,7 +671,7 @@ class ClientBase:
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status,
data=data,
)
@@ -595,6 +695,11 @@ class ClientBase:
"supports_embeddings": self.supports_embeddings,
"embeddings_status": self.embeddings_status,
"embeddings_model_name": self.embeddings_model_name,
"reason_enabled": self.reason_enabled,
"reason_tokens": self.reason_tokens,
"min_reason_tokens": self.min_reason_tokens,
"reason_response_pattern": self.reason_response_pattern,
"requires_reasoning_pattern": self.requires_reasoning_pattern,
"request_information": self.request_information.model_dump()
if self.request_information
else None,
@@ -646,20 +751,16 @@ class ClientBase:
return
try:
self.model_name = await self.get_model_name()
self.remote_model_name = await self.get_model_name()
except Exception as e:
self.log.warning("client status error", e=e, client=self.name)
self.model_name = None
self.remote_model_name = None
self.connected = False
self.emit_status()
return
self.connected = True
if not self.model_name or self.model_name == "None":
self.emit_status()
return
self.emit_status()
def generate_prompt_parameters(self, kind: str):
@@ -682,6 +783,15 @@ class ClientBase:
parameters, kind, agent_context.action
)
if self.reason_enabled and self.reason_tokens > 0:
log.debug(
"padding for reasoning",
client=self.client_type,
reason_tokens=self.reason_tokens,
validated_reason_tokens=self.validated_reason_tokens,
)
parameters["max_tokens"] += self.validated_reason_tokens
if client_context_attribute(
"nuke_repetition"
) > 0.0 and self.jiggle_enabled_for(kind):
@@ -838,12 +948,89 @@ class ClientBase:
else:
self.request_information.tokens += tokens
def strip_coercion_prompt(self, response: str, coercion_prompt: str = None) -> str:
"""
Strips the coercion prompt from the response if it is present.
"""
if not coercion_prompt or not response.startswith(coercion_prompt):
return response
return response.replace(coercion_prompt, "").lstrip()
def strip_reasoning(self, response: str) -> tuple[str, str]:
"""
Strips the reasoning from the response if the model is reasoning.
"""
if not self.reason_enabled:
return response, None
if not self.requires_reasoning_pattern:
# reasoning handled automatically during streaming
return response, None
pattern = self.reason_response_pattern
if not pattern:
pattern = DEFAULT_REASONING_PATTERN
log.debug("reasoning pattern", pattern=pattern)
extract_reason = re.search(pattern, response, re.DOTALL)
if extract_reason:
reasoning_response = extract_reason.group(0)
return response.replace(reasoning_response, ""), reasoning_response
raise ReasoningResponseError()
def attach_response_length_instruction(
self, prompt: str, response_length: int | None
) -> str:
"""
Attaches the response length instruction to the prompt.
"""
if not response_length or response_length < 0:
log.warning("response length instruction", response_length=response_length)
return prompt
instructions_prompt = Prompt.get(
"common.response-length",
vars={
"response_length": response_length,
"attach_response_length_instruction": True,
},
)
instructions_prompt = instructions_prompt.render()
if instructions_prompt.strip() in prompt:
log.debug(
"response length instruction already in prompt",
instructions_prompt=instructions_prompt,
)
return prompt
log.debug(
"response length instruction", instructions_prompt=instructions_prompt
)
if "<|RESPONSE_LENGTH_INSTRUCTIONS|>" in prompt:
return prompt.replace(
"<|RESPONSE_LENGTH_INSTRUCTIONS|>", instructions_prompt
)
elif "<|BOT|>" in prompt:
return prompt.replace("<|BOT|>", f"{instructions_prompt}<|BOT|>")
else:
return f"{prompt}{instructions_prompt}"
async def send_prompt(
self,
prompt: str,
kind: str = "conversation",
finalize: Callable = lambda x: x,
retries: int = 2,
data_expected: bool | None = None,
) -> str:
"""
Send a prompt to the AI and return its response.
@@ -852,7 +1039,9 @@ class ClientBase:
"""
try:
return await self._send_prompt(prompt, kind, finalize, retries)
return await self._send_prompt(
prompt, kind, finalize, retries, data_expected
)
except GenerationCancelled:
await self.abort_generation()
raise
@@ -863,6 +1052,7 @@ class ClientBase:
kind: str = "conversation",
finalize: Callable = lambda x: x,
retries: int = 2,
data_expected: bool | None = None,
) -> str:
"""
Send a prompt to the AI and return its response.
@@ -871,6 +1061,7 @@ class ClientBase:
"""
try:
self.rate_limit_update()
if self.rate_limit_counter:
aborted: bool = False
while not self.rate_limit_counter.increment():
@@ -927,12 +1118,27 @@ class ClientBase:
try:
self._returned_prompt_tokens = None
self._returned_response_tokens = None
self._reasoning_response = None
self.emit_status(processing=True)
await self.status()
prompt_param = self.generate_prompt_parameters(kind)
if self.reason_enabled and not data_expected:
prompt = self.attach_response_length_instruction(
prompt,
(prompt_param.get(self.max_tokens_param_name) or 0)
- self.reason_tokens,
)
if not self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
if coercion_prompt:
prompt += f"{INDIRECT_COERCION_PROMPT}{coercion_prompt}"
else:
coercion_prompt = None
finalized_prompt = self.prompt_template(
self.get_system_message(kind), prompt
).strip(" ")
@@ -954,11 +1160,26 @@ class ClientBase:
max_token_length=self.max_token_length,
parameters=prompt_param,
)
prompt_sent = self.repetition_adjustment(finalized_prompt)
if "<|RESPONSE_LENGTH_INSTRUCTIONS|>" in finalized_prompt:
finalized_prompt = finalized_prompt.replace(
"\n<|RESPONSE_LENGTH_INSTRUCTIONS|>", ""
)
self.new_request()
response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
response = await self._cancelable_generate(
finalized_prompt, prompt_param, kind
)
response, reasoning_response = self.strip_reasoning(response)
if reasoning_response:
self._reasoning_response = reasoning_response
if coercion_prompt:
response = self.process_response_for_indirect_coercion(
finalized_prompt, response, coercion_prompt
)
self.end_request()
@@ -966,14 +1187,13 @@ class ClientBase:
# generation was cancelled
raise response
# response = await self.generate(prompt_sent, prompt_param, kind)
response, finalized_prompt = await self.auto_break_repetition(
finalized_prompt, prompt_param, response, kind, retries
)
if REPLACE_SMART_QUOTES:
response = response.replace("", '"').replace("", '"')
response = (
response.replace("", '"')
.replace("", '"')
.replace("", "'")
.replace("", "'")
)
time_end = time.time()
@@ -991,7 +1211,7 @@ class ClientBase:
"prompt_sent",
data=PromptData(
kind=kind,
prompt=prompt_sent,
prompt=finalized_prompt,
response=response,
prompt_tokens=self._returned_prompt_tokens or token_length,
response_tokens=self._returned_response_tokens
@@ -1003,12 +1223,17 @@ class ClientBase:
generation_parameters=prompt_param,
inference_preset=client_context_attribute("inference_preset"),
preset_group=self.preset_group,
reasoning=self._reasoning_response,
).model_dump(),
)
return response
except GenerationCancelled:
raise
except GenerationProcessingError as e:
self.log.error("send_prompt error", e=e)
emit("status", message=str(e), status="error")
return ""
except Exception:
self.log.error("send_prompt error", e=traceback.format_exc())
emit(
@@ -1023,130 +1248,6 @@ class ClientBase:
if self.rate_limit_counter:
self.rate_limit_counter.increment()
async def auto_break_repetition(
self,
finalized_prompt: str,
prompt_param: dict,
response: str,
kind: str,
retries: int,
pad_max_tokens: int = 32,
) -> str:
"""
If repetition breaking is enabled, this will retry the prompt if its
response is too similar to other messages in the prompt
This requires the agent to have the allow_repetition_break method
and the jiggle_enabled_for method and the client to have the
auto_break_repetition_enabled attribute set to True
Arguments:
- finalized_prompt: the prompt that was sent
- prompt_param: the parameters that were used
- response: the response that was received
- kind: the kind of generation
- retries: the number of retries left
- pad_max_tokens: increase response max_tokens by this amount per iteration
Returns:
- the response
"""
if not self.auto_break_repetition_enabled or not response.strip():
return response, finalized_prompt
agent_context = active_agent.get()
if self.jiggle_enabled_for(kind, auto=True):
# check if the response is a repetition
# using the default similarity threshold of 98, meaning it needs
# to be really similar to be considered a repetition
is_repetition, similarity_score, matched_line = util.similarity_score(
response, finalized_prompt.split("\n"), similarity_threshold=80
)
if not is_repetition:
# not a repetition, return the response
self.log.debug(
"send_prompt no similarity", similarity_score=similarity_score
)
finalized_prompt = self.repetition_adjustment(
finalized_prompt, is_repetitive=False
)
return response, finalized_prompt
while is_repetition and retries > 0:
# it's a repetition, retry the prompt with adjusted parameters
self.log.warn(
"send_prompt similarity retry",
agent=agent_context.agent.agent_type,
similarity_score=similarity_score,
retries=retries,
)
# first we apply the client's randomness jiggle which will adjust
# parameters like temperature and repetition_penalty, depending
# on the client
#
# this is a cumulative adjustment, so it will add to the previous
# iteration's adjustment, this also means retries should be kept low
# otherwise it will get out of hand and start generating nonsense
self.jiggle_randomness(prompt_param, offset=0.5)
# then we pad the max_tokens by the pad_max_tokens amount
prompt_param[self.max_tokens_param_name] += pad_max_tokens
# send the prompt again
# we use the repetition_adjustment method to further encourage
# the AI to break the repetition on its own as well.
finalized_prompt = self.repetition_adjustment(
finalized_prompt, is_repetitive=True
)
response = retried_response = await self.generate(
finalized_prompt, prompt_param, kind
)
self.log.debug(
"send_prompt dedupe sentences",
response=response,
matched_line=matched_line,
)
# a lot of the times the response will now contain the repetition + something new
# so we dedupe the response to remove the repetition on sentences level
response = util.dedupe_sentences(
response, matched_line, similarity_threshold=85, debug=True
)
self.log.debug(
"send_prompt dedupe sentences (after)", response=response
)
# deduping may have removed the entire response, so we check for that
if not util.strip_partial_sentences(response).strip():
# if the response is empty, we set the response to the original
# and try again next loop
response = retried_response
# check if the response is a repetition again
is_repetition, similarity_score, matched_line = util.similarity_score(
response, finalized_prompt.split("\n"), similarity_threshold=80
)
retries -= 1
return response, finalized_prompt
def count_tokens(self, content: str):
return util.count_tokens(content)
@@ -1169,31 +1270,9 @@ class ClientBase:
return agent.allow_repetition_break(kind, agent_context.action, auto=auto)
def repetition_adjustment(self, prompt: str, is_repetitive: bool = False):
"""
Breaks the prompt into lines and checkse each line for a match with
[$REPETITION|{repetition_adjustment}].
On match and if is_repetitive is True, the line is removed from the prompt and
replaced with the repetition_adjustment.
On match and if is_repetitive is False, the line is removed from the prompt.
"""
lines = prompt.split("\n")
new_lines = []
for line in lines:
if line.startswith("[$REPETITION|"):
if is_repetitive:
new_lines.append(line.split("|")[1][:-1])
else:
new_lines.append("")
else:
new_lines.append(line)
return "\n".join(new_lines)
def process_response_for_indirect_coercion(self, prompt: str, response: str) -> str:
def process_response_for_indirect_coercion(
self, prompt: str, response: str, coercion_prompt: str
) -> str:
"""
A lot of remote APIs don't let us control the prompt template and we cannot directly
append the beginning of the desired response to the prompt.
@@ -1202,13 +1281,19 @@ class ClientBase:
and then hopefully it will adhere to it and we can strip it off the actual response.
"""
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response and expected_response.startswith("{"):
if coercion_prompt and coercion_prompt.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
log.debug(
"process_response_for_indirect_coercion",
response=f"|{response[:100]}...|",
coercion_prompt=f"|{coercion_prompt}|",
)
if coercion_prompt and response.startswith(coercion_prompt):
response = response[len(coercion_prompt) :].strip()
elif coercion_prompt and response.lstrip().startswith(coercion_prompt):
response = response.lstrip()[len(coercion_prompt) :].strip()
return response

View File

@@ -15,9 +15,8 @@ from talemate.client.remote import (
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.util import count_tokens
__all__ = [
@@ -54,7 +53,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
client_type = "cohere"
conversation_retries = 0
auto_break_repetition_enabled = False
decensor_enabled = True
config_cls = ClientConfig
@@ -67,18 +65,9 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
defaults: Defaults = Defaults()
def __init__(self, model="command-r-plus", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def cohere_api_key(self):
return self.config.get("cohere", {}).get("api_key")
return self.config.cohere.api_key
@property
def supported_parameters(self):
@@ -96,15 +85,15 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
def emit_status(self, processing: bool = None):
error_action = None
error_message = None
if processing is not None:
self.processing = processing
if self.cohere_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -117,7 +106,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
@@ -125,67 +114,18 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncClientV2("sk-1111")
log.error("No cohere API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "command-r-plus"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncClientV2(self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"cohere set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return count_tokens(response)
@@ -195,16 +135,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
@@ -228,13 +158,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No cohere API key set")
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
client = AsyncClientV2(self.api_key, base_url=self.base_url)
human_message = prompt.strip()
system_message = self.get_system_message(kind)
@@ -263,7 +187,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
# manager, so attempting to use `async with` raises a `TypeError` as seen
# in issue logs. We therefore iterate over the generator directly.
stream = self.client.chat_stream(
stream = client.chat_stream(
model=self.model_name,
messages=messages,
**parameters,
@@ -283,13 +207,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
log.debug("generated response", response=response)
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
# except PermissionDeniedError as e:
# self.log.error("generate error", e=e)

View File

@@ -4,9 +4,7 @@ from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.util import count_tokens
__all__ = [
@@ -40,7 +38,6 @@ class DeepSeekClient(ClientBase):
client_type = "deepseek"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
@@ -52,17 +49,9 @@ class DeepSeekClient(ClientBase):
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="deepseek-chat", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def deepseek_api_key(self):
return self.config.get("deepseek", {}).get("api_key")
return self.config.deepseek.api_key
@property
def supported_parameters(self):
@@ -75,15 +64,15 @@ class DeepSeekClient(ClientBase):
def emit_status(self, processing: bool = None):
error_action = None
error_message = None
if processing is not None:
self.processing = processing
if self.deepseek_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -96,7 +85,7 @@ class DeepSeekClient(ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
@@ -104,66 +93,18 @@ class DeepSeekClient(ClientBase):
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
if not self.deepseek_api_key:
self.client = AsyncOpenAI(api_key="sk-1111", base_url=BASE_URL)
log.error("No DeepSeek API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "deepseek-chat"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncOpenAI(api_key=self.deepseek_api_key, base_url=BASE_URL)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"deepseek set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def count_tokens(self, content: str):
if not self.model_name:
return 0
@@ -172,18 +113,6 @@ class DeepSeekClient(ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
# only gpt-4-1106-preview supports json_object response coersion
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def response_tokens(self, response: str):
# Count tokens in a response string using the util.count_tokens helper
return self.count_tokens(response)
@@ -200,20 +129,7 @@ class DeepSeekClient(ClientBase):
if not self.deepseek_api_key:
raise Exception("No DeepSeek API key set")
# only gpt-4-* supports enforcing json object
supports_json_object = (
self.model_name.startswith("gpt-4-")
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
)
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
client = AsyncOpenAI(api_key=self.deepseek_api_key, base_url=BASE_URL)
human_message = {"role": "user", "content": prompt.strip()}
system_message = {"role": "system", "content": self.get_system_message(kind)}
@@ -227,7 +143,7 @@ class DeepSeekClient(ClientBase):
try:
# Use streaming so we can update_Request_tokens incrementally
stream = await self.client.chat.completions.create(
stream = await client.chat.completions.create(
model=self.model_name,
messages=[system_message, human_message],
stream=True,
@@ -251,20 +167,6 @@ class DeepSeekClient(ClientBase):
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
except PermissionDeniedError as e:
self.log.error("generate error", e=e)

View File

@@ -21,10 +21,8 @@ from talemate.client.remote import (
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig
from talemate.config import load_config
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.util import count_tokens
__all__ = [
@@ -41,10 +39,11 @@ SUPPORTED_MODELS = [
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash-lite-preview-06-17",
"gemini-2.5-flash-preview-05-20",
"gemini-2.5-pro-preview-03-25",
"gemini-2.5-flash",
"gemini-2.5-pro-preview-06-05",
"gemini-2.5-pro",
]
@@ -59,6 +58,9 @@ class ClientConfig(EndpointOverride, BaseClientConfig):
disable_safety_settings: bool = False
MIN_THINKING_TOKENS = 0
@register()
class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
"""
@@ -67,7 +69,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
client_type = "google"
conversation_retries = 0
auto_break_repetition_enabled = False
decensor_enabled = True
config_cls = ClientConfig
@@ -90,21 +91,23 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
extra_fields.update(endpoint_override_extra_fields())
def __init__(self, model="gemini-2.0-flash", **kwargs):
self.model_name = model
self.setup_status = None
self.model_instance = None
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
self.google_credentials_read = False
self.google_project_id = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def disable_safety_settings(self):
return self.client_config.disable_safety_settings
@property
def min_reason_tokens(self) -> int:
return MIN_THINKING_TOKENS
@property
def can_be_coerced(self) -> bool:
return True
return not self.reason_enabled
@property
def google_credentials(self):
@@ -116,15 +119,15 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
@property
def google_credentials_path(self):
return self.config.get("google").get("gcloud_credentials_path")
return self.config.google.gcloud_credentials_path
@property
def google_location(self):
return self.config.get("google").get("gcloud_location")
return self.config.google.gcloud_location
@property
def google_api_key(self):
return self.config.get("google").get("api_key")
return self.config.google.api_key
@property
def vertexai_ready(self) -> bool:
@@ -197,6 +200,16 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
return genai_types.HttpOptions(base_url=self.base_url)
@property
def thinking_config(self) -> genai_types.ThinkingConfig | None:
if not self.reason_enabled:
return None
return genai_types.ThinkingConfig(
thinking_budget=self.validated_reason_tokens,
include_thoughts=True,
)
@property
def supported_parameters(self):
return [
@@ -211,6 +224,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
),
]
@property
def requires_reasoning_pattern(self) -> bool:
return False
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
@@ -269,46 +286,20 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
"Error setting client base URL", error=e, client=self.client_type
)
def set_client(self, max_token_length: int = None, **kwargs):
if not self.ready:
log.error("Google cloud setup incomplete")
if self.setup_status:
self.setup_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "gemini-2.0-flash"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
def make_client(self) -> genai.Client:
if self.google_credentials_path:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_credentials_path
model = self.model_name
self.max_token_length = max_token_length or 16384
if self.vertexai_ready and not self.developer_api_ready:
self.client = genai.Client(
return genai.Client(
vertexai=True,
project=self.google_project_id,
location=self.google_location,
)
else:
self.client = genai.Client(
return genai.Client(
api_key=self.api_key or None, http_options=self.http_options
)
log.info(
"google set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def response_tokens(self, response: str):
"""Return token count for a response which may be a string or SDK object."""
return count_tokens(response)
@@ -316,22 +307,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
def prompt_tokens(self, prompt: str):
return count_tokens(prompt)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "disable_safety_settings" in kwargs:
self.disable_safety_settings = kwargs["disable_safety_settings"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
@@ -339,13 +314,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
if "top_k" in parameters and parameters["top_k"] == 0:
del parameters["top_k"]
def prompt_template(self, system_message: str, prompt: str):
"""
Google handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -354,7 +322,12 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
if not self.ready:
raise Exception("Google setup incomplete")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
client = self.make_client()
if self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
else:
coercion_prompt = None
human_message = prompt.strip()
system_message = self.get_system_message(kind)
@@ -371,6 +344,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
]
if coercion_prompt:
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
contents.append(
genai_types.Content(
role="model",
@@ -384,49 +358,57 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
self.log.debug(
"generate",
model=self.model_name,
base_url=self.base_url,
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
disable_safety_settings=self.disable_safety_settings,
safety_settings=self.safety_settings,
thinking_config=self.thinking_config,
)
try:
# Use streaming so we can update_Request_tokens incrementally
# stream = await chat.send_message_async(
# human_message,
# safety_settings=self.safety_settings,
# generation_config=parameters,
# stream=True
# )
stream = await self.client.aio.models.generate_content_stream(
stream = await client.aio.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=genai_types.GenerateContentConfig(
safety_settings=self.safety_settings,
http_options=self.http_options,
thinking_config=self.thinking_config,
**parameters,
),
)
response = ""
reasoning = ""
# https://ai.google.dev/gemini-api/docs/thinking#summaries
async for chunk in stream:
# For each streamed chunk, append content and update token counts
content_piece = getattr(chunk, "text", None)
if not content_piece:
# Some SDK versions wrap text under candidates[0].text
try:
content_piece = chunk.candidates[0].text # type: ignore
except Exception:
content_piece = None
try:
if not chunk:
continue
if content_piece:
response += content_piece
# Incrementally update token usage
self.update_request_tokens(count_tokens(content_piece))
if not chunk.candidates:
continue
if not chunk.candidates[0].content.parts:
continue
for part in chunk.candidates[0].content.parts:
if not part.text:
continue
if part.thought:
reasoning += part.text
else:
response += part.text
self.update_request_tokens(count_tokens(part.text))
except Exception as e:
log.error("error processing chunk", e=e, chunk=chunk)
continue
if reasoning:
self._reasoning_response = reasoning
# Store total token accounting for prompt/response
self._returned_prompt_tokens = self.prompt_tokens(prompt)

View File

@@ -4,9 +4,8 @@ from groq import AsyncGroq, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, ExtraField
from talemate.client.registry import register
from talemate.config import load_config
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
@@ -23,6 +22,10 @@ SUPPORTED_MODELS = [
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"llama-3.3-70b-versatile",
"qwen/qwen3-32b",
"moonshotai/kimi-k2-instruct",
"deepseek-r1-distill-llama-70b",
]
JSON_OBJECT_RESPONSE_MODELS = []
@@ -30,7 +33,11 @@ JSON_OBJECT_RESPONSE_MODELS = []
class Defaults(EndpointOverride, pydantic.BaseModel):
max_token_length: int = 8192
model: str = "llama3-70b-8192"
model: str = "moonshotai/kimi-k2-instruct"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
@@ -41,9 +48,9 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
client_type = "groq"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Groq"
@@ -54,19 +61,13 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="llama3-70b-8192", **kwargs):
self.model_name = model
self.api_key_status = None
# Apply any endpoint override parameters provided via kwargs before creating client
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return not self.reason_enabled
@property
def groq_api_key(self):
return self.config.get("groq", {}).get("api_key")
return self.config.groq.api_key
@property
def supported_parameters(self):
@@ -83,15 +84,15 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
def emit_status(self, processing: bool = None):
error_action = None
error_message = None
if processing is not None:
self.processing = processing
if self.groq_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -104,7 +105,7 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
@@ -112,6 +113,7 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
# Include shared/common status data (rate limit, etc.)
data.update(self._common_status_data())
@@ -120,66 +122,11 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
# Determine if we should use the globally configured API key or the override key
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
# No API key and no endpoint override cannot initialize client correctly
self.client = AsyncGroq(api_key="sk-1111")
log.error("No groq.ai API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "llama3-70b-8192"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
# Use the override values (if any) when constructing the Groq client
self.client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"groq.ai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
# Allow dynamic reconfiguration of endpoint override parameters
self._reconfigure_endpoint_override(**kwargs)
# Reconfigure any common parameters (rate limit, data format, etc.)
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return response.usage.completion_tokens
@@ -189,16 +136,6 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -207,16 +144,12 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No groq.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
if self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
else:
coercion_prompt = None
system_message = self.get_system_message(kind)
@@ -225,6 +158,10 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
{"role": "user", "content": prompt},
]
if coercion_prompt:
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
@@ -233,27 +170,25 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
)
try:
response = await self.client.chat.completions.create(
stream = await client.chat.completions.create(
model=self.model_name,
messages=messages,
stream=True,
**parameters,
)
response = response.choices[0].message.content
response = ""
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
# Iterate over streamed chunks
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta and getattr(delta, "content", None):
content_piece = delta.content
response += content_piece
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
return response
except PermissionDeniedError as e:

View File

@@ -75,6 +75,7 @@ class KoboldEmbeddingFunction(EmbeddingFunction):
class KoboldCppClient(ClientBase):
auto_determine_prompt_template: bool = True
client_type = "koboldcpp"
remote_model_locked: bool = True
class Meta(ClientBase.Meta):
name_prefix: str = "KoboldCpp"
@@ -188,6 +189,10 @@ class KoboldCppClient(ClientBase):
def embeddings_function(self):
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
@property
def default_prompt_template(self) -> str:
return "KoboldAI.jinja2"
def api_endpoint_specified(self, url: str) -> bool:
return "/v1" in self.api_url
@@ -200,14 +205,9 @@ class KoboldCppClient(ClientBase):
self.api_url += "/"
def __init__(self, **kwargs):
self.api_key = kwargs.pop("api_key", "")
super().__init__(**kwargs)
self.ensure_api_endpoint_specified()
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)
self.ensure_api_endpoint_specified()
async def get_embeddings_model_name(self):
# if self._embeddings_model_name is set, return it
if self.embeddings_model_name:
@@ -245,15 +245,21 @@ class KoboldCppClient(ClientBase):
model_name=self.embeddings_model_name,
)
self.set_embeddings()
await self.set_embeddings()
await async_signals.get("client.embeddings_available").send(
ClientEmbeddingsStatus(
client=self,
embedding_name=self.embeddings_model_name,
)
emission = ClientEmbeddingsStatus(
client=self,
embedding_name=self.embeddings_model_name,
)
await async_signals.get("client.embeddings_available").send(emission)
if not emission.seen:
# the suggestion has not been seen by the memory agent
# yet, so we unset the embeddings model name so it will
# get suggested again
self._embeddings_model_name = None
async def get_model_name(self):
self.ensure_api_endpoint_specified()
@@ -437,12 +443,6 @@ class KoboldCppClient(ClientBase):
except KeyError:
pass
def reconfigure(self, **kwargs):
if "api_key" in kwargs:
self.api_key = kwargs.pop("api_key")
super().reconfigure(**kwargs)
async def visual_automatic1111_setup(self, visual_agent: "VisualBase") -> bool:
"""
Automatically configure the visual agent for automatic1111

View File

@@ -14,6 +14,7 @@ class Defaults(CommonDefaults, pydantic.BaseModel):
class LMStudioClient(ClientBase):
auto_determine_prompt_template: bool = True
client_type = "lmstudio"
remote_model_locked: bool = True
class Meta(ClientBase.Meta):
name_prefix: str = "LMStudio"
@@ -32,17 +33,16 @@ class LMStudioClient(ClientBase):
),
]
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def reconfigure(self, **kwargs):
super().reconfigure(**kwargs)
if self.client and self.client.base_url != self.api_url:
self.set_client()
def make_client(self):
return AsyncOpenAI(base_url=self.api_url + "/v1", api_key=self.api_key)
async def get_model_name(self):
model_name = await super().get_model_name()
client = self.make_client()
models = await client.models.list(timeout=self.status_request_timeout)
try:
model_name = models.data[0].id
except IndexError:
return None
# model name comes back as a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
@@ -65,9 +65,11 @@ class LMStudioClient(ClientBase):
parameters=parameters,
)
client = self.make_client()
try:
# Send the request in streaming mode so we can update token counts
stream = await self.client.completions.create(
stream = await client.completions.create(
model=self.model_name,
prompt=prompt,
stream=True,

View File

@@ -15,9 +15,8 @@ from talemate.client.remote import (
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"MistralAIClient",
@@ -33,14 +32,7 @@ SUPPORTED_MODELS = [
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
]
JSON_OBJECT_RESPONSE_MODELS = [
"open-mixtral-8x22b",
"open-mistral-nemo",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
"magistral-medium-2506",
]
@@ -61,7 +53,6 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
client_type = "mistral"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = True
config_cls = ClientConfig
@@ -75,17 +66,13 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="open-mixtral-8x22b", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return not self.reason_enabled
@property
def mistral_api_key(self):
return self.config.get("mistralai", {}).get("api_key")
return self.config.mistralai.api_key
@property
def supported_parameters(self):
@@ -97,15 +84,15 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
def emit_status(self, processing: bool = None):
error_action = None
error_message = None
if processing is not None:
self.processing = processing
if self.mistral_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -118,74 +105,25 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
if not self.mistral_api_key and not self.endpoint_override_base_url_configured:
self.client = Mistral(api_key="sk-1111")
log.error("No mistral.ai API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "open-mixtral-8x22b"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = Mistral(api_key=self.api_key, server_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"mistral.ai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return response.usage.completion_tokens
@@ -195,16 +133,6 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
# clamp temperature to 0.1 and 1.0
@@ -220,16 +148,12 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
if not self.mistral_api_key:
raise Exception("No mistral.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
client = Mistral(api_key=self.api_key, server_url=self.base_url)
if self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
else:
coercion_prompt = None
system_message = self.get_system_message(kind)
@@ -238,6 +162,16 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
{"role": "user", "content": prompt.strip()},
]
if coercion_prompt:
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
messages.append(
{
"role": "assistant",
"content": coercion_prompt.strip(),
"prefix": True,
}
)
self.log.debug(
"generate",
base_url=self.base_url,
@@ -247,7 +181,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
)
try:
event_stream = await self.client.chat.stream_async(
event_stream = await client.chat.stream_async(
model=self.model_name,
messages=messages,
**parameters,
@@ -271,22 +205,6 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
# response = response.choices[0].message.content
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
except SDKError as e:
self.log.error("generate error", e=e)

View File

@@ -27,6 +27,8 @@ TALEMATE_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "talemate")
# user overrides
USER_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "user")
DEFAULT_TEMPLATE = "default.jinja2"
TEMPLATE_IDENTIFIERS = []
@@ -73,10 +75,11 @@ class ModelPrompt:
system_message: str,
prompt: str,
double_coercion: str = None,
default_template: str = DEFAULT_TEMPLATE,
):
template, template_file = self.get_template(model_name)
if not template:
template_file = "default.jinja2"
template_file = default_template
template = self.env.get_template(template_file)
if not double_coercion:

View File

@@ -12,7 +12,7 @@ from talemate.client.base import (
ExtraField,
)
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
from talemate.config.schema import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.ollama")
@@ -24,12 +24,10 @@ class OllamaClientDefaults(CommonDefaults):
api_url: str = "http://localhost:11434" # Default Ollama URL
model: str = "" # Allow empty default, will fetch from Ollama
api_handles_prompt_template: bool = False
allow_thinking: bool = False
class ClientConfig(BaseClientConfig):
api_handles_prompt_template: bool = False
allow_thinking: bool = False
@register()
@@ -58,13 +56,6 @@ class OllamaClient(ClientBase):
required=False,
description="Let Ollama handle the prompt template. Only do this if you don't know which prompt template to use. Letting talemate handle the prompt template will generally lead to improved responses.",
),
"allow_thinking": ExtraField(
name="allow_thinking",
type="bool",
label="Allow thinking",
required=False,
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
),
}
@property
@@ -90,51 +81,25 @@ class OllamaClient(ClientBase):
"extra_stopping_strings",
]
def __init__(
self,
**kwargs,
):
self._available_models = []
self._models_last_fetched = 0
super().__init__(**kwargs)
@property
def can_be_coerced(self):
"""
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
return not self.api_handles_prompt_template
return not self.api_handles_prompt_template and not self.reason_enabled
@property
def can_think(self) -> bool:
"""
Allow reasoning models to think before responding.
"""
return self.allow_thinking
def __init__(
self,
model=None,
api_handles_prompt_template=False,
allow_thinking=False,
**kwargs,
):
self.model_name = model
self.api_handles_prompt_template = api_handles_prompt_template
self.allow_thinking = allow_thinking
self._available_models = []
self._models_last_fetched = 0
self.client = None
super().__init__(**kwargs)
def set_client(self, **kwargs):
"""
Initialize the Ollama client with the API URL.
"""
# Update model if provided
if kwargs.get("model"):
self.model_name = kwargs["model"]
# Create async client with the configured API URL
# Ollama's AsyncClient expects just the base URL without any path
self.client = ollama.AsyncClient(host=self.api_url)
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
self.allow_thinking = kwargs.get("allow_thinking", self.allow_thinking)
def api_handles_prompt_template(self) -> bool:
return self.client_config.api_handles_prompt_template
async def status(self):
"""
@@ -177,7 +142,9 @@ class OllamaClient(ClientBase):
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
return self._available_models
response = await self.client.list()
client = ollama.AsyncClient(host=self.api_url)
response = await client.list()
models = response.get("models", [])
model_names = [model.model for model in models]
self._available_models = sorted(model_names)
@@ -192,19 +159,11 @@ class OllamaClient(ClientBase):
return data
async def get_model_name(self):
return self.model_name
return self.model
def prompt_template(self, system_message: str, prompt: str):
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
@@ -251,6 +210,8 @@ class OllamaClient(ClientBase):
if not self.model_name:
raise Exception("No model specified or available in Ollama")
client = ollama.AsyncClient(host=self.api_url)
# Prepare options for Ollama
options = parameters
@@ -258,12 +219,11 @@ class OllamaClient(ClientBase):
try:
# Use generate endpoint for completion
stream = await self.client.generate(
stream = await client.generate(
model=self.model_name,
prompt=prompt.strip(),
options=options,
raw=self.can_be_coerced,
think=self.can_think,
stream=True,
)
@@ -306,20 +266,3 @@ class OllamaClient(ClientBase):
prompt_config["repetition_penalty"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)
def reconfigure(self, **kwargs):
"""
Reconfigure the client with new settings.
"""
# Handle model update
if kwargs.get("model"):
self.model_name = kwargs["model"]
super().reconfigure(**kwargs)
# Re-initialize client if API URL changed or model changed
if "api_url" in kwargs or "model" in kwargs:
self.set_client(**kwargs)
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]

View File

@@ -12,9 +12,8 @@ from talemate.client.remote import (
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"OpenAIClient",
@@ -44,22 +43,15 @@ SUPPORTED_MODELS = [
"o1-preview",
"o1-mini",
"o3-mini",
]
# any model starting with gpt-4- is assumed to support 'json_object'
# for others we need to explicitly state the model name
JSON_OBJECT_RESPONSE_MODELS = [
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"gpt-4o-realtime-preview",
"gpt-4o-mini-realtime-preview",
"gpt-4o",
"gpt-4o-mini",
"gpt-3.5-turbo-0125",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
]
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
# TODO this whole function probably needs to be rewritten at this point
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
@@ -83,7 +75,7 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model or "o1" in model or "o3" in model:
elif "gpt-4" in model or "o1" in model or "o3" in model or "gpt-5" in model:
return num_tokens_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
@@ -104,9 +96,13 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
return num_tokens
DEFAULT_MODEL = "gpt-4o"
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gpt-4o"
model: str = DEFAULT_MODEL
reason_tokens: int = 1024
class ClientConfig(EndpointOverride, BaseClientConfig):
@@ -121,7 +117,6 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
client_type = "openai"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
config_cls = ClientConfig
@@ -135,18 +130,9 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="gpt-4o", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
return self.config.openai.api_key
@property
def supported_parameters(self):
@@ -157,17 +143,35 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
"max_tokens",
]
@property
def requires_reasoning_pattern(self) -> bool:
return False
def emit_status(self, processing: bool = None):
error_action = None
error_message = None
if processing is not None:
self.processing = processing
# Auto-toggle reasoning based on selected model (OpenAI-specific)
# o1/o3/gpt-5 families are reasoning models
try:
if self.model_name:
is_reasoning_model = (
"o1" in self.model_name
or "o3" in self.model_name
or "gpt-5" in self.model_name
)
if self.client_config.reason_enabled != is_reasoning_model:
self.client_config.reason_enabled = is_reasoning_model
except Exception:
pass
if self.openai_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -180,7 +184,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
@@ -188,6 +192,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
data.update(self._common_status_data())
@@ -195,74 +200,11 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncOpenAI(api_key="sk-1111")
log.error("No OpenAI API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "gpt-3.5-turbo-16k"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096)
elif model == "gpt-4":
self.max_token_length = min(max_token_length or 8192, 8192)
elif model == "gpt-3.5-turbo-16k":
self.max_token_length = min(max_token_length or 16384, 16384)
elif model.startswith("gpt-4o") and model != "gpt-4o-2024-05-13":
self.max_token_length = min(max_token_length or 16384, 16384)
elif model == "gpt-4o-2024-05-13":
self.max_token_length = min(max_token_length or 4096, 4096)
elif model == "gpt-4-1106-preview":
self.max_token_length = min(max_token_length or 128000, 128000)
else:
self.max_token_length = max_token_length or 8192
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"openai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def count_tokens(self, content: str):
if not self.model_name:
return 0
@@ -271,18 +213,6 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
# only gpt-4-1106-preview supports json_object response coersion
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -291,26 +221,17 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No OpenAI API key set")
# only gpt-4-* supports enforcing json object
supports_json_object = (
self.model_name.startswith("gpt-4-")
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
)
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
human_message = {"role": "user", "content": prompt.strip()}
system_message = {"role": "system", "content": self.get_system_message(kind)}
# o1 and o3 models don't support system_message
if "o1" in self.model_name or "o3" in self.model_name:
if (
"o1" in self.model_name
or "o3" in self.model_name
or "gpt-5" in self.model_name
):
messages = [human_message]
# paramters need to be munged
# `max_tokens` becomes `max_completion_tokens`
@@ -339,13 +260,20 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
self.log.debug(
"generate",
model=self.model_name,
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
# GPT-5 models do not allow streaming for non-verified orgs; use non-streaming path
if "gpt-5" in self.model_name:
return await self._generate_non_streaming_completion(
client, messages, parameters
)
try:
stream = await self.client.chat.completions.create(
stream = await client.chat.completions.create(
model=self.model_name,
messages=messages,
stream=True,
@@ -365,23 +293,6 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
# self._returned_prompt_tokens = self.prompt_tokens(prompt)
# self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
@@ -389,3 +300,36 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
return ""
except Exception:
raise
async def _generate_non_streaming_completion(
self, client: AsyncOpenAI, messages: list[dict], parameters: dict
) -> str:
"""Perform a non-streaming chat completion request and return the content.
This is used for GPT-5 models which disallow streaming for non-verified orgs.
"""
try:
response = await client.chat.completions.create(
model=self.model_name,
messages=messages,
# No stream flag -> non-streaming
**parameters,
)
if not response.choices:
return ""
message = response.choices[0].message
content = getattr(message, "content", "") or ""
if content:
# Update token usage based on the full content
self.update_request_tokens(self.count_tokens(content))
return content
except PermissionDeniedError as e:
self.log.error("generate (non-streaming) error", e=e)
emit("status", message="OpenAI API: Permission Denied", status="error")
return ""
except Exception:
raise

View File

@@ -6,7 +6,7 @@ from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ExtraField
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
from talemate.config.schema import Client as BaseClientConfig
from talemate.emit import emit
log = structlog.get_logger("talemate.client.openai_compat")
@@ -51,13 +51,9 @@ class OpenAICompatibleClient(ClientBase):
)
}
def __init__(
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
):
self.model_name = model
self.api_key = api_key
self.api_handles_prompt_template = api_handles_prompt_template
super().__init__(**kwargs)
@property
def api_handles_prompt_template(self) -> bool:
return self.client_config.api_handles_prompt_template
@property
def experimental(self):
@@ -69,7 +65,7 @@ class OpenAICompatibleClient(ClientBase):
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
return not self.api_handles_prompt_template
return not self.reason_enabled
@property
def supported_parameters(self):
@@ -80,43 +76,21 @@ class OpenAICompatibleClient(ClientBase):
"max_tokens",
]
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
url = self.api_url
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
self.model_name = (
kwargs.get("model") or kwargs.get("model_name") or self.model_name
)
def prompt_template(self, system_message: str, prompt: str):
log.debug(
"IS API HANDLING PROMPT TEMPLATE",
api_handles_prompt_template=self.api_handles_prompt_template,
)
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
async def get_model_name(self):
return self.model_name
return self.model
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
client = AsyncOpenAI(base_url=self.api_url, api_key=self.api_key)
try:
if self.api_handles_prompt_template:
# OpenAI API handles prompt template
@@ -126,15 +100,37 @@ class OpenAICompatibleClient(ClientBase):
prompt=prompt[:128] + " ...",
parameters=parameters,
)
human_message = {"role": "user", "content": prompt.strip()}
response = await self.client.chat.completions.create(
if self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
else:
coercion_prompt = None
messages = [
{"role": "system", "content": self.get_system_message(kind)},
{"role": "user", "content": prompt.strip()},
]
if coercion_prompt:
log.debug(
"Adding coercion pre-fill", coercion_prompt=coercion_prompt
)
messages.append(
{
"role": "assistant",
"content": coercion_prompt.strip(),
"prefix": True,
}
)
response = await client.chat.completions.create(
model=self.model_name,
messages=[human_message],
messages=messages,
stream=False,
**parameters,
)
response = response.choices[0].message.content
return self.process_response_for_indirect_coercion(prompt, response)
return response
else:
# Talemate handles prompt template
# Use the completions endpoint
@@ -144,7 +140,7 @@ class OpenAICompatibleClient(ClientBase):
parameters=parameters,
)
parameters["prompt"] = prompt
response = await self.client.completions.create(
response = await client.completions.create(
model=self.model_name, stream=False, **parameters
)
return response.choices[0].text
@@ -159,34 +155,6 @@ class OpenAICompatibleClient(ClientBase):
)
return ""
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = (
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
)
if "api_key" in kwargs:
self.api_key = kwargs["api_key"]
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
# TODO: why isn't this calling super()?
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
if "rate_limit" in kwargs:
self.rate_limit = kwargs["rate_limit"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self.set_client(**kwargs)
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and presence penalty

View File

@@ -4,9 +4,17 @@ import httpx
import asyncio
import json
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
from talemate.client.base import (
ClientBase,
ErrorAction,
CommonDefaults,
ExtraField,
FieldGroup,
)
from talemate.config.schema import Client as BaseClientConfig
from talemate.config import get_config
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
@@ -18,6 +26,91 @@ log = structlog.get_logger("talemate.client.openrouter")
# Available models will be populated when first client with API key is initialized
AVAILABLE_MODELS = []
# Static list of providers that are supported by OpenRouter
# https://openrouter.ai/docs/features/provider-routing#json-schema-for-provider-preferences
AVAILABLE_PROVIDERS = [
"AnyScale",
"Cent-ML",
"HuggingFace",
"Hyperbolic 2",
"Lepton",
"Lynn 2",
"Lynn",
"Mancer",
"Modal",
"OctoAI",
"Recursal",
"Reflection",
"Replicate",
"SambaNova 2",
"SF Compute",
"Together 2",
"01.AI",
"AI21",
"AionLabs",
"Alibaba",
"Amazon Bedrock",
"Anthropic",
"AtlasCloud",
"Atoma",
"Avian",
"Azure",
"BaseTen",
"Cerebras",
"Chutes",
"Cloudflare",
"Cohere",
"CrofAI",
"Crusoe",
"DeepInfra",
"DeepSeek",
"Enfer",
"Featherless",
"Fireworks",
"Friendli",
"GMICloud",
"Google",
"Google AI Studio",
"Groq",
"Hyperbolic",
"Inception",
"InferenceNet",
"Infermatic",
"Inflection",
"InoCloud",
"Kluster",
"Lambda",
"Liquid",
"Mancer 2",
"Meta",
"Minimax",
"Mistral",
"Moonshot AI",
"Morph",
"NCompass",
"Nebius",
"NextBit",
"Nineteen",
"Novita",
"OpenAI",
"OpenInference",
"Parasail",
"Perplexity",
"Phala",
"SambaNova",
"Stealth",
"Switchpoint",
"Targon",
"Together",
"Ubicloud",
"Venice",
"xAI",
]
AVAILABLE_PROVIDERS.sort()
DEFAULT_MODEL = ""
MODELS_FETCHED = False
@@ -25,7 +118,6 @@ MODELS_FETCHED = False
async def fetch_available_models(api_key: str = None):
"""Fetch available models from OpenRouter API"""
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
if not api_key:
return []
@@ -37,6 +129,7 @@ async def fetch_available_models(api_key: str = None):
return AVAILABLE_MODELS
try:
log.debug("Fetching models from OpenRouter")
async with httpx.AsyncClient() as client:
response = await client.get(
"https://openrouter.ai/api/v1/models", timeout=10.0
@@ -61,19 +154,36 @@ async def fetch_available_models(api_key: str = None):
return AVAILABLE_MODELS
def fetch_models_sync(event):
api_key = event.data.get("openrouter", {}).get("api_key")
def fetch_models_sync(api_key: str):
loop = asyncio.get_event_loop()
loop.run_until_complete(fetch_available_models(api_key))
handlers["config_saved"].connect(fetch_models_sync)
handlers["talemate_started"].connect(fetch_models_sync)
def on_talemate_started(event):
fetch_models_sync(get_config().openrouter.api_key)
handlers["talemate_started"].connect(on_talemate_started)
class Defaults(CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = DEFAULT_MODEL
provider_only: list[str] = pydantic.Field(default_factory=list)
provider_ignore: list[str] = pydantic.Field(default_factory=list)
class ClientConfig(BaseClientConfig):
provider_only: list[str] = pydantic.Field(default_factory=list)
provider_ignore: list[str] = pydantic.Field(default_factory=list)
PROVIDER_FIELD_GROUP = FieldGroup(
name="provider",
label="Provider",
description="Configure OpenRouter provider routing.",
icon="mdi-server-network",
)
@register()
@@ -84,9 +194,9 @@ class OpenRouterClient(ClientBase):
client_type = "openrouter"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "OpenRouter"
@@ -97,23 +207,46 @@ class OpenRouterClient(ClientBase):
)
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = {
"provider_only": ExtraField(
name="provider_only",
type="flags",
label="Only use these providers",
choices=AVAILABLE_PROVIDERS,
description="Manually limit the providers to use for the selected model. This will override the default provider selection for this model.",
group=PROVIDER_FIELD_GROUP,
required=False,
),
"provider_ignore": ExtraField(
name="provider_ignore",
type="flags",
label="Ignore these providers",
choices=AVAILABLE_PROVIDERS,
description="Ignore these providers for the selected model. This will override the default provider selection for this model.",
group=PROVIDER_FIELD_GROUP,
required=False,
),
}
def __init__(self, model=None, **kwargs):
self.model_name = model or DEFAULT_MODEL
self.api_key_status = None
self.config = load_config()
def __init__(self, **kwargs):
self._models_fetched = False
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def provider_only(self) -> list[str]:
return self.client_config.provider_only
@property
def provider_ignore(self) -> list[str]:
return self.client_config.provider_ignore
@property
def can_be_coerced(self) -> bool:
return True
return not self.reason_enabled
@property
def openrouter_api_key(self):
return self.config.get("openrouter", {}).get("api_key")
return self.config.openrouter.api_key
@property
def supported_parameters(self):
@@ -130,15 +263,15 @@ class OpenRouterClient(ClientBase):
def emit_status(self, processing: bool = None):
error_action = None
error_message = None
if processing is not None:
self.processing = processing
if self.openrouter_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_message = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
@@ -151,7 +284,7 @@ class OpenRouterClient(ClientBase):
if not self.model_name:
status = "error"
model_name = "No model loaded"
error_message = "No model loaded"
self.current_status = status
@@ -159,6 +292,7 @@ class OpenRouterClient(ClientBase):
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
"error_message": error_message,
}
data.update(self._common_status_data())
@@ -166,60 +300,11 @@ class OpenRouterClient(ClientBase):
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=self.model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
# Unlike other clients, we don't need to set up a client instance
# We'll use httpx directly in the generate method
if not self.openrouter_api_key:
log.error("No OpenRouter API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = DEFAULT_MODEL
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
# Set max token length (default to 16k if not specified)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"openrouter set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=self.model_name,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
async def status(self):
# Fetch models if we have an API key and haven't fetched yet
if self.openrouter_api_key and not self._models_fetched:
@@ -229,13 +314,6 @@ class OpenRouterClient(ClientBase):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
"""
Open-router handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters using OpenRouter API.
@@ -244,7 +322,10 @@ class OpenRouterClient(ClientBase):
if not self.openrouter_api_key:
raise Exception("No OpenRouter API key set")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
if self.can_be_coerced:
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
else:
coercion_prompt = None
# Prepare messages for chat completion
messages = [
@@ -253,7 +334,23 @@ class OpenRouterClient(ClientBase):
]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
log.debug("Adding coercion pre-fill", coercion_prompt=coercion_prompt)
messages.append(
{
"role": "assistant",
"content": coercion_prompt.strip(),
"prefix": True,
}
)
provider = {}
if self.provider_only:
provider["only"] = self.provider_only
if self.provider_ignore:
provider["ignore"] = self.provider_ignore
if provider:
parameters["provider"] = provider
# Prepare request payload
payload = {
@@ -320,7 +417,7 @@ class OpenRouterClient(ClientBase):
self.count_tokens(content)
)
except json.JSONDecodeError:
except (json.JSONDecodeError, KeyError):
pass
# Extract the response content

View File

@@ -3,8 +3,7 @@ from typing import TYPE_CHECKING
import structlog
from talemate.client.context import set_client_context_attribute
from talemate.config import InferencePresets, InferencePresetGroup, load_config
from talemate.emit.signals import handlers
from talemate.config import get_config
if TYPE_CHECKING:
from talemate.client.base import ClientBase
@@ -20,42 +19,19 @@ __all__ = [
log = structlog.get_logger("talemate.client.presets")
config = load_config(as_model=True)
# Load the config
CONFIG = {
"inference": config.presets.inference,
"inference_groups": config.presets.inference_groups,
}
# Sync the config when it is saved
def sync_config(event):
CONFIG["inference"] = InferencePresets(
**event.data.get("presets", {}).get("inference", {})
)
CONFIG["inference_groups"] = {
group: InferencePresetGroup(**data)
for group, data in event.data.get("presets", {})
.get("inference_groups", {})
.items()
}
handlers["config_saved"].connect(sync_config)
def get_inference_parameters(preset_name: str, group: str | None = None) -> dict:
"""
Returns the inference parameters for the given preset name.
"""
presets = CONFIG["inference"].model_dump()
config = get_config()
presets = config.presets.inference.model_dump()
if group:
try:
group_presets = CONFIG["inference_groups"].get(group).model_dump()
group_presets = config.presets.inference_groups.get(group).model_dump()
presets.update(group_presets["presets"])
except AttributeError:
log.warning(
@@ -74,6 +50,7 @@ def configure(parameters: dict, kind: str, total_budget: int, client: "ClientBas
"""
set_preset(parameters, kind, client)
set_max_tokens(parameters, kind, total_budget)
return parameters
@@ -141,7 +118,7 @@ def preset_for_kind(kind: str, client: "ClientBase") -> dict:
if not preset_name:
log.warning(
f"No preset found for kind {kind}, defaulting to 'scene_direction'",
presets=CONFIG["inference"],
presets=get_config().presets.inference,
)
preset_name = "scene_direction"

View File

@@ -69,17 +69,13 @@ class EndpointOverrideAPIKeyField(EndpointOverrideField):
class EndpointOverrideMixin:
override_base_url: str | None = None
override_api_key: str | None = None
@property
def override_base_url(self) -> str | None:
return self.client_config.override_base_url
def set_client_api_key(self, api_key: str | None):
if getattr(self, "client", None):
try:
self.client.api_key = api_key
except Exception as e:
log.error(
"Error setting client API key", error=e, client=self.client_type
)
@property
def override_api_key(self) -> str | None:
return self.client_config.override_api_key
@property
def api_key(self) -> str | None:
@@ -108,41 +104,7 @@ class EndpointOverrideMixin:
and self.endpoint_override_api_key_configured
)
def _reconfigure_endpoint_override(self, **kwargs):
if "override_base_url" in kwargs:
orig = getattr(self, "override_base_url", None)
self.override_base_url = kwargs["override_base_url"]
if getattr(self, "client", None) and orig != self.override_base_url:
log.info("Reconfiguring client base URL", new=self.override_base_url)
self.set_client(kwargs.get("max_token_length"))
if "override_api_key" in kwargs:
self.override_api_key = kwargs["override_api_key"]
self.set_client_api_key(self.override_api_key)
class RemoteServiceMixin:
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
async def status(self):
self.emit_status()

View File

@@ -9,7 +9,7 @@ import dotenv
import runpod
import structlog
from talemate.config import load_config
from talemate.config import get_config
from .bootstrap import ClientBootstrap, ClientType, register_list
@@ -17,7 +17,6 @@ log = structlog.get_logger("talemate.client.runpod")
dotenv.load_dotenv()
runpod.api_key = load_config().get("runpod", {}).get("api_key", "")
TEXTGEN_IDENTIFIERS = ["textgen", "thebloke llms", "text-generation-webui"]
@@ -35,6 +34,7 @@ async def _async_get_pods():
"""
asyncio wrapper around get_pods.
"""
runpod.api_key = get_config().runpod.api_key
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, runpod.get_pods)
@@ -44,6 +44,7 @@ async def get_textgen_pods():
"""
Return a list of text generation pods.
"""
runpod.api_key = get_config().runpod.api_key
if not runpod.api_key:
return
@@ -60,6 +61,8 @@ async def get_automatic1111_pods():
Return a list of automatic1111 pods.
"""
runpod.api_key = get_config().runpod.api_key
if not runpod.api_key:
return

Some files were not shown because too many files have changed in this diff Show More