Compare commits

...

6 Commits

Author SHA1 Message Date
veguAI
25e646c56a 0.32.1 (#213)
* GLM 4.5 templates

* set 0.33 and relock

* fix issues with character creation

* relock

* prompt tweaks

* fix lmstudio

* fix issue with npm on windows failing on paths
set 0.32.1

* linting

* update what's new

* #214 (#215)

* max-height and overflow

* max-height and overflow

* v-tabs to list and offset new scrollbar at the top so it doesnt overlap into the divider

* tweaks

* tweaks

* prompt tweaks

---------

Co-authored-by: Iceman Oakenbear <89090218+IcemanOakenbear@users.noreply.github.com>
2025-08-23 01:16:18 +03:00
veguAI
ce4c302d73 0.32.0 (#208)
* separate other tts apis and improve chunking

* move old tts config to voice agent config and implement config widget ux elements for table editing

* elevenlabs updated to use their client and expose model selection

* linting

* separate character class into character.pt and start on voice routing

* linting

* tts hot swapping and chunking improvements

* linting

* add support for piper-tts

* update gitignore

* linting

* support google tts
fix issue where quick_toggle agent config didnt work on standard config items

* linting

* only show agent quick toggles if the agent is enabled

* change elevenlabs to use a locally maintained voice list

* tts generate before / after events

* voice library refactor

* linting

* update openai model and voices

* tweak configs

* voice library ux

* linting

* add support for kokoro tts

* fix add / remove voice

* voice library tags

* linting

* linting

* tts api status

* api infos and add more kokoro voices

* allow voice testing before saving a new voice

* tweaks to voice library ux and some api info text

* linting

* voice mixer

* polish

* voice files go into /tts instead of templates/voice

* change default narrator voice

* xtts confirmation note

* character voice select

* koboldai format template

* polish

* skip empty chunks

* change default voice

* replace em-dash with normal dash

* adjust limit

* replace libebreaks

* chunk cleanup for whitespace

* info updated

* remove invalid endif tag

* sort voices by ready api

* Character hashable type

* clarify set_simulated_environment use to avoid unwanted character deactivated

* allow manual generation of tts and fix assorted issues with tts

* tts websocket handler router renamed

* voice mixer: when there are only 2 voices auto adjust the other weight as needed

* separate persist character functions into own mixin

* auto assign voices

* fix chara load and auto assign voice during chara load

* smart speaker separation

* tts speaker separation config

* generate tts for intro text

* fix prompting issues with anthropic, google and openrouter clients

* decensor flag off again

* only to ai assisted voice markup on narrator messages

* openrouter provider configuration

* linting

* improved sound controls

* add support for chatterbox

* fix info

* chatterbox dependencies

* remove piper and xtts2

* linting

* voice params

* linting

* tts model overrides and move tts info to tab

* reorg toolbar

* allow overriding of test text

* more tts fixes, apply intensity, chatterbox voices

* confirm voice delete

* lintinG

* groq updates

* reorg decorators

* tts fixes

* cancelable audio queue

* voice library uploads

* scene voice library

* Config refactor (#13)

* config refactor progres

* config nuke continues

* fix system prompts

* linting

* client fun

* client config refactor

* fix kcpp auto embedding selection

* linting

* fix proxy config

* remove cruft

* fix remaining client bugs from config refactor
always use get_config(), dont keep an instance reference

* support for reasoning models

* more reasoning tweaks

* only allow one frontend to connect at a time

* fix tests

* relock

* relock

* more client adjustments

* pattern prefill

* some tts agent fixes

* fix ai assist cond

* tts nodes

* fix config retrieval

* assign voice node and fixes

* sim suite char gen assign voice

* fix voice assign template to consider used voices

* get rid of auto break repetition which wasn't working right for a while anyhow

* linting

* generate tts node
as string node

* linting

* voice change on character event

* tweak chatterbox max length

* koboldai default template

* linting

* fix saving of existing voice

* relock

* adjust params of eva default voice

* f5tts support

* f5tts samples

* f5tts support

* f5tts tweaks

* chunk size per tts api and reorg defaul f5tts voices

* chatterbox default voice reog to match f5-tts default voices

* voice library ux polish pass

* cleanup

* f5-tts tweaks

* missing samples

* get rid of old save cmd

* add chatterbox and f5tts

* housekeeping

* fix some issues with world entry editing

* remove cruft

* replace exclamation marks

* fix save immutable check

* fix replace_exclamation_marks

* better error handling in websocket plugins and fix issue with saves

* agent config save on dialog close

* ctrl click to disable / enable agents

* fix quick config

* allow modifying response size of focal requests

* sim suite set goal always sets story intent, encourage calling of set goal during simulation start

* allow setting of model

* voice param tweaks

* tts tweaks

* fix character card load

* fix note_on_value

* add mixed speaker_separation mode

* indicate which message the audio is for and provide way to stop audio from the message

* fix issue with some tts generation failing

* linting

* fix speaker separate modes

* bad idea

* linting

* refactor speaker separation prompt

* add kimi think pattern

* fix issue with unwanted cover image replacemenT

* no scene analysis for visual promp generation (for now)

* linting

* tts for context investigation messages

* prompt tweaks

* tweak intro

* fix intro text tts not auto playing sometimes

* consider narrator voice when assigning voice tro a character

* allow director log messages to go only into the director console

* linting

* startup performance fixes

* init time

* linting

* only show audio control for messagews taht can have it

* always create story intent and dont override existing saves during character card load

* fix history check in dynamic story line node
add HasHistory node

* linting

* fix intro message not having speaker separation

* voice library character manager

* sequantial and cancelable auto assign all

* linting

* fix generation cancel handling

* tooltips

* fix auto assign voice from scene voices

* polish

* kokoro does not like lazy import

* update info text

* complete scene export / import

* linting

* wording

* remove cruft

* fix story intent generation during character card import

* fix generation cancelled emit status inf loop

* prompt tweak

* reasoning quick toggle, reasoning token slider, tooltips

* improved reasoning pattern handling

* fix indirect coercion response parsing

* fix streaming issue

* response length instructions

* more robust streaming

* adjust default

* adjust formatting

* litning

* remove debug output

* director console log function calls

* install cuda script updated

* linting

* add another step

* adjust default

* update dialogue examples

* fix voice selection issues

* what's happening here

* third time's the charm?

* Vite migration (#207)

* add vite config

* replace babel, webpack, vue-cli deps with vite, switch to esm modules, separate eslint config

* change process.env to import.meta.env

* update index.html for vite and move to root

* update docs for vite

* remove vue cli config

* update example env with vite

* bump frontend deps after rebase to 32.0

---------

Co-authored-by: pax-co <Pax_801@proton.me>

* properly referencer data type

* what's new

* better indication of dialogue example supporting multiple lines, improve dialogue example display

* fix potential issue with cached scene anlysis being reused when it shouldn't

* fix character creation issues with player character toggle

* fix issue where editing a message would sometimes lose parts of the message

* fix slider ux thumb labels (vuetify update)

* relock

* narrative conversation format

* remove planning step

* linting

* tweaks

* don't overthink

* update dialogue examples and intro

* dont dictate response length instructions when data structures are expected

* prompt tweaks

* prompt tweaks

* linting

* fix edit message not handling : well

* prompt tweaks

* fix tests

* fix manual revision when character message was generated in new narrative mode

* fix issue with message editing

* Docker packages relese (#204)

* add CI workflow for Docker image build and MkDocs deployment

* rename CI workflow from 'ci' to 'package'

* refactor CI workflow: consolidate container build and documentation deployment into a single file

* fix: correct indentation for permissions in CI workflow

* fix: correct indentation for steps in deploy-docs job in CI workflow

* build both cpu and cuda image

* docs

* docs

* expose writing style during state reinforcement

* prompt tweaks

* test container build

* test container  image

* update docker compose

* docs

* test-container-build

* test container build

* test container build

* update docker build workflows

* fix guidance prompt prefix not being dropped

* mount tts dir

* add gpt-5

* remove debug output

* docs

* openai auto toggle reasoning based on model selection

* linting

---------

Co-authored-by: pax-co <123330830+pax-co@users.noreply.github.com>
Co-authored-by: pax-co <Pax_801@proton.me>
Co-authored-by: Luis Alexandre Deschamps Brandão <brandao_luis@yahoo.com>
2025-08-08 13:56:29 +03:00
vegu-ai-tools
685ca994f9 linting can be done at merge 2025-07-06 20:32:40 +03:00
vegu-ai-tools
285b0699ab contributing.md 2025-07-06 18:41:44 +03:00
vegu-ai-tools
7825489cfc add contributing.md 2025-07-06 18:37:00 +03:00
veguAI
fb2fa31f13 linting
* precommit

* linting

* add linting to workflow

* ruff.toml added
2025-06-29 19:51:08 +03:00
364 changed files with 29002 additions and 26523 deletions

View File

@@ -1,30 +1,57 @@
name: ci
name: ci
on:
push:
branches:
- master
- main
- prep-0.26.0
- master
release:
types: [published]
permissions:
contents: write
packages: write
jobs:
deploy:
container-build:
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
- name: Log in to GHCR
uses: docker/login-action@v3
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build & push
uses: docker/build-push-action@v5
with:
context: .
file: Dockerfile
push: true
tags: |
ghcr.io/${{ github.repository }}:latest
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
deploy-docs:
if: github.event_name == 'release'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git credentials
run: |
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
- uses: actions/setup-python@v5
with: { python-version: '3.x' }
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
restore-keys: mkdocs-material-
- run: pip install mkdocs-material mkdocs-awesome-pages-plugin mkdocs-glightbox
- run: mkdocs gh-deploy --force

View File

@@ -0,0 +1,32 @@
name: test-container-build
on:
push:
branches: [ 'prep-*' ]
permissions:
contents: read
packages: write
jobs:
container-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build & push
uses: docker/build-push-action@v5
with:
context: .
file: Dockerfile
push: true
# Tag with prep suffix to avoid conflicts with production
tags: |
ghcr.io/${{ github.repository }}:${{ github.ref_name }}

View File

@@ -42,6 +42,11 @@ jobs:
source .venv/bin/activate
uv pip install -e ".[dev]"
- name: Run linting
run: |
source .venv/bin/activate
uv run pre-commit run --all-files
- name: Setup configuration file
run: |
cp config.example.yaml config.yaml

12
.gitignore vendored
View File

@@ -8,11 +8,20 @@
talemate_env
chroma
config.yaml
.cursor
.claude
# uv
.venv/
templates/llm-prompt/user/*.jinja2
templates/world-state/*.yaml
tts/voice/piper/*.onnx
tts/voice/piper/*.json
tts/voice/kokoro/*.pt
tts/voice/xtts2/*.wav
tts/voice/chatterbox/*.wav
tts/voice/f5tts/*.wav
tts/voice/voice-library.json
scenes/
!scenes/infinity-quest-dynamic-scenario/
!scenes/infinity-quest-dynamic-scenario/assets/
@@ -21,4 +30,5 @@ scenes/
!scenes/infinity-quest/assets/
!scenes/infinity-quest/infinity-quest.json
tts_voice_samples/*.wav
third-party-docs/
third-party-docs/
legacy-state-reinforcements.yaml

16
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,16 @@
fail_fast: false
exclude: |
(?x)^(
tests/data/.*
|install-utils/.*
)$
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.12.1
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format

64
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,64 @@
# Contributing to Talemate
## About This Project
Talemate is a **personal hobbyist project** that I maintain in my spare time. While I appreciate the community's interest and contributions, please understand that:
- This is primarily a passion project that I enjoy working on myself
- I have limited time for code reviews and prefer to spend that time developing fixes or new features myself
- Large contributions require significant review and testing time that takes away from my own development
For these reasons, I've established contribution guidelines that balance community involvement with my desire to actively develop the project myself.
## Contribution Policy
**I welcome small bugfix and small feature pull requests!** If you've found a bug and have a fix, or have a small feature improvement, I'd love to review it.
However, please note that **I am not accepting large refactors or major feature additions** at this time. This includes:
- Major architectural changes
- Large new features or significant functionality additions
- Large-scale code reorganization
- Breaking API changes
- Features that would require significant maintenance
## What is accepted
**Small bugfixes** - Fixes for specific, isolated bugs
**Small features** - Minor improvements that don't break existing functionality
**Documentation fixes** - Typo corrections, clarifications in existing docs
**Minor dependency updates** - Security patches or minor version bumps
## What is not accepted
**Major features** - Large new functionality or systems
**Large refactors** - Code reorganization or architectural changes
**Breaking changes** - Any changes that break existing functionality
**Major dependency changes** - Framework upgrades or replacements
## Submitting a PR
If you'd like to submit a bugfix or small feature:
1. **Open an issue first** - Describe the bug you've found or feature you'd like to add
2. **Keep it small** - Focus on one specific issue or small improvement
3. **Follow existing code style** - Match the project's current patterns
4. **Don't break existing functionality** - Ensure all existing tests pass
5. **Include tests** - Add or update tests that verify your fix or feature
6. **Update documentation** - If your changes affect behavior, update relevant docs
## Testing
Ensure all tests pass by running:
```bash
uv run pytest tests/ -p no:warnings
```
## Questions?
If you're unsure whether your contribution would be welcome, please open an issue to discuss it first. This saves everyone time and ensures alignment with the project's direction.

View File

@@ -35,18 +35,9 @@ COPY pyproject.toml uv.lock /app/
# Copy the Python source code (needed for editable install)
COPY ./src /app/src
# Create virtual environment and install dependencies
# Create virtual environment and install dependencies (includes CUDA support via pyproject.toml)
RUN uv sync
# Conditional PyTorch+CUDA install
ARG CUDA_AVAILABLE=false
RUN . /app/.venv/bin/activate && \
if [ "$CUDA_AVAILABLE" = "true" ]; then \
echo "Installing PyTorch with CUDA support..." && \
uv pip uninstall torch torchaudio && \
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
fi
# Stage 3: Final image
FROM python:3.11-slim

20
docker-compose.manual.yml Normal file
View File

@@ -0,0 +1,20 @@
version: '3.8'
services:
talemate:
build:
context: .
dockerfile: Dockerfile
ports:
- "${FRONTEND_PORT:-8080}:8080"
- "${BACKEND_PORT:-5050}:5050"
volumes:
- ./config.yaml:/app/config.yaml
- ./scenes:/app/scenes
- ./templates:/app/templates
- ./chroma:/app/chroma
- ./tts:/app/tts
environment:
- PYTHONUNBUFFERED=1
- PYTHONPATH=/app/src:$PYTHONPATH
command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]

View File

@@ -2,11 +2,7 @@ version: '3.8'
services:
talemate:
build:
context: .
dockerfile: Dockerfile
args:
- CUDA_AVAILABLE=${CUDA_AVAILABLE:-false}
image: ghcr.io/vegu-ai/talemate:latest
ports:
- "${FRONTEND_PORT:-8080}:8080"
- "${BACKEND_PORT:-5050}:5050"
@@ -15,6 +11,7 @@ services:
- ./scenes:/app/scenes
- ./templates:/app/templates
- ./chroma:/app/chroma
- ./tts:/app/tts
environment:
- PYTHONUNBUFFERED=1
- PYTHONPATH=/app/src:$PYTHONPATH

View File

@@ -1,60 +1,63 @@
import os
import re
import subprocess
from pathlib import Path
import argparse
def find_image_references(md_file):
"""Find all image references in a markdown file."""
with open(md_file, 'r', encoding='utf-8') as f:
with open(md_file, "r", encoding="utf-8") as f:
content = f.read()
pattern = r'!\[.*?\]\((.*?)\)'
pattern = r"!\[.*?\]\((.*?)\)"
matches = re.findall(pattern, content)
cleaned_paths = []
for match in matches:
path = match.lstrip('/')
if 'img/' in path:
path = path[path.index('img/') + 4:]
path = match.lstrip("/")
if "img/" in path:
path = path[path.index("img/") + 4 :]
# Only keep references to versioned images
parts = os.path.normpath(path).split(os.sep)
if len(parts) >= 2 and parts[0].replace('.', '').isdigit():
if len(parts) >= 2 and parts[0].replace(".", "").isdigit():
cleaned_paths.append(path)
return cleaned_paths
def scan_markdown_files(docs_dir):
"""Recursively scan all markdown files in the docs directory."""
md_files = []
for root, _, files in os.walk(docs_dir):
for file in files:
if file.endswith('.md'):
if file.endswith(".md"):
md_files.append(os.path.join(root, file))
return md_files
def find_all_images(img_dir):
"""Find all image files in version subdirectories."""
image_files = []
for root, _, files in os.walk(img_dir):
# Get the relative path from img_dir to current directory
rel_dir = os.path.relpath(root, img_dir)
# Skip if we're in the root img directory
if rel_dir == '.':
if rel_dir == ".":
continue
# Check if the immediate parent directory is a version number
parent_dir = rel_dir.split(os.sep)[0]
if not parent_dir.replace('.', '').isdigit():
if not parent_dir.replace(".", "").isdigit():
continue
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.svg')):
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".svg")):
rel_path = os.path.relpath(os.path.join(root, file), img_dir)
image_files.append(rel_path)
return image_files
def grep_check_image(docs_dir, image_path):
"""
Check if versioned image is referenced anywhere using grep.
@@ -65,33 +68,46 @@ def grep_check_image(docs_dir, image_path):
parts = os.path.normpath(image_path).split(os.sep)
version = parts[0] # e.g., "0.29.0"
filename = parts[-1] # e.g., "world-state-suggestions-2.png"
# For versioned images, require both version and filename to match
version_pattern = f"{version}.*{filename}"
try:
result = subprocess.run(
['grep', '-r', '-l', version_pattern, docs_dir],
["grep", "-r", "-l", version_pattern, docs_dir],
capture_output=True,
text=True
text=True,
)
if result.stdout.strip():
print(f"Found reference to {image_path} with version pattern: {version_pattern}")
print(
f"Found reference to {image_path} with version pattern: {version_pattern}"
)
return True
except subprocess.CalledProcessError:
pass
except Exception as e:
print(f"Error during grep check for {image_path}: {e}")
return False
def main():
parser = argparse.ArgumentParser(description='Find and optionally delete unused versioned images in MkDocs project')
parser.add_argument('--docs-dir', type=str, required=True, help='Path to the docs directory')
parser.add_argument('--img-dir', type=str, required=True, help='Path to the images directory')
parser.add_argument('--delete', action='store_true', help='Delete unused images')
parser.add_argument('--verbose', action='store_true', help='Show all found references and files')
parser.add_argument('--skip-grep', action='store_true', help='Skip the additional grep validation')
parser = argparse.ArgumentParser(
description="Find and optionally delete unused versioned images in MkDocs project"
)
parser.add_argument(
"--docs-dir", type=str, required=True, help="Path to the docs directory"
)
parser.add_argument(
"--img-dir", type=str, required=True, help="Path to the images directory"
)
parser.add_argument("--delete", action="store_true", help="Delete unused images")
parser.add_argument(
"--verbose", action="store_true", help="Show all found references and files"
)
parser.add_argument(
"--skip-grep", action="store_true", help="Skip the additional grep validation"
)
args = parser.parse_args()
# Convert paths to absolute paths
@@ -118,7 +134,7 @@ def main():
print("\nAll versioned image references found in markdown:")
for img in sorted(used_images):
print(f"- {img}")
print("\nAll versioned images in directory:")
for img in sorted(all_images):
print(f"- {img}")
@@ -133,9 +149,11 @@ def main():
for img in unused_images:
if not grep_check_image(docs_dir, img):
actually_unused.add(img)
if len(actually_unused) != len(unused_images):
print(f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!")
print(
f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!"
)
unused_images = actually_unused
# Report findings
@@ -148,7 +166,7 @@ def main():
print("\nUnused versioned images:")
for img in sorted(unused_images):
print(f"- {img}")
if args.delete:
print("\nDeleting unused versioned images...")
for img in unused_images:
@@ -162,5 +180,6 @@ def main():
else:
print("\nNo unused versioned images found!")
if __name__ == "__main__":
main()
main()

View File

@@ -4,12 +4,12 @@ from talemate.events import GameLoopEvent
import talemate.emit.async_signals
from talemate.emit import emit
@register()
class TestAgent(Agent):
agent_type = "test"
verbose_name = "Test"
def __init__(self, client):
self.client = client
self.is_enabled = True
@@ -20,7 +20,7 @@ class TestAgent(Agent):
description="Test",
),
}
@property
def enabled(self):
return self.is_enabled
@@ -36,7 +36,7 @@ class TestAgent(Agent):
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called on the beginning of every game loop
@@ -45,4 +45,8 @@ class TestAgent(Agent):
if not self.enabled:
return
emit("status", status="info", message="Annoying you with a test message every game loop.")
emit(
"status",
status="info",
message="Annoying you with a test message every game loop.",
)

View File

@@ -19,14 +19,17 @@ from talemate.config import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.runpod_vllm")
class Defaults(pydantic.BaseModel):
max_token_length: int = 4096
model: str = ""
runpod_id: str = ""
class ClientConfig(BaseClientConfig):
runpod_id: str = ""
@register()
class RunPodVLLMClient(ClientBase):
client_type = "runpod_vllm"
@@ -49,7 +52,6 @@ class RunPodVLLMClient(ClientBase):
)
}
def __init__(self, model=None, runpod_id=None, **kwargs):
self.model_name = model
self.runpod_id = runpod_id
@@ -59,12 +61,10 @@ class RunPodVLLMClient(ClientBase):
def experimental(self):
return False
def set_client(self, **kwargs):
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
@@ -88,32 +88,37 @@ class RunPodVLLMClient(ClientBase):
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
async with aiohttp.ClientSession() as session:
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
run_request = await endpoint.run({
"input": {
"prompt": prompt,
run_request = await endpoint.run(
{
"input": {
"prompt": prompt,
}
# "parameters": parameters
}
#"parameters": parameters
})
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
)
while (await run_request.status()) not in [
"COMPLETED",
"FAILED",
"CANCELLED",
]:
status = await run_request.status()
log.debug("generate", status=status)
await asyncio.sleep(0.1)
status = await run_request.status()
log.debug("generate", status=status)
response = await run_request.output()
log.debug("generate", response=response)
return response["choices"][0]["tokens"][0]
except Exception as e:
self.log.error("generate error", e=e)
emit(

View File

@@ -9,6 +9,7 @@ class Defaults(pydantic.BaseModel):
api_url: str = "http://localhost:1234"
max_token_length: int = 4096
@register()
class TestClient(ClientBase):
client_type = "test"
@@ -22,14 +23,13 @@ class TestClient(ClientBase):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters: dict, kind: str):
"""
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
This method is called before the prompt is sent to the client, and it allows the client to remove
any parameters that it doesn't support.
"""
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
@@ -41,11 +41,10 @@ class TestClient(ClientBase):
del parameters[key]
async def get_model_name(self):
"""
This should return the name of the model that is being used.
"""
return "Mock test model"
async def generate(self, prompt: str, parameters: dict, kind: str):

View File

@@ -27,10 +27,10 @@ uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
### Letting the frontend know about the new host and port
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL`.
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VITE_TALEMATE_BACKEND_WEBSOCKET_URL`.
```env
VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
VITE_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
```
Next rebuild the frontend.

View File

@@ -1,22 +1,15 @@
!!! example "Experimental"
Talemate through docker has not received a lot of testing from me, so please let me know if you encounter any issues.
You can do so by creating an issue on the [:material-github: GitHub repository](https://github.com/vegu-ai/talemate)
## Quick install instructions
1. `git clone https://github.com/vegu-ai/talemate.git`
1. `cd talemate`
1. copy config file
1. linux: `cp config.example.yaml config.yaml`
1. windows: `copy config.example.yaml config.yaml`
1. If your host has a CUDA compatible Nvidia GPU
1. Windows (via PowerShell): `$env:CUDA_AVAILABLE="true"; docker compose up`
1. Linux: `CUDA_AVAILABLE=true docker compose up`
1. If your host does **NOT** have a CUDA compatible Nvidia GPU
1. Windows: `docker compose up`
1. Linux: `docker compose up`
1. windows: `copy config.example.yaml config.yaml` (or just copy the file and rename it via the file explorer)
1. `docker compose up`
1. Navigate your browser to http://localhost:8080
!!! info "Pre-built Images"
The default setup uses pre-built images from GitHub Container Registry that include CUDA support by default. To manually build the container instead, use `docker compose -f docker-compose.manual.yml up --build`.
!!! note
When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

View File

@@ -0,0 +1,58 @@
# Chatterbox
Local zero shot voice cloning from .wav files.
![Chatterbox API settings](/talemate/img/0.32.0/chatterbox-api-settings.png)
##### Device
Auto-detects best available option
##### Model
Default Chatterbox model optimized for speed
##### Chunk size
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
## Adding Chatterbox Voices
### Voice Requirements
Chatterbox voices require:
- Reference audio file (.wav format, 5-15 seconds optimal)
- Clear speech with minimal background noise
- Single speaker throughout the sample
### Creating a Voice
1. Open the Voice Library
2. Click **:material-plus: New**
3. Select "Chatterbox" as the provider
4. Configure the voice:
![Add Chatterbox voice](/talemate/img/0.32.0/add-chatterbox-voice.png)
**Label:** Descriptive name (e.g., "Marcus - Deep Male")
**Voice ID / Upload File** Upload a .wav file containing the voice sample. The uploaded reference audio will also be the voice ID.
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
**Tags:** Add descriptive tags for organization
**Extra voice parameters**
There exist some optional parameters that can be set here on a per voice level.
![Chatterbox extra voice parameters](/talemate/img/0.32.0/chatterbox-parameters.png)
##### Exaggeration Level
Exaggeration (Neutral = 0.5, extreme values can be unstable). Higher exaggeration tends to speed up speech; reducing cfg helps compensate with slower, more deliberate pacing.
##### CFG / Pace
If the reference speaker has a fast speaking style, lowering cfg to around 0.3 can improve pacing.

View File

@@ -1,7 +1,41 @@
# ElevenLabs
If you have not configured the ElevenLabs TTS API, the voice agent will show that the API key is missing.
Professional voice synthesis with voice cloning capabilities using ElevenLabs API.
![Elevenlaps api key missing](/talemate/img/0.26.0/voice-agent-missing-api-key.png)
![ElevenLabs API settings](/talemate/img/0.32.0/elevenlabs-api-settings.png)
See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on how to set up the API key.
## API Setup
ElevenLabs requires an API key. See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on obtaining and setting an API key.
## Configuration
**Model:** Select from available ElevenLabs models
!!! warning "Voice Limits"
Your ElevenLabs subscription allows you to maintain a set number of voices (10 for the cheapest plan). Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button. It is recommended to test voices via their voice library instead.
## Adding ElevenLabs Voices
### Getting Voice IDs
1. Go to [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab) to view your voices
2. Find or create the voice you want to use
3. Click "More Actions" -> "Copy Voice ID" for the desired voice
![Copy Voice ID](/talemate/img/0.32.0/elevenlabs-copy-voice-id.png)
### Creating a Voice in Talemate
![Add ElevenLabs voice](/talemate/img/0.32.0/add-elevenlabs-voice.png)
1. Open the Voice Library
2. Click "Add Voice"
3. Select "ElevenLabs" as the provider
4. Configure the voice:
**Label:** Descriptive name for the voice
**Provider ID:** Paste the ElevenLabs voice ID you copied
**Tags:** Add descriptive tags for organization

View File

@@ -0,0 +1,78 @@
# F5-TTS
Local zero shot voice cloning from .wav files.
![F5-TTS configuration](/talemate/img/0.32.0/f5tts-api-settings.png)
##### Device
Auto-detects best available option (GPU preferred)
##### Model
- F5TTS_v1_Base (default, most recent model)
- F5TTS_Base
- E2TTS_Base
##### NFE Step
Number of steps to generate the voice. Higher values result in more detailed voices.
##### Chunk size
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
##### Replace exclamation marks
If checked, exclamation marks will be replaced with periods. This is recommended for `F5TTS_v1_Base` since it seems to over exaggerate exclamation marks.
## Adding F5-TTS Voices
### Voice Requirements
F5-TTS voices require:
- Reference audio file (.wav format, 10-30 seconds)
- Clear speech with minimal background noise
- Single speaker throughout the sample
- Reference text (optional but recommended)
### Creating a Voice
1. Open the Voice Library
2. Click "Add Voice"
3. Select "F5-TTS" as the provider
4. Configure the voice:
![Add F5-TTS voice](/talemate/img/0.32.0/add-f5tts-voice.png)
**Label:** Descriptive name (e.g., "Emma - Calm Female")
**Voice ID / Upload File** Upload a .wav file containing the **reference audio** voice sample. The uploaded reference audio will also be the voice ID.
- Use 6-10 second samples (longer doesn't improve quality)
- Ensure clear speech with minimal background noise
- Record at natural speaking pace
**Reference Text:** Enter the exact text spoken in the reference audio for improved quality
- Enter exactly what is spoken in the reference audio
- Include proper punctuation and capitalization
- Improves voice cloning accuracy significantly
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
**Tags:** Add descriptive tags (gender, age, style) for organization
**Extra voice parameters**
There exist some optional parameters that can be set here on a per voice level.
![F5-TTS extra voice parameters](/talemate/img/0.32.0/f5tts-parameters.png)
##### Speed
Allows you to adjust the speed of the voice.
##### CFG Strength
A higher CFG strength generally leads to more faithful reproduction of the input text, while a lower CFG strength can result in more varied or creative speech output, potentially at the cost of text-to-speech accuracy.

View File

@@ -0,0 +1,15 @@
# Google Gemini-TTS
Google Gemini-TTS provides access to Google's text-to-speech service.
## API Setup
Google Gemini-TTS requires a Google Cloud API key.
See the [Google Cloud API setup](/talemate/user-guide/apis/google/) for instructions on obtaining an API key.
## Configuration
![Google TTS settings](/talemate/img/0.32.0/google-tts-api-settings.png)
**Model:** Select from available Google TTS models

View File

@@ -1,6 +1,26 @@
# Overview
Talemate supports Text-to-Speech (TTS) functionality, allowing users to convert text into spoken audio. This document outlines the steps required to configure TTS for Talemate using different providers, including ElevenLabs and a local TTS API.
In 0.32.0 Talemate's TTS (Text-to-Speech) agent has been completely refactored to provide advanced voice capabilities including per-character voice assignment, speaker separation, and support for multiple local and remote APIs. The voice system now includes a comprehensive voice library for managing and organizing voices across all supported providers.
## Key Features
- **Per-character voice assignment** - Each character can have their own unique voice
- **Speaker separation** - Automatic detection and separation of dialogue from narration
- **Voice library management** - Centralized management of all voices across providers
- **Multiple API support** - Support for both local and remote TTS providers
- **Director integration** - Automatic voice assignment for new characters
## Supported APIs
### Local APIs
- **Kokoro** - Fastest generation with predefined voice models and mixing
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
- **Chatterbox** - High-quality voice cloning (slower generation)
### Remote APIs
- **ElevenLabs** - Professional voice synthesis with voice cloning
- **Google Gemini-TTS** - Google's text-to-speech service
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
## Enable the Voice agent
@@ -12,28 +32,30 @@ If your voice agent is disabled - indicated by the grey dot next to the agent -
![Agent disabled](/talemate/img/0.26.0/agent-disabled.png) ![Agent enabled](/talemate/img/0.26.0/agent-enabled.png)
!!! note "Ctrl click to toggle agent"
You can use Ctrl click to toggle the agent on and off.
!!! abstract "Next: Connect to a TTS api"
Next you need to decide which service / api to use for audio generation and configure the voice agent accordingly.
## Voice Library Management
- [OpenAI](openai.md)
- [ElevenLabs](elevenlabs.md)
- [Local TTS](local_tts.md)
Voices are managed through the Voice Library, accessible from the main application bar. The Voice Library allows you to:
You can also find more information about the various settings [here](settings.md).
- Add and organize voices from all supported providers
- Assign voices to specific characters
- Create mixed voices (Kokoro)
- Manage both global and scene-specific voice libraries
## Select a voice
See the [Voice Library Guide](voice-library.md) for detailed instructions.
![Elevenlaps voice missing](/talemate/img/0.26.0/voice-agent-no-voice-selected.png)
## Character Voice Assignment
Click on the agent to open the agent settings.
![Character voice assignment](/talemate/img/0.32.0/character-voice-assignment.png)
Then click on the `Narrator Voice` dropdown and select a voice.
Characters can have individual voices assigned through the Voice Library. When a character has a voice assigned:
![Elevenlaps voice selected](/talemate/img/0.26.0/voice-agent-select-voice.png)
1. Their dialogue will use their specific voice
2. The narrator voice is used for exposition in their messages (with speaker separation enabled)
3. If their assigned voice's API is not available, it falls back to the narrator voice
The selection is saved automatically, click anywhere outside the agent window to close it.
The Voice agent status will show all assigned character voices and their current status.
The Voice agent should now show that the voice is selected and be ready to use.
![Elevenlabs ready](/talemate/img/0.26.0/elevenlabs-ready.png)
![Voice agent status with characters](/talemate/img/0.32.0/voice-agent-status-characters.png)

View File

@@ -0,0 +1,55 @@
# Kokoro
Kokoro provides predefined voice models and voice mixing capabilities for creating custom voices.
## Using Predefined Voices
Kokoro comes with built-in voice models that are ready to use immediately
Available predefined voices include various male and female voices with different characteristics.
## Creating Mixed Voices
Kokor allows you to mix voices together to create a new voice.
### Voice Mixing Interface
To create a mixed voice:
1. Open the Voice Library
2. Click ":material-plus: New"
3. Select "Kokoro" as the provider
4. Choose ":material-tune:Mixer" option
5. Configure the mixed voice:
![Voice mixing interface](/talemate/img/0.32.0/kokoro-mixer.png)
**Label:** Descriptive name for the mixed voice
**Base Voices:** Select 2-4 existing Kokoro voices to combine
**Weights:** Set the influence of each voice (0.1 to 1.0)
**Tags:** Descriptive tags for organization
### Weight Configuration
Each selected voice can have its weight adjusted:
- Higher weights make that voice more prominent in the mix
- Lower weights make that voice more subtle
- Total weights need to sum to 1.0
- Experiment with different combinations to achieve desired results
### Saving Mixed Voices
Once configured click "Add Voice", mixed voices are saved to your voice library and can be:
- Assigned to characters
- Used as narrator voices
just like any other voice.
Saving a mixed cvoice may take a moment to complete.

View File

@@ -1,53 +0,0 @@
# Local TTS
!!! warning
This has not been tested in a while and may not work as expected. It will likely be replaced with something different in the future. If this approach is currently broken its likely to remain so until it is replaced.
For running a local TTS API, Talemate requires specific dependencies to be installed.
### Windows Installation
Run `install-local-tts.bat` to install the necessary requirements.
### Linux Installation
Execute the following command:
```bash
pip install TTS
```
### Model and Device Configuration
1. Choose a TTS model from the [Coqui TTS model list](https://github.com/coqui-ai/TTS).
2. Decide whether to use `cuda` or `cpu` for the device setting.
3. The first time you run TTS through the local API, it will download the specified model. Please note that this may take some time, and the download progress will be visible in the Talemate backend output.
Example configuration snippet:
```yaml
tts:
device: cuda # or 'cpu'
model: tts_models/multilingual/multi-dataset/xtts_v2
```
### Voice Samples Configuration
Configure voice samples by setting the `value` field to the path of a .wav file voice sample. Official samples can be downloaded from [Coqui XTTS-v2 samples](https://huggingface.co/coqui/XTTS-v2/tree/main/samples).
Example configuration snippet:
```yaml
tts:
voices:
- label: English Male
value: path/to/english_male.wav
- label: English Female
value: path/to/english_female.wav
```
## Saving the Configuration
After configuring the `config.yaml` file, save your changes. Talemate will use the updated settings the next time it starts.
For more detailed information on configuring Talemate, refer to the `config.py` file in the Talemate source code and the `config.example.yaml` file for a barebone configuration example.

View File

@@ -8,16 +8,12 @@ See the [OpenAI API setup](/apis/openai.md) for instructions on how to set up th
## Settings
![Voice agent openai settings](/talemate/img/0.26.0/voice-agent-openai-settings.png)
![Voice agent openai settings](/talemate/img/0.32.0/openai-tts-api-settings.png)
##### Model
Which model to use for generation.
- GPT-4o Mini TTS
- TTS-1
- TTS-1 HD
!!! quote "OpenAI API documentation on quality"
For real-time applications, the standard tts-1 model provides the lowest latency but at a lower quality than the tts-1-hd model. Due to the way the audio is generated, tts-1 is likely to generate content that has more static in certain situations than tts-1-hd. In some cases, the audio may not have noticeable differences depending on your listening device and the individual person.
Generally i have found that HD is fast enough for talemate, so this is the default.
- TTS-1 HD

View File

@@ -1,36 +1,65 @@
# Settings
![Voice agent settings](/talemate/img/0.26.0/voice-agent-settings.png)
![Voice agent settings](/talemate/img/0.32.0/voice-agent-settings.png)
##### API
##### Enabled APIs
The TTS API to use for voice generation.
Select which TTS APIs to enable. You can enable multiple APIs simultaneously:
- OpenAI
- ElevenLabs
- Local TTS
- **Kokoro** - Fastest generation with predefined voice models and mixing
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
- **Chatterbox** - High-quality voice cloning (slower generation)
- **ElevenLabs** - Professional voice synthesis with voice cloning
- **Google Gemini-TTS** - Google's text-to-speech service
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
!!! note "Multi-API Support"
You can enable multiple APIs and assign different voices from different providers to different characters. The system will automatically route voice generation to the appropriate API based on the voice assignment.
##### Narrator Voice
The voice to use for narration. Each API will come with its own set of voices.
The default voice used for narration and as a fallback for characters without assigned voices.
![Narrator voice](/talemate/img/0.26.0/voice-agent-select-voice.png)
The dropdown shows all available voices from all enabled APIs, with the format: "Voice Name (Provider)"
!!! note "Local TTS"
For local TTS, you will have to provide voice samples yourself. See [Local TTS Instructions](local_tts.md) for more information.
!!! info "Voice Management"
Voices are managed through the Voice Library, accessible from the main application bar. Adding, removing, or modifying voices should be done through the Voice Library interface.
##### Generate for player
##### Speaker Separation
Whether to generate voice for the player. If enabled, whenever the player speaks, the voice agent will generate audio for them.
Controls how dialogue is separated from exposition in messages:
##### Generate for NPCs
- **No separation** - Character messages use character voice entirely, narrator messages use narrator voice
- **Simple** - Basic separation of dialogue from exposition using punctuation analysis, with exposition being read by the narrator voice
- **Mixed** - Enables AI assisted separation for narrator messages and simple separation for character messages
- **AI assisted** - AI assisted separation for both narrator and character messages
Whether to generate voice for NPCs. If enabled, whenever a non player character speaks, the voice agent will generate audio for them.
!!! warning "AI Assisted Performance"
AI-assisted speaker separation sends additional prompts to your LLM, which may impact response time and API costs.
##### Generate for narration
##### Auto-generate for player
Whether to generate voice for narration. If enabled, whenever the narrator speaks, the voice agent will generate audio for them.
Generate voice automatically for player messages
##### Split generation
##### Auto-generate for AI characters
If enabled, the voice agent will generate audio in chunks, allowing for faster generation. This does however cause it lose context between chunks, and inflection may not be as good.
Generate voice automatically for NPC/AI character messages
##### Auto-generate for narration
Generate voice automatically for narrator messages
##### Auto-generate for context investigation
Generate voice automatically for context investigation messages
## Advanced Settings
Advanced settings are configured per-API and can be found in the respective API configuration sections:
- **Chunk size** - Maximum text length per generation request
- **Model selection** - Choose specific models for each API
- **Voice parameters** - Provider-specific voice settings
!!! tip "Performance Optimization"
Each API has different optimal chunk sizes and parameters. The system automatically handles chunking and queuing for optimal performance across all enabled APIs.

View File

@@ -0,0 +1,156 @@
# Voice Library
The Voice Library is the central hub for managing all voices across all TTS providers in Talemate. It provides a unified interface for organizing, creating, and assigning voices to characters.
## Accessing the Voice Library
The Voice Library can be accessed from the main application bar at the top of the Talemate interface.
![Voice Library access](/talemate/img/0.32.0/voice-library-access.png)
Click the voice icon to open the Voice Library dialog.
!!! note "Voice agent needs to be enabled"
The Voice agent needs to be enabled for the voice library to be available.
## Voice Library Interface
![Voice Library interface](/talemate/img/0.32.0/voice-library-interface.png)
The Voice Library interface consists of:
### Scope Tabs
- **Global** - Voices available across all scenes
- **Scene** - Voices specific to the current scene (only visible when a scene is loaded)
- **Characters** - Character voice assignments for the current scene (only visible when a scene is loaded)
### API Status
The toolbar shows the status of all TTS APIs:
- **Green** - API is enabled and ready
- **Orange** - API is enabled but not configured
- **Red** - API has configuration issues
- **Gray** - API is disabled
![API status](/talemate/img/0.32.0/voice-library-api-status.png)
## Managing Voices
### Global Voice Library
The global voice library contains voices that are available across all scenes. These include:
- Default voices provided by each TTS provider
- Custom voices you've added
#### Adding New Voices
To add a new voice:
1. Click the "+ New" button
2. Select the TTS provider
3. Configure the voice parameters:
- **Label** - Display name for the voice
- **Provider ID** - Provider-specific identifier
- **Tags** - Free-form descriptive tags you define (gender, age, style, etc.)
- **Parameters** - Provider-specific settings
Check the provider specific documentation for more information on how to configure the voice.
#### Voice Types by Provider
**F5-TTS & Chatterbox:**
- Upload .wav reference files for voice cloning
- Specify reference text for better quality
- Adjust speed and other parameters
**Kokoro:**
- Select from predefined voice models
- Create mixed voices by combining multiple models
- Adjust voice mixing weights
**ElevenLabs:**
- Select from available ElevenLabs voices
- Configure voice settings and stability
- Use custom cloned voices from your ElevenLabs account
**OpenAI:**
- Choose from available OpenAI voice models
- Configure model (GPT-4o Mini TTS, TTS-1, TTS-1-HD)
**Google Gemini-TTS:**
- Select from Google's voice models
- Configure language and gender settings
### Scene Voice Library
Scene-specific voices are only available within the current scene. This is useful for:
- Scene-specific characters
- Temporary voice experiments
- Custom voices for specific scenarios
Scene voices are saved with the scene and will be available when the scene is loaded.
## Character Voice Assignment
### Automatic Assignment
The Director agent can automatically assign voices to new characters based on:
- Character tags and attributes
- Voice tags matching character personality
- Available voices in the voice library
This feature can be enabled in the Director agent settings.
### Manual Assignment
![Character voice assignment](/talemate/img/0.32.0/character-voice-assignment.png)
To manually assign a voice to a character:
1. Go to the "Characters" tab in the Voice Library
2. Find the character in the list
3. Click the voice dropdown for that character
4. Select a voice from the available options
5. The assignment is saved automatically
### Character Voice Status
The character list shows:
- **Character name**
- **Currently assigned voice** (if any)
- **Voice status** - whether the voice's API is available
- **Quick assignment controls**
## Voice Tags and Organization
### Tagging System
Voices can be tagged with any descriptive attributes you choose. Tags are completely free-form and user-defined. Common examples include:
- **Gender**: male, female, neutral
- **Age**: young, mature, elderly
- **Style**: calm, energetic, dramatic, mysterious
- **Quality**: deep, high, raspy, smooth
- **Character types**: narrator, villain, hero, comic relief
- **Custom tags**: You can create any tags that help you organize your voices
### Filtering and Search
Use the search bar to filter voices by:
- Voice label/name
- Provider
- Tags
- Character assignments
This makes it easy to find the right voice for specific characters or situations.

View File

@@ -0,0 +1,82 @@
# Reasoning Model Support
Talemate supports reasoning models that can perform step-by-step thinking before generating their final response. This feature allows models to work through complex problems internally before providing an answer.
## Enabling Reasoning Support
To enable reasoning support for a client:
1. Open the **Clients** dialog from the main toolbar
2. Select the client you want to configure
3. Navigate to the **Reasoning** tab in the client configuration
![Client reasoning configuration](/talemate/img/0.32.0/client-reasoning-2.png)
4. Check the **Enable Reasoning** checkbox
## Configuring Reasoning Tokens
Once reasoning is enabled, you can configure the **Reasoning Tokens** setting using the slider:
![Reasoning tokens configuration](/talemate/img/0.32.0/client-reasoning.png)
### Recommended Token Amounts
**For local reasoning models:** Use a high token allocation (recommended: 4096 tokens) to give the model sufficient space for complex reasoning.
**For remote APIs:** Start with lower amounts (512-1024 tokens) and adjust based on your needs and token costs.
### Token Allocation Behavior
The behavior of the reasoning tokens setting depends on your API provider:
**For APIs that support direct reasoning token specification:**
- The specified tokens will be allocated specifically for reasoning
- The model will use these tokens for internal thinking before generating the response
**For APIs that do NOT support reasoning token specification:**
- The tokens are added as extra allowance to the response token limit for ALL requests
- This may lead to more verbose responses than usual since Talemate normally uses response token limits to control verbosity
!!! warning "Increased Verbosity"
For providers without direct reasoning token support, enabling reasoning may result in more verbose responses since the extra tokens are added to all requests.
## Response Pattern Configuration
When reasoning is enabled, you may need to configure a **Pattern to strip from the response** to remove the thinking process from the final output.
### Default Patterns
Talemate provides quick-access buttons for common reasoning patterns:
- **Default** - Uses the built-in pattern: `.*?</think>`
- **`.*?◁/think▷`** - For models using arrow-style thinking delimiters
- **`.*?</think>`** - For models using XML-style think tags
### Custom Patterns
You can also specify a custom regular expression pattern that matches your model's reasoning format. This pattern will be used to strip the thinking tokens from the response before displaying it to the user.
## Model Compatibility
Not all models support reasoning. This feature works best with:
- Models specifically trained for chain-of-thought reasoning
- Models that support structured thinking patterns
- APIs that provide reasoning token specification
## Important Notes
- **Coercion Disabled**: When reasoning is enabled, LLM coercion (pre-filling responses) is automatically disabled since reasoning models need to generate their complete thought process
- **Response Time**: Reasoning models may take longer to respond as they work through their thinking process
## Troubleshooting
### Pattern Not Working
If the reasoning pattern isn't properly stripping the thinking process:
1. Check your model's actual reasoning output format
2. Adjust the regular expression pattern to match your model's specific format
3. Test with the default pattern first to see if it works

View File

@@ -35,4 +35,19 @@ A unique name for the client that makes sense to you.
Which model to use. Currently defaults to `gpt-4o`.
!!! note "Talemate lags behind OpenAI"
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
##### Reasoning models (o1, o3, gpt-5)
!!! important "Enable reasoning and allocate tokens"
The `o1`, `o3`, and `gpt-5` families are reasoning models. They always perform internal thinking before producing the final answer. To use them effectively in Talemate:
- Enable the **Reasoning** option in the client configuration.
- Set **Reasoning Tokens** to a sufficiently high value to make room for the model's thinking process.
A good starting range is 5121024 tokens. Increase if your tasks are complex. Without enabling reasoning and allocating tokens, these models may return minimal or empty visible content because the token budget is consumed by internal reasoning.
See the detailed guide: [Reasoning Model Support](/talemate/user-guide/clients/reasoning/).
!!! tip "Getting empty responses?"
If these models return empty or very short answers, it usually means the reasoning budget was exhausted. Increase **Reasoning Tokens** and try again.

View File

@@ -1,6 +1,6 @@
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from starlette.exceptions import HTTPException
@@ -14,6 +14,7 @@ app = FastAPI()
# Serve static files, but exclude the root path
app.mount("/", StaticFiles(directory=dist_dir, html=True), name="static")
@app.get("/", response_class=HTMLResponse)
async def serve_root():
index_path = os.path.join(dist_dir, "index.html")
@@ -24,5 +25,6 @@ async def serve_root():
else:
raise HTTPException(status_code=404, detail="index.html not found")
# This is the ASGI application
application = app
application = app

View File

@@ -4,4 +4,4 @@
uv pip uninstall torch torchaudio
# install torch and torchaudio with CUDA support
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
uv pip install torch~=2.7.1 torchaudio~=2.7.1 --index-url https://download.pytorch.org/whl/cu128

View File

@@ -1,11 +1,12 @@
[project]
name = "talemate"
version = "0.31.0"
version = "0.32.1"
description = "AI-backed roleplay and narrative tools"
authors = [{name = "VeguAITools"}]
license = {text = "GNU Affero General Public License v3.0"}
requires-python = ">=3.10,<3.14"
requires-python = ">=3.11,<3.14"
dependencies = [
"pip",
"astroid>=2.8",
"jedi>=0.18",
"black",
@@ -50,11 +51,20 @@ dependencies = [
# ChromaDB
"chromadb>=1.0.12",
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
"torch>=2.7.0",
"torchaudio>=2.7.0",
# locked for instructor embeddings
#sentence-transformers==2.2.2
"torch>=2.7.1",
"torchaudio>=2.7.1",
"sentence_transformers>=2.7.0",
# TTS
"elevenlabs>=2.7.1",
# Local TTS
# Chatterbox TTS
#"chatterbox-tts @ https://github.com/rsxdalv/chatterbox/archive/refs/heads/fast.zip",
"chatterbox-tts==0.1.2",
# kokoro TTS
"kokoro>=0.9.4",
"soundfile>=0.13.1",
# F5-TTS
"f5-tts>=1.1.7",
]
[project.optional-dependencies]
@@ -65,6 +75,7 @@ dev = [
"mkdocs-material>=9.5.27",
"mkdocs-awesome-pages-plugin>=2.9.2",
"mkdocs-glightbox>=0.4.0",
"pre-commit>=2.13",
]
[project.scripts]
@@ -103,4 +114,27 @@ include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
line_length = 88
[tool.uv]
override-dependencies = [
# chatterbox wants torch 2.6.0, but is confirmed working with 2.7.1
"torchaudio>=2.7.1",
"torch>=2.7.1",
# numba needs numpy < 2.3
"numpy>=2,<2.3",
"pydantic>=2.11",
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
torchaudio = [
{ index = "pytorch-cu128" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

5
ruff.toml Normal file
View File

@@ -0,0 +1,5 @@
[lint]
# Disable automatic fix for unused imports (`F401`). We check these manually.
unfixable = ["F401"]
# Ignore E402
extend-ignore = ["E402"]

View File

@@ -1,6 +1,6 @@
{
"description": "Captain Elmer Farstield and his trusty first officer, Kaira, embark upon a daring mission into uncharted space. Their small but mighty exploration vessel, the Starlight Nomad, is equipped with state-of-the-art technology and crewed by an elite team of scientists, engineers, and pilots. Together they brave the vast cosmos seeking answers to humanity's most pressing questions about life beyond our solar system.",
"intro": "You awaken aboard your ship, the Starlight Nomad, surrounded by darkness. A soft hum resonates throughout the vessel indicating its systems are online. Your mind struggles to recall what brought you here - where 'here' actually is. You remember nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there was a bright light followed by this endless void.\n\nGingerly, you make your way through the dimly lit corridors of the ship. It seems smaller than you expected given the magnitude of the mission ahead. However, each room reveals intricate technology designed specifically for long-term space travel and exploration. There appears to be no other living soul besides yourself. An eerie silence fills every corner.",
"intro": "Elmer awoke aboard his ship, the Starlight Nomad, surrounded by darkness. A soft hum resonated throughout the vessel indicating its systems were online. His mind struggled to recall what had brought him here - where here actually was. He remembered nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there had been a bright light followed by this endless void.\n\nGingerly, he made his way through the dimly lit corridors of the ship. It seemed smaller than he had expected given the magnitude of the mission ahead. However, each room revealed intricate technology designed specifically for long-term space travel and exploration. There appeared to be no other living soul besides himself. An eerie silence filled every corner.",
"name": "Infinity Quest",
"history": [],
"environment": "scene",
@@ -90,11 +90,11 @@
"gender": "female",
"color": "red",
"example_dialogue": [
"Kaira: \"Yes Captain, I believe that is the best course of action\" She nods slightly, as if to punctuate her approval of the decision*",
"Kaira: \"This device appears to have multiple functions, Captain. Allow me to analyze its capabilities and determine if it could be useful in our exploration efforts.\"",
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" Excitement bubbles beneath her calm exterior as she shares the news",
"Kaira: \"Captain, I understand why you would want us to pursue this course of action based on our current data, but I cannot shake the feeling that there might be unforeseen consequences if we proceed without further investigation into potential hazards.\"",
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world... But then again, perhaps it was fate that led me here, onto this ship bound for destinations unknown...\""
"Kaira: \"Yes Captain, I believe that is the best course of action.\" Kaira glanced at the navigation display, then back at him with a slight nod. \"The numbers check out on my end too. If we adjust our trajectory by 2.7 degrees and increase thrust by fifteen percent, we should reach the nebula's outer edge within six hours.\" Her violet fingers moved efficiently across the controls as she pulled up the gravitational readings.",
"Kaira: The scanner hummed as it analyzed the alien artifact. Kaira knelt beside the strange object, frowning in concentration at the shifting symbols on its warm metal surface. \"This device appears to have multiple functions, Captain,\" she said, adjusting her scanner's settings. \"Give me a few minutes to run a full analysis and I'll know what we're dealing with. The material composition is fascinating - it's responding to our ship's systems in ways that shouldn't be possible.\"",
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" The excitement in her voice was unmistakable as Kaira looked up from her console. \"These readings are incredible - I've never seen anything like this before. There are structures beneath the surface that predate our oldest sites by millions of years.\" She paused, processing the implications. \"If these readings are accurate, we may have found something truly significant.\"",
"Kaira: Something felt off about the proposed course of action. Kaira moved from her station to stand beside the Captain, organizing her thoughts carefully. \"Captain, I understand why you would want us to pursue this based on our current data,\" she began respectfully, clasping her hands behind her back. \"But something feels wrong about this. The quantum signatures have subtle variations that remind me of Hegemony cloaking technology. Maybe we should run a few more scans before we commit to a full approach.\"",
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world,\" she said softly, not turning from the observation deck viewport as footsteps approached. The stars wheeled slowly past in their eternal dance. \"Sometimes I dream of Altrusia's crystal gardens, the way our twin suns would set over the mountains.\" Kaira finally turned, her expression thoughtful. \"Then again, I suppose I was always meant to end up out here somehow. Perhaps this journey is exactly where I'm supposed to be.\""
],
"history_events": [],
"is_player": false,

View File

@@ -494,34 +494,6 @@
"registry": "state/GetState",
"base_type": "core/Node"
},
"8a050403-5c69-46f7-abe2-f65db4553942": {
"title": "TRUE",
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
"properties": {
"value": true
},
"x": 460,
"y": 670,
"width": 210,
"height": 58,
"collapsed": true,
"inherited": false,
"registry": "core/MakeBool",
"base_type": "core/Node"
},
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
"title": "Create Character",
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
"properties": {},
"x": 580,
"y": 550,
"width": 245,
"height": 186,
"collapsed": false,
"inherited": false,
"registry": "agents/creator/CreateCharacter",
"base_type": "core/Graph"
},
"970f12d0-330e-41b3-b025-9a53bcf2fc6f": {
"title": "SET - created_character",
"id": "970f12d0-330e-41b3-b025-9a53bcf2fc6f",
@@ -628,6 +600,34 @@
"inherited": false,
"registry": "data/string/MakeText",
"base_type": "core/Node"
},
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
"title": "Create Character",
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
"properties": {},
"x": 580,
"y": 550,
"width": 245,
"height": 206,
"collapsed": false,
"inherited": false,
"registry": "agents/creator/CreateCharacter",
"base_type": "core/Graph"
},
"8a050403-5c69-46f7-abe2-f65db4553942": {
"title": "TRUE",
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
"properties": {
"value": true
},
"x": 400,
"y": 720,
"width": 210,
"height": 58,
"collapsed": true,
"inherited": false,
"registry": "core/MakeBool",
"base_type": "core/Node"
}
},
"edges": {
@@ -714,17 +714,6 @@
"fd4cd318-121b-47de-84a0-b1ab62c5601b.value": [
"41c0c2cd-d39b-4528-a5fe-459094393ba3.list"
],
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
],
"a2457116-35cb-4571-ba1a-cbf63851544e.value": [
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.state",
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.character",
@@ -738,6 +727,18 @@
],
"5041f12d-507f-4f5d-a26f-048625974602.value": [
"b50e23d4-b456-4385-b80e-c2b6884c7855.template"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
],
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
],
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.assign_voice"
]
},
"groups": [
@@ -808,13 +809,14 @@
"inputs": [],
"outputs": [
{
"id": "f78fa84b-8b6f-4c8a-83c0-754537bb9060",
"id": "92423e0a-89d8-4ad8-a42d-fdcba75b31d1",
"name": "fn",
"optional": false,
"group": null,
"socket_type": "function"
}
],
"module_properties": {},
"style": {
"title_color": "#573a2e",
"node_color": "#392f2c",

View File

@@ -116,8 +116,8 @@
"context_aware": true,
"history_aware": true
},
"x": 568,
"y": 861,
"x": 372,
"y": 857,
"width": 249,
"height": 406,
"collapsed": false,
@@ -130,7 +130,7 @@
"id": "a8110d74-0fb5-4601-b883-6c63ceaa9d31",
"properties": {},
"x": 24,
"y": 2084,
"y": 2088,
"width": 140,
"height": 106,
"collapsed": false,
@@ -145,7 +145,7 @@
"attribute": "title"
},
"x": 213,
"y": 2094,
"y": 2098,
"width": 210,
"height": 98,
"collapsed": false,
@@ -160,7 +160,7 @@
"value": "The Simulation Suite"
},
"x": 213,
"y": 2284,
"y": 2288,
"width": 210,
"height": 58,
"collapsed": false,
@@ -177,7 +177,7 @@
"case_sensitive": true
},
"x": 523,
"y": 2184,
"y": 2188,
"width": 210,
"height": 126,
"collapsed": false,
@@ -192,7 +192,7 @@
"pass_through": true
},
"x": 774,
"y": 2185,
"y": 2189,
"width": 210,
"height": 78,
"collapsed": false,
@@ -207,7 +207,7 @@
"attribute": "0"
},
"x": 1629,
"y": 2411,
"y": 2415,
"width": 210,
"height": 98,
"collapsed": false,
@@ -222,7 +222,7 @@
"attribute": "title"
},
"x": 2189,
"y": 2321,
"y": 2325,
"width": 210,
"height": 98,
"collapsed": false,
@@ -237,7 +237,7 @@
"stage": 5
},
"x": 2459,
"y": 2331,
"y": 2335,
"width": 210,
"height": 118,
"collapsed": true,
@@ -250,7 +250,7 @@
"id": "8208d05c-1822-4f4a-ba75-cfd18d2de8ca",
"properties": {},
"x": 1989,
"y": 2331,
"y": 2335,
"width": 140,
"height": 106,
"collapsed": true,
@@ -266,7 +266,7 @@
"chars": "\"'*"
},
"x": 1899,
"y": 2411,
"y": 2415,
"width": 210,
"height": 102,
"collapsed": false,
@@ -282,7 +282,7 @@
"max_splits": 1
},
"x": 1359,
"y": 2411,
"y": 2415,
"width": 210,
"height": 102,
"collapsed": false,
@@ -304,7 +304,7 @@
"history_aware": true
},
"x": 1014,
"y": 2185,
"y": 2189,
"width": 276,
"height": 406,
"collapsed": false,
@@ -376,8 +376,8 @@
"name": "arg_goal",
"scope": "local"
},
"x": 228,
"y": 1041,
"x": 32,
"y": 1037,
"width": 256,
"height": 122,
"collapsed": false,
@@ -393,7 +393,7 @@
"max_scene_types": 2
},
"x": 671,
"y": 1490,
"y": 1494,
"width": 210,
"height": 122,
"collapsed": false,
@@ -409,7 +409,7 @@
"scope": "local"
},
"x": 30,
"y": 1551,
"y": 1555,
"width": 256,
"height": 122,
"collapsed": false,
@@ -448,37 +448,6 @@
"registry": "state/GetState",
"base_type": "core/Node"
},
"59d31050-f61d-4798-9790-e22d34ecbd4b": {
"title": "GET local.auto_direct_enabled",
"id": "59d31050-f61d-4798-9790-e22d34ecbd4b",
"properties": {
"name": "auto_direct_enabled",
"scope": "local"
},
"x": 20,
"y": 850,
"width": 256,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/GetState",
"base_type": "core/Node"
},
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8": {
"title": "Switch",
"id": "e8f19a05-43fe-4e4a-9cc4-bec0a29779d8",
"properties": {
"pass_through": true
},
"x": 310,
"y": 870,
"width": 210,
"height": 78,
"collapsed": false,
"inherited": false,
"registry": "core/Switch",
"base_type": "core/Node"
},
"b03fa942-c48e-4c04-b9ae-a009a7e0f947": {
"title": "GET local.auto_direct_enabled",
"id": "b03fa942-c48e-4c04-b9ae-a009a7e0f947",
@@ -487,7 +456,7 @@
"scope": "local"
},
"x": 24,
"y": 1367,
"y": 1371,
"width": 256,
"height": 122,
"collapsed": false,
@@ -502,7 +471,7 @@
"pass_through": true
},
"x": 360,
"y": 1390,
"y": 1394,
"width": 210,
"height": 78,
"collapsed": false,
@@ -518,7 +487,7 @@
"scope": "local"
},
"x": 30,
"y": 1821,
"y": 1826,
"width": 256,
"height": 122,
"collapsed": false,
@@ -533,7 +502,7 @@
"stage": 4
},
"x": 1100,
"y": 1870,
"y": 1875,
"width": 210,
"height": 118,
"collapsed": true,
@@ -546,7 +515,7 @@
"id": "9db37d1e-3cf8-49bd-bdc5-8663494e5657",
"properties": {},
"x": 670,
"y": 1840,
"y": 1845,
"width": 226,
"height": 62,
"collapsed": false,
@@ -561,7 +530,7 @@
"stage": 3
},
"x": 1080,
"y": 1520,
"y": 1524,
"width": 210,
"height": 118,
"collapsed": true,
@@ -576,7 +545,7 @@
"pass_through": true
},
"x": 370,
"y": 1840,
"y": 1845,
"width": 210,
"height": 78,
"collapsed": false,
@@ -590,8 +559,8 @@
"properties": {
"stage": 2
},
"x": 1280,
"y": 890,
"x": 1084,
"y": 886,
"width": 210,
"height": 118,
"collapsed": true,
@@ -599,14 +568,29 @@
"registry": "core/Stage",
"base_type": "core/Node"
},
"5559196c-f6b1-4223-8e13-2bf64e3cfef0": {
"title": "true",
"id": "5559196c-f6b1-4223-8e13-2bf64e3cfef0",
"properties": {
"value": true
},
"x": 24,
"y": 876,
"width": 210,
"height": 58,
"collapsed": false,
"inherited": false,
"registry": "core/MakeBool",
"base_type": "core/Node"
},
"6ef94917-f9b1-4c18-af15-617430e50cfe": {
"title": "Set Scene Intent",
"id": "6ef94917-f9b1-4c18-af15-617430e50cfe",
"properties": {
"intent": ""
},
"x": 930,
"y": 860,
"x": 731,
"y": 853,
"width": 210,
"height": 78,
"collapsed": false,
@@ -693,14 +677,8 @@
"e4cd1391-daed-4951-a6c6-438d993c07a9.state"
],
"c66bdaeb-4166-4835-9415-943af547c926.value": [
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description",
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state"
],
"59d31050-f61d-4798-9790-e22d34ecbd4b.value": [
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.value"
],
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.yes": [
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state",
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description"
],
"b03fa942-c48e-4c04-b9ae-a009a7e0f947.value": [
"8ad7c42c-110e-46ae-b649-4a1e6d055e25.value"
@@ -717,6 +695,9 @@
"6a8762c4-16cf-4e8c-9d10-8af7597c4097.yes": [
"9db37d1e-3cf8-49bd-bdc5-8663494e5657.state"
],
"5559196c-f6b1-4223-8e13-2bf64e3cfef0.value": [
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
],
"6ef94917-f9b1-4c18-af15-617430e50cfe.state": [
"f4cd34d9-0628-4145-a3da-ec1215cd356c.state"
]
@@ -745,7 +726,7 @@
{
"title": "Generate Scene Types",
"x": -1,
"y": 1287,
"y": 1290,
"width": 1298,
"height": 408,
"color": "#8AA",
@@ -756,8 +737,8 @@
"title": "Set story intention",
"x": -1,
"y": 773,
"width": 1539,
"height": 512,
"width": 1320,
"height": 514,
"color": "#8AA",
"font_size": 24,
"inherited": false
@@ -765,7 +746,7 @@
{
"title": "Evaluate Scene Intent",
"x": -1,
"y": 1697,
"y": 1701,
"width": 1293,
"height": 302,
"color": "#8AA",
@@ -775,7 +756,7 @@
{
"title": "Set title",
"x": -1,
"y": 2003,
"y": 2006,
"width": 2618,
"height": 637,
"color": "#8AA",
@@ -794,7 +775,7 @@
{
"text": "Some times the AI will produce more text after the title, we only care about the title on the first line.",
"x": 1359,
"y": 2311,
"y": 2315,
"width": 471,
"inherited": false
}
@@ -804,13 +785,14 @@
"inputs": [],
"outputs": [
{
"id": "dede1a38-2107-4475-9db5-358c09cb0d12",
"id": "5c8dee64-5832-40ba-b1e2-2a411d913cc7",
"name": "fn",
"optional": false,
"group": null,
"socket_type": "function"
}
],
"module_properties": {},
"style": {
"title_color": "#573a2e",
"node_color": "#392f2c",

View File

@@ -666,23 +666,6 @@
"registry": "data/ListAppend",
"base_type": "core/Node"
},
"4eb36f21-1020-4609-85e5-d16b42019c66": {
"title": "AI Function Calling",
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
"properties": {
"template": "computer",
"max_calls": 5,
"retries": 1
},
"x": 1000,
"y": 2581,
"width": 212,
"height": 206,
"collapsed": false,
"inherited": false,
"registry": "focal/Focal",
"base_type": "core/Node"
},
"3da498c0-55de-4ec3-9943-6486279b9826": {
"title": "GET scene loop.user_message",
"id": "3da498c0-55de-4ec3-9943-6486279b9826",
@@ -1167,6 +1150,24 @@
"inherited": false,
"registry": "data/MakeList",
"base_type": "core/Node"
},
"4eb36f21-1020-4609-85e5-d16b42019c66": {
"title": "AI Function Calling",
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
"properties": {
"template": "computer",
"max_calls": 5,
"retries": 1,
"response_length": 1408
},
"x": 1000,
"y": 2581,
"width": 210,
"height": 230,
"collapsed": false,
"inherited": false,
"registry": "focal/Focal",
"base_type": "core/Node"
}
},
"edges": {
@@ -1296,9 +1297,6 @@
"d45f07ff-c3ce-4aac-bae6-b9c77089cb69.list": [
"4eb36f21-1020-4609-85e5-d16b42019c66.callbacks"
],
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
],
"3da498c0-55de-4ec3-9943-6486279b9826.value": [
"6d707f1c-af55-481a-970a-eb7a9f9c45dd.message"
],
@@ -1380,6 +1378,9 @@
],
"632d4a0e-3327-409b-aaa4-ed38b932286b.list": [
"06175a92-abbe-4483-9ab2-abaee2104728.value"
],
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
]
},
"groups": [

View File

@@ -19,7 +19,7 @@ You have access to the following functions you must call to fulfill the user's r
focal.callbacks.set_simulated_environment.render(
"Create or change the simulated environment. This means the location, time, specific conditions, or any other aspect of the simulation that is not directly related to the characters.",
instructions="Instructions on how to change the simulated environment. These will be given to the simulation computer to setup the new environment. REQUIRED.",
reset="If true, the environment should be reset and all simulated characters are removed. If false, the environment should be changed but the characters should remain. REQUIRED.",
reset="If true, the environment should be reset and ALL simulated characters are removed. IMPORTANT: If you set reset=true, this function MUST be the FIRST call in your stack; otherwise, set reset=false to avoid deactivating characters added earlier. REQUIRED.",
examples=[
{"instructions": "Change the location to a lush forest, with a river running through it.", "reset":true},
{"instructions": "The simulation suite flickers and changes to a bustling city street.", "reset":true},
@@ -123,7 +123,7 @@ You have access to the following functions you must call to fulfill the user's r
{{
focal.callbacks.set_simulation_goal.render(
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives.",
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives. You MUST call this on new simulations or if the user has requested a change in the simulation's goal.",
goal="The overall goal of the simulation. This should be a clear and concise statement that outlines the user's objective. REQUIRED.",
examples=[
{"goal": "The user is exploring a mysterious alien planet to uncover the secrets of an ancient civilization."},

View File

@@ -1,111 +1,112 @@
def game(TM):
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
MSG_HELP = 'Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating "Computer," followed by an instruction. For example ... "Computer, i want to experience being on a derelict spaceship."'
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice. Remind the user that this is an old version of the simulation suite and they should check out version two for a more advanced experience."
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
CTX_PIN_UNAWARE = (
"Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
)
AUTO_NARRATE_INTERVAL = 10
def parse_sim_call_arguments(call:str) -> str:
def parse_sim_call_arguments(call: str) -> str:
"""
Returns the value between the parentheses of a simulation call
Example:
call = 'change_environment("a house")'
parse_sim_call_arguments(call) -> "a house"
"""
try:
return call.split("(", 1)[1].split(")")[0]
except Exception:
return ""
class SimulationSuite:
def __init__(self):
"""
This is initialized at the beginning of each round of the simulation suite
"""
# do we update the world state at the end of the round
self.update_world_state = False
self.simulation_reset = False
# will keep track of any npcs added during the current round
self.added_npcs = []
TM.log.debug("SIMULATION SUITE INIT!", scene=TM.scene)
self.player_message = TM.scene.last_player_message
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
self.last_processed_call = TM.game_state.get_var(
"instr.lastprocessed_call", -1
)
# determine whether the player / user input is an instruction
# to the simulation computer
#
#
# we do this by checking if the message starts with "Computer,"
self.player_message_is_instruction = (
self.player_message and
self.player_message.raw.lower().startswith("computer") and
not self.player_message.hidden and
not self.last_processed_call > self.player_message.id
self.player_message
and self.player_message.raw.lower().startswith("computer")
and not self.player_message.hidden
and not self.last_processed_call > self.player_message.id
)
def run(self):
"""
Main entry point for the simulation suite
"""
if not TM.game_state.has_var("instr.simulation_stopped"):
# simulation is still running
self.simulation()
self.finalize_round()
def simulation(self):
"""
Simulation suite logic
"""
if not TM.game_state.has_var("instr.simulation_started"):
self.startup()
else:
self.simulation_calls()
if self.update_world_state:
self.run_update_world_state(force=True)
def startup(self):
"""
Scene startup logic
"""
# we are at the beginning of the simulation
TM.signals.status("busy", "Simulation suite powering up.", as_scene_message=True)
TM.signals.status(
"busy", "Simulation suite powering up.", as_scene_message=True
)
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
# add narration for the introduction
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=PROMPT_STARTUP,
emit_message=False
emit_message=False,
)
# add narration for the instructions on how to interact with the simulation
# this is a passthrough since we don't want the AI to paraphrase this
TM.agents.narrator.action_to_narration(
action_name="passthrough",
narration=MSG_HELP
action_name="passthrough", narration=MSG_HELP
)
# create a world state entry letting the AI know that characters
# interacting in the simulation are not aware of the computer or the simulation
TM.agents.world_state.save_world_entry(
@@ -113,37 +114,43 @@ def game(TM):
text=CTX_PIN_UNAWARE,
meta={},
# this should always be pinned
pin=True
pin=True,
)
# set flag that we have started the simulation
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
# signal to the UX that the simulation suite is ready
TM.signals.status("success", "Simulation suite ready", as_scene_message=True)
TM.signals.status(
"success", "Simulation suite ready", as_scene_message=True
)
# we want to update the world state at the end of the round
self.update_world_state = True
def simulation_calls(self):
"""
Calls the simulation suite main prompt to determine the appropriate
simulation calls
"""
# we only process instructions that are not hidden and are not the last processed call
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
if (
not self.player_message_is_instruction
or self.player_message.id == self.last_processed_call
):
return
# First instruction?
if not TM.game_state.has_var("instr.has_issued_instructions"):
# determine the context of the simulation
context_context = TM.agents.creator.determine_content_context_for_description(
description=self.player_message.raw,
context_context = (
TM.agents.creator.determine_content_context_for_description(
description=self.player_message.raw,
)
)
TM.scene.set_content_context(context_context)
# Render the `computer` template and send it to the LLM for processing
# The LLM will return a list of calls that the simulation suite will process
# The calls are pseudo code that the simulation suite will interpret and execute
@@ -153,90 +160,98 @@ def game(TM):
player_instruction=self.player_message.raw,
scene=TM.scene,
)
self.calls = calls = calls.split("\n")
calls = self.prepare_calls(calls)
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
# calls that are processed
processed = []
for call in calls:
processed_call = self.process_call(call)
if processed_call:
processed.append(processed_call)
if processed:
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
TM.signals.status("busy", "Simulation suite altering environment.", as_scene_message=True)
TM.game_state.set_var(
"instr.has_issued_instructions", "yes", commit=False
)
TM.signals.status(
"busy", "Simulation suite altering environment.", as_scene_message=True
)
compiled = "\n".join(processed)
if not self.simulation_reset and compiled:
# send the compiled calls to the narrator to generate a narrative based
# on them
narration = TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
emit_message=True
emit_message=True,
)
# on the first narration we update the scene description and remove any mention of the computer
# or the simulation from the previous narration
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
is_initial_narration = TM.game_state.get_var(
"instr.intro_narration", False
)
if not is_initial_narration:
TM.scene.set_description(narration.raw)
TM.scene.set_intro(narration.raw)
TM.log.debug("SIMULATION SUITE: initial narration", intro=narration.raw)
TM.log.debug(
"SIMULATION SUITE: initial narration", intro=narration.raw
)
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
TM.scene.pop_history(typ="director", all=True, reverse=True)
TM.game_state.set_var("instr.intro_narration", True, commit=False)
self.update_world_state = True
self.set_simulation_title(compiled)
def set_simulation_title(self, compiled_calls):
"""
Generates a fitting title for the simulation based on the user's instructions
"""
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
TM.log.debug(
"SIMULATION SUITE: set simulation title",
name=TM.scene.title,
compiled_calls=compiled_calls,
)
if not compiled_calls:
return
if TM.scene.title != "Simulation Suite":
# name already changed, no need to do it again
return
title = TM.agents.creator.contextual_generate_from_args(
"scene:simulation title",
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
length=75
length=75,
)
title = title.strip('"').strip()
TM.scene.set_title(title)
def prepare_calls(self, calls):
"""
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
found, ensure that the `set_player_name` call is processed first by moving it in front of the
`set_player_persona` call.
"""
set_player_name_call_exists = -1
set_player_persona_call_exists = -1
i = 0
for call in calls:
if "set_player_name" in call:
@@ -244,351 +259,445 @@ def game(TM):
elif "set_player_persona" in call:
set_player_persona_call_exists = i
i = i + 1
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
if set_player_name_call_exists > set_player_persona_call_exists:
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
calls.insert(
set_player_persona_call_exists,
calls.pop(set_player_name_call_exists),
)
TM.log.debug(
"SIMULATION SUITE: prepare calls - moved set_player_persona call",
calls=calls,
)
return calls
def process_call(self, call:str) -> str:
def process_call(self, call: str) -> str:
"""
Processes a simulation call
Simulation alls are pseudo functions that are called by the simulation suite
We grab the function name by splitting against ( and taking the first element
if the SimulationSuite has a method with the name _call_{function_name} then we call it
if a function name could be found but we do not have a method to call we dont do anything
but we still return it as procssed as the AI can still interpret it as something later on
"""
if "(" not in call:
return None
function_name = call.split("(")[0]
if hasattr(self, f"call_{function_name}"):
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
TM.log.debug(
"SIMULATION SUITE CALL", call=call, function_name=function_name
)
inject = f"The computer executes the function `{call}`"
return getattr(self, f"call_{function_name}")(call, inject)
return call
def call_set_simulation_goal(self, call:str, inject:str) -> str:
def call_set_simulation_goal(self, call: str, inject: str) -> str:
"""
Set's the simulation goal as a permanent pin
"""
TM.signals.status("busy", "Simulation suite setting goal.", as_scene_message=True)
TM.agents.world_state.save_world_entry(
entry_id="sim.goal",
text=self.player_message.raw,
meta={},
pin=True
TM.signals.status(
"busy", "Simulation suite setting goal.", as_scene_message=True
)
TM.agents.world_state.save_world_entry(
entry_id="sim.goal", text=self.player_message.raw, meta={}, pin=True
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer sets the goal for the simulation.",
)
return call
def call_change_environment(self, call:str, inject:str) -> str:
def call_change_environment(self, call: str, inject: str) -> str:
"""
Simulation changes the environment, this is entirely interpreted by the AI
and we dont need to do any logic on our end, so we just return the call
"""
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer changes the environment of the simulation."
action_description="The computer changes the environment of the simulation.",
)
return call
def call_answer_question(self, call:str, inject:str) -> str:
def call_answer_question(self, call: str, inject: str) -> str:
"""
The player asked the simulation a query, we need to process this and have
the AI produce an answer
"""
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
emit_message=True
emit_message=True,
)
def call_set_player_persona(self, call:str, inject:str) -> str:
def call_set_player_persona(self, call: str, inject: str) -> str:
"""
The simulation suite is altering the player persona
"""
player_character = TM.scene.get_player_character()
TM.signals.status("busy", "Simulation suite altering user persona.", as_scene_message=True)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=player_character.name, text=inject, alteration_instructions=self.player_message.raw
TM.signals.status(
"busy", "Simulation suite altering user persona.", as_scene_message=True
)
TM.scene.set_character_attributes(player_character.name, character_attributes)
character_description = TM.agents.creator.determine_character_description(player_character.name)
TM.scene.set_character_description(player_character.name, character_description)
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=player_character.name,
text=inject,
alteration_instructions=self.player_message.raw,
)
TM.scene.set_character_attributes(
player_character.name, character_attributes
)
character_description = TM.agents.creator.determine_character_description(
player_character.name
)
TM.scene.set_character_description(
player_character.name, character_description
)
TM.log.debug(
"SIMULATION SUITE: transform player",
attributes=character_attributes,
description=character_description,
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer transforms the player persona."
action_description="The computer transforms the player persona.",
)
return call
def call_set_player_name(self, call:str, inject:str) -> str:
def call_set_player_name(self, call: str, inject: str) -> str:
"""
The simulation suite is altering the player name
"""
player_character = TM.scene.get_player_character()
TM.signals.status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
TM.signals.status(
"busy",
"Simulation suite adjusting user identity.",
as_scene_message=True,
)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits."
)
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
if character_name != player_character.name:
TM.scene.set_character_name(player_character.name, character_name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer changes the player's identity to {character_name}."
action_description=f"The computer changes the player's identity to {character_name}.",
)
return call
def call_add_ai_character(self, call:str, inject:str) -> str:
return call
def call_add_ai_character(self, call: str, inject: str) -> str:
# sometimes the AI will call this function an pass an inanimate object as the parameter
# we need to determine if this is the case and just ignore it
is_inanimate = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
is_inanimate = TM.agents.world_state.answer_query_true_or_false(
f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)",
call,
)
if is_inanimate:
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
TM.log.debug(
"SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped",
call=call,
)
return
# sometimes the AI will ask if the function adds a group of characters, we need to
# determine if this is the case
adds_group = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add MULTIPLE ai characters?", call)
adds_group = TM.agents.world_state.answer_query_true_or_false(
f"does the function `{call}` add MULTIPLE ai characters?", call
)
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
TM.signals.status("busy", "Simulation suite adding character.", as_scene_message=True)
TM.signals.status(
"busy", "Simulation suite adding character.", as_scene_message=True
)
if not adds_group:
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name."
)
else:
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.",
group=True,
)
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
# the same character, if this happens we need to combine into a single add_ai_character call
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(
f"Are there any calls to `change_ai_character` in the instruction for {character_name}?",
"\n".join(self.calls),
)
if has_change_ai_character_call:
combined_arg = TM.prompt.request(
"combine-add-and-alter-ai-character",
dedupe_enabled=False,
calls="\n".join(self.calls),
character_name=character_name,
scene=TM.scene,
).replace("COMBINED ARGUMENT:", "").strip()
combined_arg = (
TM.prompt.request(
"combine-add-and-alter-ai-character",
dedupe_enabled=False,
calls="\n".join(self.calls),
character_name=character_name,
scene=TM.scene,
)
.replace("COMBINED ARGUMENT:", "")
.strip()
)
call = f"add_ai_character({combined_arg})"
inject = f"The computer executes the function `{call}`"
TM.signals.status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
TM.signals.status(
"busy",
f"Simulation suite adding character: {character_name}",
as_scene_message=True,
)
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
npc = TM.agents.director.persist_character(character_name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
npc = TM.agents.director.persist_character(
character_name=character_name,
content=self.player_message.raw + f"\n\n{inject}",
determine_name=False,
)
self.added_npcs.append(npc.name)
TM.agents.world_state.add_detail_reinforcement(
character_name=npc.name,
detail="Goal",
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
interval=25,
run_immediately=True
run_immediately=True,
)
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
TM.agents.visual.generate_character_portrait(character_name=npc.name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer adds {npc.name} to the simulation."
action_description=f"The computer adds {npc.name} to the simulation.",
)
return call
return call
####
def call_remove_ai_character(self, call:str, inject:str) -> str:
TM.signals.status("busy", "Simulation suite removing character.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names)
def call_remove_ai_character(self, call: str, inject: str) -> str:
TM.signals.status(
"busy", "Simulation suite removing character.", as_scene_message=True
)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character being removed?",
allowed_names=TM.scene.npc_character_names,
)
npc = TM.scene.get_character(character_name)
if npc:
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
TM.agents.world_state.deactivate_character(action_name="deactivate_character", character_name=npc.name)
TM.agents.world_state.deactivate_character(
action_name="deactivate_character", character_name=npc.name
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer removes {npc.name} from the simulation."
action_description=f"The computer removes {npc.name} from the simulation.",
)
return call
def call_change_ai_character(self, call:str, inject:str) -> str:
TM.signals.status("busy", "Simulation suite altering character.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names)
def call_change_ai_character(self, call: str, inject: str) -> str:
TM.signals.status(
"busy", "Simulation suite altering character.", as_scene_message=True
)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?",
allowed_names=TM.scene.npc_character_names,
)
if character_name in self.added_npcs:
# we dont want to change the character if it was just added
return
character_name_after = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
character_name_after = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?"
)
npc = TM.scene.get_character(character_name)
if npc:
TM.signals.status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
TM.signals.status(
"busy",
f"Changing {character_name} -> {character_name_after}",
as_scene_message=True,
)
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=npc.name,
text=inject,
alteration_instructions=self.player_message.raw
alteration_instructions=self.player_message.raw,
)
TM.scene.set_character_attributes(npc.name, character_attributes)
character_description = TM.agents.creator.determine_character_description(npc.name)
character_description = (
TM.agents.creator.determine_character_description(npc.name)
)
TM.scene.set_character_description(npc.name, character_description)
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
TM.log.debug(
"SIMULATION SUITE: transform npc",
attributes=character_attributes,
description=character_description,
)
if character_name_after != character_name:
TM.scene.set_character_name(npc.name, character_name_after)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer transforms {npc.name}."
action_description=f"The computer transforms {npc.name}.",
)
return call
def call_end_simulation(self, call:str, inject:str) -> str:
def call_end_simulation(self, call: str, inject: str) -> str:
player_character = TM.scene.get_player_character()
explicit_command = TM.agents.world_state.answer_query_true_or_false("has the player explicitly asked to end the simulation?", self.player_message.raw)
explicit_command = TM.agents.world_state.answer_query_true_or_false(
"has the player explicitly asked to end the simulation?",
self.player_message.raw,
)
if explicit_command:
TM.signals.status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
TM.signals.status(
"busy",
"Simulation suite ending current simulation.",
as_scene_message=True,
)
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names)}. The player is also transformed back to their normal, non-descript persona as the form of {player_character.name} ceases to exist.",
emit_message=True
emit_message=True,
)
TM.scene.restore()
self.simulation_reset = True
TM.game_state.unset_var("instr.has_issued_instructions")
TM.game_state.unset_var("instr.lastprocessed_call")
TM.game_state.unset_var("instr.simulation_started")
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer ends the simulation."
action_description="The computer ends the simulation.",
)
def finalize_round(self):
# track rounds
rounds = TM.game_state.get_var("instr.rounds", 0)
# increase rounds
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
has_issued_instructions = TM.game_state.has_var(
"instr.has_issued_instructions"
)
if self.update_world_state:
self.run_update_world_state()
if self.player_message_is_instruction:
TM.scene.hide_message(self.player_message.id)
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
TM.signals.status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
TM.game_state.set_var(
"instr.lastprocessed_call", self.player_message.id, commit=False
)
TM.signals.status(
"success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True
)
elif self.player_message and not has_issued_instructions:
# simulation started, player message is NOT an instruction, and player has not given
# any instructions
self.guide_player()
elif self.player_message and not TM.scene.npc_character_names:
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
self.narrate_round()
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names and has_issued_instructions:
elif (
rounds % AUTO_NARRATE_INTERVAL == 0
and rounds
and TM.scene.npc_character_names
and has_issued_instructions
):
# every N rounds, narrate the round
self.narrate_round()
def guide_player(self):
TM.agents.narrator.action_to_narration(
action_name="paraphrase",
narration=MSG_HELP,
emit_message=True
action_name="paraphrase", narration=MSG_HELP, emit_message=True
)
def narrate_round(self):
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=PROMPT_NARRATE_ROUND,
emit_message=True
emit_message=True,
)
def run_update_world_state(self, force=False):
TM.log.debug("SIMULATION SUITE: update world state", force=force)
TM.signals.status("busy", "Simulation suite updating world state.", as_scene_message=True)
TM.signals.status(
"busy", "Simulation suite updating world state.", as_scene_message=True
)
TM.agents.world_state.update_world_state(force=force)
TM.signals.status("success", "Simulation suite updated world state.", as_scene_message=True)
TM.signals.status(
"success",
"Simulation suite updated world state.",
as_scene_message=True,
)
SimulationSuite().run()
def on_generation_cancelled(TM, exc):
"""
Called when user pressed the cancel button during the simulation suite
loop.
"""
TM.signals.status("success", "Simulation suite instructions cancelled", as_scene_message=True)
TM.signals.status(
"success", "Simulation suite instructions cancelled", as_scene_message=True
)
rounds = TM.game_state.get_var("instr.rounds", 0)
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)

View File

@@ -1,5 +1,5 @@
from .tale_mate import *
from .tale_mate import * # noqa: F401, F403
from .version import VERSION
__version__ = VERSION
__version__ = VERSION

View File

@@ -1,12 +1,12 @@
from .base import Agent
from .conversation import ConversationAgent
from .creator import CreatorAgent
from .director import DirectorAgent
from .editor import EditorAgent
from .memory import ChromaDBMemoryAgent, MemoryAgent
from .narrator import NarratorAgent
from .registry import AGENT_CLASSES, get_agent_class, register
from .summarize import SummarizeAgent
from .tts import TTSAgent
from .visual import VisualAgent
from .world_state import WorldStateAgent
from .base import Agent # noqa: F401
from .conversation import ConversationAgent # noqa: F401
from .creator import CreatorAgent # noqa: F401
from .director import DirectorAgent # noqa: F401
from .editor import EditorAgent # noqa: F401
from .memory import ChromaDBMemoryAgent, MemoryAgent # noqa: F401
from .narrator import NarratorAgent # noqa: F401
from .registry import AGENT_CLASSES, get_agent_class, register # noqa: F401
from .summarize import SummarizeAgent # noqa: F401
from .tts import TTSAgent # noqa: F401
from .visual import VisualAgent # noqa: F401
from .world_state import WorldStateAgent # noqa: F401

View File

@@ -6,11 +6,10 @@ from inspect import signature
import re
from abc import ABC
from functools import wraps
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from typing import Callable, Union
import uuid
import pydantic
import structlog
from blinker import signal
import talemate.emit.async_signals
import talemate.instance as instance
@@ -19,10 +18,11 @@ from talemate.agents.context import ActiveAgent, active_agent
from talemate.emit import emit
from talemate.events import GameLoopStartEvent
from talemate.context import active_scene
import talemate.config as config
from talemate.ux.schema import Column
from talemate.config import get_config, Config
import talemate.config.schema as config_schema
from talemate.client.context import (
ClientContext,
set_client_context_attribute,
)
__all__ = [
@@ -39,6 +39,7 @@ __all__ = [
log = structlog.get_logger("talemate.agents.base")
class AgentActionConditional(pydantic.BaseModel):
attribute: str
value: int | float | str | bool | list[int | float | str | bool] | None = None
@@ -48,24 +49,26 @@ class AgentActionNote(pydantic.BaseModel):
type: str
text: str
class AgentActionConfig(pydantic.BaseModel):
type: str
label: str
description: str = ""
value: Union[int, float, str, bool, list, None] = None
default_value: Union[int, float, str, bool] = None
max: Union[int, float, None] = None
min: Union[int, float, None] = None
step: Union[int, float, None] = None
value: int | float | str | bool | list | None = None
default_value: int | float | str | bool | None = None
max: int | float | None = None
min: int | float | None = None
step: int | float | None = None
scope: str = "global"
choices: Union[list[dict[str, str]], None] = None
choices: list[dict[str, str | int | float | bool]] | None = None
note: Union[str, None] = None
expensive: bool = False
quick_toggle: bool = False
condition: Union[AgentActionConditional, None] = None
title: Union[str, None] = None
value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True)
condition: AgentActionConditional | None = None
title: str | None = None
value_migration: Callable | None = pydantic.Field(default=None, exclude=True)
columns: list[Column] | None = None
note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict)
class Config:
@@ -77,45 +80,46 @@ class AgentAction(pydantic.BaseModel):
label: str
description: str = ""
warning: str = ""
config: Union[dict[str, AgentActionConfig], None] = None
condition: Union[AgentActionConditional, None] = None
config: dict[str, AgentActionConfig] | None = None
condition: AgentActionConditional | None = None
container: bool = False
icon: Union[str, None] = None
icon: str | None = None
can_be_disabled: bool = False
quick_toggle: bool = False
experimental: bool = False
class AgentDetail(pydantic.BaseModel):
value: Union[str, None] = None
description: Union[str, None] = None
icon: Union[str, None] = None
value: str | None = None
description: str | None = None
icon: str | None = None
color: str = "grey"
hidden: bool = False
class DynamicInstruction(pydantic.BaseModel):
title: str
content: str
def __str__(self) -> str:
return "\n".join(
[
f"<|SECTION:{self.title}|>",
self.content,
"<|CLOSE_SECTION|>"
]
[f"<|SECTION:{self.title}|>", self.content, "<|CLOSE_SECTION|>"]
)
def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = None) -> dict:
def args_and_kwargs_to_dict(
fn, args: list, kwargs: dict, filter: list[str] = None
) -> dict:
"""
Takes a list of arguments and a dict of keyword arguments and returns
a dict mapping parameter names to their values.
Args:
fn: The function whose parameters we want to map
args: List of positional arguments
kwargs: Dictionary of keyword arguments
filter: List of parameter names to include in the result, if None all parameters are included
Returns:
Dict mapping parameter names to their values
"""
@@ -124,34 +128,36 @@ def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = Non
bound_args.apply_defaults()
rv = dict(bound_args.arguments)
rv.pop("self", None)
if filter:
for key in list(rv.keys()):
if key not in filter:
rv.pop(key)
return rv
class store_context_state:
"""
Flag to store a function's arguments in the agent's context state.
Any arguments passed to the function will be stored in the agent's context
If no arguments are passed, all arguments will be stored.
Keyword arguments can be passed to store additional values in the context state.
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __call__(self, fn):
fn.store_context_state = self.args
fn.store_context_state_kwargs = self.kwargs
return fn
def set_processing(fn):
"""
decorator that emits the agent status as processing while the function
@@ -165,31 +171,33 @@ def set_processing(fn):
async def wrapper(self, *args, **kwargs):
with ClientContext():
scene = active_scene.get()
if scene:
scene.continue_actions()
if getattr(scene, "config", None):
set_client_context_attribute("app_config_system_prompts", scene.config.get("system_prompts", {}))
with ActiveAgent(self, fn, args, kwargs) as active_agent_context:
try:
await self.emit_status(processing=True)
# Now pass the complete args list
if getattr(fn, "store_context_state", None) is not None:
all_args = args_and_kwargs_to_dict(
fn, [self] + list(args), kwargs, getattr(fn, "store_context_state", [])
fn,
[self] + list(args),
kwargs,
getattr(fn, "store_context_state", []),
)
if getattr(fn, "store_context_state_kwargs", None) is not None:
all_args.update(getattr(fn, "store_context_state_kwargs", {}))
all_args.update(
getattr(fn, "store_context_state_kwargs", {})
)
all_args[f"fn_{fn.__name__}"] = True
active_agent_context.state_params = all_args
self.set_context_states(**all_args)
return await fn(self, *args, **kwargs)
finally:
try:
@@ -211,18 +219,23 @@ class Agent(ABC):
verbose_name = None
set_processing = set_processing
requires_llm_client = True
auto_break_repetition = False
websocket_handler = None
essential = True
ready_check_error = None
@classmethod
def init_actions(cls, actions: dict[str, AgentAction] | None = None) -> dict[str, AgentAction]:
def init_actions(
cls, actions: dict[str, AgentAction] | None = None
) -> dict[str, AgentAction]:
if actions is None:
actions = {}
return actions
@property
def config(self) -> Config:
return get_config()
@property
def agent_details(self):
if hasattr(self, "client"):
@@ -230,12 +243,14 @@ class Agent(ABC):
return self.client.name
return None
@property
def verbose_name(self):
return self.agent_type.capitalize()
@property
def ready(self):
if not self.requires_llm_client:
return True
if not hasattr(self, "client"):
return False
if not getattr(self.client, "enabled", True):
return False
@@ -315,67 +330,69 @@ class Agent(ABC):
return {}
return {k: v.model_dump() for k, v in self.actions.items()}
# scene state
def context_fingerpint(self, extra: list[str] = []) -> str | None:
def context_fingerprint(self, extra: list[str] = None) -> str | None:
active_agent_context = active_agent.get()
if not active_agent_context:
return None
if self.scene.history:
fingerprint = f"{self.scene.history[-1].fingerprint}-{active_agent_context.first.fingerprint}"
else:
fingerprint = f"START-{active_agent_context.first.fingerprint}"
for extra_key in extra:
fingerprint += f"-{hash(extra_key)}"
if extra:
for extra_key in extra:
fingerprint += f"-{hash(extra_key)}"
return fingerprint
def get_scene_state(self, key:str, default=None):
def get_scene_state(self, key: str, default=None):
agent_state = self.scene.agent_state.get(self.agent_type, {})
return agent_state.get(key, default)
def set_scene_states(self, **kwargs):
agent_state = self.scene.agent_state.get(self.agent_type, {})
for key, value in kwargs.items():
agent_state[key] = value
self.scene.agent_state[self.agent_type] = agent_state
def dump_scene_state(self):
return self.scene.agent_state.get(self.agent_type, {})
# active agent context state
def get_context_state(self, key:str, default=None):
def get_context_state(self, key: str, default=None):
key = f"{self.agent_type}__{key}"
try:
return active_agent.get().state.get(key, default)
except AttributeError:
log.warning("get_context_state error", agent=self.agent_type, key=key)
return default
def set_context_states(self, **kwargs):
try:
items = {f"{self.agent_type}__{k}": v for k, v in kwargs.items()}
active_agent.get().state.update(items)
log.debug("set_context_states", agent=self.agent_type, state=active_agent.get().state)
log.debug(
"set_context_states",
agent=self.agent_type,
state=active_agent.get().state,
)
except AttributeError:
log.error("set_context_states error", agent=self.agent_type, kwargs=kwargs)
def dump_context_state(self):
try:
return active_agent.get().state
except AttributeError:
return {}
###
async def _handle_ready_check(self, fut: asyncio.Future):
callback_failure = getattr(self, "on_ready_check_failure", None)
if fut.cancelled():
@@ -425,42 +442,53 @@ class Agent(ABC):
if not action.config:
continue
for config_key, config in action.config.items():
for config_key, _config in action.config.items():
try:
config.value = (
_config.value = (
kwargs.get("actions", {})
.get(action_key, {})
.get("config", {})
.get(config_key, {})
.get("value", config.value)
.get("value", _config.value)
)
if config.value_migration and callable(config.value_migration):
config.value = config.value_migration(config.value)
if _config.value_migration and callable(_config.value_migration):
_config.value = _config.value_migration(_config.value)
except AttributeError:
pass
async def save_config(self, app_config: config.Config | None = None):
async def save_config(self):
"""
Saves the agent config to the config file.
If no config object is provided, the config is loaded from the config file.
"""
if not app_config:
app_config:config.Config = config.load_config(as_model=True)
app_config.agents[self.agent_type] = config.Agent(
app_config: Config = get_config()
app_config.agents[self.agent_type] = config_schema.Agent(
name=self.agent_type,
client=self.client.name if self.client else None,
client=self.client.name if getattr(self, "client", None) else None,
enabled=self.enabled,
actions={action_key: config.AgentAction(
enabled=action.enabled,
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
) for action_key, action in self.actions.items()}
actions={
action_key: config_schema.AgentAction(
enabled=action.enabled,
config={
config_key: config_schema.AgentActionConfig(
value=config_obj.value
)
for config_key, config_obj in action.config.items()
},
)
for action_key, action in self.actions.items()
},
)
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
config.save_config(app_config)
log.debug(
"saving agent config",
agent=self.agent_type,
config=app_config.agents[self.agent_type],
)
app_config.dirty = True
async def on_game_loop_start(self, event: GameLoopStartEvent):
"""
@@ -474,19 +502,19 @@ class Agent(ABC):
if not action.config:
continue
for _, config in action.config.items():
if config.scope == "scene":
for _, _config in action.config.items():
if _config.scope == "scene":
# if default_value is None, just use the `type` of the current
# value
if config.default_value is None:
default_value = type(config.value)()
if _config.default_value is None:
default_value = type(_config.value)()
else:
default_value = config.default_value
default_value = _config.default_value
log.debug(
"resetting config", config=config, default_value=default_value
"resetting config", config=_config, default_value=default_value
)
config.value = default_value
_config.value = default_value
await self.emit_status()
@@ -518,7 +546,9 @@ class Agent(ABC):
await asyncio.sleep(0.01)
async def _handle_background_processing(self, fut: asyncio.Future, error_handler = None):
async def _handle_background_processing(
self, fut: asyncio.Future, error_handler=None
):
try:
if fut.cancelled():
return
@@ -541,7 +571,7 @@ class Agent(ABC):
self.processing_bg -= 1
await self.emit_status()
async def set_background_processing(self, task: asyncio.Task, error_handler = None):
async def set_background_processing(self, task: asyncio.Task, error_handler=None):
log.info("set_background_processing", agent=self.agent_type)
if not hasattr(self, "processing_bg"):
self.processing_bg = 0
@@ -550,7 +580,9 @@ class Agent(ABC):
await self.emit_status()
task.add_done_callback(
lambda fut: asyncio.create_task(self._handle_background_processing(fut, error_handler))
lambda fut: asyncio.create_task(
self._handle_background_processing(fut, error_handler)
)
)
def connect(self, scene):
@@ -580,7 +612,7 @@ class Agent(ABC):
exclude_fn: Callable = None,
):
current_memory_context = []
memory_helper = self.scene.get_helper("memory")
memory_helper = instance.get_agent("memory")
if memory_helper:
history_messages = "\n".join(
self.scene.recent_history(memory_history_context_max)
@@ -623,7 +655,6 @@ class Agent(ABC):
"""
return False
@set_processing
async def delegate(self, fn: Callable, *args, **kwargs):
"""
@@ -631,20 +662,22 @@ class Agent(ABC):
by the agent.
"""
return await fn(*args, **kwargs)
async def emit_message(self, header:str, message:str | list[dict], meta: dict = None, **data):
async def emit_message(
self, header: str, message: str | list[dict], meta: dict = None, **data
):
if not data:
data = {}
if not meta:
meta = {}
if "uuid" not in data:
data["uuid"] = str(uuid.uuid4())
if "agent" not in data:
data["agent"] = self.agent_type
data["header"] = header
emit(
"agent_message",
@@ -653,17 +686,22 @@ class Agent(ABC):
meta=meta,
websocket_passthrough=True,
)
@dataclasses.dataclass
class AgentEmission:
agent: Agent
@dataclasses.dataclass
class AgentTemplateEmission(AgentEmission):
template_vars: dict = dataclasses.field(default_factory=dict)
response: str = None
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
@dataclasses.dataclass
class RagBuildSubInstructionEmission(AgentEmission):
sub_instruction: str | None = None

View File

@@ -1,13 +1,9 @@
import contextvars
import uuid
import hashlib
from typing import TYPE_CHECKING, Callable
from typing import Callable
import pydantic
if TYPE_CHECKING:
from talemate.tale_mate import Character
__all__ = [
"active_agent",
]
@@ -25,7 +21,6 @@ class ActiveAgentContext(pydantic.BaseModel):
state: dict = pydantic.Field(default_factory=dict)
state_params: dict = pydantic.Field(default_factory=dict)
previous: "ActiveAgentContext" = None
class Config:
arbitrary_types_allowed = True
@@ -35,29 +30,30 @@ class ActiveAgentContext(pydantic.BaseModel):
return self.previous.first if self.previous else self
@property
def action(self):
def action(self):
name = self.fn.__name__
if name == "delegate":
return self.fn_args[0].__name__
return name
@property
def fingerprint(self) -> int:
if hasattr(self, "_fingerprint"):
return self._fingerprint
self._fingerprint = hash(frozenset(self.state_params.items()))
return self._fingerprint
def __str__(self):
return f"{self.agent.verbose_name}.{self.action}"
class ActiveAgent:
def __init__(self, agent, fn, args=None, kwargs=None):
self.agent = ActiveAgentContext(agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {})
self.agent = ActiveAgentContext(
agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {}
)
def __enter__(self):
previous_agent = active_agent.get()
if previous_agent:
@@ -70,7 +66,7 @@ class ActiveAgent:
self.agent.agent_stack_uid = str(uuid.uuid4())
self.token = active_agent.set(self.agent)
return self.agent
def __exit__(self, *args, **kwargs):

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import dataclasses
import random
import re
from datetime import datetime
from typing import TYPE_CHECKING, Optional
@@ -10,14 +9,12 @@ import structlog
import talemate.client as client
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
from talemate.client.context import (
client_context_attribute,
set_client_context_attribute,
set_conversation_context_attribute,
)
from talemate.events import GameLoopEvent
from talemate.exceptions import LLMAccuracyError
from talemate.prompts import Prompt
from talemate.scene_message import CharacterMessage, DirectorMessage
@@ -26,6 +23,7 @@ from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
AgentActionNote,
AgentDetail,
AgentEmission,
DynamicInstruction,
@@ -37,7 +35,7 @@ from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.agents.context import active_agent
from .websocket_handler import ConversationWebsocketHandler
import talemate.agents.conversation.nodes
import talemate.agents.conversation.nodes # noqa: F401
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Character
@@ -50,21 +48,20 @@ class ConversationAgentEmission(AgentEmission):
actor: Actor
character: Character
response: str
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
talemate.emit.async_signals.register(
"agent.conversation.before_generate",
"agent.conversation.before_generate",
"agent.conversation.inject_instructions",
"agent.conversation.generated"
"agent.conversation.generated",
)
@register()
class ConversationAgent(
MemoryRAGMixin,
Agent
):
class ConversationAgent(MemoryRAGMixin, Agent):
"""
An agent that can be used to have a conversation with the AI
@@ -89,12 +86,22 @@ class ConversationAgent(
"format": AgentActionConfig(
type="text",
label="Format",
description="The generation format of the scene context, as seen by the AI.",
description="The generation format of the scene progression, as seen by the AI. Has no direct effect on your view of the scene, but will affect the way the AI perceives the scene and its characters, leading to changes in the response, for better or worse.",
choices=[
{"label": "Screenplay", "value": "movie_script"},
{"label": "Chat (legacy)", "value": "chat"},
{
"label": "Narrative (NEW, experimental)",
"value": "narrative",
},
],
value="movie_script",
note_on_value={
"narrative": AgentActionNote(
type="primary",
text="Will attempt to generate flowing, novel-like prose with scene intent awareness and character goal consideration. A reasoning model is STRONGLY recommended. Experimental and more prone to generate out of turn character actions and dialogue.",
)
},
),
"length": AgentActionConfig(
type="number",
@@ -135,16 +142,8 @@ class ConversationAgent(
max=20,
step=1,
),
},
),
"auto_break_repetition": AgentAction(
enabled=True,
can_be_disabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"content": AgentAction(
enabled=True,
can_be_disabled=False,
@@ -159,7 +158,7 @@ class ConversationAgent(
description="Use the writing style selected in the scene settings",
value=True,
),
}
},
),
}
MemoryRAGMixin.add_actions(actions)
@@ -167,7 +166,7 @@ class ConversationAgent(
def __init__(
self,
client: client.TaleMateClient,
client: client.ClientBase | None = None,
kind: Optional[str] = "pygmalion",
logging_enabled: Optional[bool] = True,
**kwargs,
@@ -205,7 +204,6 @@ class ConversationAgent(
@property
def agent_details(self) -> dict:
details = {
"client": AgentDetail(
icon="mdi-network-outline",
@@ -231,22 +229,24 @@ class ConversationAgent(
@property
def generation_settings_actor_instructions_offset(self):
return self.actions["generation_override"].config["actor_instructions_offset"].value
return (
self.actions["generation_override"]
.config["actor_instructions_offset"]
.value
)
@property
def generation_settings_response_length(self):
return self.actions["generation_override"].config["length"].value
@property
def generation_settings_override_enabled(self):
return self.actions["generation_override"].enabled
@property
def content_use_writing_style(self) -> bool:
return self.actions["content"].config["use_writing_style"].value
def connect(self, scene):
super().connect(scene)
@@ -276,7 +276,7 @@ class ConversationAgent(
main_character = scene.main_character.character
character_names = [c.name for c in scene.characters]
if main_character:
try:
character_names.remove(main_character.name)
@@ -296,21 +296,22 @@ class ConversationAgent(
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
except IndexError:
director_message = False
inject_instructions_emission = ConversationAgentEmission(
agent=self,
response="",
actor=None,
character=character,
response="",
actor=None,
character=character,
)
await talemate.emit.async_signals.get(
"agent.conversation.inject_instructions"
).send(inject_instructions_emission)
agent_context = active_agent.get()
agent_context.state["dynamic_instructions"] = inject_instructions_emission.dynamic_instructions
agent_context.state["dynamic_instructions"] = (
inject_instructions_emission.dynamic_instructions
)
conversation_format = self.conversation_format
prompt = Prompt.get(
f"conversation.dialogue-{conversation_format}",
@@ -319,26 +320,30 @@ class ConversationAgent(
"max_tokens": self.client.max_token_length,
"scene_and_dialogue_budget": scene_and_dialogue_budget,
"scene_and_dialogue": scene_and_dialogue,
"memory": None, # DEPRECATED VARIABLE
"memory": None, # DEPRECATED VARIABLE
"characters": list(scene.get_characters()),
"main_character": main_character,
"formatted_names": formatted_names,
"talking_character": character,
"partial_message": char_message,
"director_message": director_message,
"extra_instructions": self.generation_settings_task_instructions, #backward compatibility
"extra_instructions": self.generation_settings_task_instructions, # backward compatibility
"task_instructions": self.generation_settings_task_instructions,
"actor_instructions": self.generation_settings_actor_instructions,
"actor_instructions_offset": self.generation_settings_actor_instructions_offset,
"direct_instruction": instruction,
"decensor": self.client.decensor_enabled,
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
"response_length": self.generation_settings_response_length
if self.generation_settings_override_enabled
else None,
},
)
return str(prompt)
async def build_prompt(self, character, char_message: str = "", instruction:str = None):
async def build_prompt(
self, character, char_message: str = "", instruction: str = None
):
fn = self.build_prompt_default
return await fn(character, char_message=char_message, instruction=instruction)
@@ -376,12 +381,12 @@ class ConversationAgent(
set_client_context_attribute("nuke_repetition", nuke_repetition)
@set_processing
@store_context_state('instruction')
@store_context_state("instruction")
async def converse(
self,
self,
actor,
instruction:str = None,
emit_signals:bool = True,
instruction: str = None,
emit_signals: bool = True,
) -> list[CharacterMessage]:
"""
Have a conversation with the AI
@@ -398,7 +403,9 @@ class ConversationAgent(
self.set_generation_overrides()
result = await self.client.send_prompt(await self.build_prompt(character, instruction=instruction))
result = await self.client.send_prompt(
await self.build_prompt(character, instruction=instruction)
)
result = self.clean_result(result, character)
@@ -451,21 +458,31 @@ class ConversationAgent(
if total_result.startswith(":\n") or total_result.startswith(": "):
total_result = total_result[2:]
# movie script format
# {uppercase character name}
# {dialogue}
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
conversation_format = self.conversation_format
# chat format
# {character name}: {dialogue}
total_result = total_result.replace(f"{character.name}:", "")
if conversation_format == "narrative":
# For narrative format, the LLM generates pure prose without character name prefixes
# We need to store it internally in the standard {name}: {text} format
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Only add character name if it's not already there
if not total_result.startswith(character.name + ":"):
total_result = f"{character.name}: {total_result}"
else:
# movie script format
# {uppercase character name}
# {dialogue}
total_result = total_result.replace(f"{character.name.upper()}\n", "")
# Removes partial sentence at the end
total_result = util.clean_dialogue(total_result, main_name=character.name)
# chat format
# {character name}: {dialogue}
total_result = total_result.replace(f"{character.name}:", "")
# Check if total_result starts with character name, if not, prepend it
if not total_result.startswith(character.name+":"):
total_result = f"{character.name}: {total_result}"
# Removes partial sentence at the end
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Check if total_result starts with character name, if not, prepend it
if not total_result.startswith(character.name + ":"):
total_result = f"{character.name}: {total_result}"
total_result = total_result.strip()
@@ -481,11 +498,11 @@ class ConversationAgent(
log.debug("conversation agent", response=response)
emission = ConversationAgentEmission(
agent=self,
actor=actor,
character=character,
agent=self,
actor=actor,
character=character,
response=response,
)
)
if emit_signals:
await talemate.emit.async_signals.get("agent.conversation.generated").send(
emission
@@ -497,9 +514,6 @@ class ConversationAgent(
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return agent_function_name == "converse"
def inject_prompt_paramters(

View File

@@ -1,74 +1,75 @@
import structlog
from typing import TYPE_CHECKING, ClassVar
from talemate.game.engine.nodes.core import Node, GraphState, UNRESOLVED
from talemate.game.engine.nodes.core import GraphState, UNRESOLVED
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentNode, AgentSettingsNode
from talemate.context import active_scene
from talemate.client.context import ConversationContext, ClientContext
import talemate.events as events
if TYPE_CHECKING:
from talemate.tale_mate import Scene, Character
log = structlog.get_logger("talemate.game.engine.nodes.agents.conversation")
@register("agents/conversation/Settings")
class ConversationSettings(AgentSettingsNode):
"""
Base node to render conversation agent settings.
"""
_agent_name:ClassVar[str] = "conversation"
_agent_name: ClassVar[str] = "conversation"
def __init__(self, title="Conversation Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/conversation/Generate")
class GenerateConversation(AgentNode):
"""
Generate a conversation between two characters
"""
_agent_name:ClassVar[str] = "conversation"
_agent_name: ClassVar[str] = "conversation"
def __init__(self, title="Generate Conversation", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("instruction", socket_type="str", optional=True)
self.set_property("trigger_conversation_generated", True)
self.add_output("generated", socket_type="str")
self.add_output("message", socket_type="message_object")
async def run(self, state: GraphState):
character:"Character" = self.get_input_value("character")
scene:"Scene" = active_scene.get()
character: "Character" = self.get_input_value("character")
scene: "Scene" = active_scene.get()
instruction = self.get_input_value("instruction")
trigger_conversation_generated = self.get_property("trigger_conversation_generated")
trigger_conversation_generated = self.get_property(
"trigger_conversation_generated"
)
other_characters = [c.name for c in scene.characters if c != character]
conversation_context = ConversationContext(
talking_character=character.name,
other_characters=other_characters,
)
if instruction == UNRESOLVED:
instruction = None
with ClientContext(conversation=conversation_context):
messages = await self.agent.converse(
character.actor,
character.actor,
instruction=instruction,
emit_signals=trigger_conversation_generated,
)
message = messages[0]
self.set_output_values({
"generated": message.message,
"message": message
})
self.set_output_values({"generated": message.message, "message": message})

View File

@@ -18,23 +18,25 @@ __all__ = [
log = structlog.get_logger("talemate.server.conversation")
class RequestActorActionPayload(pydantic.BaseModel):
character:str = ""
instructions:str = ""
emit_signals:bool = True
instructions_through_director:bool = True
character: str = ""
instructions: str = ""
emit_signals: bool = True
instructions_through_director: bool = True
class ConversationWebsocketHandler(Plugin):
"""
Handles narrator actions
"""
router = "conversation"
@property
def agent(self) -> "ConversationAgent":
return get_agent("conversation")
@set_loading("Generating actor action", cancellable=True, as_async=True)
async def handle_request_actor_action(self, data: dict):
"""
@@ -43,38 +45,37 @@ class ConversationWebsocketHandler(Plugin):
payload = RequestActorActionPayload(**data)
character = None
actor = None
if payload.character:
character = self.scene.get_character(payload.character)
actor = character.actor
else:
actor = random.choice(list(self.scene.get_npc_characters())).actor
if not actor:
log.error("handle_request_actor_action: No actor found")
return
character = actor.character
if payload.instructions_through_director:
director_message = DirectorMessage(
payload.instructions,
source="player",
meta={"character": character.name}
meta={"character": character.name},
)
emit("director", message=director_message, character=character)
self.scene.push_history(director_message)
generated_messages = await self.agent.converse(
actor,
emit_signals=payload.emit_signals
actor, emit_signals=payload.emit_signals
)
else:
generated_messages = await self.agent.converse(
actor,
actor,
instruction=payload.instructions,
emit_signals=payload.emit_signals
emit_signals=payload.emit_signals,
)
for message in generated_messages:
self.scene.push_history(message)
emit("character", message=message, character=character)
emit("character", message=message, character=character)

View File

@@ -1,13 +1,9 @@
from __future__ import annotations
import json
import os
import talemate.client as client
from talemate.agents.base import Agent, set_processing
from talemate.agents.registry import register
from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.emit import emit
from talemate.prompts import Prompt
from .assistant import AssistantMixin
@@ -16,7 +12,8 @@ from .scenario import ScenarioCreatorMixin
from talemate.agents.base import AgentAction
import talemate.agents.creator.nodes
import talemate.agents.creator.nodes # noqa: F401
@register()
class CreatorAgent(
@@ -42,7 +39,7 @@ class CreatorAgent(
def __init__(
self,
client: client.ClientBase,
client: client.ClientBase | None = None,
**kwargs,
):
self.client = client
@@ -51,7 +48,7 @@ class CreatorAgent(
@set_processing
async def generate_title(self, text: str):
title = await Prompt.request(
f"creator.generate-title",
"creator.generate-title",
self.client,
"create_short",
vars={

View File

@@ -40,6 +40,7 @@ async_signals.register(
"agent.creator.autocomplete.after",
)
@dataclasses.dataclass
class ContextualGenerateEmission(AgentTemplateEmission):
"""
@@ -48,15 +49,16 @@ class ContextualGenerateEmission(AgentTemplateEmission):
content_generation_context: "ContentGenerationContext | None" = None
character: "Character | None" = None
@property
def context_type(self) -> str:
return self.content_generation_context.computed_context[0]
@property
def context_name(self) -> str:
return self.content_generation_context.computed_context[1]
@dataclasses.dataclass
class AutocompleteEmission(AgentTemplateEmission):
"""
@@ -67,6 +69,7 @@ class AutocompleteEmission(AgentTemplateEmission):
type: str = ""
character: "Character | None" = None
class ContentGenerationContext(pydantic.BaseModel):
"""
A context for generating content.
@@ -104,7 +107,6 @@ class ContentGenerationContext(pydantic.BaseModel):
@property
def spice(self) -> str:
spice_level = self.generation_options.spice_level
if self.template and not getattr(self.template, "supports_spice", False):
@@ -148,7 +150,6 @@ class ContentGenerationContext(pydantic.BaseModel):
@property
def style(self):
if self.template and not getattr(self.template, "supports_style", False):
# template supplied that doesn't support style
return ""
@@ -165,11 +166,12 @@ class ContentGenerationContext(pydantic.BaseModel):
def get_state(self, key: str) -> str | int | float | bool | None:
return self.state.get(key)
class AssistantMixin:
"""
Creator mixin that allows quick contextual generation of content.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["autocomplete"] = AgentAction(
@@ -198,15 +200,15 @@ class AssistantMixin:
max=256,
step=16,
),
}
},
)
# property helpers
@property
def autocomplete_dialogue_suggestion_length(self):
return self.actions["autocomplete"].config["dialogue_suggestion_length"].value
@property
def autocomplete_narrative_suggestion_length(self):
return self.actions["autocomplete"].config["narrative_suggestion_length"].value
@@ -255,7 +257,7 @@ class AssistantMixin:
history_aware=history_aware,
information=information,
)
for key, value in kwargs.items():
generation_context.set_state(key, value)
@@ -279,9 +281,13 @@ class AssistantMixin:
f"Contextual generate: {context_typ} - {context_name}",
generation_context=generation_context,
)
character = self.scene.get_character(generation_context.character) if generation_context.character else None
character = (
self.scene.get_character(generation_context.character)
if generation_context.character
else None
)
template_vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
@@ -295,27 +301,29 @@ class AssistantMixin:
"character": character,
"template": generation_context.template,
}
emission = ContextualGenerateEmission(
agent=self,
content_generation_context=generation_context,
character=character,
template_vars=template_vars,
)
await async_signals.get("agent.creator.contextual_generate.before").send(emission)
await async_signals.get("agent.creator.contextual_generate.before").send(
emission
)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
content = await Prompt.request(
f"creator.contextual-generate",
"creator.contextual-generate",
self.client,
kind,
vars=template_vars,
)
emission.response = content
if not generation_context.partial:
content = util.strip_partial_sentences(content)
@@ -329,22 +337,29 @@ class AssistantMixin:
if not content.startswith(generation_context.character + ":"):
content = generation_context.character + ": " + content
content = util.strip_partial_sentences(content)
character = self.scene.get_character(generation_context.character)
if not character:
log.warning("Character not found", character=generation_context.character)
return content
emission.response = await editor.cleanup_character_message(content, character)
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
return emission.response
emission.response = content.strip().strip("*").strip()
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
return emission.response
character = self.scene.get_character(generation_context.character)
if not character:
log.warning(
"Character not found", character=generation_context.character
)
return content
emission.response = await editor.cleanup_character_message(
content, character
)
await async_signals.get("agent.creator.contextual_generate.after").send(
emission
)
return emission.response
emission.response = content.strip().strip("*").strip()
await async_signals.get("agent.creator.contextual_generate.after").send(
emission
)
return emission.response
@set_processing
async def generate_character_attribute(
@@ -357,10 +372,10 @@ class AssistantMixin:
) -> str:
"""
Wrapper for contextual_generate that generates a character attribute.
"""
"""
if not generation_options:
generation_options = GenerationOptions()
return await self.contextual_generate_from_args(
context=f"character attribute:{attribute_name}",
character=character.name,
@@ -368,7 +383,7 @@ class AssistantMixin:
original=original,
**generation_options.model_dump(),
)
@set_processing
async def generate_character_detail(
self,
@@ -381,11 +396,11 @@ class AssistantMixin:
) -> str:
"""
Wrapper for contextual_generate that generates a character detail.
"""
"""
if not generation_options:
generation_options = GenerationOptions()
return await self.contextual_generate_from_args(
context=f"character detail:{detail_name}",
character=character.name,
@@ -394,7 +409,7 @@ class AssistantMixin:
length=length,
**generation_options.model_dump(),
)
@set_processing
async def generate_thematic_list(
self,
@@ -408,11 +423,11 @@ class AssistantMixin:
"""
if not generation_options:
generation_options = GenerationOptions()
i = 0
result = []
while i < iterations:
i += 1
_result = await self.contextual_generate_from_args(
@@ -420,14 +435,14 @@ class AssistantMixin:
instructions=instructions,
length=length,
original="\n".join(result) if result else None,
extend=i>1,
extend=i > 1,
**generation_options.model_dump(),
)
_result = json.loads(_result)
result = list(set(result + _result))
return result
@set_processing
@@ -443,34 +458,34 @@ class AssistantMixin:
"""
if not response_length:
response_length = self.autocomplete_dialogue_suggestion_length
# continuing recent character message
non_anchor, anchor = util.split_anchor_text(input, 10)
self.scene.log.debug(
"autocomplete_anchor",
anchor=anchor,
non_anchor=non_anchor,
input=input
"autocomplete_anchor", anchor=anchor, non_anchor=non_anchor, input=input
)
continuing_message = False
message = None
try:
message = self.scene.history[-1]
if isinstance(message, CharacterMessage) and message.character_name == character.name:
if (
isinstance(message, CharacterMessage)
and message.character_name == character.name
):
continuing_message = input.strip() == message.without_name.strip()
except IndexError:
pass
if input.strip().endswith('"'):
prefix = ' *'
elif input.strip().endswith('*'):
prefix = " *"
elif input.strip().endswith("*"):
prefix = ' "'
else:
prefix = ''
prefix = ""
template_vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
@@ -484,7 +499,7 @@ class AssistantMixin:
"non_anchor": non_anchor,
"prefix": prefix,
}
emission = AutocompleteEmission(
agent=self,
input=input,
@@ -492,13 +507,13 @@ class AssistantMixin:
character=character,
template_vars=template_vars,
)
await async_signals.get("agent.creator.autocomplete.before").send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"creator.autocomplete-dialogue",
"creator.autocomplete-dialogue",
self.client,
f"create_{response_length}",
vars=template_vars,
@@ -506,21 +521,22 @@ class AssistantMixin:
dedupe_enabled=False,
)
response = response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
response = (
response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
)
if prefix:
response = prefix + response
emission.response = response
await async_signals.get("agent.creator.autocomplete.after").send(emission)
if not response:
if emit_signal:
emit("autocomplete_suggestion", "")
return ""
response = util.strip_partial_sentences(response).replace("*", "")
if response.startswith(input):
@@ -550,14 +566,14 @@ class AssistantMixin:
# Split the input text into non-anchor and anchor parts
non_anchor, anchor = util.split_anchor_text(input, 10)
self.scene.log.debug(
"autocomplete_narrative_anchor",
"autocomplete_narrative_anchor",
anchor=anchor,
non_anchor=non_anchor,
input=input
input=input,
)
template_vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
@@ -567,20 +583,20 @@ class AssistantMixin:
"anchor": anchor,
"non_anchor": non_anchor,
}
emission = AutocompleteEmission(
agent=self,
input=input,
type="narrative",
template_vars=template_vars,
)
await async_signals.get("agent.creator.autocomplete.before").send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"creator.autocomplete-narrative",
"creator.autocomplete-narrative",
self.client,
f"create_{response_length}",
vars=template_vars,
@@ -593,7 +609,7 @@ class AssistantMixin:
response = response[len(input) :]
emission.response = response
await async_signals.get("agent.creator.autocomplete.after").send(emission)
self.scene.log.debug(
@@ -614,78 +630,75 @@ class AssistantMixin:
"""
Allows to fork a new scene from a specific message
in the current scene.
All content after the message will be removed and the
context database will be re imported ensuring a clean state.
All state reinforcements will be reset to their most recent
state before the message.
"""
emit("status", "Creating scene fork ...", status="busy")
try:
if not save_name:
# build a save name
uuid_str = str(uuid.uuid4())[:8]
save_name = f"{uuid_str}-forked"
log.info(f"Forking scene", message_id=message_id, save_name=save_name)
log.info("Forking scene", message_id=message_id, save_name=save_name)
world_state = get_agent("world_state")
# does a message with the given id exist?
index = self.scene.message_index(message_id)
if index is None:
raise ValueError(f"Message with id {message_id} not found.")
# truncate scene.history keeping index as the last element
self.scene.history = self.scene.history[:index + 1]
self.scene.history = self.scene.history[: index + 1]
# truncate scene.archived_history keeping the element where `end` is < `index`
# as the last element
self.scene.archived_history = [
x for x in self.scene.archived_history if "end" not in x or x["end"] < index
x
for x in self.scene.archived_history
if "end" not in x or x["end"] < index
]
# the same needs to be done for layered history
# where each layer is truncated based on what's left in the previous layer
# using similar logic as above (checking `end` vs `index`)
# layer 0 checks archived_history
new_layered_history = []
for layer_number, layer in enumerate(self.scene.layered_history):
if layer_number == 0:
index = len(self.scene.archived_history) - 1
else:
index = len(new_layered_history[layer_number - 1]) - 1
new_layer = [
x for x in layer if x["end"] < index
]
new_layer = [x for x in layer if x["end"] < index]
new_layered_history.append(new_layer)
self.scene.layered_history = new_layered_history
# save the scene
await self.scene.save(copy_name=save_name)
log.info(f"Scene forked", save_name=save_name)
log.info("Scene forked", save_name=save_name)
# re-emit history
await self.scene.emit_history()
emit("status", f"Updating world state ...", status="busy")
emit("status", "Updating world state ...", status="busy")
# reset state reinforcements
await world_state.update_reinforcements(force = True, reset= True)
await world_state.update_reinforcements(force=True, reset=True)
# update world state
await self.scene.world_state.request_update()
emit("status", f"Scene forked", status="success")
except Exception as e:
emit("status", "Scene forked", status="success")
except Exception:
log.error("Scene fork failed", exc=traceback.format_exc())
emit("status", "Scene fork failed", status="error")

View File

@@ -7,8 +7,6 @@ import structlog
from talemate.agents.base import set_processing
from talemate.prompts import Prompt
import talemate.game.focal as focal
if TYPE_CHECKING:
from talemate.tale_mate import Character
@@ -18,14 +16,13 @@ DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audien
class CharacterCreatorMixin:
@set_processing
async def determine_content_context_for_character(
self,
character: Character,
):
content_context = await Prompt.request(
f"creator.determine-content-context",
"creator.determine-content-context",
self.client,
"create_192",
vars={
@@ -42,7 +39,7 @@ class CharacterCreatorMixin:
information: str = "",
):
instructions = await Prompt.request(
f"creator.determine-character-dialogue-instructions",
"creator.determine-character-dialogue-instructions",
self.client,
"create_concise",
vars={
@@ -63,7 +60,7 @@ class CharacterCreatorMixin:
character: Character,
):
attributes = await Prompt.request(
f"creator.determine-character-attributes",
"creator.determine-character-attributes",
self.client,
"analyze_long",
vars={
@@ -81,7 +78,7 @@ class CharacterCreatorMixin:
instructions: str = "",
) -> str:
name = await Prompt.request(
f"creator.determine-character-name",
"creator.determine-character-name",
self.client,
"analyze_freeform_short",
vars={
@@ -97,14 +94,14 @@ class CharacterCreatorMixin:
@set_processing
async def determine_character_description(
self,
self,
character: Character,
text: str = "",
instructions: str = "",
information: str = "",
):
description = await Prompt.request(
f"creator.determine-character-description",
"creator.determine-character-description",
self.client,
"create",
vars={
@@ -125,7 +122,7 @@ class CharacterCreatorMixin:
goal_instructions: str,
):
goals = await Prompt.request(
f"creator.determine-character-goals",
"creator.determine-character-goals",
self.client,
"create",
vars={
@@ -141,4 +138,4 @@ class CharacterCreatorMixin:
log.debug("determine_character_goals", goals=goals, character=character)
await character.set_detail("goals", goals.strip())
return goals.strip()
return goals.strip()

View File

@@ -21,7 +21,7 @@
"num": 3
},
"x": 39,
"y": 507,
"y": 236,
"width": 210,
"height": 154,
"collapsed": false,
@@ -40,7 +40,7 @@
"num": 1
},
"x": 38,
"y": 107,
"y": -164,
"width": 210,
"height": 154,
"collapsed": false,
@@ -59,7 +59,7 @@
"num": 0
},
"x": 38,
"y": -93,
"y": -364,
"width": 210,
"height": 154,
"collapsed": false,
@@ -76,7 +76,7 @@
"num": 0
},
"x": 288,
"y": -63,
"y": -334,
"width": 210,
"height": 106,
"collapsed": true,
@@ -89,7 +89,7 @@
"id": "553125be-2c2b-4404-98b5-d6333a4f9655",
"properties": {},
"x": 348,
"y": 507,
"y": 236,
"width": 140,
"height": 26,
"collapsed": false,
@@ -102,7 +102,7 @@
"id": "207e357e-5e83-4d40-a331-d0041b9dfa49",
"properties": {},
"x": 348,
"y": 307,
"y": 36,
"width": 140,
"height": 26,
"collapsed": false,
@@ -115,7 +115,7 @@
"id": "bad7d2ed-f6fa-452b-8b47-d753ae7a45e0",
"properties": {},
"x": 348,
"y": 107,
"y": -164,
"width": 140,
"height": 26,
"collapsed": false,
@@ -131,7 +131,7 @@
"scope": "local"
},
"x": 598,
"y": 107,
"y": -164,
"width": 210,
"height": 122,
"collapsed": false,
@@ -147,7 +147,7 @@
"scope": "local"
},
"x": 598,
"y": 307,
"y": 36,
"width": 210,
"height": 122,
"collapsed": false,
@@ -163,7 +163,7 @@
"scope": "local"
},
"x": 598,
"y": 507,
"y": 236,
"width": 210,
"height": 122,
"collapsed": false,
@@ -176,7 +176,7 @@
"id": "1765a51c-82ac-4d96-9c60-d0adc7faaa68",
"properties": {},
"x": 498,
"y": 707,
"y": 436,
"width": 171,
"height": 26,
"collapsed": false,
@@ -192,7 +192,7 @@
"scope": "local"
},
"x": 798,
"y": 706,
"y": 435,
"width": 244,
"height": 122,
"collapsed": false,
@@ -205,7 +205,7 @@
"id": "f3dc37df-1748-4235-b575-60e55b8bec73",
"properties": {},
"x": 498,
"y": 906,
"y": 635,
"width": 171,
"height": 26,
"collapsed": false,
@@ -220,7 +220,7 @@
"stage": 0
},
"x": 1208,
"y": 366,
"y": 95,
"width": 210,
"height": 118,
"collapsed": true,
@@ -238,7 +238,7 @@
"icon": "F1719"
},
"x": 1178,
"y": -144,
"y": -415,
"width": 210,
"height": 130,
"collapsed": false,
@@ -253,7 +253,7 @@
"default": true
},
"x": 298,
"y": 736,
"y": 465,
"width": 210,
"height": 58,
"collapsed": true,
@@ -268,7 +268,7 @@
"default": false
},
"x": 298,
"y": 936,
"y": 665,
"width": 210,
"height": 58,
"collapsed": true,
@@ -287,7 +287,7 @@
"num": 2
},
"x": 38,
"y": 306,
"y": 35,
"width": 210,
"height": 154,
"collapsed": false,
@@ -303,7 +303,7 @@
"scope": "local"
},
"x": 798,
"y": 906,
"y": 635,
"width": 244,
"height": 122,
"collapsed": false,
@@ -322,7 +322,7 @@
"num": 5
},
"x": 38,
"y": 706,
"y": 435,
"width": 237,
"height": 154,
"collapsed": false,
@@ -341,7 +341,7 @@
"num": 7
},
"x": 38,
"y": 906,
"y": 635,
"width": 210,
"height": 154,
"collapsed": false,
@@ -360,7 +360,7 @@
"num": 4
},
"x": 48,
"y": 1326,
"y": 1055,
"width": 237,
"height": 154,
"collapsed": false,
@@ -375,7 +375,7 @@
"default": true
},
"x": 318,
"y": 1356,
"y": 1085,
"width": 210,
"height": 58,
"collapsed": true,
@@ -388,7 +388,7 @@
"id": "6d387c67-6b32-4435-b984-4760f0f1f8d2",
"properties": {},
"x": 498,
"y": 1336,
"y": 1065,
"width": 171,
"height": 26,
"collapsed": false,
@@ -404,7 +404,7 @@
"scope": "local"
},
"x": 798,
"y": 1316,
"y": 1045,
"width": 244,
"height": 122,
"collapsed": false,
@@ -455,7 +455,7 @@
"num": 6
},
"x": 38,
"y": 1106,
"y": 835,
"width": 252,
"height": 154,
"collapsed": false,
@@ -471,7 +471,7 @@
"apply_on_unresolved": true
},
"x": 518,
"y": 1136,
"y": 865,
"width": 210,
"height": 102,
"collapsed": true,
@@ -1101,7 +1101,7 @@
"num": 8
},
"x": 49,
"y": 1546,
"y": 1275,
"width": 210,
"height": 154,
"collapsed": false,
@@ -1116,7 +1116,7 @@
"default": false
},
"x": 329,
"y": 1586,
"y": 1315,
"width": 210,
"height": 58,
"collapsed": true,
@@ -1129,7 +1129,7 @@
"id": "8acfe789-fbb5-4e29-8fd8-2217b987c086",
"properties": {},
"x": 499,
"y": 1566,
"y": 1295,
"width": 171,
"height": 26,
"collapsed": false,
@@ -1145,7 +1145,7 @@
"scope": "local"
},
"x": 799,
"y": 1526,
"y": 1255,
"width": 244,
"height": 122,
"collapsed": false,
@@ -1258,8 +1258,8 @@
"output_name": "character",
"num": 0
},
"x": 331,
"y": 5739,
"x": 332,
"y": 6178,
"width": 210,
"height": 106,
"collapsed": false,
@@ -1274,8 +1274,8 @@
"name": "character",
"scope": "local"
},
"x": 51,
"y": 5729,
"x": 52,
"y": 6168,
"width": 210,
"height": 122,
"collapsed": false,
@@ -1291,8 +1291,8 @@
"output_name": "actor",
"num": 0
},
"x": 331,
"y": 5949,
"x": 332,
"y": 6388,
"width": 210,
"height": 106,
"collapsed": false,
@@ -1307,8 +1307,8 @@
"name": "actor",
"scope": "local"
},
"x": 51,
"y": 5949,
"x": 52,
"y": 6388,
"width": 210,
"height": 122,
"collapsed": false,
@@ -1323,7 +1323,7 @@
"stage": 0
},
"x": 1320,
"y": 1160,
"y": 889,
"width": 210,
"height": 118,
"collapsed": true,
@@ -1414,7 +1414,7 @@
"writing_style": null
},
"x": 320,
"y": 1200,
"y": 929,
"width": 270,
"height": 122,
"collapsed": true,
@@ -1430,7 +1430,7 @@
"scope": "local"
},
"x": 790,
"y": 1110,
"y": 839,
"width": 244,
"height": 122,
"collapsed": false,
@@ -1475,6 +1475,159 @@
"inherited": false,
"registry": "agents/creator/ContextualGenerate",
"base_type": "core/Node"
},
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1": {
"title": "assign_voice",
"id": "d61de1ad-6f2a-447f-918a-dce7e76ea3a1",
"properties": {},
"x": 509,
"y": 1544,
"width": 171,
"height": 26,
"collapsed": false,
"inherited": false,
"registry": "core/Watch",
"base_type": "core/Node"
},
"3d655827-b66b-4355-910d-96097e7f2f13": {
"title": "SET local.assign_voice",
"id": "3d655827-b66b-4355-910d-96097e7f2f13",
"properties": {
"name": "assign_voice",
"scope": "local"
},
"x": 810,
"y": 1506,
"width": 244,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/SetState",
"base_type": "core/Node"
},
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1": {
"title": "Stage 0",
"id": "6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1",
"properties": {
"stage": 0
},
"x": 1170,
"y": 1546,
"width": 210,
"height": 118,
"collapsed": true,
"inherited": false,
"registry": "core/Stage",
"base_type": "core/Node"
},
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f": {
"title": "IN assign_voice",
"id": "9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f",
"properties": {
"input_type": "bool",
"input_name": "assign_voice",
"input_optional": true,
"input_group": "",
"num": 9
},
"x": 60,
"y": 1527,
"width": 210,
"height": 154,
"collapsed": false,
"inherited": false,
"registry": "core/Input",
"base_type": "core/Node"
},
"de11206d-13db-44a5-befd-de559fb68d09": {
"title": "GET local.assign_voice",
"id": "de11206d-13db-44a5-befd-de559fb68d09",
"properties": {
"name": "assign_voice",
"scope": "local"
},
"x": 25,
"y": 5720,
"width": 240,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/GetState",
"base_type": "core/Node"
},
"97b196a3-e7a6-4cfa-905e-686a744890b7": {
"title": "Switch",
"id": "97b196a3-e7a6-4cfa-905e-686a744890b7",
"properties": {
"pass_through": true
},
"x": 355,
"y": 5740,
"width": 210,
"height": 78,
"collapsed": false,
"inherited": false,
"registry": "core/Switch",
"base_type": "core/Node"
},
"3956711a-a3df-4213-b739-104cbe704964": {
"title": "GET local.character",
"id": "3956711a-a3df-4213-b739-104cbe704964",
"properties": {
"name": "character",
"scope": "local"
},
"x": 25,
"y": 5930,
"width": 210,
"height": 122,
"collapsed": false,
"inherited": false,
"registry": "state/GetState",
"base_type": "core/Node"
},
"47b2b492-4178-4254-b468-3877a5341f66": {
"title": "Assign Voice",
"id": "47b2b492-4178-4254-b468-3877a5341f66",
"properties": {},
"x": 665,
"y": 5830,
"width": 161,
"height": 66,
"collapsed": false,
"inherited": false,
"registry": "agents/director/AssignVoice",
"base_type": "core/Node"
},
"44d78795-2d41-468b-94d5-399b9b655888": {
"title": "Stage 6",
"id": "44d78795-2d41-468b-94d5-399b9b655888",
"properties": {
"stage": 6
},
"x": 885,
"y": 5860,
"width": 210,
"height": 118,
"collapsed": true,
"inherited": false,
"registry": "core/Stage",
"base_type": "core/Node"
},
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85": {
"title": "As Bool",
"id": "f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85",
"properties": {
"default": true
},
"x": 330,
"y": 1560,
"width": 210,
"height": 58,
"collapsed": true,
"inherited": false,
"registry": "core/AsBool",
"base_type": "core/Node"
}
},
"edges": {
@@ -1726,15 +1879,39 @@
],
"4acb67ea-68ee-43ae-a8c6-98a2b0e0f053.text": [
"d41f0d98-14d5-49dd-8e57-7812fb9fee94.value"
],
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value": [
"3d655827-b66b-4355-910d-96097e7f2f13.value"
],
"3d655827-b66b-4355-910d-96097e7f2f13.value": [
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1.state"
],
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f.value": [
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value"
],
"de11206d-13db-44a5-befd-de559fb68d09.value": [
"97b196a3-e7a6-4cfa-905e-686a744890b7.value"
],
"97b196a3-e7a6-4cfa-905e-686a744890b7.yes": [
"47b2b492-4178-4254-b468-3877a5341f66.state"
],
"3956711a-a3df-4213-b739-104cbe704964.value": [
"47b2b492-4178-4254-b468-3877a5341f66.character"
],
"47b2b492-4178-4254-b468-3877a5341f66.state": [
"44d78795-2d41-468b-94d5-399b9b655888.state"
],
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value": [
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value"
]
},
"groups": [
{
"title": "Process Arguments - Stage 0",
"x": 1,
"y": -218,
"width": 1446,
"height": 1948,
"y": -490,
"width": 1432,
"height": 2216,
"color": "#3f789e",
"font_size": 24,
"inherited": false
@@ -1792,12 +1969,22 @@
{
"title": "Outputs",
"x": 0,
"y": 5642,
"y": 6080,
"width": 595,
"height": 472,
"color": "#8A8",
"font_size": 24,
"inherited": false
},
{
"title": "Assign Voice - Stage 6",
"x": 0,
"y": 5640,
"width": 1120,
"height": 437,
"color": "#3f789e",
"font_size": 24,
"inherited": false
}
],
"comments": [],
@@ -1805,6 +1992,7 @@
"base_type": "core/Graph",
"inputs": [],
"outputs": [],
"module_properties": {},
"style": {
"title_color": "#572e44",
"node_color": "#392c34",

View File

@@ -1,145 +1,153 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from typing import ClassVar
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
UNRESOLVED,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
if TYPE_CHECKING:
from talemate.tale_mate import Scene
log = structlog.get_logger("talemate.game.engine.nodes.agents.creator")
@register("agents/creator/Settings")
class CreatorSettings(AgentSettingsNode):
"""
Base node to render creator agent settings.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
def __init__(self, title="Creator Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/creator/DetermineContentContext")
class DetermineContentContext(AgentNode):
"""
Determines the context for the content creation.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
description = PropertyField(
name="description",
description="Description of the context",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
def __init__(self, title="Determine Content Context", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("description", socket_type="str", optional=True)
self.set_property("description", UNRESOLVED)
self.add_output("content_context", socket_type="str")
async def run(self, state: GraphState):
context = await self.agent.determine_content_context_for_description(
self.require_input("description")
)
self.set_output_values({
"content_context": context
})
self.set_output_values({"content_context": context})
@register("agents/creator/DetermineCharacterDescription")
class DetermineCharacterDescription(AgentNode):
"""
Determines the description for a character.
Inputs:
- state: The current state of the graph
- character: The character to determine the description for
- extra_context: Extra context to use in determining the
Outputs:
- description: The determined description
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
def __init__(self, title="Determine Character Description", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("extra_context", socket_type="str", optional=True)
self.add_output("description", socket_type="str")
async def run(self, state: GraphState):
character = self.require_input("character")
extra_context = self.get_input_value("extra_context")
if extra_context is UNRESOLVED:
extra_context = ""
description = await self.agent.determine_character_description(character, extra_context)
self.set_output_values({
"description": description
})
description = await self.agent.determine_character_description(
character, extra_context
)
self.set_output_values({"description": description})
@register("agents/creator/DetermineCharacterDialogueInstructions")
class DetermineCharacterDialogueInstructions(AgentNode):
"""
Determines the dialogue instructions for a character.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
instructions = PropertyField(
name="instructions",
description="Any additional instructions to use in determining the dialogue instructions",
type="text",
default=""
default="",
)
def __init__(self, title="Determine Character Dialogue Instructions", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("instructions", socket_type="str", optional=True)
self.set_property("instructions", "")
self.add_output("dialogue_instructions", socket_type="str")
async def run(self, state: GraphState):
character = self.require_input("character")
instructions = self.normalized_input_value("instructions")
dialogue_instructions = await self.agent.determine_character_dialogue_instructions(character, instructions)
self.set_output_values({
"dialogue_instructions": dialogue_instructions
})
dialogue_instructions = (
await self.agent.determine_character_dialogue_instructions(
character, instructions
)
)
self.set_output_values({"dialogue_instructions": dialogue_instructions})
@register("agents/creator/ContextualGenerate")
class ContextualGenerate(AgentNode):
"""
Generates text based on the given context.
Inputs:
- state: The current state of the graph
- context_type: The type of context to use in generating the text
- context_name: The name of the context to use in generating the text
@@ -150,9 +158,9 @@ class ContextualGenerate(AgentNode):
- partial: The partial text to use in generating the text
- uid: The uid to use in generating the text
- generation_options: The generation options to use in generating the text
Properties:
- context_type: The type of context to use in generating the text
- context_name: The name of the context to use in generating the text
- instructions: The instructions to use in generating the text
@@ -161,82 +169,82 @@ class ContextualGenerate(AgentNode):
- uid: The uid to use in generating the text
- context_aware: Whether to use the context in generating the text
- history_aware: Whether to use the history in generating the text
Outputs:
- state: The updated state of the graph
- text: The generated text
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
context_type = PropertyField(
name="context_type",
description="The type of context to use in generating the text",
type="str",
choices=[
"character attribute",
"character detail",
"character dialogue",
"scene intro",
"scene intent",
"scene phase intent",
"character attribute",
"character detail",
"character dialogue",
"scene intro",
"scene intent",
"scene phase intent",
"scene type description",
"scene type instructions",
"general",
"list",
"scene",
"scene type instructions",
"general",
"list",
"scene",
"world context",
],
default="general"
default="general",
)
context_name = PropertyField(
name="context_name",
description="The name of the context to use in generating the text",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
instructions = PropertyField(
name="instructions",
description="The instructions to use in generating the text",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
length = PropertyField(
name="length",
description="The length of the text to generate",
type="int",
default=100
default=100,
)
character = PropertyField(
name="character",
description="The character to generate the text for",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
uid = PropertyField(
name="uid",
description="The uid to use in generating the text",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
context_aware = PropertyField(
name="context_aware",
description="Whether to use the context in generating the text",
type="bool",
default=True
default=True,
)
history_aware = PropertyField(
name="history_aware",
description="Whether to use the history in generating the text",
type="bool",
default=True
default=True,
)
def __init__(self, title="Contextual Generate", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("context_type", socket_type="str", optional=True)
@@ -247,8 +255,10 @@ class ContextualGenerate(AgentNode):
self.add_input("original", socket_type="str", optional=True)
self.add_input("partial", socket_type="str", optional=True)
self.add_input("uid", socket_type="str", optional=True)
self.add_input("generation_options", socket_type="generation_options", optional=True)
self.add_input(
"generation_options", socket_type="generation_options", optional=True
)
self.set_property("context_type", "general")
self.set_property("context_name", UNRESOLVED)
self.set_property("instructions", UNRESOLVED)
@@ -257,10 +267,10 @@ class ContextualGenerate(AgentNode):
self.set_property("uid", UNRESOLVED)
self.set_property("context_aware", True)
self.set_property("history_aware", True)
self.add_output("state")
self.add_output("text", socket_type="str")
async def run(self, state: GraphState):
scene = active_scene.get()
context_type = self.require_input("context_type")
@@ -274,12 +284,12 @@ class ContextualGenerate(AgentNode):
generation_options = self.normalized_input_value("generation_options")
context_aware = self.normalized_input_value("context_aware")
history_aware = self.normalized_input_value("history_aware")
context = f"{context_type}:{context_name}" if context_name else context_type
if isinstance(character, scene.Character):
character = character.name
text = await self.agent.contextual_generate_from_args(
context=context,
instructions=instructions,
@@ -288,31 +298,32 @@ class ContextualGenerate(AgentNode):
original=original,
partial=partial or "",
uid=uid,
writing_style=generation_options.writing_style if generation_options else None,
writing_style=generation_options.writing_style
if generation_options
else None,
spices=generation_options.spices if generation_options else None,
spice_level=generation_options.spice_level if generation_options else 0.0,
context_aware=context_aware,
history_aware=history_aware
history_aware=history_aware,
)
self.set_output_values({
"state": state,
"text": text
})
self.set_output_values({"state": state, "text": text})
@register("agents/creator/GenerateThematicList")
class GenerateThematicList(AgentNode):
"""
Generates a list of thematic items based on the instructions.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
instructions = PropertyField(
name="instructions",
description="The instructions to use in generating the list",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
iterations = PropertyField(
name="iterations",
@@ -323,27 +334,24 @@ class GenerateThematicList(AgentNode):
min=1,
max=10,
)
def __init__(self, title="Generate Thematic List", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("instructions", socket_type="str", optional=True)
self.set_property("instructions", UNRESOLVED)
self.set_property("iterations", 1)
self.add_output("state")
self.add_output("list", socket_type="list")
async def run(self, state: GraphState):
instructions = self.normalized_input_value("instructions")
iterations = self.require_number_input("iterations")
list = await self.agent.generate_thematic_list(instructions, iterations)
self.set_output_values({
"state": state,
"list": list
})
self.set_output_values({"state": state, "list": list})

View File

@@ -10,7 +10,7 @@ class ScenarioCreatorMixin:
@set_processing
async def determine_scenario_description(self, text: str):
description = await Prompt.request(
f"creator.determine-scenario-description",
"creator.determine-scenario-description",
self.client,
"analyze_long",
vars={
@@ -25,7 +25,7 @@ class ScenarioCreatorMixin:
description: str,
):
content_context = await Prompt.request(
f"creator.determine-content-context",
"creator.determine-content-context",
self.client,
"create_short",
vars={

View File

@@ -1,36 +1,23 @@
from __future__ import annotations
import random
from typing import TYPE_CHECKING, List
import structlog
import traceback
import talemate.emit.async_signals
import talemate.instance as instance
from talemate.agents.conversation import ConversationAgentEmission
from talemate.emit import emit
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage
from talemate.util import random_color
from talemate.character import deactivate_character
from talemate.status import LoadingStatus
from talemate.exceptions import GenerationCancelled
from talemate.scene_message import DirectorMessage, Flags
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
from talemate.agents.base import Agent, AgentAction, AgentActionConfig
from talemate.agents.registry import register
from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.client import ClientBase
from talemate.game.focal.schema import Call
from .guide import GuideSceneMixin
from .generate_choices import GenerateChoicesMixin
from .legacy_scene_instructions import LegacySceneInstructionsMixin
from .auto_direct import AutoDirectMixin
from .websocket_handler import DirectorWebsocketHandler
import talemate.agents.director.nodes
if TYPE_CHECKING:
from talemate import Character, Scene
from .character_management import CharacterManagementMixin
import talemate.agents.director.nodes # noqa: F401
log = structlog.get_logger("talemate.agent.director")
@@ -42,7 +29,8 @@ class DirectorAgent(
GenerateChoicesMixin,
AutoDirectMixin,
LegacySceneInstructionsMixin,
Agent
CharacterManagementMixin,
Agent,
):
agent_type = "director"
verbose_name = "Director"
@@ -74,15 +62,16 @@ class DirectorAgent(
],
),
},
),
),
}
MemoryRAGMixin.add_actions(actions)
GenerateChoicesMixin.add_actions(actions)
GuideSceneMixin.add_actions(actions)
AutoDirectMixin.add_actions(actions)
CharacterManagementMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.is_enabled = True
self.client = client
self.next_direct_character = {}
@@ -105,165 +94,24 @@ class DirectorAgent(
def actor_direction_mode(self):
return self.actions["direct"].config["actor_direction_mode"].value
@set_processing
async def persist_characters_from_worldstate(
self, exclude: list[str] = None
) -> List[Character]:
created_characters = []
for character_name in self.scene.world_state.characters.keys():
if exclude and character_name.lower() in exclude:
continue
if character_name in self.scene.character_names:
continue
character = await self.persist_character(name=character_name)
created_characters.append(character)
self.scene.emit_status()
return created_characters
@set_processing
async def persist_character(
self,
name: str,
content: str = None,
attributes: str = None,
determine_name: bool = True,
templates: list[str] = None,
active: bool = True,
narrate_entry: bool = False,
narrate_entry_direction: str = "",
augment_attributes: str = "",
generate_attributes: bool = True,
description: str = "",
) -> Character:
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
narrator = instance.get_agent("narrator")
memory = instance.get_agent("memory")
scene: "Scene" = self.scene
any_attribute_templates = False
loading_status = LoadingStatus(max_steps=None, cancellable=True)
# Start of character creation
log.debug("persist_character", name=name)
# Determine the character's name (or clarify if it's already set)
if determine_name:
loading_status("Determining character name")
name = await creator.determine_character_name(name, instructions=content)
log.debug("persist_character", adjusted_name=name)
# Create the blank character
character:Character = self.scene.Character(name=name)
# Add the character to the scene
character.color = random_color()
actor = self.scene.Actor(
character=character, agent=instance.get_agent("conversation")
async def log_function_call(self, call: Call):
log.debug("director.log_function_call", call=call)
message = DirectorMessage(
message=f"Called {call.name}",
action=call.name,
flags=Flags.HIDDEN,
subtype="function_call",
)
await self.scene.add_actor(actor)
try:
emit("director", message, data={"function_call": call.model_dump()})
# Apply any character generation templates
if templates:
loading_status("Applying character generation templates")
templates = scene.world_state_manager.template_collection.collect_all(templates)
log.debug("persist_character", applying_templates=templates)
await scene.world_state_manager.apply_templates(
templates.values(),
character_name=character.name,
information=content
)
# if any of the templates are attribute templates, then we no longer need to
# generate a character sheet
any_attribute_templates = any(template.template_type == "character_attribute" for template in templates.values())
log.debug("persist_character", any_attribute_templates=any_attribute_templates)
if any_attribute_templates and augment_attributes and generate_attributes:
log.debug("persist_character", augmenting_attributes=augment_attributes)
loading_status("Augmenting character attributes")
additional_attributes = await world_state.extract_character_sheet(
name=name,
text=content,
augmentation_instructions=augment_attributes
)
character.base_attributes.update(additional_attributes)
# Generate a character sheet if there are no attribute templates
if not any_attribute_templates and generate_attributes:
loading_status("Generating character sheet")
log.debug("persist_character", extracting_character_sheet=True)
if not attributes:
attributes = await world_state.extract_character_sheet(
name=name, text=content
)
else:
attributes = world_state._parse_character_sheet(attributes)
log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
# Generate a description for the character
if not description:
loading_status("Generating character description")
description = await creator.determine_character_description(character, information=content)
character.description = description
log.debug("persist_character", description=description)
# Generate a dialogue instructions for the character
loading_status("Generating acting instructions")
dialogue_instructions = await creator.determine_character_dialogue_instructions(
character,
information=content
)
character.dialogue_instructions = dialogue_instructions
log.debug(
"persist_character", dialogue_instructions=dialogue_instructions
)
# Narrate the character's entry if the option is selected
if active and narrate_entry:
loading_status("Narrating character entry")
is_present = await world_state.is_character_present(name)
if not is_present:
await narrator.action_to_narration(
"narrate_character_entry",
emit_message=True,
character=character,
narrative_direction=narrate_entry_direction
)
# Deactivate the character if not active
if not active:
await deactivate_character(scene, character)
# Commit the character's details to long term memory
await character.commit_to_memory(memory)
self.scene.emit_status()
self.scene.world_state.emit()
loading_status.done(message=f"{character.name} added to scene", status="success")
return character
except GenerationCancelled:
loading_status.done(message="Character creation cancelled", status="idle")
await scene.remove_actor(actor)
except Exception as e:
loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor)
log.error("Error persisting character", error=traceback.format_exc())
async def log_action(self, action: str, action_description: str):
message = DirectorMessage(message=action_description, action=action)
async def log_action(
self, action: str, action_description: str, console_only: bool = False
):
message = DirectorMessage(
message=action_description,
action=action,
flags=Flags.HIDDEN if console_only else Flags.NONE,
)
self.scene.push_history(message)
emit("director", message)
@@ -276,4 +124,4 @@ class DirectorAgent(
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
return False
return False

View File

@@ -1,23 +1,22 @@
from typing import TYPE_CHECKING
import structlog
import pydantic
from talemate.agents.base import (
set_processing,
AgentAction,
AgentActionConfig,
AgentEmission,
AgentTemplateEmission,
)
from talemate.status import set_loading
import talemate.game.focal as focal
from talemate.prompts import Prompt
import talemate.emit.async_signals as async_signals
from talemate.scene_message import CharacterMessage, TimePassageMessage, DirectorMessage, NarratorMessage
from talemate.scene.schema import ScenePhase, SceneType, SceneIntent
from talemate.scene_message import (
CharacterMessage,
TimePassageMessage,
NarratorMessage,
)
from talemate.scene.schema import ScenePhase, SceneType
from talemate.scene.intent import set_scene_phase
from talemate.world_state.manager import WorldStateManager
from talemate.world_state.templates.scene import SceneType as TemplateSceneType
import talemate.agents.director.auto_direct_nodes
import talemate.agents.director.auto_direct_nodes # noqa: F401
from talemate.world_state.templates.base import TypedCollection
if TYPE_CHECKING:
@@ -26,15 +25,15 @@ if TYPE_CHECKING:
log = structlog.get_logger("talemate.agents.conversation.direct")
#talemate.emit.async_signals.register(
#)
# talemate.emit.async_signals.register(
# )
class AutoDirectMixin:
"""
Director agent mixin for automatic scene direction.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["auto_direct"] = AgentAction(
@@ -108,113 +107,117 @@ class AutoDirectMixin:
),
},
)
# config property helpers
@property
def auto_direct_enabled(self) -> bool:
return self.actions["auto_direct"].enabled
@property
def auto_direct_max_auto_turns(self) -> int:
return self.actions["auto_direct"].config["max_auto_turns"].value
@property
def auto_direct_max_idle_turns(self) -> int:
return self.actions["auto_direct"].config["max_idle_turns"].value
@property
def auto_direct_max_repeat_turns(self) -> int:
return self.actions["auto_direct"].config["max_repeat_turns"].value
@property
def auto_direct_instruct_actors(self) -> bool:
return self.actions["auto_direct"].config["instruct_actors"].value
@property
def auto_direct_instruct_narrator(self) -> bool:
return self.actions["auto_direct"].config["instruct_narrator"].value
@property
def auto_direct_instruct_frequency(self) -> int:
return self.actions["auto_direct"].config["instruct_frequency"].value
@property
def auto_direct_evaluate_scene_intention(self) -> int:
return self.actions["auto_direct"].config["evaluate_scene_intention"].value
@property
def auto_direct_instruct_any(self) -> bool:
"""
Will check whether actor or narrator instructions are enabled.
For narrator instructions to be enabled instruct_narrator needs to be enabled as well.
Returns:
bool: True if either actor or narrator instructions are enabled.
"""
return self.auto_direct_instruct_actors or self.auto_direct_instruct_narrator
# signal connect
def connect(self, scene):
super().connect(scene)
async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, event):
if not self.auto_direct_enabled:
return
if self.auto_direct_evaluate_scene_intention > 0:
evaluation_due = self.get_scene_state("evaluated_scene_intention", 0)
if evaluation_due == 0:
await self.auto_direct_set_scene_intent()
self.set_scene_states(evaluated_scene_intention=self.auto_direct_evaluate_scene_intention)
self.set_scene_states(
evaluated_scene_intention=self.auto_direct_evaluate_scene_intention
)
else:
self.set_scene_states(evaluated_scene_intention=evaluation_due - 1)
# helpers
def auto_direct_is_due_for_instruction(self, actor_name:str) -> bool:
def auto_direct_is_due_for_instruction(self, actor_name: str) -> bool:
"""
Check if the actor is due for instruction.
"""
if self.auto_direct_instruct_frequency == 1:
return True
messages_since_last_instruction = 0
def count_messages(message):
nonlocal messages_since_last_instruction
if message.typ in ["character", "narrator"]:
messages_since_last_instruction += 1
last_instruction = self.scene.last_message_of_type(
"director",
character_name=actor_name,
max_iterations=25,
on_iterate=count_messages,
)
log.debug("auto_direct_is_due_for_instruction", messages_since_last_instruction=messages_since_last_instruction, last_instruction=last_instruction.id if last_instruction else None)
log.debug(
"auto_direct_is_due_for_instruction",
messages_since_last_instruction=messages_since_last_instruction,
last_instruction=last_instruction.id if last_instruction else None,
)
if not last_instruction:
return True
return messages_since_last_instruction >= self.auto_direct_instruct_frequency
def auto_direct_candidates(self) -> list["Character"]:
"""
Returns a list of characters who are valid candidates to speak next.
based on the max_idle_turns, max_repeat_turns, and the most recent character.
"""
scene:"Scene" = self.scene
scene: "Scene" = self.scene
most_recent_character = None
repeat_count = 0
last_player_turn = None
@@ -223,89 +226,105 @@ class AutoDirectMixin:
active_charcters = list(scene.characters)
active_character_names = [character.name for character in active_charcters]
instruct_narrator = self.auto_direct_instruct_narrator
# if there is only one character then they are the only candidate
if len(active_charcters) == 1:
return active_charcters
BACKLOG_LIMIT = 50
player_character_active = scene.player_character_exists
# check the last BACKLOG_LIMIT entries in the scene history and collect into
# a dictionary of character names and the number of turns since they last spoke.
len_history = len(scene.history) - 1
num = 0
for idx in range(len_history, -1, -1):
message = scene.history[idx]
turns = len_history - idx
num += 1
if num > BACKLOG_LIMIT:
break
if isinstance(message, TimePassageMessage):
break
if not isinstance(message, (CharacterMessage, NarratorMessage)):
continue
# if character message but character is not in the active characters list then skip
if isinstance(message, CharacterMessage) and message.character_name not in active_character_names:
if (
isinstance(message, CharacterMessage)
and message.character_name not in active_character_names
):
continue
if isinstance(message, NarratorMessage):
if not instruct_narrator:
continue
character = scene.narrator_character_object
else:
character = scene.get_character(message.character_name)
if not character:
continue
if character.is_player and last_player_turn is None:
last_player_turn = turns
elif not character.is_player and last_player_turn is None:
consecutive_auto_turns += 1
if not most_recent_character:
most_recent_character = character
repeat_count += 1
elif character == most_recent_character:
repeat_count += 1
if character.name not in candidates:
candidates[character.name] = turns
# add any characters that have not spoken yet
for character in active_charcters:
if character.name not in candidates:
candidates[character.name] = 0
# explicitly add narrator if enabled and not already in candidates
if instruct_narrator and scene.narrator_character_object:
narrator = scene.narrator_character_object
if narrator.name not in candidates:
candidates[narrator.name] = 0
log.debug(f"auto_direct_candidates: {candidates}", most_recent_character=most_recent_character, repeat_count=repeat_count, last_player_turn=last_player_turn, consecutive_auto_turns=consecutive_auto_turns)
log.debug(
f"auto_direct_candidates: {candidates}",
most_recent_character=most_recent_character,
repeat_count=repeat_count,
last_player_turn=last_player_turn,
consecutive_auto_turns=consecutive_auto_turns,
)
if not most_recent_character:
log.debug("auto_direct_candidates: No most recent character found.")
return list(scene.characters)
# if player has not spoken in a while then they are favored
if player_character_active and consecutive_auto_turns >= self.auto_direct_max_auto_turns:
log.debug("auto_direct_candidates: User controlled character has not spoken in a while.")
if (
player_character_active
and consecutive_auto_turns >= self.auto_direct_max_auto_turns
):
log.debug(
"auto_direct_candidates: User controlled character has not spoken in a while."
)
return [scene.get_player_character()]
# check if most recent character has spoken too many times in a row
# if so then remove from candidates
if repeat_count >= self.auto_direct_max_repeat_turns:
log.debug("auto_direct_candidates: Most recent character has spoken too many times in a row.", most_recent_character=most_recent_character
log.debug(
"auto_direct_candidates: Most recent character has spoken too many times in a row.",
most_recent_character=most_recent_character,
)
candidates.pop(most_recent_character.name, None)
@@ -314,27 +333,34 @@ class AutoDirectMixin:
favored_candidates = []
for name, turns in candidates.items():
if turns > self.auto_direct_max_idle_turns:
log.debug("auto_direct_candidates: Character has gone too long without speaking.", character_name=name, turns=turns)
log.debug(
"auto_direct_candidates: Character has gone too long without speaking.",
character_name=name,
turns=turns,
)
favored_candidates.append(scene.get_character(name))
if favored_candidates:
return favored_candidates
return [scene.get_character(character_name) for character_name in candidates.keys()]
return [
scene.get_character(character_name) for character_name in candidates.keys()
]
# actions
@set_processing
async def auto_direct_set_scene_intent(self, require:bool=False) -> ScenePhase | None:
async def set_scene_intention(type:str, intention:str) -> ScenePhase:
async def auto_direct_set_scene_intent(
self, require: bool = False
) -> ScenePhase | None:
async def set_scene_intention(type: str, intention: str) -> ScenePhase:
await set_scene_phase(self.scene, type, intention)
self.scene.emit_status()
return self.scene.intent_state.phase
async def do_nothing(*args, **kwargs) -> None:
return None
focal_handler = focal.Focal(
self.client,
callbacks=[
@@ -355,57 +381,63 @@ class AutoDirectMixin:
],
max_calls=1,
scene=self.scene,
scene_type_ids=", ".join([f'"{scene_type.id}"' for scene_type in self.scene.intent_state.scene_types.values()]),
scene_type_ids=", ".join(
[
f'"{scene_type.id}"'
for scene_type in self.scene.intent_state.scene_types.values()
]
),
retries=1,
require=require,
)
await focal_handler.request(
"director.direct-determine-scene-intent",
)
return self.scene.intent_state.phase
@set_processing
async def auto_direct_generate_scene_types(
self,
instructions:str,
max_scene_types:int=1,
self,
instructions: str,
max_scene_types: int = 1,
):
world_state_manager:WorldStateManager = self.scene.world_state_manager
scene_type_templates:TypedCollection = await world_state_manager.get_templates(types=["scene_type"])
async def add_from_template(id:str) -> SceneType:
template:TemplateSceneType | None = scene_type_templates.find_by_name(id)
world_state_manager: WorldStateManager = self.scene.world_state_manager
scene_type_templates: TypedCollection = await world_state_manager.get_templates(
types=["scene_type"]
)
async def add_from_template(id: str) -> SceneType:
template: TemplateSceneType | None = scene_type_templates.find_by_name(id)
if not template:
log.warning("auto_direct_generate_scene_types: Template not found.", name=id)
log.warning(
"auto_direct_generate_scene_types: Template not found.", name=id
)
return None
return template.apply_to_scene(self.scene)
async def generate_scene_type(
id:str = None,
name:str = None,
description:str = None,
instructions:str = None,
id: str = None,
name: str = None,
description: str = None,
instructions: str = None,
) -> SceneType:
if not id or not name:
return None
scene_type = SceneType(
id=id,
name=name,
description=description,
instructions=instructions,
)
self.scene.intent_state.scene_types[id] = scene_type
return scene_type
focal_handler = focal.Focal(
self.client,
callbacks=[
@@ -435,7 +467,7 @@ class AutoDirectMixin:
instructions=instructions,
scene_type_templates=scene_type_templates.templates,
)
await focal_handler.request(
"director.generate-scene-types",
)
)

View File

@@ -1,17 +1,13 @@
import structlog
from typing import TYPE_CHECKING, ClassVar
from typing import ClassVar
from talemate.game.engine.nodes.core import GraphState, PropertyField
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentNode
from talemate.scene.schema import ScenePhase
from talemate.context import active_scene
if TYPE_CHECKING:
from talemate.tale_mate import Scene
from talemate.agents.director import DirectorAgent
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
@register("agents/director/auto-direct/Candidates")
class AutoDirectCandidates(AgentNode):
"""
@@ -19,52 +15,51 @@ class AutoDirectCandidates(AgentNode):
next action, based on the director's auto-direct settings and
the recent scene history.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Auto Direct Candidates", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_output("characters", socket_type="list")
async def run(self, state: GraphState):
candidates = self.agent.auto_direct_candidates()
self.set_output_values({
"characters": candidates
})
self.set_output_values({"characters": candidates})
@register("agents/director/auto-direct/DetermineSceneIntent")
class DetermineSceneIntent(AgentNode):
"""
Determines the scene intent based on the current scene state.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Determine Scene Intent", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_output("state")
self.add_output("scene_phase", socket_type="scene_intent/scene_phase")
async def run(self, state: GraphState):
phase:ScenePhase = await self.agent.auto_direct_set_scene_intent()
self.set_output_values({
"state": state,
"scene_phase": phase
})
phase: ScenePhase = await self.agent.auto_direct_set_scene_intent()
self.set_output_values({"state": state, "scene_phase": phase})
@register("agents/director/auto-direct/GenerateSceneTypes")
class GenerateSceneTypes(AgentNode):
"""
Generates scene types based on the current scene state.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
class Fields:
instructions = PropertyField(
name="instructions",
@@ -72,17 +67,17 @@ class GenerateSceneTypes(AgentNode):
description="The instructions for the scene types",
default="",
)
max_scene_types = PropertyField(
name="max_scene_types",
type="int",
description="The maximum number of scene types to generate",
default=1,
)
def __init__(self, title="Generate Scene Types", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("instructions", socket_type="str", optional=True)
@@ -90,29 +85,26 @@ class GenerateSceneTypes(AgentNode):
self.set_property("instructions", "")
self.set_property("max_scene_types", 1)
self.add_output("state")
async def run(self, state: GraphState):
instructions = self.normalized_input_value("instructions")
max_scene_types = self.normalized_input_value("max_scene_types")
scene_types = await self.agent.auto_direct_generate_scene_types(
instructions=instructions,
max_scene_types=max_scene_types
instructions=instructions, max_scene_types=max_scene_types
)
self.set_output_values({
"state": state,
"scene_types": scene_types
})
self.set_output_values({"state": state, "scene_types": scene_types})
@register("agents/director/auto-direct/IsDueForInstruction")
class IsDueForInstruction(AgentNode):
"""
Checks if the actor is due for instruction based on the auto-direct settings.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
class Fields:
actor_name = PropertyField(
name="actor_name",
@@ -120,24 +112,21 @@ class IsDueForInstruction(AgentNode):
description="The name of the actor to check instruction timing for",
default="",
)
def __init__(self, title="Is Due For Instruction", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("actor_name", socket_type="str")
self.set_property("actor_name", "")
self.add_output("is_due", socket_type="bool")
self.add_output("actor_name", socket_type="str")
async def run(self, state: GraphState):
actor_name = self.require_input("actor_name")
is_due = self.agent.auto_direct_is_due_for_instruction(actor_name)
self.set_output_values({
"is_due": is_due,
"actor_name": actor_name
})
self.set_output_values({"is_due": is_due, "actor_name": actor_name})

View File

@@ -0,0 +1,336 @@
from typing import TYPE_CHECKING
import traceback
import structlog
import talemate.instance as instance
import talemate.agents.tts.voice_library as voice_library
from talemate.agents.tts.schema import Voice
from talemate.util import random_color
from talemate.character import deactivate_character, set_voice
from talemate.status import LoadingStatus
from talemate.exceptions import GenerationCancelled
from talemate.agents.base import AgentAction, AgentActionConfig, set_processing
import talemate.game.focal as focal
__all__ = [
"CharacterManagementMixin",
]
log = structlog.get_logger()
if TYPE_CHECKING:
from talemate import Character, Scene
from talemate.agents.tts import TTSAgent
class VoiceCandidate(Voice):
used: bool = False
class CharacterManagementMixin:
"""
Director agent mixin that provides functionality for automatically guiding
the actors or the narrator during the scene progression.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["character_management"] = AgentAction(
enabled=True,
container=True,
can_be_disabled=False,
label="Character Management",
icon="mdi-account",
description="Configure how the director manages characters.",
config={
"assign_voice": AgentActionConfig(
type="bool",
label="Assign Voice (TTS)",
description="If enabled, the director is allowed to assign a text-to-speech voice when persisting a character.",
value=True,
title="Persisting Characters",
),
},
)
# config property helpers
@property
def cm_assign_voice(self) -> bool:
return self.actions["character_management"].config["assign_voice"].value
@property
def cm_should_assign_voice(self) -> bool:
if not self.cm_assign_voice:
return False
tts_agent: "TTSAgent" = instance.get_agent("tts")
if not tts_agent.enabled:
return False
if not tts_agent.ready_apis:
return False
return True
# actions
@set_processing
async def persist_characters_from_worldstate(
self, exclude: list[str] = None
) -> list["Character"]:
created_characters = []
for character_name in self.scene.world_state.characters.keys():
if exclude and character_name.lower() in exclude:
continue
if character_name in self.scene.character_names:
continue
character = await self.persist_character(name=character_name)
created_characters.append(character)
self.scene.emit_status()
return created_characters
@set_processing
async def persist_character(
self,
name: str,
content: str = None,
attributes: str = None,
determine_name: bool = True,
templates: list[str] = None,
active: bool = True,
narrate_entry: bool = False,
narrate_entry_direction: str = "",
augment_attributes: str = "",
generate_attributes: bool = True,
description: str = "",
assign_voice: bool = True,
is_player: bool = False,
) -> "Character":
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
narrator = instance.get_agent("narrator")
memory = instance.get_agent("memory")
scene: "Scene" = self.scene
any_attribute_templates = False
loading_status = LoadingStatus(max_steps=None, cancellable=True)
# Start of character creation
log.debug("persist_character", name=name)
# Determine the character's name (or clarify if it's already set)
if determine_name:
loading_status("Determining character name")
name = await creator.determine_character_name(name, instructions=content)
log.debug("persist_character", adjusted_name=name)
if name in self.scene.all_character_names:
raise ValueError(f'Name "{name}" already exists.')
# Create the blank character
character: "Character" = self.scene.Character(name=name, is_player=is_player)
# Add the character to the scene
character.color = random_color()
if is_player:
actor = self.scene.Player(
character=character, agent=instance.get_agent("conversation")
)
else:
actor = self.scene.Actor(
character=character, agent=instance.get_agent("conversation")
)
await self.scene.add_actor(actor)
try:
# Apply any character generation templates
if templates:
loading_status("Applying character generation templates")
templates = scene.world_state_manager.template_collection.collect_all(
templates
)
log.debug("persist_character", applying_templates=templates)
await scene.world_state_manager.apply_templates(
templates.values(),
character_name=character.name,
information=content,
)
# if any of the templates are attribute templates, then we no longer need to
# generate a character sheet
any_attribute_templates = any(
template.template_type == "character_attribute"
for template in templates.values()
)
log.debug(
"persist_character", any_attribute_templates=any_attribute_templates
)
if (
any_attribute_templates
and augment_attributes
and generate_attributes
):
log.debug(
"persist_character", augmenting_attributes=augment_attributes
)
loading_status("Augmenting character attributes")
additional_attributes = await world_state.extract_character_sheet(
name=name,
text=content,
augmentation_instructions=augment_attributes,
)
character.base_attributes.update(additional_attributes)
# Generate a character sheet if there are no attribute templates
if not any_attribute_templates and generate_attributes:
loading_status("Generating character sheet")
log.debug("persist_character", extracting_character_sheet=True)
if not attributes:
attributes = await world_state.extract_character_sheet(
name=name, text=content
)
else:
attributes = world_state._parse_character_sheet(attributes)
log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
# Generate a description for the character
if not description:
loading_status("Generating character description")
description = await creator.determine_character_description(
character, information=content
)
character.description = description
log.debug("persist_character", description=description)
# Generate a dialogue instructions for the character
loading_status("Generating acting instructions")
dialogue_instructions = (
await creator.determine_character_dialogue_instructions(
character, information=content
)
)
character.dialogue_instructions = dialogue_instructions
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
# Narrate the character's entry if the option is selected
if active and narrate_entry:
loading_status("Narrating character entry")
is_present = await world_state.is_character_present(name)
if not is_present:
await narrator.action_to_narration(
"narrate_character_entry",
emit_message=True,
character=character,
narrative_direction=narrate_entry_direction,
)
if assign_voice:
await self.assign_voice_to_character(character)
# Deactivate the character if not active
if not active:
await deactivate_character(scene, character)
# Commit the character's details to long term memory
await character.commit_to_memory(memory)
self.scene.emit_status()
self.scene.world_state.emit()
loading_status.done(
message=f"{character.name} added to scene", status="success"
)
return character
except GenerationCancelled:
loading_status.done(message="Character creation cancelled", status="idle")
await scene.remove_actor(actor)
except Exception:
loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor)
log.error("Error persisting character", error=traceback.format_exc())
@set_processing
async def assign_voice_to_character(
self, character: "Character"
) -> list[focal.Call]:
tts_agent: "TTSAgent" = instance.get_agent("tts")
if not self.cm_should_assign_voice:
log.debug("assign_voice_to_character", skip=True, reason="not enabled")
return
vl: voice_library.VoiceLibrary = voice_library.get_instance()
ready_tts_apis = tts_agent.ready_apis
voices_global = voice_library.voices_for_apis(ready_tts_apis, vl)
voices_scene = voice_library.voices_for_apis(
ready_tts_apis, self.scene.voice_library
)
voices = voices_global + voices_scene
if not voices:
log.debug(
"assign_voice_to_character", skip=True, reason="no voices available"
)
return
voice_candidates = {
voice.id: VoiceCandidate(**voice.model_dump()) for voice in voices
}
for scene_character in self.scene.all_characters:
if scene_character.voice:
voice_candidates[scene_character.voice.id].used = True
async def assign_voice(voice_id: str):
voice = vl.get_voice(voice_id) or self.scene.voice_library.get_voice(
voice_id
)
if not voice:
log.error(
"assign_voice_to_character",
skip=True,
reason="voice not found",
voice_id=voice_id,
)
return
await set_voice(character, voice, auto=True)
await self.log_action(
f"Assigned voice `{voice.label}` to `{character.name}`",
"Assigned voice",
console_only=True,
)
focal_handler = focal.Focal(
self.client,
callbacks=[
focal.Callback(
name="assign_voice",
arguments=[focal.Argument(name="voice_id", type="str")],
fn=assign_voice,
),
],
max_calls=1,
character=character,
voices=list(voice_candidates.values()),
scene=self.scene,
narrator_voice=tts_agent.narrator_voice,
)
await focal_handler.request("director.cm-assign-voice")
log.debug("assign_voice_to_character", calls=focal_handler.state.calls)
return focal_handler.state.calls

View File

@@ -1,14 +1,12 @@
from typing import TYPE_CHECKING
import random
import structlog
from functools import wraps
import dataclasses
from talemate.agents.base import (
set_processing,
AgentAction,
AgentActionConfig,
AgentTemplateEmission,
DynamicInstruction,
)
from talemate.events import GameLoopStartEvent
from talemate.scene_message import NarratorMessage, CharacterMessage
@@ -24,7 +22,7 @@ __all__ = [
log = structlog.get_logger()
talemate.emit.async_signals.register(
"agent.director.generate_choices.before_generate",
"agent.director.generate_choices.before_generate",
"agent.director.generate_choices.inject_instructions",
"agent.director.generate_choices.generated",
)
@@ -32,18 +30,19 @@ talemate.emit.async_signals.register(
if TYPE_CHECKING:
from talemate.tale_mate import Character
@dataclasses.dataclass
class GenerateChoicesEmission(AgentTemplateEmission):
character: "Character | None" = None
choices: list[str] = dataclasses.field(default_factory=list)
class GenerateChoicesMixin:
"""
Director agent mixin that provides functionality for automatically guiding
the actors or the narrator during the scene progression.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_generate_choices"] = AgentAction(
@@ -65,7 +64,6 @@ class GenerateChoicesMixin:
max=1,
step=0.1,
),
"num_choices": AgentActionConfig(
type="number",
label="Number of Actions",
@@ -75,33 +73,31 @@ class GenerateChoicesMixin:
max=10,
step=1,
),
"never_auto_progress": AgentActionConfig(
type="bool",
label="Never Auto Progress on Action Selection",
description="If enabled, the scene will not auto progress after you select an action.",
value=False,
),
"instructions": AgentActionConfig(
type="blob",
label="Instructions",
description="Provide some instructions to the director for generating actions.",
value="",
),
}
},
)
# config property helpers
@property
def generate_choices_enabled(self):
return self.actions["_generate_choices"].enabled
@property
def generate_choices_chance(self):
return self.actions["_generate_choices"].config["chance"].value
@property
def generate_choices_num_choices(self):
return self.actions["_generate_choices"].config["num_choices"].value
@@ -115,24 +111,25 @@ class GenerateChoicesMixin:
return self.actions["_generate_choices"].config["instructions"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("player_turn_start").connect(self.on_player_turn_start)
talemate.emit.async_signals.get("player_turn_start").connect(
self.on_player_turn_start
)
async def on_player_turn_start(self, event: GameLoopStartEvent):
if not self.enabled:
return
if self.generate_choices_enabled:
# look backwards through history and abort if we encounter
# a character message with source "player" before either
# a character message with a different source or a narrator message
#
# this is so choices aren't generated when the player message was
# the most recent content in the scene
for i in range(len(self.scene.history) - 1, -1, -1):
message = self.scene.history[i]
if isinstance(message, NarratorMessage):
@@ -141,12 +138,11 @@ class GenerateChoicesMixin:
if message.source == "player":
return
break
if random.random() < self.generate_choices_chance:
await self.generate_choices()
await self.generate_choices()
# methods
@set_processing
async def generate_choices(
@@ -154,20 +150,23 @@ class GenerateChoicesMixin:
instructions: str = None,
character: "Character | str | None" = None,
):
emission: GenerateChoicesEmission = GenerateChoicesEmission(agent=self)
if isinstance(character, str):
character = self.scene.get_character(character)
if not character:
character = self.scene.get_player_character()
emission.character = character
await talemate.emit.async_signals.get("agent.director.generate_choices.before_generate").send(emission)
await talemate.emit.async_signals.get("agent.director.generate_choices.inject_instructions").send(emission)
await talemate.emit.async_signals.get(
"agent.director.generate_choices.before_generate"
).send(emission)
await talemate.emit.async_signals.get(
"agent.director.generate_choices.inject_instructions"
).send(emission)
response = await Prompt.request(
"director.generate-choices",
self.client,
@@ -178,7 +177,9 @@ class GenerateChoicesMixin:
"character": character,
"num_choices": self.generate_choices_num_choices,
"instructions": instructions or self.generate_choices_instructions,
"dynamic_instructions": emission.dynamic_instructions if emission else None,
"dynamic_instructions": emission.dynamic_instructions
if emission
else None,
},
)
@@ -187,10 +188,10 @@ class GenerateChoicesMixin:
choices = util.extract_list(choice_text)
# strip quotes
choices = [choice.strip().strip('"') for choice in choices]
# limit to num_choices
choices = choices[:self.generate_choices_num_choices]
choices = choices[: self.generate_choices_num_choices]
except Exception as e:
log.error("generate_choices failed", error=str(e), response=response)
return
@@ -198,15 +199,17 @@ class GenerateChoicesMixin:
emit(
"player_choice",
response,
data = {
data={
"choices": choices,
"character": character.name,
},
websocket_passthrough=True
websocket_passthrough=True,
)
emission.response = response
emission.choices = choices
await talemate.emit.async_signals.get("agent.director.generate_choices.generated").send(emission)
return emission.response
await talemate.emit.async_signals.get(
"agent.director.generate_choices.generated"
).send(emission)
return emission.response

View File

@@ -17,12 +17,11 @@ from talemate.util import strip_partial_sentences
if TYPE_CHECKING:
from talemate.tale_mate import Character
from talemate.agents.summarize.analyze_scene import SceneAnalysisEmission
from talemate.agents.editor.revision import RevisionAnalysisEmission
log = structlog.get_logger()
talemate.emit.async_signals.register(
"agent.director.guide.before_generate",
"agent.director.guide.before_generate",
"agent.director.guide.inject_instructions",
"agent.director.guide.generated",
)
@@ -32,6 +31,7 @@ talemate.emit.async_signals.register(
class DirectorGuidanceEmission(AgentTemplateEmission):
pass
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
@@ -42,29 +42,33 @@ def set_processing(fn):
@wraps(fn)
async def wrapper(self, *args, **kwargs):
emission: DirectorGuidanceEmission = DirectorGuidanceEmission(agent=self)
await talemate.emit.async_signals.get("agent.director.guide.before_generate").send(emission)
await talemate.emit.async_signals.get("agent.director.guide.inject_instructions").send(emission)
await talemate.emit.async_signals.get(
"agent.director.guide.before_generate"
).send(emission)
await talemate.emit.async_signals.get(
"agent.director.guide.inject_instructions"
).send(emission)
agent_context = active_agent.get()
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
response = await fn(self, *args, **kwargs)
emission.response = response
await talemate.emit.async_signals.get("agent.director.guide.generated").send(emission)
await talemate.emit.async_signals.get("agent.director.guide.generated").send(
emission
)
return emission.response
return wrapper
class GuideSceneMixin:
"""
Director agent mixin that provides functionality for automatically guiding
the actors or the narrator during the scene progression.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["guide_scene"] = AgentAction(
@@ -81,13 +85,13 @@ class GuideSceneMixin:
type="bool",
label="Guide actors",
description="Guide the actors in the scene. This happens during every actor turn.",
value=True
value=True,
),
"guide_narrator": AgentActionConfig(
type="bool",
label="Guide narrator",
description="Guide the narrator during the scene. This happens during the narrator's turn.",
value=True
value=True,
),
"guidance_length": AgentActionConfig(
type="text",
@@ -101,7 +105,7 @@ class GuideSceneMixin:
{"label": "Medium (512)", "value": "512"},
{"label": "Medium Long (768)", "value": "768"},
{"label": "Long (1024)", "value": "1024"},
]
],
),
"cache_guidance": AgentActionConfig(
type="bool",
@@ -109,57 +113,57 @@ class GuideSceneMixin:
description="Will not regenerate the guidance until the scene moves forward or the analysis changes.",
value=False,
quick_toggle=True,
)
}
),
},
)
# config property helpers
@property
def guide_scene(self) -> bool:
return self.actions["guide_scene"].enabled
@property
def guide_actors(self) -> bool:
return self.actions["guide_scene"].config["guide_actors"].value
@property
def guide_narrator(self) -> bool:
return self.actions["guide_scene"].config["guide_narrator"].value
@property
def guide_scene_guidance_length(self) -> int:
return int(self.actions["guide_scene"].config["guidance_length"].value)
@property
def guide_scene_cache_guidance(self) -> bool:
return self.actions["guide_scene"].config["cache_guidance"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").connect(
self.on_summarization_scene_analysis_after
)
talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").connect(
self.on_summarization_scene_analysis_after
)
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
self.on_editor_revision_analysis_before
)
async def on_summarization_scene_analysis_after(self, emission: "SceneAnalysisEmission"):
talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.after"
).connect(self.on_summarization_scene_analysis_after)
talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.cached"
).connect(self.on_summarization_scene_analysis_after)
talemate.emit.async_signals.get(
"agent.editor.revision-analysis.before"
).connect(self.on_editor_revision_analysis_before)
async def on_summarization_scene_analysis_after(
self, emission: "SceneAnalysisEmission"
):
if not self.guide_scene:
return
guidance = None
cached_guidance = await self.get_cached_guidance(emission.response)
if emission.analysis_type == "narration" and self.guide_narrator:
if cached_guidance:
guidance = cached_guidance
else:
@@ -167,15 +171,14 @@ class GuideSceneMixin:
emission.response,
response_length=self.guide_scene_guidance_length,
)
if not guidance:
log.warning("director.guide_scene.narration: Empty resonse")
return
self.set_context_states(narrator_guidance=guidance)
elif emission.analysis_type == "conversation" and self.guide_actors:
elif emission.analysis_type == "conversation" and self.guide_actors:
if cached_guidance:
guidance = cached_guidance
else:
@@ -184,94 +187,112 @@ class GuideSceneMixin:
emission.template_vars.get("character"),
response_length=self.guide_scene_guidance_length,
)
if not guidance:
log.warning("director.guide_scene.conversation: Empty resonse")
return
self.set_context_states(actor_guidance=guidance)
if guidance:
await self.set_cached_guidance(
emission.response,
guidance,
emission.analysis_type,
emission.template_vars.get("character")
emission.response,
guidance,
emission.analysis_type,
emission.template_vars.get("character"),
)
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
cached_guidance = await self.get_cached_guidance(emission.response)
if cached_guidance:
emission.dynamic_instructions.append(DynamicInstruction(
title="Guidance",
content=cached_guidance
))
emission.dynamic_instructions.append(
DynamicInstruction(title="Guidance", content=cached_guidance)
)
# helpers
def _cache_key(self) -> str:
return f"cached_guidance"
async def get_cached_guidance(self, analysis:str | None = None) -> str | None:
return "cached_guidance"
async def get_cached_guidance(self, analysis: str | None = None) -> str | None:
"""
Returns the cached guidance for the given analysis.
If analysis is not provided, it will return the cached guidance for the last analysis regardless
of the fingerprint.
"""
if not self.guide_scene_cache_guidance:
return None
key = self._cache_key()
cached_guidance = self.get_scene_state(key)
if cached_guidance:
if not analysis:
return cached_guidance.get("guidance")
elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]):
elif cached_guidance.get("fp") == self.context_fingerprint(
extra=[analysis]
):
return cached_guidance.get("guidance")
return None
async def set_cached_guidance(self, analysis:str, guidance: str, analysis_type: str, character: "Character | None" = None):
async def set_cached_guidance(
self,
analysis: str,
guidance: str,
analysis_type: str,
character: "Character | None" = None,
):
"""
Sets the cached guidance for the given analysis.
"""
key = self._cache_key()
self.set_scene_states(**{
key: {
"fp": self.context_fingerpint(extra=[analysis]),
"guidance": guidance,
"analysis_type": analysis_type,
"character": character.name if character else None,
self.set_scene_states(
**{
key: {
"fp": self.context_fingerprint(extra=[analysis]),
"guidance": guidance,
"analysis_type": analysis_type,
"character": character.name if character else None,
}
}
})
)
async def get_cached_character_guidance(self, character_name: str) -> str | None:
"""
Returns the cached guidance for the given character.
"""
key = self._cache_key()
cached_guidance = self.get_scene_state(key)
if not cached_guidance:
return None
if cached_guidance.get("character") == character_name and cached_guidance.get("analysis_type") == "conversation":
if (
cached_guidance.get("character") == character_name
and cached_guidance.get("analysis_type") == "conversation"
):
return cached_guidance.get("guidance")
return None
# methods
@set_processing
async def guide_actor_off_of_scene_analysis(self, analysis: str, character: "Character", response_length: int = 256):
async def guide_actor_off_of_scene_analysis(
self, analysis: str, character: "Character", response_length: int = 256
):
"""
Guides the actor based on the scene analysis.
"""
log.debug("director.guide_actor_off_of_scene_analysis", analysis=analysis, character=character)
log.debug(
"director.guide_actor_off_of_scene_analysis",
analysis=analysis,
character=character,
)
response = await Prompt.request(
"director.guide-conversation",
self.client,
@@ -285,19 +306,17 @@ class GuideSceneMixin:
},
)
return strip_partial_sentences(response).strip()
@set_processing
async def guide_narrator_off_of_scene_analysis(
self,
analysis: str,
response_length: int = 256
self, analysis: str, response_length: int = 256
):
"""
Guides the narrator based on the scene analysis.
"""
log.debug("director.guide_narrator_off_of_scene_analysis", analysis=analysis)
response = await Prompt.request(
"director.guide-narration",
self.client,
@@ -309,4 +328,4 @@ class GuideSceneMixin:
"max_tokens": self.client.max_token_length,
},
)
return strip_partial_sentences(response).strip()
return strip_partial_sentences(response).strip()

View File

@@ -1,4 +1,3 @@
from typing import TYPE_CHECKING
import structlog
from talemate.agents.base import (
set_processing,
@@ -9,24 +8,27 @@ from talemate.events import GameLoopActorIterEvent, SceneStateEvent
log = structlog.get_logger("talemate.agents.conversation.legacy_scene_instructions")
class LegacySceneInstructionsMixin(
GameInstructionsMixin,
):
"""
Legacy support for scoped api instructions in scenes.
This is being replaced by node based in structions, but kept for backwards compatibility.
THIS WILL BE DEPRECATED IN THE FUTURE.
"""
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.LSI_on_player_dialog)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
self.LSI_on_player_dialog
)
talemate.emit.async_signals.get("scene_init").connect(self.LSI_on_scene_init)
async def LSI_on_scene_init(self, event: SceneStateEvent):
"""
LEGACY: If game state instructions specify to be run at the start of the game loop
@@ -57,8 +59,10 @@ class LegacySceneInstructionsMixin(
if not event.actor.character.is_player:
return
log.warning(f"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
log.warning(
"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
)
if event.game_loop.had_passive_narration:
log.debug(
@@ -77,17 +81,17 @@ class LegacySceneInstructionsMixin(
not self.scene.npc_character_names
or self.scene.game_state.ops.always_direct
)
log.warning(f"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.", always_direct=always_direct)
log.warning(
"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.",
always_direct=always_direct,
)
next_direct = self.next_direct_scene
TURNS = 5
if (
next_direct % TURNS != 0
or next_direct == 0
):
if next_direct % TURNS != 0 or next_direct == 0:
if not always_direct:
log.info("direct", skip=True, next_direct=next_direct)
self.next_direct_scene += 1
@@ -112,8 +116,10 @@ class LegacySceneInstructionsMixin(
async def LSI_direct_scene(self):
"""
LEGACY: Direct the scene based scoped api scene instructions.
This is being replaced by node based instructions, but kept for
This is being replaced by node based instructions, but kept for
backwards compatibility.
"""
log.warning(f"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
await self.run_scene_instructions(self.scene)
log.warning(
"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
)
await self.run_scene_instructions(self.scene)

View File

@@ -1,29 +1,34 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from typing import ClassVar
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
TYPE_CHOICES,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
from talemate.character import Character
if TYPE_CHECKING:
from talemate.tale_mate import Scene
TYPE_CHOICES.extend([
"director/direction",
])
TYPE_CHOICES.extend(
[
"director/direction",
]
)
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
@register("agents/director/Settings")
class DirectorSettings(AgentSettingsNode):
"""
Base node to render director agent settings.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Director Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/director/PersistCharacter")
class PersistCharacter(AgentNode):
@@ -31,45 +36,131 @@ class PersistCharacter(AgentNode):
Persists a character that currently only exists as part of the given context
as a real character that can actively participate in the scene.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
class Fields:
determine_name = PropertyField(
name="determine_name",
type="bool",
description="Whether to determine the name of the character",
default=True
default=True,
)
def __init__(self, title="Persist Character", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character_name", socket_type="str")
self.add_input("context", socket_type="str", optional=True)
self.add_input("attributes", socket_type="dict,str", optional=True)
self.set_property("determine_name", True)
self.add_output("state")
self.add_output("character", socket_type="character")
async def run(self, state: GraphState):
character_name = self.get_input_value("character_name")
context = self.normalized_input_value("context")
attributes = self.normalized_input_value("attributes")
determine_name = self.normalized_input_value("determine_name")
character = await self.agent.persist_character(
name=character_name,
content=context,
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()]) if attributes else None,
determine_name=determine_name
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()])
if attributes
else None,
determine_name=determine_name,
)
self.set_output_values({
"state": state,
"character": character
})
self.set_output_values({"state": state, "character": character})
@register("agents/director/AssignVoice")
class AssignVoice(AgentNode):
"""
Assigns a voice to a character.
"""
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Assign Voice", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_output("state")
self.add_output("character", socket_type="character")
self.add_output("voice", socket_type="tts/voice")
async def run(self, state: GraphState):
character: "Character" = self.require_input("character")
await self.agent.assign_voice_to_character(character)
voice = character.voice
self.set_output_values({"state": state, "character": character, "voice": voice})
@register("agents/director/LogAction")
class LogAction(AgentNode):
"""
Logs an action to the console.
"""
_agent_name: ClassVar[str] = "director"
class Fields:
action = PropertyField(
name="action",
type="str",
description="The action to log",
default="",
)
action_description = PropertyField(
name="action_description",
type="str",
description="The description of the action",
default="",
)
console_only = PropertyField(
name="console_only",
type="bool",
description="Whether to log the action to the console only",
default=False,
)
def __init__(self, title="Log Director Action", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("action", socket_type="str")
self.add_input("action_description", socket_type="str")
self.add_input("console_only", socket_type="bool", optional=True)
self.set_property("action", "")
self.set_property("action_description", "")
self.set_property("console_only", False)
self.add_output("state")
async def run(self, state: GraphState):
state = self.require_input("state")
action = self.require_input("action")
action_description = self.require_input("action_description")
console_only = self.normalized_input_value("console_only") or False
await self.agent.log_action(
action=action,
action_description=action_description,
console_only=console_only,
)
self.set_output_values({"state": state})

View File

@@ -1,13 +1,13 @@
import pydantic
import asyncio
import structlog
import traceback
from typing import TYPE_CHECKING
from talemate.instance import get_agent
from talemate.server.websocket_plugin import Plugin
from talemate.context import interaction
from talemate.context import interaction, handle_generation_cancelled
from talemate.status import set_loading
from talemate.exceptions import GenerationCancelled
if TYPE_CHECKING:
from talemate.tale_mate import Scene
@@ -18,42 +18,52 @@ __all__ = [
log = structlog.get_logger("talemate.server.director")
class InstructionPayload(pydantic.BaseModel):
instructions:str = ""
instructions: str = ""
class SelectChoicePayload(pydantic.BaseModel):
choice: str
character:str = ""
character: str = ""
class CharacterPayload(InstructionPayload):
character:str = ""
character: str = ""
class PersistCharacterPayload(pydantic.BaseModel):
name: str
templates: list[str] | None = None
narrate_entry: bool = True
narrate_entry_direction: str = ""
active: bool = True
determine_name: bool = True
augment_attributes: str = ""
generate_attributes: bool = True
content: str = ""
description: str = ""
is_player: bool = False
class AssignVoiceToCharacterPayload(pydantic.BaseModel):
character_name: str
class DirectorWebsocketHandler(Plugin):
"""
Handles director actions
"""
router = "director"
@property
def director(self):
return get_agent("director")
@set_loading("Generating dynamic actions", cancellable=True, as_async=True)
async def handle_request_dynamic_choices(self, data: dict):
"""
@@ -61,21 +71,21 @@ class DirectorWebsocketHandler(Plugin):
"""
payload = CharacterPayload(**data)
await self.director.generate_choices(**payload.model_dump())
async def handle_select_choice(self, data: dict):
payload = SelectChoicePayload(**data)
log.debug("selecting choice", payload=payload)
if payload.character:
character = self.scene.get_character(payload.character)
else:
character = self.scene.get_player_character()
if not character:
log.error("handle_select_choice: could not find character", payload=payload)
return
# hijack the interaction state
try:
interaction_state = interaction.get()
@@ -83,24 +93,33 @@ class DirectorWebsocketHandler(Plugin):
# no interaction state
log.error("handle_select_choice: no interaction state", payload=payload)
return
interaction_state.from_choice = payload.choice
interaction_state.act_as = character.name if not character.is_player else None
interaction_state.input = f"@{payload.choice}"
async def handle_persist_character(self, data: dict):
payload = PersistCharacterPayload(**data)
scene: "Scene" = self.scene
if not payload.content:
payload.content = scene.snapshot(lines=15)
# add as asyncio task
task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
task = asyncio.create_task(
self.director.persist_character(**payload.model_dump())
)
async def handle_task_done(task):
if task.exception():
log.error("Error persisting character", error=task.exception())
await self.signal_operation_failed("Error persisting character")
exc = task.exception()
log.error("Error persisting character", error=exc)
# Handle GenerationCancelled properly to reset cancel_requested flag
if isinstance(exc, GenerationCancelled):
handle_generation_cancelled(exc)
await self.signal_operation_failed(f"Error persisting character: {exc}")
else:
self.websocket_handler.queue_put(
{
@@ -113,3 +132,62 @@ class DirectorWebsocketHandler(Plugin):
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
async def handle_assign_voice_to_character(self, data: dict):
"""
Assign a voice to a character using the director agent
"""
try:
payload = AssignVoiceToCharacterPayload(**data)
except pydantic.ValidationError as e:
await self.signal_operation_failed(str(e))
return
scene: "Scene" = self.scene
if not scene:
await self.signal_operation_failed("No scene active")
return
character = scene.get_character(payload.character_name)
if not character:
await self.signal_operation_failed(
f"Character '{payload.character_name}' not found"
)
return
character.voice = None
# Add as asyncio task
task = asyncio.create_task(self.director.assign_voice_to_character(character))
async def handle_task_done(task):
if task.exception():
exc = task.exception()
log.error("Error assigning voice to character", error=exc)
# Handle GenerationCancelled properly to reset cancel_requested flag
if isinstance(exc, GenerationCancelled):
handle_generation_cancelled(exc)
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "assign_voice_to_character_failed",
"character_name": payload.character_name,
"error": str(exc),
}
)
await self.signal_operation_failed(
f"Error assigning voice to character: {exc}"
)
else:
self.websocket_handler.queue_put(
{
"type": self.router,
"action": "assign_voice_to_character_done",
"character_name": payload.character_name,
}
)
await self.signal_operation_done()
self.scene.emit_status()
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))

View File

@@ -6,6 +6,7 @@ import structlog
import talemate.emit.async_signals
import talemate.util as util
from talemate.client import ClientBase
from talemate.prompts import Prompt
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
@@ -40,7 +41,7 @@ class EditorAgent(
agent_type = "editor"
verbose_name = "Editor"
websocket_handler = EditorWebsocketHandler
@classmethod
def init_actions(cls) -> dict[str, AgentAction]:
actions = {
@@ -56,9 +57,9 @@ class EditorAgent(
description="The formatting to use for exposition.",
value="novel",
choices=[
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
{"label": "Novel: \"Speech\" narration", "value": "novel"},
]
{"label": 'Chat RP: "Speech" *narration*', "value": "chat"},
{"label": 'Novel: "Speech" narration', "value": "novel"},
],
),
"narrator": AgentActionConfig(
type="bool",
@@ -81,16 +82,16 @@ class EditorAgent(
description="Attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
),
}
MemoryRAGMixin.add_actions(actions)
RevisionMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.client = client
self.is_enabled = True
self.actions = EditorAgent.init_actions()
@property
def enabled(self):
return self.is_enabled
@@ -102,11 +103,11 @@ class EditorAgent(
@property
def experimental(self):
return True
@property
def fix_exposition_enabled(self):
return self.actions["fix_exposition"].enabled
@property
def fix_exposition_formatting(self):
return self.actions["fix_exposition"].config["formatting"].value
@@ -114,11 +115,10 @@ class EditorAgent(
@property
def fix_exposition_narrator(self):
return self.actions["fix_exposition"].config["narrator"].value
@property
def fix_exposition_user_input(self):
return self.actions["fix_exposition"].config["user_input"].value
def connect(self, scene):
super().connect(scene)
@@ -134,7 +134,7 @@ class EditorAgent(
formatting = "md"
else:
formatting = None
if self.fix_exposition_formatting == "chat":
text = text.replace("**", "*")
text = text.replace("[", "*").replace("]", "*")
@@ -143,15 +143,14 @@ class EditorAgent(
text = text.replace("*", "")
text = text.replace("[", "").replace("]", "")
text = text.replace("(", "").replace(")", "")
cleaned = util.ensure_dialog_format(
text,
talking_character=character.name if character else None,
formatting=formatting
)
return cleaned
cleaned = util.ensure_dialog_format(
text,
talking_character=character.name if character else None,
formatting=formatting,
)
return cleaned
async def on_conversation_generated(self, emission: ConversationAgentEmission):
"""
@@ -181,11 +180,13 @@ class EditorAgent(
emission.response = edit
@set_processing
async def cleanup_character_message(self, content: str, character: Character, force: bool = False):
async def cleanup_character_message(
self, content: str, character: Character, force: bool = False
):
"""
Edits a text to make sure all narrative exposition and emotes is encased in *
"""
# if not content was generated, return it as is
if not content:
return content
@@ -210,14 +211,14 @@ class EditorAgent(
content = util.clean_dialogue(content, main_name=character.name)
content = util.strip_partial_sentences(content)
# if there are uneven quotation marks, fix them by adding a closing quote
if '"' in content and content.count('"') % 2 != 0:
content += '"'
if not self.fix_exposition_enabled and not exposition_fixed:
return content
content = self.fix_exposition_in_text(content, character)
return content
@@ -225,7 +226,7 @@ class EditorAgent(
@set_processing
async def clean_up_narration(self, content: str, force: bool = False):
content = util.strip_partial_sentences(content)
if (self.fix_exposition_enabled and self.fix_exposition_narrator or force):
if self.fix_exposition_enabled and self.fix_exposition_narrator or force:
content = self.fix_exposition_in_text(content, None)
if self.fix_exposition_formatting == "chat":
if '"' not in content and "*" not in content:
@@ -234,25 +235,27 @@ class EditorAgent(
return content
@set_processing
async def cleanup_user_input(self, text: str, as_narration: bool = False, force: bool = False):
async def cleanup_user_input(
self, text: str, as_narration: bool = False, force: bool = False
):
# special prefix characters - when found, never edit
PREFIX_CHARACTERS = ("!", "@", "/")
if text.startswith(PREFIX_CHARACTERS):
return text
if (not self.fix_exposition_user_input or not self.fix_exposition_enabled) and not force:
if (
not self.fix_exposition_user_input or not self.fix_exposition_enabled
) and not force:
return text
if not as_narration:
if self.fix_exposition_formatting == "chat":
if '"' not in text and "*" not in text:
text = f'"{text}"'
else:
return await self.clean_up_narration(text)
return self.fix_exposition_in_text(text)
@set_processing
async def add_detail(self, content: str, character: Character):
@@ -279,4 +282,4 @@ class EditorAgent(
response = util.clean_dialogue(response, main_name=character.name)
response = util.strip_partial_sentences(response)
return response
return response

View File

@@ -1,70 +1,37 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
if TYPE_CHECKING:
from talemate.tale_mate import Scene
from talemate.agents.editor import EditorAgent
log = structlog.get_logger("talemate.game.engine.nodes.agents.editor")
@register("agents/editor/Settings")
class EditorSettings(AgentSettingsNode):
"""
Base node to render editor agent settings.
"""
_agent_name:ClassVar[str] = "editor"
_agent_name: ClassVar[str] = "editor"
def __init__(self, title="Editor Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/editor/CleanUpUserInput")
class CleanUpUserInput(AgentNode):
"""
Cleans up user input.
"""
_agent_name:ClassVar[str] = "editor"
class Fields:
force = PropertyField(
name="force",
description="Force clean up",
type="bool",
default=False,
)
def __init__(self, title="Clean Up User Input", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("user_input", socket_type="str")
self.add_input("as_narration", socket_type="bool")
self.set_property("force", False)
self.add_output("cleaned_user_input", socket_type="str")
async def run(self, state: GraphState):
editor:"EditorAgent" = self.agent
user_input = self.get_input_value("user_input")
force = self.get_property("force")
as_narration = self.get_input_value("as_narration")
cleaned_user_input = await editor.cleanup_user_input(user_input, as_narration=as_narration, force=force)
self.set_output_values({
"cleaned_user_input": cleaned_user_input,
})
@register("agents/editor/CleanUpNarration")
class CleanUpNarration(AgentNode):
"""
Cleans up narration.
"""
_agent_name:ClassVar[str] = "editor"
_agent_name: ClassVar[str] = "editor"
class Fields:
force = PropertyField(
@@ -73,31 +40,41 @@ class CleanUpNarration(AgentNode):
type="bool",
default=False,
)
def __init__(self, title="Clean Up Narration", **kwargs):
def __init__(self, title="Clean Up User Input", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("narration", socket_type="str")
self.add_input("user_input", socket_type="str")
self.add_input("as_narration", socket_type="bool")
self.set_property("force", False)
self.add_output("cleaned_narration", socket_type="str")
self.add_output("cleaned_user_input", socket_type="str")
async def run(self, state: GraphState):
editor:"EditorAgent" = self.agent
narration = self.get_input_value("narration")
editor: "EditorAgent" = self.agent
user_input = self.get_input_value("user_input")
force = self.get_property("force")
cleaned_narration = await editor.cleanup_narration(narration, force=force)
self.set_output_values({
"cleaned_narration": cleaned_narration,
})
@register("agents/editor/CleanUoCharacterMessage")
class CleanUpCharacterMessage(AgentNode):
as_narration = self.get_input_value("as_narration")
cleaned_user_input = await editor.cleanup_user_input(
user_input, as_narration=as_narration, force=force
)
self.set_output_values(
{
"cleaned_user_input": cleaned_user_input,
}
)
@register("agents/editor/CleanUpNarration")
class CleanUpNarration(AgentNode):
"""
Cleans up character message.
Cleans up narration.
"""
_agent_name:ClassVar[str] = "editor"
_agent_name: ClassVar[str] = "editor"
class Fields:
force = PropertyField(
name="force",
@@ -105,22 +82,62 @@ class CleanUpCharacterMessage(AgentNode):
type="bool",
default=False,
)
def __init__(self, title="Clean Up Narration", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("narration", socket_type="str")
self.set_property("force", False)
self.add_output("cleaned_narration", socket_type="str")
async def run(self, state: GraphState):
editor: "EditorAgent" = self.agent
narration = self.get_input_value("narration")
force = self.get_property("force")
cleaned_narration = await editor.cleanup_narration(narration, force=force)
self.set_output_values(
{
"cleaned_narration": cleaned_narration,
}
)
@register("agents/editor/CleanUoCharacterMessage")
class CleanUpCharacterMessage(AgentNode):
"""
Cleans up character message.
"""
_agent_name: ClassVar[str] = "editor"
class Fields:
force = PropertyField(
name="force",
description="Force clean up",
type="bool",
default=False,
)
def __init__(self, title="Clean Up Character Message", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("text", socket_type="str")
self.add_input("character", socket_type="character")
self.set_property("force", False)
self.add_output("cleaned_character_message", socket_type="str")
async def run(self, state: GraphState):
editor:"EditorAgent" = self.agent
editor: "EditorAgent" = self.agent
text = self.get_input_value("text")
force = self.get_property("force")
character = self.get_input_value("character")
cleaned_character_message = await editor.cleanup_character_message(text, character, force=force)
self.set_output_values({
"cleaned_character_message": cleaned_character_message,
})
cleaned_character_message = await editor.cleanup_character_message(
text, character, force=force
)
self.set_output_values(
{
"cleaned_character_message": cleaned_character_message,
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
from talemate.instance import get_agent
from talemate.server.websocket_plugin import Plugin
from talemate.status import set_loading
from talemate.scene_message import CharacterMessage
from talemate.agents.editor.revision import RevisionContext, RevisionInformation
@@ -17,47 +16,52 @@ __all__ = [
log = structlog.get_logger("talemate.server.editor")
class RevisionPayload(pydantic.BaseModel):
message_id: int
class EditorWebsocketHandler(Plugin):
"""
Handles editor actions
"""
router = "editor"
@property
def editor(self):
return get_agent("editor")
async def handle_request_revision(self, data: dict):
"""
Generate clickable actions for the user
"""
editor = self.editor
scene:"Scene" = self.scene
scene: "Scene" = self.scene
if not editor.revision_enabled:
raise Exception("Revision is not enabled")
payload = RevisionPayload(**data)
message = scene.get_message(payload.message_id)
character = None
if isinstance(message, CharacterMessage):
character = scene.get_character(message.character_name)
if not message:
raise Exception("Message not found")
with RevisionContext(message.id):
info = RevisionInformation(
text=message.message,
character=character,
)
revised = await editor.revision_revise(info)
scene.edit_message(message.id, revised)
if isinstance(message, CharacterMessage):
if not revised.startswith(character.name + ":"):
revised = f"{character.name}: {revised}"
scene.edit_message(message.id, revised)

View File

@@ -23,10 +23,9 @@ from talemate.agents.base import (
AgentDetail,
set_processing,
)
from talemate.config import load_config
from talemate.config.schema import EmbeddingFunctionPreset
from talemate.context import scene_is_loading, active_scene
from talemate.emit import emit
from talemate.emit.signals import handlers
import talemate.emit.async_signals as async_signals
from talemate.agents.memory.context import memory_request, MemoryRequest
from talemate.agents.memory.exceptions import (
@@ -75,10 +74,9 @@ class MemoryAgent(Agent):
@classmethod
def init_actions(cls, presets: list[dict] | None = None) -> dict[str, AgentAction]:
if presets is None:
presets = []
actions = {
"_config": AgentAction(
enabled=True,
@@ -101,23 +99,24 @@ class MemoryAgent(Agent):
choices=[
{"value": "cpu", "label": "CPU"},
{"value": "cuda", "label": "CUDA"},
]
],
),
},
),
}
return actions
def __init__(self, scene, **kwargs):
def __init__(self, **kwargs):
self.db = None
self.scene = scene
self.memory_tracker = {}
self.config = load_config()
self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved)
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
async_signals.get("config.changed").connect(self.on_config_changed)
async_signals.get("client.embeddings_available").connect(
self.on_client_embeddings_available
)
self.actions = MemoryAgent.init_actions(presets=self.get_presets)
@property
@@ -132,30 +131,33 @@ class MemoryAgent(Agent):
@property
def db_name(self):
raise NotImplementedError()
@property
def get_presets(self):
def _label(embedding:dict):
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
if embedding['model']:
return f"{prefix}: {embedding['model']}"
def _label(embedding: EmbeddingFunctionPreset):
prefix = embedding.client if embedding.client else embedding.embeddings
if embedding.model:
return f"{prefix}: {embedding.model}"
else:
return f"{prefix}"
return [
{"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
{"value": k, "label": _label(v)}
for k, v in self.config.presets.embeddings.items()
]
@property
def embeddings_config(self):
_embeddings = self.actions["_config"].config["embeddings"].value
return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {})
return self.config.presets.embeddings.get(_embeddings)
@property
def embeddings(self):
return self.embeddings_config.get("embeddings", "sentence-transformer")
try:
return self.embeddings_config.embeddings
except AttributeError:
return None
@property
def using_openai_embeddings(self):
return self.embeddings == "openai"
@@ -163,42 +165,46 @@ class MemoryAgent(Agent):
@property
def using_instructor_embeddings(self):
return self.embeddings == "instructor"
@property
def using_sentence_transformer_embeddings(self):
return self.embeddings == "default" or self.embeddings == "sentence-transformer"
@property
def using_client_api_embeddings(self):
return self.embeddings == "client-api"
@property
def using_local_embeddings(self):
return self.embeddings in [
"instructor",
"sentence-transformer",
"default"
]
return self.embeddings in ["instructor", "sentence-transformer", "default"]
@property
def embeddings_client(self):
return self.embeddings_config.get("client")
try:
return self.embeddings_config.client
except AttributeError:
return None
@property
def max_distance(self) -> float:
distance = float(self.embeddings_config.get("distance", 1.0))
distance_mod = float(self.embeddings_config.get("distance_mod", 1.0))
distance = float(self.embeddings_config.distance)
distance_mod = float(self.embeddings_config.distance_mod)
return distance * distance_mod
@property
def model(self):
return self.embeddings_config.get("model")
try:
return self.embeddings_config.model
except AttributeError:
return None
@property
def distance_function(self):
return self.embeddings_config.get("distance_function", "l2")
try:
return self.embeddings_config.distance_function
except AttributeError:
return None
@property
def device(self) -> str:
@@ -206,32 +212,38 @@ class MemoryAgent(Agent):
@property
def trust_remote_code(self) -> bool:
return self.embeddings_config.get("trust_remote_code", False)
try:
return self.embeddings_config.trust_remote_code
except AttributeError:
return False
@property
def fingerprint(self) -> str:
"""
Returns a unique fingerprint for the current configuration
"""
model_name = self.model.replace('/','-') if self.model else "none"
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
model_name = self.model.replace("/", "-") if self.model else "none"
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
async def apply_config(self, *args, **kwargs):
_fingerprint = self.fingerprint
await super().apply_config(*args, **kwargs)
fingerprint_changed = _fingerprint != self.fingerprint
# have embeddings or device changed?
if fingerprint_changed:
log.warning("memory agent", status="embedding function changed", old=_fingerprint, new=self.fingerprint)
log.warning(
"memory agent",
status="embedding function changed",
old=_fingerprint,
new=self.fingerprint,
)
await self.handle_embeddings_change()
@set_processing
async def handle_embeddings_change(self):
scene = active_scene.get()
@@ -239,10 +251,10 @@ class MemoryAgent(Agent):
# if sentence-transformer and no model-name, set embeddings to default
if self.using_sentence_transformer_embeddings and not self.model:
self.actions["_config"].config["embeddings"].value = "default"
if not scene or not scene.get_helper("memory"):
if not scene or not scene.active:
return
self.close_db(scene)
emit("status", "Re-importing context database", status="busy")
await scene.commit_to_memory()
@@ -254,45 +266,66 @@ class MemoryAgent(Agent):
self.actions["_config"].config["embeddings"].choices = self.get_presets
return self.actions["_config"].config["embeddings"].choices
def on_config_saved(self, event):
loop = asyncio.get_running_loop()
openai_key = self.openai_api_key
fingerprint = self.fingerprint
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
self.config = load_config()
new_presets = self.sync_presets()
if fingerprint != self.fingerprint:
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
loop.run_until_complete(self.handle_embeddings_change())
emit_status = False
if openai_key != self.openai_api_key:
emit_status = True
if old_presets != new_presets:
emit_status = True
if emit_status:
loop.run_until_complete(self.emit_status())
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
current_embeddings = self.actions["_config"].config["embeddings"].value
if current_embeddings == event.client.embeddings_identifier:
return
if not self.using_client_api_embeddings or not self.ready:
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
async def fix_broken_embeddings(self):
if not self.embeddings_config:
self.actions["_config"].config["embeddings"].value = "default"
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
async def on_config_changed(self, event):
loop = asyncio.get_running_loop()
openai_key = self.openai_api_key
fingerprint = self.fingerprint
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
new_presets = self.sync_presets()
if fingerprint != self.fingerprint:
log.warning(
"memory agent",
status="embedding function changed",
old=fingerprint,
new=self.fingerprint,
)
loop.run_until_complete(self.handle_embeddings_change())
emit_status = False
if openai_key != self.openai_api_key:
emit_status = True
if old_presets != new_presets:
emit_status = True
if emit_status:
loop.run_until_complete(self.emit_status())
await self.fix_broken_embeddings()
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
current_embeddings = self.actions["_config"].config["embeddings"].value
if current_embeddings == event.client.embeddings_identifier:
event.seen = True
return
if not self.using_client_api_embeddings or not self.ready:
log.warning(
"memory agent - client embeddings available",
status="changing embeddings",
old=current_embeddings,
new=event.client.embeddings_identifier,
)
self.actions["_config"].config[
"embeddings"
].value = event.client.embeddings_identifier
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
event.seen = True
@set_processing
async def set_db(self):
loop = asyncio.get_running_loop()
@@ -301,13 +334,16 @@ class MemoryAgent(Agent):
except EmbeddingsModelLoadError:
raise
except Exception as e:
log.error("memory agent", error="failed to set db", details=traceback.format_exc())
if "torchvision::nms does not exist" in str(e):
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
raise SetDBError(str(e))
log.error(
"memory agent", error="failed to set db", details=traceback.format_exc()
)
if "torchvision::nms does not exist" in str(e):
raise SetDBError(
"The embeddings you are trying to use require the `torchvision` package to be installed"
)
raise SetDBError(str(e))
def close_db(self):
raise NotImplementedError()
@@ -426,9 +462,9 @@ class MemoryAgent(Agent):
with MemoryRequest(query=text, query_params=query) as active_memory_request:
active_memory_request.max_distance = self.max_distance
return await asyncio.to_thread(self._get, text, character, **query)
#return await loop.run_in_executor(
# return await loop.run_in_executor(
# None, functools.partial(self._get, text, character, **query)
#)
# )
def _get(self, text, character=None, **query):
raise NotImplementedError()
@@ -528,9 +564,7 @@ class MemoryAgent(Agent):
continue
# Fetch potential memories for this query.
raw_results = await self.get(
formatter(query), limit=limit, **where
)
raw_results = await self.get(formatter(query), limit=limit, **where)
# Apply filter and respect the `iterate` limit for this query.
accepted: list[str] = []
@@ -591,9 +625,9 @@ class MemoryAgent(Agent):
Returns a dictionary with 'cosine_similarity' and 'euclidean_distance'.
"""
embed_fn = self.embedding_function
# Embed the two strings
vec1 = np.array(embed_fn([string1])[0])
vec2 = np.array(embed_fn([string2])[0])
@@ -604,17 +638,14 @@ class MemoryAgent(Agent):
# Compute Euclidean distance
euclidean_dist = np.linalg.norm(vec1 - vec2)
return {
"cosine_similarity": cosine_sim,
"euclidean_distance": euclidean_dist
}
return {"cosine_similarity": cosine_sim, "euclidean_distance": euclidean_dist}
async def compare_string_lists(
self,
list_a: list[str],
list_b: list[str],
similarity_threshold: float = None,
distance_threshold: float = None
distance_threshold: float = None,
) -> dict:
"""
Compare two lists of strings using the current embedding function without touching the database.
@@ -625,15 +656,21 @@ class MemoryAgent(Agent):
- 'similarity_matches': list of (i, j, score) (filtered if threshold set, otherwise all)
- 'distance_matches': list of (i, j, distance) (filtered if threshold set, otherwise all)
"""
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
if (
not self.db
or not hasattr(self.db, "_embedding_function")
or self.db._embedding_function is None
):
raise RuntimeError(
"Embedding function is not initialized. Make sure the database is set."
)
if not list_a or not list_b:
return {
"cosine_similarity_matrix": np.array([]),
"euclidean_distance_matrix": np.array([]),
"similarity_matches": [],
"distance_matches": []
"distance_matches": [],
}
embed_fn = self.db._embedding_function
@@ -653,21 +690,29 @@ class MemoryAgent(Agent):
cosine_similarity_matrix = np.dot(vecs_a_norm, vecs_b_norm.T)
# Euclidean distance matrix
a_squared = np.sum(vecs_a ** 2, axis=1).reshape(-1, 1)
b_squared = np.sum(vecs_b ** 2, axis=1).reshape(1, -1)
euclidean_distance_matrix = np.sqrt(a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T))
a_squared = np.sum(vecs_a**2, axis=1).reshape(-1, 1)
b_squared = np.sum(vecs_b**2, axis=1).reshape(1, -1)
euclidean_distance_matrix = np.sqrt(
a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T)
)
# Prepare matches
similarity_matches = []
distance_matches = []
# Populate similarity matches
sim_indices = np.argwhere(cosine_similarity_matrix >= (similarity_threshold if similarity_threshold is not None else -np.inf))
sim_indices = np.argwhere(
cosine_similarity_matrix
>= (similarity_threshold if similarity_threshold is not None else -np.inf)
)
for i, j in sim_indices:
similarity_matches.append((i, j, cosine_similarity_matrix[i, j]))
# Populate distance matches
dist_indices = np.argwhere(euclidean_distance_matrix <= (distance_threshold if distance_threshold is not None else np.inf))
dist_indices = np.argwhere(
euclidean_distance_matrix
<= (distance_threshold if distance_threshold is not None else np.inf)
)
for i, j in dist_indices:
distance_matches.append((i, j, euclidean_distance_matrix[i, j]))
@@ -675,9 +720,10 @@ class MemoryAgent(Agent):
"cosine_similarity_matrix": cosine_similarity_matrix,
"euclidean_distance_matrix": euclidean_distance_matrix,
"similarity_matches": similarity_matches,
"distance_matches": distance_matches
"distance_matches": distance_matches,
}
@register(condition=lambda: chromadb is not None)
class ChromaDBMemoryAgent(MemoryAgent):
requires_llm_client = False
@@ -690,32 +736,34 @@ class ChromaDBMemoryAgent(MemoryAgent):
if getattr(self, "db_client", None):
return True
return False
@property
def client_api_ready(self) -> bool:
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
embeddings_client: ClientBase | None = instance.get_client(
self.embeddings_client
)
if not embeddings_client:
return False
if not embeddings_client.supports_embeddings:
return False
if not embeddings_client.embeddings_status:
return False
if embeddings_client.current_status not in ["idle", "busy"]:
return False
return True
return False
@property
def status(self):
if self.using_client_api_embeddings and not self.client_api_ready:
return "error"
if self.ready:
return "active" if not getattr(self, "processing", False) else "busy"
@@ -726,7 +774,6 @@ class ChromaDBMemoryAgent(MemoryAgent):
@property
def agent_details(self):
details = {
"backend": AgentDetail(
icon="mdi-server-outline",
@@ -738,23 +785,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
value=self.embeddings,
description="The embeddings type.",
).model_dump(),
}
if self.model:
details["model"] = AgentDetail(
icon="mdi-brain",
value=self.model,
description="The embeddings model.",
).model_dump()
if self.embeddings_client:
details["client"] = AgentDetail(
icon="mdi-network-outline",
value=self.embeddings_client,
description="The client to use for embeddings.",
).model_dump()
if self.using_local_embeddings:
details["device"] = AgentDetail(
icon="mdi-memory",
@@ -770,9 +816,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
"color": "error",
}
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
embeddings_client: ClientBase | None = instance.get_client(
self.embeddings_client
)
if not embeddings_client:
details["error"] = {
@@ -784,7 +832,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
return details
client_name = embeddings_client.name
if not embeddings_client.supports_embeddings:
error_message = f"{client_name} does not support embeddings"
elif embeddings_client.current_status not in ["idle", "busy"]:
@@ -793,7 +841,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
error_message = f"{client_name} has no embeddings model loaded"
else:
error_message = None
if error_message:
details["error"] = {
"icon": "mdi-alert",
@@ -810,12 +858,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
return self.config.openai.api_key
@property
def embedding_function(self) -> Callable:
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
if (
not self.db
or not hasattr(self.db, "_embedding_function")
or self.db._embedding_function is None
):
raise RuntimeError(
"Embedding function is not initialized. Make sure the database is set."
)
embed_fn = self.db._embedding_function
return embed_fn
@@ -823,18 +877,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
def make_collection_name(self, scene) -> str:
# generate plain text collection name
collection_name = f"{self.fingerprint}"
# chromadb collection names have the following rules:
# Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address
# Step 1: Hash the input string using MD5
md5_hash = hashlib.md5(collection_name.encode()).hexdigest()
# Step 2: Ensure the result is exactly 32 characters long
hashed_collection_name = md5_hash[:32]
return f"{scene.memory_id}-tm-{hashed_collection_name}"
async def count(self):
await asyncio.sleep(0)
return self.db.count()
@@ -853,14 +907,17 @@ class ChromaDBMemoryAgent(MemoryAgent):
self.collection_name = collection_name = self.make_collection_name(self.scene)
log.info(
"chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
"chromadb agent",
status="setting up db",
collection_name=collection_name,
embeddings=self.embeddings,
)
distance_function = self.distance_function
collection_metadata = {"hnsw:space": distance_function}
device = self.actions["_config"].config["device"].value
device = self.actions["_config"].config["device"].value
model_name = self.model
if self.using_openai_embeddings:
if not openai_key:
raise ValueError(
@@ -878,7 +935,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
model_name=model_name,
)
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=openai_ef, metadata=collection_metadata
collection_name,
embedding_function=openai_ef,
metadata=collection_metadata,
)
elif self.using_client_api_embeddings:
log.info(
@@ -886,20 +945,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
embeddings="Client API",
client=self.embeddings_client,
)
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
embeddings_client: ClientBase | None = instance.get_client(
self.embeddings_client
)
if not embeddings_client:
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
raise ValueError(
f"Client API embeddings client {self.embeddings_client} not found"
)
if not embeddings_client.supports_embeddings:
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
raise ValueError(
f"Client API embeddings client {self.embeddings_client} does not support embeddings"
)
ef = embeddings_client.embeddings_function
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef, metadata=collection_metadata
)
elif self.using_instructor_embeddings:
log.info(
"chromadb",
@@ -909,7 +974,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
)
ef = embedding_functions.InstructorEmbeddingFunction(
model_name=model_name, device=device, instruction="Represent the document for retrieval:"
model_name=model_name,
device=device,
instruction="Represent the document for retrieval:",
)
log.info("chromadb", status="embedding function ready")
@@ -919,25 +986,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
log.info("chromadb", status="instructor db ready")
else:
log.info(
"chromadb",
embeddings="SentenceTransformer",
"chromadb",
embeddings="SentenceTransformer",
model=model_name,
device=device,
distance_function=distance_function
distance_function=distance_function,
)
try:
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=model_name,
trust_remote_code=self.trust_remote_code,
device=device
device=device,
)
except ValueError as e:
if "`trust_remote_code=True` to remove this error" in str(e):
raise EmbeddingsModelLoadError(model_name, "Model requires `Trust remote code` to be enabled")
raise EmbeddingsModelLoadError(
model_name, "Model requires `Trust remote code` to be enabled"
)
raise EmbeddingsModelLoadError(model_name, str(e))
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef, metadata=collection_metadata
)
@@ -989,9 +1057,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
try:
self.db_client.delete_collection(collection_name)
except chromadb.errors.NotFoundError as exc:
log.error(
"chromadb agent", error="collection not found", details=exc
)
log.error("chromadb agent", error="collection not found", details=exc)
except ValueError as exc:
log.error(
"chromadb agent", error="failed to delete collection", details=exc
@@ -1049,16 +1115,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
if not objects:
return
# track seen documents by id
seen_ids = set()
for obj in objects:
if obj["id"] in seen_ids:
log.warning("chromadb agent", status="duplicate id discarded", id=obj["id"])
log.warning(
"chromadb agent", status="duplicate id discarded", id=obj["id"]
)
continue
seen_ids.add(obj["id"])
documents.append(obj["text"])
meta = obj.get("meta", {})
source = meta.get("source", "talemate")
@@ -1071,7 +1139,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
metadatas.append(meta)
uid = obj.get("id", f"{character}-{self.memory_tracker[character]}")
ids.append(uid)
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
def _delete(self, meta: dict):
@@ -1081,11 +1149,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
return
where = {"$and": [{k: v} for k, v in meta.items()]}
# if there is only one item in $and reduce it to the key value pair
if len(where["$and"]) == 1:
where = where["$and"][0]
self.db.delete(where=where)
log.debug("chromadb agent delete", meta=meta, where=where)
@@ -1122,16 +1190,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
log.error("chromadb agent", error="failed to query", details=e)
return []
#import json
#print(json.dumps(_results["ids"], indent=2))
#print(json.dumps(_results["distances"], indent=2))
# import json
# print(json.dumps(_results["ids"], indent=2))
# print(json.dumps(_results["distances"], indent=2))
results = []
max_distance = self.max_distance
closest = None
active_memory_request = memory_request.get()
for i in range(len(_results["distances"][0])):
@@ -1139,7 +1207,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
doc = _results["documents"][0][i]
meta = _results["metadatas"][0][i]
active_memory_request.add_result(doc, distance, meta)
if not meta:
@@ -1150,7 +1218,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
# skip pin_only entries
if meta.get("pin_only", False):
continue
if closest is None:
closest = {"distance": distance, "doc": doc}
elif distance < closest["distance"]:
@@ -1172,13 +1240,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
if len(results) > limit:
break
log.debug("chromadb agent get", closest=closest, max_distance=max_distance)
self.last_query = {
"query": text,
"closest": closest,
}
return results
def convert_ts_to_date_prefix(self, ts):
@@ -1197,10 +1265,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
return None
def _get_document(self, id) -> dict:
if not id:
return {}
result = self.db.get(ids=[id] if isinstance(id, str) else id)
documents = {}

View File

@@ -1,5 +1,5 @@
"""
Context manager that collects and tracks memory agent requests
Context manager that collects and tracks memory agent requests
for profiling and debugging purposes
"""
@@ -11,91 +11,118 @@ import time
from talemate.emit import emit
from talemate.agents.context import active_agent
__all__ = [
"MemoryRequest",
"start_memory_request"
"MemoryRequestState"
"memory_request"
]
__all__ = ["MemoryRequest"]
log = structlog.get_logger()
DEBUG_MEMORY_REQUESTS = False
class MemoryRequestResult(pydantic.BaseModel):
doc: str
distance: float
meta: dict = pydantic.Field(default_factory=dict)
class MemoryRequestState(pydantic.BaseModel):
query:str
query: str
results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
accepted_results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
query_params: dict = pydantic.Field(default_factory=dict)
closest_distance: float | None = None
furthest_distance: float | None = None
max_distance: float | None = None
def add_result(self, doc:str, distance:float, meta:dict):
def add_result(self, doc: str, distance: float, meta: dict):
if doc is None:
return
self.results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
self.closest_distance = min(self.closest_distance, distance) if self.closest_distance is not None else distance
self.furthest_distance = max(self.furthest_distance, distance) if self.furthest_distance is not None else distance
def accept_result(self, doc:str, distance:float, meta:dict):
self.closest_distance = (
min(self.closest_distance, distance)
if self.closest_distance is not None
else distance
)
self.furthest_distance = (
max(self.furthest_distance, distance)
if self.furthest_distance is not None
else distance
)
def accept_result(self, doc: str, distance: float, meta: dict):
if doc is None:
return
self.accepted_results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
self.accepted_results.append(
MemoryRequestResult(doc=doc, distance=distance, meta=meta)
)
@property
def closest_text(self):
return str(self.results[0].doc) if self.results else None
memory_request = contextvars.ContextVar("memory_request", default=None)
class MemoryRequest:
def __init__(self, query:str, query_params:dict=None):
def __init__(self, query: str, query_params: dict = None):
self.query = query
self.query_params = query_params
def __enter__(self):
self.state = MemoryRequestState(query=self.query, query_params=self.query_params)
self.state = MemoryRequestState(
query=self.query, query_params=self.query_params
)
self.token = memory_request.set(self.state)
self.time_start = time.time()
return self.state
def __exit__(self, *args):
self.time_end = time.time()
if DEBUG_MEMORY_REQUESTS:
max_length = 50
query = self.state.query[:max_length]+"..." if len(self.state.query) > max_length else self.state.query
log.debug("MemoryRequest", number_of_results=len(self.state.results), query=query)
log.debug("MemoryRequest", number_of_accepted_results=len(self.state.accepted_results), query=query)
query = (
self.state.query[:max_length] + "..."
if len(self.state.query) > max_length
else self.state.query
)
log.debug(
"MemoryRequest", number_of_results=len(self.state.results), query=query
)
log.debug(
"MemoryRequest",
number_of_accepted_results=len(self.state.accepted_results),
query=query,
)
for result in self.state.results:
# distance to 2 decimal places
log.debug("MemoryRequest RESULT", distance=f"{result.distance:.2f}", doc=result.doc[:max_length]+"...")
log.debug(
"MemoryRequest RESULT",
distance=f"{result.distance:.2f}",
doc=result.doc[:max_length] + "...",
)
agent_context = active_agent.get()
emit("memory_request", data=self.state.model_dump(), meta={
"agent_stack": agent_context.agent_stack if agent_context else [],
"agent_stack_uid": agent_context.agent_stack_uid if agent_context else None,
"duration": self.time_end - self.time_start,
}, websocket_passthrough=True)
emit(
"memory_request",
data=self.state.model_dump(),
meta={
"agent_stack": agent_context.agent_stack if agent_context else [],
"agent_stack_uid": agent_context.agent_stack_uid
if agent_context
else None,
"duration": self.time_end - self.time_start,
},
websocket_passthrough=True,
)
memory_request.reset(self.token)
return False
# decorator that opens a memory request context
async def start_memory_request(query):
@@ -103,5 +130,7 @@ async def start_memory_request(query):
async def wrapper(*args, **kwargs):
with MemoryRequest(query):
return await fn(*args, **kwargs)
return wrapper
return decorator
return decorator

View File

@@ -1,18 +1,17 @@
__all__ = [
'EmbeddingsModelLoadError',
'MemoryAgentError',
'SetDBError'
]
__all__ = ["EmbeddingsModelLoadError", "MemoryAgentError", "SetDBError"]
class MemoryAgentError(Exception):
pass
class SetDBError(OSError, MemoryAgentError):
def __init__(self, details:str):
def __init__(self, details: str):
super().__init__(f"Memory Agent - Failed to set up the database: {details}")
class EmbeddingsModelLoadError(ValueError, MemoryAgentError):
def __init__(self, model_name:str, details:str):
super().__init__(f"Memory Agent - Failed to load embeddings model {model_name}: {details}")
def __init__(self, model_name: str, details: str):
super().__init__(
f"Memory Agent - Failed to load embeddings model {model_name}: {details}"
)

View File

@@ -4,7 +4,6 @@ from talemate.agents.base import (
AgentAction,
AgentActionConfig,
)
from talemate.emit import emit
import talemate.instance as instance
if TYPE_CHECKING:
@@ -14,11 +13,10 @@ __all__ = ["MemoryRAGMixin"]
log = structlog.get_logger()
class MemoryRAGMixin:
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["use_long_term_memory"] = AgentAction(
enabled=True,
container=True,
@@ -44,7 +42,7 @@ class MemoryRAGMixin:
{
"label": "AI compiled question and answers (slow)",
"value": "questions",
}
},
],
),
"number_of_queries": AgentActionConfig(
@@ -65,7 +63,7 @@ class MemoryRAGMixin:
{"label": "Short (256)", "value": "256"},
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"},
]
],
),
"cache": AgentActionConfig(
type="bool",
@@ -73,16 +71,16 @@ class MemoryRAGMixin:
description="Cache the long term memory for faster retrieval.",
note="This is a cross-agent cache, assuming they use the same options.",
value=True,
)
),
},
)
# config property helpers
@property
def long_term_memory_enabled(self):
return self.actions["use_long_term_memory"].enabled
@property
def long_term_memory_retrieval_method(self):
return self.actions["use_long_term_memory"].config["retrieval_method"].value
@@ -90,60 +88,60 @@ class MemoryRAGMixin:
@property
def long_term_memory_number_of_queries(self):
return self.actions["use_long_term_memory"].config["number_of_queries"].value
@property
def long_term_memory_answer_length(self):
return int(self.actions["use_long_term_memory"].config["answer_length"].value)
@property
def long_term_memory_cache(self):
return self.actions["use_long_term_memory"].config["cache"].value
@property
def long_term_memory_cache_key(self):
"""
Build the key from the various options
"""
parts = [
self.long_term_memory_retrieval_method,
self.long_term_memory_number_of_queries,
self.long_term_memory_answer_length
self.long_term_memory_answer_length,
]
return "-".join(map(str, parts))
def connect(self, scene):
super().connect(scene)
# new scene, reset cache
scene.rag_cache = {}
# methods
async def rag_set_cache(self, content:list[str]):
async def rag_set_cache(self, content: list[str]):
self.scene.rag_cache[self.long_term_memory_cache_key] = {
"content": content,
"fingerprint": self.scene.history[-1].fingerprint if self.scene.history else 0
"fingerprint": self.scene.history[-1].fingerprint
if self.scene.history
else 0,
}
async def rag_get_cache(self) -> list[str] | None:
if not self.long_term_memory_cache:
return None
fingerprint = self.scene.history[-1].fingerprint if self.scene.history else 0
cache = self.scene.rag_cache.get(self.long_term_memory_cache_key)
if cache and cache["fingerprint"] == fingerprint:
return cache["content"]
return None
async def rag_build(
self,
character: "Character | None" = None,
self,
character: "Character | None" = None,
prompt: str = "",
sub_instruction: str = "",
) -> list[str]:
@@ -153,37 +151,41 @@ class MemoryRAGMixin:
if not self.long_term_memory_enabled:
return []
cached = await self.rag_get_cache()
if cached:
log.debug(f"Using cached long term memory", agent=self.agent_type, key=self.long_term_memory_cache_key)
log.debug(
"Using cached long term memory",
agent=self.agent_type,
key=self.long_term_memory_cache_key,
)
return cached
memory_context = ""
retrieval_method = self.long_term_memory_retrieval_method
if not sub_instruction:
if character:
sub_instruction = f"continue the scene as {character.name}"
elif hasattr(self, "rag_build_sub_instruction"):
sub_instruction = await self.rag_build_sub_instruction()
if not sub_instruction:
sub_instruction = "continue the scene"
if retrieval_method != "direct":
world_state = instance.get_agent("world_state")
if not prompt:
prompt = self.scene.context_history(
keep_director=False,
budget=int(self.client.max_token_length * 0.75),
)
if isinstance(prompt, list):
prompt = "\n".join(prompt)
log.debug(
"memory_rag_mixin.build_prompt_default_memory",
direct=False,
@@ -193,20 +195,21 @@ class MemoryRAGMixin:
if retrieval_method == "questions":
memory_context = (
await world_state.analyze_text_and_extract_context(
prompt, sub_instruction,
prompt,
sub_instruction,
include_character_context=True,
response_length=self.long_term_memory_answer_length,
num_queries=self.long_term_memory_number_of_queries
num_queries=self.long_term_memory_number_of_queries,
)
).split("\n")
elif retrieval_method == "queries":
memory_context = (
await world_state.analyze_text_and_extract_context_via_queries(
prompt, sub_instruction,
prompt,
sub_instruction,
include_character_context=True,
response_length=self.long_term_memory_answer_length,
num_queries=self.long_term_memory_number_of_queries
num_queries=self.long_term_memory_number_of_queries,
)
)
@@ -223,4 +226,4 @@ class MemoryRAGMixin:
await self.rag_set_cache(memory_context)
return memory_context
return memory_context

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import dataclasses
import random
from functools import wraps
from inspect import signature
from typing import TYPE_CHECKING
import structlog
@@ -50,15 +49,18 @@ log = structlog.get_logger("talemate.agents.narrator")
class NarratorAgentEmission(AgentEmission):
generation: list[str] = dataclasses.field(default_factory=list)
response: str = dataclasses.field(default="")
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
talemate.emit.async_signals.register(
"agent.narrator.before_generate",
"agent.narrator.before_generate",
"agent.narrator.inject_instructions",
"agent.narrator.generated",
)
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
@@ -74,11 +76,15 @@ def set_processing(fn):
if self.content_use_writing_style:
self.set_context_states(writing_style=self.scene.writing_style)
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(emission)
await talemate.emit.async_signals.get("agent.narrator.inject_instructions").send(emission)
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(
emission
)
await talemate.emit.async_signals.get(
"agent.narrator.inject_instructions"
).send(emission)
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
response = await fn(self, *args, **kwargs)
emission.response = response
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
@@ -88,10 +94,7 @@ def set_processing(fn):
@register()
class NarratorAgent(
MemoryRAGMixin,
Agent
):
class NarratorAgent(MemoryRAGMixin, Agent):
"""
Handles narration of the story
"""
@@ -99,7 +102,7 @@ class NarratorAgent(
agent_type = "narrator"
verbose_name = "Narrator"
set_processing = set_processing
websocket_handler = NarratorWebsocketHandler
@classmethod
@@ -117,7 +120,7 @@ class NarratorAgent(
min=32,
max=1024,
step=32,
),
),
"instructions": AgentActionConfig(
type="text",
label="Instructions",
@@ -135,11 +138,6 @@ class NarratorAgent(
),
},
),
"auto_break_repetition": AgentAction(
enabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"content": AgentAction(
enabled=True,
can_be_disabled=False,
@@ -154,7 +152,7 @@ class NarratorAgent(
description="Use the writing style selected in the scene settings",
value=True,
),
}
},
),
"narrate_time_passage": AgentAction(
enabled=True,
@@ -201,13 +199,13 @@ class NarratorAgent(
},
),
}
MemoryRAGMixin.add_actions(actions)
return actions
def __init__(
self,
client: client.TaleMateClient,
client: client.ClientBase | None = None,
**kwargs,
):
self.client = client
@@ -232,28 +230,33 @@ class NarratorAgent(
if self.actions["generation_override"].enabled:
return self.actions["generation_override"].config["length"].value
return 128
@property
def narrate_time_passage_enabled(self) -> bool:
return self.actions["narrate_time_passage"].enabled
@property
def narrate_dialogue_enabled(self) -> bool:
return self.actions["narrate_dialogue"].enabled
@property
def narrate_dialogue_ai_chance(self) -> float:
def narrate_dialogue_ai_chance(self) -> float:
return self.actions["narrate_dialogue"].config["ai_dialog"].value
@property
def narrate_dialogue_player_chance(self) -> float:
return self.actions["narrate_dialogue"].config["player_dialog"].value
@property
def content_use_writing_style(self) -> bool:
return self.actions["content"].config["use_writing_style"].value
def clean_result(self, result:str, ensure_dialog_format:bool=True, force_narrative:bool=True) -> str:
def clean_result(
self,
result: str,
ensure_dialog_format: bool = True,
force_narrative: bool = True,
) -> str:
"""
Cleans the result of a narration
"""
@@ -264,15 +267,14 @@ class NarratorAgent(
cleaned = []
for line in result.split("\n"):
# skip lines that start with a #
if line.startswith("#"):
continue
log.debug("clean_result", line=line)
character_dialogue_detected = False
for character_name in character_names:
if not character_name:
continue
@@ -280,25 +282,24 @@ class NarratorAgent(
character_dialogue_detected = True
elif line.startswith(f"{character_name.upper()}"):
character_dialogue_detected = True
if character_dialogue_detected:
break
if character_dialogue_detected:
break
cleaned.append(line)
result = "\n".join(cleaned)
result = util.strip_partial_sentences(result)
editor = get_agent("editor")
if ensure_dialog_format or force_narrative:
if editor.fix_exposition_enabled and editor.fix_exposition_narrator:
result = editor.fix_exposition_in_text(result)
return result
def connect(self, scene):
@@ -324,16 +325,16 @@ class NarratorAgent(
event.duration, event.human_duration, event.narrative
)
narrator_message = NarratorMessage(
response,
meta = {
response,
meta={
"agent": "narrator",
"function": "narrate_time_passage",
"arguments": {
"duration": event.duration,
"time_passed": event.human_duration,
"narrative_direction": event.narrative,
}
}
},
},
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
@@ -345,7 +346,7 @@ class NarratorAgent(
if not self.narrate_dialogue_enabled:
return
if event.game_loop.had_passive_narration:
log.debug(
"narrate on dialog",
@@ -375,14 +376,14 @@ class NarratorAgent(
response = await self.narrate_after_dialogue(event.actor.character)
narrator_message = NarratorMessage(
response,
response,
meta={
"agent": "narrator",
"function": "narrate_after_dialogue",
"arguments": {
"character": event.actor.character.name,
}
}
},
},
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
@@ -390,7 +391,7 @@ class NarratorAgent(
event.game_loop.had_passive_narration = True
@set_processing
@store_context_state('narrative_direction', visual_narration=True)
@store_context_state("narrative_direction", visual_narration=True)
async def narrate_scene(self, narrative_direction: str | None = None):
"""
Narrate the scene
@@ -413,13 +414,13 @@ class NarratorAgent(
return response
@set_processing
@store_context_state('narrative_direction')
@store_context_state("narrative_direction")
async def progress_story(self, narrative_direction: str | None = None):
"""
Narrate scene progression, moving the plot forward.
Arguments:
- narrative_direction: A string describing the direction the narrative should take. If not provided, will attempt to subtly move the story forward.
"""
@@ -431,10 +432,8 @@ class NarratorAgent(
if narrative_direction is None:
narrative_direction = "Slightly move the current scene forward."
log.debug(
"narrative_direction", narrative_direction=narrative_direction
)
log.debug("narrative_direction", narrative_direction=narrative_direction)
response = await Prompt.request(
"narrator.narrate-progress",
self.client,
@@ -453,13 +452,17 @@ class NarratorAgent(
log.debug("progress_story", response=response)
response = self.clean_result(response.strip())
return response
@set_processing
@store_context_state('query', query_narration=True)
@store_context_state("query", query_narration=True)
async def narrate_query(
self, query: str, at_the_end: bool = False, as_narrative: bool = True, extra_context: str = None
self,
query: str,
at_the_end: bool = False,
as_narrative: bool = True,
extra_context: str = None,
):
"""
Narrate a specific query
@@ -479,20 +482,20 @@ class NarratorAgent(
},
)
response = self.clean_result(
response.strip(),
ensure_dialog_format=False,
force_narrative=as_narrative
response.strip(), ensure_dialog_format=False, force_narrative=as_narrative
)
return response
@set_processing
@store_context_state('character', 'narrative_direction', visual_narration=True)
async def narrate_character(self, character:"Character", narrative_direction: str = None):
@store_context_state("character", "narrative_direction", visual_narration=True)
async def narrate_character(
self, character: "Character", narrative_direction: str = None
):
"""
Narrate a specific character
"""
response = await Prompt.request(
"narrator.narrate-character",
self.client,
@@ -506,12 +509,14 @@ class NarratorAgent(
},
)
response = self.clean_result(response.strip(), ensure_dialog_format=False, force_narrative=True)
response = self.clean_result(
response.strip(), ensure_dialog_format=False, force_narrative=True
)
return response
@set_processing
@store_context_state('narrative_direction', time_narration=True)
@store_context_state("narrative_direction", time_narration=True)
async def narrate_time_passage(
self, duration: str, time_passed: str, narrative_direction: str
):
@@ -528,7 +533,7 @@ class NarratorAgent(
"max_tokens": self.client.max_token_length,
"duration": duration,
"time_passed": time_passed,
"narrative": narrative_direction, # backwards compatibility
"narrative": narrative_direction, # backwards compatibility
"narrative_direction": narrative_direction,
"extra_instructions": self.extra_instructions,
},
@@ -541,7 +546,7 @@ class NarratorAgent(
return response
@set_processing
@store_context_state('narrative_direction', sensory_narration=True)
@store_context_state("narrative_direction", sensory_narration=True)
async def narrate_after_dialogue(
self,
character: Character,
@@ -572,16 +577,16 @@ class NarratorAgent(
async def narrate_environment(self, narrative_direction: str = None):
"""
Narrate the environment
Wraps narrate_after_dialogue with the player character
as the perspective character
"""
pc = self.scene.get_player_character()
return await self.narrate_after_dialogue(pc, narrative_direction)
@set_processing
@store_context_state('narrative_direction', 'character')
@store_context_state("narrative_direction", "character")
async def narrate_character_entry(
self, character: Character, narrative_direction: str = None
):
@@ -607,11 +612,9 @@ class NarratorAgent(
return response
@set_processing
@store_context_state('narrative_direction', 'character')
@store_context_state("narrative_direction", "character")
async def narrate_character_exit(
self,
character: Character,
narrative_direction: str = None
self, character: Character, narrative_direction: str = None
):
"""
Narrate a character exiting the scene
@@ -679,8 +682,10 @@ class NarratorAgent(
later on
"""
args = parameters.copy()
if args.get("character") and isinstance(args["character"], self.scene.Character):
if args.get("character") and isinstance(
args["character"], self.scene.Character
):
args["character"] = args["character"].name
return {
@@ -701,7 +706,9 @@ class NarratorAgent(
fn = getattr(self, action_name)
narration = await fn(**kwargs)
narrator_message = NarratorMessage(narration, meta=self.action_to_meta(action_name, kwargs))
narrator_message = NarratorMessage(
narration, meta=self.action_to_meta(action_name, kwargs)
)
self.scene.push_history(narrator_message)
if emit_message:
@@ -720,35 +727,36 @@ class NarratorAgent(
kind=kind,
agent_function_name=agent_function_name,
)
# depending on conversation format in the context, stopping strings
# for character names may change format
conversation_agent = get_agent("conversation")
if conversation_agent.conversation_format == "movie_script":
character_names = [f"\n{c.name.upper()}\n" for c in self.scene.get_characters()]
else:
character_names = [
f"\n{c.name.upper()}\n" for c in self.scene.get_characters()
]
else:
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += character_names
self.set_generation_overrides(prompt_param)
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return True
def set_generation_overrides(self, prompt_param: dict):
if not self.actions["generation_override"].enabled:
return
prompt_param["max_tokens"] = min(prompt_param.get("max_tokens", 256), self.max_generation_length)
prompt_param["max_tokens"] = min(
prompt_param.get("max_tokens", 256), self.max_generation_length
)
if self.jiggle > 0.0:
nuke_repetition = client_context_attribute("nuke_repetition")
@@ -756,4 +764,4 @@ class NarratorAgent(
# we only apply the agent override if some other mechanism isn't already
# setting the nuke_repetition value
nuke_repetition = self.jiggle
set_client_context_attribute("nuke_repetition", nuke_repetition)
set_client_context_attribute("nuke_repetition", nuke_repetition)

View File

@@ -8,15 +8,16 @@ from talemate.util import iso8601_duration_to_human
log = structlog.get_logger("talemate.game.engine.nodes.agents.narrator")
class GenerateNarrationBase(AgentNode):
"""
Generate a narration message
"""
_agent_name:ClassVar[str] = "narrator"
_action_name:ClassVar[str] = ""
_title:ClassVar[str] = "Generate Narration"
_agent_name: ClassVar[str] = "narrator"
_action_name: ClassVar[str] = ""
_title: ClassVar[str] = "Generate Narration"
class Fields:
narrative_direction = PropertyField(
name="narrative_direction",
@@ -24,151 +25,170 @@ class GenerateNarrationBase(AgentNode):
default="",
type="str",
)
def __init__(self, **kwargs):
if "title" not in kwargs:
kwargs["title"] = self._title
super().__init__(**kwargs)
def setup(self):
self.add_input("state")
self.add_input("narrative_direction", socket_type="str", optional=True)
self.add_output("generated", socket_type="str")
self.add_output("message", socket_type="message_object")
async def prepare_input_values(self) -> dict:
input_values = self.get_input_values()
input_values.pop("state", None)
return input_values
async def run(self, state: GraphState):
input_values = await self.prepare_input_values()
try:
agent_fn = getattr(self.agent, self._action_name)
except AttributeError:
raise InputValueError(self, "_action_name", f"Agent does not have a function named {self._action_name}")
raise InputValueError(
self,
"_action_name",
f"Agent does not have a function named {self._action_name}",
)
narration = await agent_fn(**input_values)
message = NarratorMessage(
message=narration,
meta=self.agent.action_to_meta(self._action_name, input_values),
)
self.set_output_values({
"generated": narration,
"message": message
})
self.set_output_values({"generated": narration, "message": message})
@register("agents/narrator/GenerateProgress")
class GenerateProgressNarration(GenerateNarrationBase):
"""
Generate a progress narration message
"""
_action_name:ClassVar[str] = "progress_story"
_title:ClassVar[str] = "Generate Progress Narration"
"""
_action_name: ClassVar[str] = "progress_story"
_title: ClassVar[str] = "Generate Progress Narration"
@register("agents/narrator/GenerateSceneNarration")
class GenerateSceneNarration(GenerateNarrationBase):
"""
Generate a scene narration message
"""
_action_name:ClassVar[str] = "narrate_scene"
_title:ClassVar[str] = "Generate Scene Narration"
"""
@register("agents/narrator/GenerateAfterDialogNarration")
_action_name: ClassVar[str] = "narrate_scene"
_title: ClassVar[str] = "Generate Scene Narration"
@register("agents/narrator/GenerateAfterDialogNarration")
class GenerateAfterDialogNarration(GenerateNarrationBase):
"""
Generate an after dialog narration message
"""
_action_name:ClassVar[str] = "narrate_after_dialogue"
_title:ClassVar[str] = "Generate After Dialog Narration"
"""
_action_name: ClassVar[str] = "narrate_after_dialogue"
_title: ClassVar[str] = "Generate After Dialog Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/GenerateEnvironmentNarration")
class GenerateEnvironmentNarration(GenerateNarrationBase):
"""
Generate an environment narration message
"""
_action_name:ClassVar[str] = "narrate_environment"
_title:ClassVar[str] = "Generate Environment Narration"
"""
_action_name: ClassVar[str] = "narrate_environment"
_title: ClassVar[str] = "Generate Environment Narration"
@register("agents/narrator/GenerateQueryNarration")
class GenerateQueryNarration(GenerateNarrationBase):
"""
Generate a query narration message
"""
_action_name:ClassVar[str] = "narrate_query"
_title:ClassVar[str] = "Generate Query Narration"
"""
_action_name: ClassVar[str] = "narrate_query"
_title: ClassVar[str] = "Generate Query Narration"
def setup(self):
super().setup()
self.add_input("query", socket_type="str")
self.add_input("extra_context", socket_type="str", optional=True)
self.remove_input("narrative_direction")
@register("agents/narrator/GenerateCharacterNarration")
class GenerateCharacterNarration(GenerateNarrationBase):
"""
Generate a character narration message
"""
_action_name:ClassVar[str] = "narrate_character"
_title:ClassVar[str] = "Generate Character Narration"
"""
_action_name: ClassVar[str] = "narrate_character"
_title: ClassVar[str] = "Generate Character Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/GenerateTimeNarration")
class GenerateTimeNarration(GenerateNarrationBase):
"""
Generate a time narration message
"""
_action_name:ClassVar[str] = "narrate_time_passage"
_title:ClassVar[str] = "Generate Time Narration"
"""
_action_name: ClassVar[str] = "narrate_time_passage"
_title: ClassVar[str] = "Generate Time Narration"
def setup(self):
super().setup()
self.add_input("duration", socket_type="str")
self.set_property("duration", "P0T1S")
async def prepare_input_values(self) -> dict:
input_values = await super().prepare_input_values()
input_values["time_passed"] = iso8601_duration_to_human(input_values["duration"])
input_values["time_passed"] = iso8601_duration_to_human(
input_values["duration"]
)
return input_values
@register("agents/narrator/GenerateCharacterEntryNarration")
class GenerateCharacterEntryNarration(GenerateNarrationBase):
"""
Generate a character entry narration message
"""
_action_name:ClassVar[str] = "narrate_character_entry"
_title:ClassVar[str] = "Generate Character Entry Narration"
"""
_action_name: ClassVar[str] = "narrate_character_entry"
_title: ClassVar[str] = "Generate Character Entry Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/GenerateCharacterExitNarration")
class GenerateCharacterExitNarration(GenerateNarrationBase):
"""
Generate a character exit narration message
"""
_action_name:ClassVar[str] = "narrate_character_exit"
_title:ClassVar[str] = "Generate Character Exit Narration"
"""
_action_name: ClassVar[str] = "narrate_character_exit"
_title: ClassVar[str] = "Generate Character Exit Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/UnpackSource")
class UnpackSource(AgentNode):
"""
@@ -176,25 +196,19 @@ class UnpackSource(AgentNode):
into action name and arguments
DEPRECATED
"""
_agent_name:ClassVar[str] = "narrator"
_agent_name: ClassVar[str] = "narrator"
def __init__(self, title="Unpack Source", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("source", socket_type="str")
self.add_output("action_name", socket_type="str")
self.add_output("arguments", socket_type="dict")
async def run(self, state: GraphState):
source = self.get_input_value("source")
action_name = ""
arguments = {}
self.set_output_values({
"action_name": action_name,
"arguments": arguments
})
self.set_output_values({"action_name": action_name, "arguments": arguments})

View File

@@ -15,27 +15,31 @@ __all__ = [
log = structlog.get_logger("talemate.server.narrator")
class QueryPayload(pydantic.BaseModel):
query:str
at_the_end:bool=True
query: str
at_the_end: bool = True
class NarrativeDirectionPayload(pydantic.BaseModel):
narrative_direction:str = ""
narrative_direction: str = ""
class CharacterPayload(NarrativeDirectionPayload):
character:str = ""
character: str = ""
class NarratorWebsocketHandler(Plugin):
"""
Handles narrator actions
"""
router = "narrator"
@property
def narrator(self):
return get_agent("narrator")
@set_loading("Progressing the story", cancellable=True, as_async=True)
async def handle_progress(self, data: dict):
"""
@@ -47,7 +51,7 @@ class NarratorWebsocketHandler(Plugin):
narrative_direction=payload.narrative_direction,
emit_message=True,
)
@set_loading("Narrating the environment", cancellable=True, as_async=True)
async def handle_narrate_environment(self, data: dict):
"""
@@ -59,8 +63,7 @@ class NarratorWebsocketHandler(Plugin):
narrative_direction=payload.narrative_direction,
emit_message=True,
)
@set_loading("Working on a query", cancellable=True, as_async=True)
async def handle_query(self, data: dict):
"""
@@ -68,56 +71,55 @@ class NarratorWebsocketHandler(Plugin):
message.
"""
payload = QueryPayload(**data)
narration = await self.narrator.narrate_query(**payload.model_dump())
message: ContextInvestigationMessage = ContextInvestigationMessage(
narration, sub_type="query"
narration, sub_type="query"
)
message.set_source("narrator", "narrate_query", **payload.model_dump())
emit("context_investigation", message=message)
self.scene.push_history(message)
@set_loading("Looking at the scene", cancellable=True, as_async=True)
async def handle_look_at_scene(self, data: dict):
"""
Look at the scene (optionally to a specific direction)
This will result in a context investigation message.
"""
payload = NarrativeDirectionPayload(**data)
narration = await self.narrator.narrate_scene(narrative_direction=payload.narrative_direction)
narration = await self.narrator.narrate_scene(
narrative_direction=payload.narrative_direction
)
message: ContextInvestigationMessage = ContextInvestigationMessage(
narration, sub_type="visual-scene"
)
message.set_source("narrator", "narrate_scene", **payload.model_dump())
emit("context_investigation", message=message)
self.scene.push_history(message)
@set_loading("Looking at a character", cancellable=True, as_async=True)
async def handle_look_at_character(self, data: dict):
"""
Look at a character (optionally to a specific direction)
This will result in a context investigation message.
"""
payload = CharacterPayload(**data)
narration = await self.narrator.narrate_character(
character = self.scene.get_character(payload.character),
character=self.scene.get_character(payload.character),
narrative_direction=payload.narrative_direction,
)
message: ContextInvestigationMessage = ContextInvestigationMessage(
narration, sub_type="visual-character"
)
message.set_source("narrator", "narrate_character", **payload.model_dump())
emit("context_investigation", message=message)
self.scene.push_history(message)

View File

@@ -24,4 +24,4 @@ def get_agent_class(name):
def get_agent_types() -> list[str]:
return list(AGENT_CLASSES.keys())
return list(AGENT_CLASSES.keys())

View File

@@ -7,26 +7,21 @@ import structlog
from typing import TYPE_CHECKING, Literal
import talemate.emit.async_signals
import talemate.util as util
from talemate.emit import emit
from talemate.events import GameLoopEvent
from talemate.prompts import Prompt
from talemate.scene_message import (
DirectorMessage,
TimePassageMessage,
ContextInvestigationMessage,
DirectorMessage,
TimePassageMessage,
ContextInvestigationMessage,
ReinforcementMessage,
)
from talemate.world_state.templates import GenerationOptions
from talemate.instance import get_agent
from talemate.exceptions import GenerationCancelled
import talemate.game.focal as focal
import talemate.emit.async_signals
from talemate.client import ClientBase
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
set_processing,
Agent,
AgentAction,
AgentActionConfig,
set_processing,
AgentEmission,
AgentTemplateEmission,
RagBuildSubInstructionEmission,
@@ -39,6 +34,7 @@ from talemate.history import ArchiveEntry
from .analyze_scene import SceneAnalyzationMixin
from .context_investigation import ContextInvestigationMixin
from .layered_history import LayeredHistoryMixin
from .tts_utils import TTSUtilsMixin
if TYPE_CHECKING:
from talemate.tale_mate import Character
@@ -53,10 +49,12 @@ talemate.emit.async_signals.register(
"agent.summarization.summarize.after",
)
@dataclasses.dataclass
class BuildArchiveEmission(AgentEmission):
generation_options: GenerationOptions | None = None
@dataclasses.dataclass
class SummarizeEmission(AgentTemplateEmission):
text: str = ""
@@ -66,6 +64,7 @@ class SummarizeEmission(AgentTemplateEmission):
summarization_history: list[str] | None = None
summarization_type: Literal["dialogue", "events"] = "dialogue"
@register()
class SummarizeAgent(
MemoryRAGMixin,
@@ -73,7 +72,8 @@ class SummarizeAgent(
ContextInvestigationMixin,
# Needs to be after ContextInvestigationMixin so signals are connected in the right order
SceneAnalyzationMixin,
Agent
TTSUtilsMixin,
Agent,
):
"""
An agent that can be used to summarize text
@@ -82,7 +82,7 @@ class SummarizeAgent(
agent_type = "summarizer"
verbose_name = "Summarizer"
auto_squish = False
@classmethod
def init_actions(cls) -> dict[str, AgentAction]:
actions = {
@@ -131,7 +131,7 @@ class SummarizeAgent(
ContextInvestigationMixin.add_actions(actions)
return actions
def __init__(self, client, **kwargs):
def __init__(self, client: ClientBase | None = None, **kwargs):
self.client = client
self.actions = SummarizeAgent.init_actions()
@@ -148,11 +148,11 @@ class SummarizeAgent(
@property
def archive_threshold(self):
return self.actions["archive"].config["threshold"].value
@property
def archive_method(self):
return self.actions["archive"].config["method"].value
@property
def archive_include_previous(self):
return self.actions["archive"].config["include_previous"].value
@@ -179,7 +179,7 @@ class SummarizeAgent(
return result
# RAG HELPERS
async def rag_build_sub_instruction(self):
# Fire event to get the sub instruction from mixins
emission = RagBuildSubInstructionEmission(
@@ -188,37 +188,42 @@ class SummarizeAgent(
await talemate.emit.async_signals.get(
"agent.summarization.rag_build_sub_instruction"
).send(emission)
return emission.sub_instruction
# SUMMARIZATION HELPERS
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
num_previous = self.archive_include_previous
# find entry by .id
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
entry_index = next(
(
i
for i, e in enumerate(self.scene.archived_history)
if e["id"] == entry.id
),
None,
)
if entry_index is None:
raise ValueError("Entry not found")
end = entry_index - 1
previous_summaries = []
if entry and num_previous > 0:
if self.layered_history_available:
previous_summaries = self.compile_layered_history(
include_base_layer=True,
base_layer_end_id=entry.id
include_base_layer=True, base_layer_end_id=entry.id
)[-num_previous:]
else:
previous_summaries = [
entry.text for entry in self.scene.archived_history[end-num_previous:end]
entry.text
for entry in self.scene.archived_history[end - num_previous : end]
]
return previous_summaries
# SUMMARIZE
@set_processing
@@ -226,13 +231,15 @@ class SummarizeAgent(
self, scene, generation_options: GenerationOptions | None = None
):
end = None
emission = BuildArchiveEmission(
agent=self,
generation_options=generation_options,
)
await talemate.emit.async_signals.get("agent.summarization.before_build_archive").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.before_build_archive"
).send(emission)
if not self.actions["archive"].enabled:
return
@@ -260,7 +267,7 @@ class SummarizeAgent(
extra_context = [
entry["text"] for entry in scene.archived_history[-num_previous:]
]
else:
extra_context = None
@@ -283,7 +290,10 @@ class SummarizeAgent(
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
if isinstance(dialogue, (DirectorMessage, ContextInvestigationMessage, ReinforcementMessage)):
if isinstance(
dialogue,
(DirectorMessage, ContextInvestigationMessage, ReinforcementMessage),
):
# these messages are not part of the dialogue and should not be summarized
if i == start:
start += 1
@@ -292,7 +302,7 @@ class SummarizeAgent(
if isinstance(dialogue, TimePassageMessage):
log.debug("build_archive", time_passage_message=dialogue)
ts = util.iso8601_add(ts, dialogue.ts)
if i == start:
log.debug(
"build_archive",
@@ -347,23 +357,28 @@ class SummarizeAgent(
if str(line) in terminating_line:
break
adjusted_dialogue.append(line)
# if difference start and end is less than 4, ignore the termination
if len(adjusted_dialogue) > 4:
dialogue_entries = adjusted_dialogue
end = start + len(dialogue_entries) - 1
else:
log.debug("build_archive", message="Ignoring termination", start=start, end=end, adjusted_dialogue=adjusted_dialogue)
log.debug(
"build_archive",
message="Ignoring termination",
start=start,
end=end,
adjusted_dialogue=adjusted_dialogue,
)
if dialogue_entries:
if not extra_context:
# prepend scene intro to dialogue
dialogue_entries.insert(0, scene.intro)
summarized = None
retries = 5
while not summarized and retries > 0:
summarized = await self.summarize(
"\n".join(map(str, dialogue_entries)),
@@ -371,7 +386,7 @@ class SummarizeAgent(
generation_options=generation_options,
)
retries -= 1
if not summarized:
raise IOError("Failed to summarize dialogue", dialogue=dialogue_entries)
@@ -382,13 +397,17 @@ class SummarizeAgent(
# determine the appropariate timestamp for the summarization
await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
scene.ts=ts
await scene.push_archive(
ArchiveEntry(text=summarized, start=start, end=end, ts=ts)
)
scene.ts = ts
scene.emit_status()
await talemate.emit.async_signals.get("agent.summarization.after_build_archive").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.after_build_archive"
).send(emission)
return True
@set_processing
@@ -408,23 +427,23 @@ class SummarizeAgent(
return response
@set_processing
async def find_natural_scene_termination(self, event_chunks:list[str]) -> list[list[str]]:
async def find_natural_scene_termination(
self, event_chunks: list[str]
) -> list[list[str]]:
"""
Will analyze a list of events and return a list of events that
has been separated at a natural scene termination points.
"""
# scan through event chunks and split into paragraphs
rebuilt_chunks = []
for chunk in event_chunks:
paragraphs = [
p.strip() for p in chunk.split("\n") if p.strip()
]
paragraphs = [p.strip() for p in chunk.split("\n") if p.strip()]
rebuilt_chunks.extend(paragraphs)
event_chunks = rebuilt_chunks
response = await Prompt.request(
"summarizer.find-natural-scene-termination-events",
self.client,
@@ -436,38 +455,42 @@ class SummarizeAgent(
},
)
response = response.strip()
items = util.extract_list(response)
# will be a list of
# will be a list of
# ["Progress 1", "Progress 12", "Progress 323", ...]
# convert to a list of just numbers
# convert to a list of just numbers
numbers = []
for item in items:
match = re.match(r"Progress (\d+)", item.strip())
if match:
numbers.append(int(match.group(1)))
# make sure its unique and sorted
numbers = sorted(list(set(numbers)))
result = []
prev_number = 0
for number in numbers:
result.append(event_chunks[prev_number:number+1])
prev_number = number+1
#result = {
result.append(event_chunks[prev_number : number + 1])
prev_number = number + 1
# result = {
# "selected": event_chunks[:number+1],
# "remaining": event_chunks[number+1:]
#}
log.debug("find_natural_scene_termination", response=response, result=result, numbers=numbers)
# }
log.debug(
"find_natural_scene_termination",
response=response,
result=result,
numbers=numbers,
)
return result
@set_processing
async def summarize(
@@ -481,9 +504,9 @@ class SummarizeAgent(
"""
Summarize the given text
"""
response_length = 1024
template_vars = {
"dialogue": text,
"scene": self.scene,
@@ -500,7 +523,7 @@ class SummarizeAgent(
"analyze_chunks": self.layered_history_analyze_chunks,
"response_length": response_length,
}
emission = SummarizeEmission(
agent=self,
text=text,
@@ -511,43 +534,46 @@ class SummarizeAgent(
summarization_history=extra_context or [],
summarization_type="dialogue",
)
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.before"
).send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"summarizer.summarize-dialogue",
"summarizer.summarize-dialogue",
self.client,
f"summarize_{response_length}",
vars=template_vars,
dedupe_enabled=False
dedupe_enabled=False,
)
log.debug(
"summarize", dialogue_length=len(text), summarized_length=len(response)
)
try:
summary = response.split("SUMMARY:")[1].strip()
except Exception as e:
log.error("summarize failed", response=response, exc=e)
return ""
# capitalize first letter
try:
summary = summary[0].upper() + summary[1:]
except IndexError:
pass
emission.response = self.clean_result(summary)
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.after"
).send(emission)
summary = emission.response
return self.clean_result(summary)
@set_processing
async def summarize_events(
@@ -563,15 +589,14 @@ class SummarizeAgent(
"""
Summarize the given text
"""
if not extra_context:
extra_context = ""
mentioned_characters: list["Character"] = self.scene.parse_characters_from_text(
text + extra_context,
exclude_active=True
text + extra_context, exclude_active=True
)
template_vars = {
"dialogue": text,
"scene": self.scene,
@@ -585,7 +610,7 @@ class SummarizeAgent(
"response_length": response_length,
"mentioned_characters": mentioned_characters,
}
emission = SummarizeEmission(
agent=self,
text=text,
@@ -596,46 +621,62 @@ class SummarizeAgent(
summarization_history=[extra_context] if extra_context else [],
summarization_type="events",
)
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.before"
).send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"summarizer.summarize-events",
"summarizer.summarize-events",
self.client,
f"summarize_{response_length}",
vars=template_vars,
dedupe_enabled=False
dedupe_enabled=False,
)
response = response.strip()
response = response.replace('"', "")
log.debug(
"layered_history_summarize", original_length=len(text), summarized_length=len(response)
"layered_history_summarize",
original_length=len(text),
summarized_length=len(response),
)
# clean up analyzation (remove analyzation text)
if self.layered_history_analyze_chunks:
# remove all lines that begin with "ANALYSIS OF CHUNK \d+:"
response = "\n".join([line for line in response.split("\n") if not line.startswith("ANALYSIS OF CHUNK")])
response = "\n".join(
[
line
for line in response.split("\n")
if not line.startswith("ANALYSIS OF CHUNK")
]
)
# strip all occurences of "CHUNK \d+: " from the summary
response = re.sub(r"(CHUNK|CHAPTER) \d+:\s+", "", response)
# capitalize first letter
try:
response = response[0].upper() + response[1:]
except IndexError:
pass
emission.response = self.clean_result(response)
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.after"
).send(emission)
response = emission.response
log.debug("summarize_events", original_length=len(text), summarized_length=len(response))
log.debug(
"summarize_events",
original_length=len(text),
summarized_length=len(response),
)
return self.clean_result(response)

View File

@@ -16,25 +16,44 @@ from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission
from talemate.agents.context import active_agent
from talemate.agents.base import RagBuildSubInstructionEmission
from contextvars import ContextVar
if TYPE_CHECKING:
from talemate.tale_mate import Character
log = structlog.get_logger()
## CONTEXT
scene_analysis_disabled_context = ContextVar("scene_analysis_disabled", default=False)
class SceneAnalysisDisabled:
"""
Context manager to disable scene analysis during specific agent actions.
"""
def __enter__(self):
self.token = scene_analysis_disabled_context.set(True)
def __exit__(self, _exc_type, _exc_value, _traceback):
scene_analysis_disabled_context.reset(self.token)
talemate.emit.async_signals.register(
"agent.summarization.scene_analysis.before",
"agent.summarization.scene_analysis.after",
"agent.summarization.scene_analysis.cached",
"agent.summarization.scene_analysis.before_deep_analysis",
"agent.summarization.scene_analysis.after_deep_analysis",
"agent.summarization.scene_analysis.after_deep_analysis",
)
@dataclasses.dataclass
class SceneAnalysisEmission(AgentTemplateEmission):
analysis_type: str | None = None
@dataclasses.dataclass
class SceneAnalysisDeepAnalysisEmission(AgentEmission):
analysis: str
@@ -42,15 +61,16 @@ class SceneAnalysisDeepAnalysisEmission(AgentEmission):
analysis_sub_type: str | None = None
max_content_investigations: int = 1
character: "Character" = None
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
class SceneAnalyzationMixin:
"""
Summarizer agent mixin that provides functionality for scene analyzation.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["analyze_scene"] = AgentAction(
@@ -71,8 +91,8 @@ class SceneAnalyzationMixin:
choices=[
{"label": "Short (256)", "value": "256"},
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"}
]
{"label": "Long (1024)", "value": "1024"},
],
),
"for_conversation": AgentActionConfig(
type="bool",
@@ -109,196 +129,221 @@ class SceneAnalyzationMixin:
value=True,
quick_toggle=True,
),
}
},
)
# config property helpers
@property
def analyze_scene(self) -> bool:
return self.actions["analyze_scene"].enabled
@property
def analysis_length(self) -> int:
return int(self.actions["analyze_scene"].config["analysis_length"].value)
@property
def cache_analysis(self) -> bool:
return self.actions["analyze_scene"].config["cache_analysis"].value
@property
def deep_analysis(self) -> bool:
return self.actions["analyze_scene"].config["deep_analysis"].value
@property
def deep_analysis_max_context_investigations(self) -> int:
return self.actions["analyze_scene"].config["deep_analysis_max_context_investigations"].value
return (
self.actions["analyze_scene"]
.config["deep_analysis_max_context_investigations"]
.value
)
@property
def analyze_scene_for_conversation(self) -> bool:
return self.actions["analyze_scene"].config["for_conversation"].value
@property
def analyze_scene_for_narration(self) -> bool:
return self.actions["analyze_scene"].config["for_narration"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
self.on_inject_instructions
)
talemate.emit.async_signals.get(
"agent.conversation.inject_instructions"
).connect(self.on_inject_instructions)
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
self.on_inject_instructions
)
talemate.emit.async_signals.get("agent.summarization.rag_build_sub_instruction").connect(
self.on_rag_build_sub_instruction
)
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
self.on_editor_revision_analysis_before
)
talemate.emit.async_signals.get(
"agent.summarization.rag_build_sub_instruction"
).connect(self.on_rag_build_sub_instruction)
talemate.emit.async_signals.get(
"agent.editor.revision-analysis.before"
).connect(self.on_editor_revision_analysis_before)
async def on_inject_instructions(
self,
emission:ConversationAgentEmission | NarratorAgentEmission,
self,
emission: ConversationAgentEmission | NarratorAgentEmission,
):
"""
Injects instructions into the conversation.
"""
if isinstance(emission, ConversationAgentEmission):
emission_type = "conversation"
elif isinstance(emission, NarratorAgentEmission):
emission_type = "narration"
else:
raise ValueError("Invalid emission type.")
if not self.analyze_scene:
return
try:
if scene_analysis_disabled_context.get():
log.debug(
"on_inject_instructions: scene analysis disabled through context",
emission=emission,
)
return
except LookupError:
pass
analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}")
if not analyze_scene_for_type:
return
analysis = None
# self.set_scene_states and self.get_scene_state to store
# cached analysis in scene states
if self.cache_analysis:
analysis = await self.get_cached_analysis(emission_type)
if analysis:
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").send(
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.cached"
).send(
SceneAnalysisEmission(
agent=self,
analysis_type=emission_type,
response=analysis,
agent=self,
analysis_type=emission_type,
response=analysis,
template_vars={
"character": emission.character if hasattr(emission, "character") else None,
"character": emission.character
if hasattr(emission, "character")
else None,
},
dynamic_instructions=emission.dynamic_instructions
dynamic_instructions=emission.dynamic_instructions,
)
)
if not analysis and self.analyze_scene:
# analyze the scene for the next action
analysis = await self.analyze_scene_for_next_action(
emission_type,
emission.character if hasattr(emission, "character") else None,
self.analysis_length
self.analysis_length,
)
await self.set_cached_analysis(emission_type, analysis)
if not analysis:
return
emission.dynamic_instructions.append(
DynamicInstruction(
title="Scene Analysis",
content=analysis
)
DynamicInstruction(title="Scene Analysis", content=analysis)
)
async def on_rag_build_sub_instruction(self, emission:"RagBuildSubInstructionEmission"):
async def on_rag_build_sub_instruction(
self, emission: "RagBuildSubInstructionEmission"
):
"""
Injects the sub instruction into the analysis.
"""
sub_instruction = await self.analyze_scene_rag_build_sub_instruction()
if sub_instruction:
emission.sub_instruction = sub_instruction
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
last_analysis = self.get_scene_state("scene_analysis")
if last_analysis:
emission.dynamic_instructions.append(DynamicInstruction(
title="Scene Analysis",
content=last_analysis
))
emission.dynamic_instructions.append(
DynamicInstruction(title="Scene Analysis", content=last_analysis)
)
# helpers
async def get_cached_analysis(self, typ:str) -> str | None:
async def get_cached_analysis(self, typ: str) -> str | None:
"""
Returns the cached analysis for the given type.
"""
cached_analysis = self.get_scene_state(f"cached_analysis_{typ}")
if not cached_analysis:
return None
fingerprint = self.context_fingerpint()
fingerprint = self.context_fingerprint()
log.debug(
"get_cached_analysis",
fingerprint=fingerprint,
cached_analysis_fp=cached_analysis.get("fp"),
match=cached_analysis.get("fp") == fingerprint,
)
if cached_analysis.get("fp") == fingerprint:
return cached_analysis["guidance"]
return None
async def set_cached_analysis(self, typ:str, analysis:str):
async def set_cached_analysis(self, typ: str, analysis: str):
"""
Sets the cached analysis for the given type.
"""
fingerprint = self.context_fingerpint()
fingerprint = self.context_fingerprint()
self.set_scene_states(
**{f"cached_analysis_{typ}": {
"fp": fingerprint,
"guidance": analysis,
}}
**{
f"cached_analysis_{typ}": {
"fp": fingerprint,
"guidance": analysis,
}
}
)
async def analyze_scene_sub_type(self, analysis_type:str) -> str:
async def analyze_scene_sub_type(self, analysis_type: str) -> str:
"""
Analyzes the active agent context to figure out the appropriate sub type
"""
fn = getattr(self, f"analyze_scene_{analysis_type}_sub_type", None)
if fn:
return await fn()
return ""
async def analyze_scene_narration_sub_type(self) -> str:
"""
Analyzes the active agent context to figure out the appropriate sub type
for narration analysis. (progress, query etc.)
"""
active_agent_context = active_agent.get()
if not active_agent_context:
return "progress"
state = active_agent_context.state
if state.get("narrator__query_narration"):
return "query"
if state.get("narrator__sensory_narration"):
return "sensory"
@@ -306,70 +351,85 @@ class SceneAnalyzationMixin:
if state.get("narrator__character"):
return "visual-character"
return "visual"
if state.get("narrator__fn_narrate_character_entry"):
return "progress-character-entry"
if state.get("narrator__fn_narrate_character_exit"):
return "progress-character-exit"
return "progress"
async def analyze_scene_rag_build_sub_instruction(self):
"""
Analyzes the active agent context to figure out the appropriate sub type
for rag build sub instruction.
"""
active_agent_context = active_agent.get()
if not active_agent_context:
return ""
state = active_agent_context.state
if state.get("narrator__query_narration"):
query = state["narrator__query"]
if query.endswith("?"):
return "Answer the following question: " + query
else:
return query
narrative_direction = state.get("narrator__narrative_direction")
if state.get("narrator__sensory_narration") and narrative_direction:
return "Collect information that aids in describing the following sensory experience: " + narrative_direction
return (
"Collect information that aids in describing the following sensory experience: "
+ narrative_direction
)
if state.get("narrator__visual_narration") and narrative_direction:
return "Collect information that aids in describing the following visual experience: " + narrative_direction
return (
"Collect information that aids in describing the following visual experience: "
+ narrative_direction
)
if state.get("narrator__fn_narrate_character_entry") and narrative_direction:
return "Collect information that aids in describing the following character entry: " + narrative_direction
return (
"Collect information that aids in describing the following character entry: "
+ narrative_direction
)
if state.get("narrator__fn_narrate_character_exit") and narrative_direction:
return "Collect information that aids in describing the following character exit: " + narrative_direction
return (
"Collect information that aids in describing the following character exit: "
+ narrative_direction
)
if state.get("narrator__fn_narrate_progress"):
return "Collect information that aids in progressing the story: " + narrative_direction
return (
"Collect information that aids in progressing the story: "
+ narrative_direction
)
return ""
# actions
@set_processing
async def analyze_scene_for_next_action(self, typ:str, character:"Character"=None, length:int=1024) -> str:
async def analyze_scene_for_next_action(
self, typ: str, character: "Character" = None, length: int = 1024
) -> str:
"""
Analyzes the current scene progress and gives a suggestion for the next action.
taken by the given actor.
"""
# deep analysis is only available if the scene has a layered history
# and context investigation is enabled
deep_analysis = (self.deep_analysis and self.context_investigation_available)
deep_analysis = self.deep_analysis and self.context_investigation_available
analysis_sub_type = await self.analyze_scene_sub_type(typ)
template_vars = {
"max_tokens": self.client.max_token_length,
"scene": self.scene,
@@ -381,58 +441,60 @@ class SceneAnalyzationMixin:
"analysis_type": typ,
"analysis_sub_type": analysis_sub_type,
}
emission = SceneAnalysisEmission(agent=self, template_vars=template_vars, analysis_type=typ)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before").send(
emission
emission = SceneAnalysisEmission(
agent=self, template_vars=template_vars, analysis_type=typ
)
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.before"
).send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"summarizer.analyze-scene-for-next-{typ}",
self.client,
f"investigate_{length}",
vars=template_vars,
)
response = strip_partial_sentences(response)
if not response.strip():
return response
if deep_analysis:
emission = SceneAnalysisDeepAnalysisEmission(
agent=self,
agent=self,
analysis=response,
analysis_type=typ,
analysis_sub_type=analysis_sub_type,
character=character,
max_content_investigations=self.deep_analysis_max_context_investigations
max_content_investigations=self.deep_analysis_max_context_investigations,
)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").send(
emission
)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after_deep_analysis").send(
emission
)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").send(
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.before_deep_analysis"
).send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.after_deep_analysis"
).send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.after"
).send(
SceneAnalysisEmission(
agent=self,
template_vars=template_vars,
response=response,
agent=self,
template_vars=template_vars,
response=response,
analysis_type=typ,
dynamic_instructions=emission.dynamic_instructions
dynamic_instructions=emission.dynamic_instructions,
)
)
self.set_context_states(scene_analysis=response)
self.set_scene_states(scene_analysis=response)
return response
return response

View File

@@ -22,14 +22,12 @@ if TYPE_CHECKING:
log = structlog.get_logger()
class ContextInvestigationMixin:
"""
Summarizer agent mixin that provides functionality for context investigation
through the layered history of the scene.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["context_investigation"] = AgentAction(
@@ -52,7 +50,7 @@ class ContextInvestigationMixin:
{"label": "Short (256)", "value": "256"},
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"},
]
],
),
"update_method": AgentActionConfig(
type="text",
@@ -62,57 +60,56 @@ class ContextInvestigationMixin:
choices=[
{"label": "Replace", "value": "replace"},
{"label": "Smart Merge", "value": "merge"},
]
)
}
],
),
},
)
# config property helpers
@property
def context_investigation_enabled(self):
return self.actions["context_investigation"].enabled
@property
def context_investigation_available(self):
return (
self.context_investigation_enabled and
self.layered_history_available
)
return self.context_investigation_enabled and self.layered_history_available
@property
def context_investigation_answer_length(self) -> int:
return int(self.actions["context_investigation"].config["answer_length"].value)
@property
def context_investigation_update_method(self) -> str:
return self.actions["context_investigation"].config["update_method"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
self.on_inject_context_investigation
)
talemate.emit.async_signals.get(
"agent.conversation.inject_instructions"
).connect(self.on_inject_context_investigation)
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
self.on_inject_context_investigation
)
talemate.emit.async_signals.get("agent.director.guide.inject_instructions").connect(
self.on_inject_context_investigation
)
talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").connect(
self.on_summarization_scene_analysis_before_deep_analysis
)
async def on_summarization_scene_analysis_before_deep_analysis(self, emission:SceneAnalysisDeepAnalysisEmission):
talemate.emit.async_signals.get(
"agent.director.guide.inject_instructions"
).connect(self.on_inject_context_investigation)
talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.before_deep_analysis"
).connect(self.on_summarization_scene_analysis_before_deep_analysis)
async def on_summarization_scene_analysis_before_deep_analysis(
self, emission: SceneAnalysisDeepAnalysisEmission
):
"""
Handles context investigation for deep scene analysis.
"""
if not self.context_investigation_enabled:
return
suggested_investigations = await self.suggest_context_investigations(
emission.analysis,
emission.analysis_type,
@@ -120,67 +117,72 @@ class ContextInvestigationMixin:
max_calls=emission.max_content_investigations,
character=emission.character,
)
response = emission.analysis
ci_calls:list[focal.Call] = await self.request_context_investigations(
suggested_investigations,
max_calls=emission.max_content_investigations
ci_calls: list[focal.Call] = await self.request_context_investigations(
suggested_investigations, max_calls=emission.max_content_investigations
)
log.debug("analyze_scene_for_next_action", ci_calls=ci_calls)
# append call queries and answers to the response
ci_text = []
for ci_call in ci_calls:
try:
ci_text.append(f"{ci_call.arguments['query']}\n{ci_call.result}")
except KeyError as e:
log.error("analyze_scene_for_next_action", error="Missing key in call", ci_call=ci_call)
context_investigation="\n\n".join(ci_text if ci_text else [])
except KeyError:
log.error(
"analyze_scene_for_next_action",
error="Missing key in call",
ci_call=ci_call,
)
context_investigation = "\n\n".join(ci_text if ci_text else [])
current_context_investigation = self.get_scene_state("context_investigation")
if current_context_investigation and context_investigation:
if self.context_investigation_update_method == "merge":
context_investigation = await self.update_context_investigation(
current_context_investigation, context_investigation, response
)
self.set_scene_states(context_investigation=context_investigation)
self.set_context_states(context_investigation=context_investigation)
async def on_inject_context_investigation(self, emission:ConversationAgentEmission | NarratorAgentEmission):
async def on_inject_context_investigation(
self, emission: ConversationAgentEmission | NarratorAgentEmission
):
"""
Injects context investigation into the conversation.
"""
if not self.context_investigation_enabled:
return
context_investigation = self.get_scene_state("context_investigation")
log.debug("summarizer.on_inject_context_investigation", context_investigation=context_investigation, emission=emission)
log.debug(
"summarizer.on_inject_context_investigation",
context_investigation=context_investigation,
emission=emission,
)
if context_investigation:
emission.dynamic_instructions.append(
DynamicInstruction(
title="Context Investigation",
content=context_investigation
title="Context Investigation", content=context_investigation
)
)
# methods
@set_processing
async def suggest_context_investigations(
self,
analysis:str,
analysis_type:str,
analysis_sub_type:str="",
max_calls:int=3,
character:"Character"=None,
analysis: str,
analysis_type: str,
analysis_sub_type: str = "",
max_calls: int = 3,
character: "Character" = None,
) -> str:
template_vars = {
"max_tokens": self.client.max_token_length,
"scene": self.scene,
@@ -192,111 +194,119 @@ class ContextInvestigationMixin:
"analysis_type": analysis_type,
"analysis_sub_type": analysis_sub_type,
}
if not analysis_sub_type:
template = f"summarizer.suggest-context-investigations-for-{analysis_type}"
else:
template = f"summarizer.suggest-context-investigations-for-{analysis_type}-{analysis_sub_type}"
log.debug("summarizer.suggest_context_investigations", template=template, template_vars=template_vars)
log.debug(
"summarizer.suggest_context_investigations",
template=template,
template_vars=template_vars,
)
response = await Prompt.request(
template,
self.client,
"investigate_512",
vars=template_vars,
)
return response.strip()
@set_processing
async def investigate_context(
self,
layer:int,
index:int,
query:str,
analysis:str="",
max_calls:int=3,
pad_entries:int=5,
self,
layer: int,
index: int,
query: str,
analysis: str = "",
max_calls: int = 3,
pad_entries: int = 5,
) -> str:
"""
Processes a context investigation.
Arguments:
- layer: The layer to investigate
- index: The index in the layer to investigate
- query: The query to investigate
- analysis: Scene analysis text
- pad_entries: if > 0 will pad the entries with the given number of entries before and after the start and end index
"""
log.debug("summarizer.investigate_context", layer=layer, index=index, query=query)
log.debug(
"summarizer.investigate_context", layer=layer, index=index, query=query
)
entry = self.scene.layered_history[layer][index]
layer_to_investigate = layer - 1
start = max(entry["start"] - pad_entries, 0)
end = entry["end"] + pad_entries + 1
if layer_to_investigate == -1:
entries = self.scene.archived_history[start:end]
else:
entries = self.scene.layered_history[layer_to_investigate][start:end]
async def answer(query:str, instructions:str) -> str:
log.debug("Answering context investigation", query=query, instructions=answer)
async def answer(query: str, instructions: str) -> str:
log.debug(
"Answering context investigation", query=query, instructions=answer
)
world_state = get_agent("world_state")
return await world_state.analyze_history_and_follow_instructions(
entries,
f"{query}\n{instructions}",
analysis=analysis,
response_length=self.context_investigation_answer_length
response_length=self.context_investigation_answer_length,
)
async def investigate_context(chapter_number:str, query:str) -> str:
async def investigate_context(chapter_number: str, query: str) -> str:
# look for \d.\d in the chapter number, extract as layer and index
match = re.match(r"(\d+)\.(\d+)", chapter_number)
if not match:
log.error("summarizer.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
log.error(
"summarizer.investigate_context",
error="Invalid chapter number",
chapter_number=chapter_number,
)
return ""
layer = int(match.group(1))
index = int(match.group(2))
return await self.investigate_context(layer-1, index-1, query, analysis=analysis, max_calls=max_calls)
return await self.investigate_context(
layer - 1, index - 1, query, analysis=analysis, max_calls=max_calls
)
async def abort():
log.debug("Aborting context investigation")
focal_handler: focal.Focal = focal.Focal(
self.client,
callbacks=[
focal.Callback(
name="investigate_context",
arguments = [
arguments=[
focal.Argument(name="chapter_number", type="str"),
focal.Argument(name="query", type="str")
focal.Argument(name="query", type="str"),
],
fn=investigate_context
fn=investigate_context,
),
focal.Callback(
name="answer",
arguments = [
arguments=[
focal.Argument(name="instructions", type="str"),
focal.Argument(name="query", type="str")
focal.Argument(name="query", type="str"),
],
fn=answer
fn=answer,
),
focal.Callback(
name="abort",
fn=abort
)
focal.Callback(name="abort", fn=abort),
],
max_calls=max_calls,
scene=self.scene,
@@ -307,84 +317,86 @@ class ContextInvestigationMixin:
entries=entries,
analysis=analysis,
)
await focal_handler.request(
"summarizer.investigate-context",
)
log.debug("summarizer.investigate_context", calls=focal_handler.state.calls)
return focal_handler.state.calls
return focal_handler.state.calls
@set_processing
async def request_context_investigations(
self,
analysis:str,
max_calls:int=3,
self,
analysis: str,
max_calls: int = 3,
) -> list[focal.Call]:
"""
Requests context investigations for the given analysis.
"""
async def abort():
log.debug("Aborting context investigations")
async def investigate_context(chapter_number:str, query:str) -> str:
async def investigate_context(chapter_number: str, query: str) -> str:
# look for \d.\d in the chapter number, extract as layer and index
match = re.match(r"(\d+)\.(\d+)", chapter_number)
if not match:
log.error("summarizer.request_context_investigations.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
log.error(
"summarizer.request_context_investigations.investigate_context",
error="Invalid chapter number",
chapter_number=chapter_number,
)
return ""
layer = int(match.group(1))
index = int(match.group(2))
num_layers = len(self.scene.layered_history)
return await self.investigate_context(num_layers - layer, index-1, query, analysis, max_calls=max_calls)
return await self.investigate_context(
num_layers - layer, index - 1, query, analysis, max_calls=max_calls
)
focal_handler: focal.Focal = focal.Focal(
self.client,
callbacks=[
focal.Callback(
name="investigate_context",
arguments = [
arguments=[
focal.Argument(name="chapter_number", type="str"),
focal.Argument(name="query", type="str")
focal.Argument(name="query", type="str"),
],
fn=investigate_context
fn=investigate_context,
),
focal.Callback(
name="abort",
fn=abort
)
focal.Callback(name="abort", fn=abort),
],
max_calls=max_calls,
scene=self.scene,
text=analysis
text=analysis,
)
await focal_handler.request(
"summarizer.request-context-investigation",
)
log.debug("summarizer.request_context_investigations", calls=focal_handler.state.calls)
return focal.collect_calls(
focal_handler.state.calls,
nested=True,
filter=lambda c: c.name == "answer"
log.debug(
"summarizer.request_context_investigations", calls=focal_handler.state.calls
)
# return focal_handler.state.calls
return focal.collect_calls(
focal_handler.state.calls, nested=True, filter=lambda c: c.name == "answer"
)
# return focal_handler.state.calls
@set_processing
async def update_context_investigation(
self,
current_context_investigation:str,
new_context_investigation:str,
analysis:str,
current_context_investigation: str,
new_context_investigation: str,
analysis: str,
):
response = await Prompt.request(
"summarizer.update-context-investigation",
@@ -398,5 +410,5 @@ class ContextInvestigationMixin:
"max_tokens": self.client.max_token_length,
},
)
return response.strip()
return response.strip()

View File

@@ -1,5 +1,5 @@
import structlog
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
from talemate.agents.base import (
set_processing,
AgentAction,
@@ -24,11 +24,12 @@ talemate.emit.async_signals.register(
"agent.summarization.layered_history.finalize",
)
@dataclasses.dataclass
class LayeredHistoryFinalizeEmission(AgentEmission):
entry: LayeredArchiveEntry | None = None
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
@property
def response(self) -> str | None:
return self.entry.text if self.entry else None
@@ -38,22 +39,23 @@ class LayeredHistoryFinalizeEmission(AgentEmission):
if self.entry:
self.entry.text = value
class SummaryLongerThanOriginalError(ValueError):
def __init__(self, original_length:int, summarized_length:int):
def __init__(self, original_length: int, summarized_length: int):
self.original_length = original_length
self.summarized_length = summarized_length
super().__init__(f"Summarized text is longer than original text: {summarized_length} > {original_length}")
super().__init__(
f"Summarized text is longer than original text: {summarized_length} > {original_length}"
)
class LayeredHistoryMixin:
"""
Summarizer agent mixin that provides functionality for maintaining a layered history.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["layered_history"] = AgentAction(
enabled=True,
container=True,
@@ -80,7 +82,7 @@ class LayeredHistoryMixin:
max=5,
step=1,
value=3,
),
),
"max_process_tokens": AgentActionConfig(
type="number",
label="Maximum tokens to process",
@@ -116,69 +118,71 @@ class LayeredHistoryMixin:
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"},
{"label": "Exhaustive (2048)", "value": "2048"},
]
],
),
},
)
# config property helpers
@property
def layered_history_enabled(self):
return self.actions["layered_history"].enabled
@property
def layered_history_threshold(self):
return self.actions["layered_history"].config["threshold"].value
@property
def layered_history_max_process_tokens(self):
return self.actions["layered_history"].config["max_process_tokens"].value
@property
def layered_history_max_layers(self):
return self.actions["layered_history"].config["max_layers"].value
@property
def layered_history_chunk_size(self) -> int:
return self.actions["layered_history"].config["chunk_size"].value
@property
def layered_history_analyze_chunks(self) -> bool:
return self.actions["layered_history"].config["analyze_chunks"].value
@property
def layered_history_response_length(self) -> int:
return int(self.actions["layered_history"].config["response_length"].value)
@property
def layered_history_available(self):
return self.layered_history_enabled and self.scene.layered_history and self.scene.layered_history[0]
return (
self.layered_history_enabled
and self.scene.layered_history
and self.scene.layered_history[0]
)
# signals
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.summarization.after_build_archive").connect(
self.on_after_build_archive
)
async def on_after_build_archive(self, emission:"BuildArchiveEmission"):
talemate.emit.async_signals.get(
"agent.summarization.after_build_archive"
).connect(self.on_after_build_archive)
async def on_after_build_archive(self, emission: "BuildArchiveEmission"):
"""
After the archive has been built, we will update the layered history.
"""
if self.layered_history_enabled:
await self.summarize_to_layered_history(
generation_options=emission.generation_options
)
# helpers
async def _lh_split_and_summarize_chunks(
self,
self,
chunks: list[dict],
extra_context: str,
generation_options: GenerationOptions | None = None,
@@ -189,21 +193,29 @@ class LayeredHistoryMixin:
"""
summaries = []
current_chunk = chunks.copy()
while current_chunk:
partial_chunk = []
max_process_tokens = self.layered_history_max_process_tokens
# Build partial chunk up to max_process_tokens
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
while (
current_chunk
and util.count_tokens(
"\n\n".join(chunk["text"] for chunk in partial_chunk)
)
< max_process_tokens
):
partial_chunk.append(current_chunk.pop(0))
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
log.debug("_split_and_summarize_chunks",
tokens_in_chunk=util.count_tokens(text_to_summarize),
max_process_tokens=max_process_tokens)
text_to_summarize = "\n\n".join(chunk["text"] for chunk in partial_chunk)
log.debug(
"_split_and_summarize_chunks",
tokens_in_chunk=util.count_tokens(text_to_summarize),
max_process_tokens=max_process_tokens,
)
summary_text = await self.summarize_events(
text_to_summarize,
extra_context=extra_context + "\n\n".join(summaries),
@@ -213,9 +225,9 @@ class LayeredHistoryMixin:
chunk_size=self.layered_history_chunk_size,
)
summaries.append(summary_text)
return summaries
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
"""
Validates that the summarized text is not longer than the original.
@@ -224,17 +236,19 @@ class LayeredHistoryMixin:
summarized_length = util.count_tokens(summaries)
if summarized_length > original_length:
raise SummaryLongerThanOriginalError(original_length, summarized_length)
log.debug("_validate_summary_length",
original_length=original_length,
summarized_length=summarized_length)
log.debug(
"_validate_summary_length",
original_length=original_length,
summarized_length=summarized_length,
)
def _lh_build_extra_context(self, layer_index: int) -> str:
"""
Builds extra context from compiled layered history for the given layer.
"""
return "\n\n".join(self.compile_layered_history(layer_index))
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
"""
Extracts timestamps from a chunk of entries.
@@ -242,144 +256,156 @@ class LayeredHistoryMixin:
"""
if not chunk:
return "PT1S", "PT1S", "PT1S"
ts = chunk[0].get('ts', 'PT1S')
ts_start = chunk[0].get('ts_start', ts)
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
ts = chunk[0].get("ts", "PT1S")
ts_start = chunk[0].get("ts_start", ts)
ts_end = chunk[-1].get("ts_end", chunk[-1].get("ts", ts))
return ts, ts_start, ts_end
async def _lh_finalize_archive_entry(
self,
self,
entry: LayeredArchiveEntry,
summarization_history: list[str] | None = None,
) -> LayeredArchiveEntry:
"""
Finalizes an archive entry by summarizing it and adding it to the layered history.
"""
emission = LayeredHistoryFinalizeEmission(
agent=self,
entry=entry,
summarization_history=summarization_history,
)
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.layered_history.finalize"
).send(emission)
return emission.entry
# methods
def compile_layered_history(
self,
for_layer_index:int = None,
as_objects:bool=False,
include_base_layer:bool=False,
max:int = None,
self,
for_layer_index: int = None,
as_objects: bool = False,
include_base_layer: bool = False,
max: int = None,
base_layer_end_id: str | None = None,
) -> list[str]:
"""
Starts at the last layer and compiles the layered history into a single
list of events.
We are iterating backwards, so the last layer will be the most granular.
Each preceeding layer starts from the end of the the next layer.
"""
layered_history = self.scene.layered_history
compiled = []
next_layer_start = None
len_layered_history = len(layered_history)
for i in range(len_layered_history - 1, -1, -1):
if for_layer_index is not None:
if i < for_layer_index:
break
log.debug("compilelayered history", i=i, next_layer_start=next_layer_start)
if not layered_history[i]:
continue
entry_num = 1
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
for layered_history_entry in layered_history[i][
next_layer_start if next_layer_start is not None else 0 :
]:
if base_layer_end_id:
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
index=0,
layer=i+1,
**layered_history_entry)
contained = entry_contained(
self.scene,
base_layer_end_id,
HistoryEntry(index=0, layer=i + 1, **layered_history_entry),
)
if contained:
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
log.debug(
"compile_layered_history",
contained=True,
base_layer_end_id=base_layer_end_id,
)
break
text = f"{layered_history_entry['text']}"
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
if (
for_layer_index == i
and max is not None
and max <= layered_history_entry["end"]
):
break
if as_objects:
compiled.append({
"text": text,
"start": layered_history_entry["start"],
"end": layered_history_entry["end"],
"layer": i,
"layer_r": len_layered_history - i,
"ts_start": layered_history_entry["ts_start"],
"index": entry_num,
})
compiled.append(
{
"text": text,
"start": layered_history_entry["start"],
"end": layered_history_entry["end"],
"layer": i,
"layer_r": len_layered_history - i,
"ts_start": layered_history_entry["ts_start"],
"index": entry_num,
}
)
entry_num += 1
else:
compiled.append(text)
next_layer_start = layered_history_entry["end"] + 1
if i == 0 and include_base_layer:
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested
# so we append the base layer to the compiled list, starting from
# index `next_layer_start`
entry_num = 1
for ah in self.scene.archived_history[next_layer_start or 0:]:
for ah in self.scene.archived_history[next_layer_start or 0 :]:
if base_layer_end_id and ah["id"] == base_layer_end_id:
break
text = f"{ah['text']}"
if as_objects:
compiled.append({
"text": text,
"start": ah["start"],
"end": ah["end"],
"layer": -1,
"layer_r": 1,
"ts": ah["ts"],
"index": entry_num,
})
compiled.append(
{
"text": text,
"start": ah["start"],
"end": ah["end"],
"layer": -1,
"layer_r": 1,
"ts": ah["ts"],
"index": entry_num,
}
)
entry_num += 1
else:
compiled.append(text)
return compiled
@set_processing
async def summarize_to_layered_history(self, generation_options: GenerationOptions | None = None):
async def summarize_to_layered_history(
self, generation_options: GenerationOptions | None = None
):
"""
The layered history is a summarized archive with dynamic layers that
will get less and less granular as the scene progresses.
The most granular is still self.scene.archived_history, which holds
all the base layer summarizations.
self.scene.layered_history = [
# first layer after archived_history
[
@@ -391,7 +417,7 @@ class LayeredHistoryMixin:
},
...
],
# second layer
[
{
@@ -402,29 +428,29 @@ class LayeredHistoryMixin:
},
...
],
# additional layers
...
]
The same token threshold as for the base layer will be used for the
layers.
The same summarization function will be used for the layers.
The next level layer will be generated automatically when the token
threshold is reached.
"""
if not self.scene.archived_history:
return # No base layer summaries to work with
token_threshold = self.layered_history_threshold
max_layers = self.layered_history_max_layers
if not hasattr(self.scene, 'layered_history'):
if not hasattr(self.scene, "layered_history"):
self.scene.layered_history = []
layered_history = self.scene.layered_history
async def summarize_layer(source_layer, next_layer_index, start_from) -> bool:
@@ -432,147 +458,192 @@ class LayeredHistoryMixin:
current_tokens = 0
start_index = start_from
noop = True
total_tokens_in_previous_layer = util.count_tokens([
entry['text'] for entry in source_layer
])
total_tokens_in_previous_layer = util.count_tokens(
[entry["text"] for entry in source_layer]
)
estimated_entries = total_tokens_in_previous_layer // token_threshold
for i in range(start_from, len(source_layer)):
entry = source_layer[i]
entry_tokens = util.count_tokens(entry['text'])
log.debug("summarize_to_layered_history", entry=entry["text"][:100]+"...", tokens=entry_tokens, current_layer=next_layer_index-1)
entry_tokens = util.count_tokens(entry["text"])
log.debug(
"summarize_to_layered_history",
entry=entry["text"][:100] + "...",
tokens=entry_tokens,
current_layer=next_layer_index - 1,
)
if current_tokens + entry_tokens > token_threshold:
if current_chunk:
try:
# check if the next layer exists
next_layer = layered_history[next_layer_index]
except IndexError:
# create the next layer
layered_history.append([])
log.debug("summarize_to_layered_history", created_layer=next_layer_index)
log.debug(
"summarize_to_layered_history",
created_layer=next_layer_index,
)
next_layer = layered_history[next_layer_index]
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
ts, ts_start, ts_end = self._lh_extract_timestamps(
current_chunk
)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
text_length = util.count_tokens(
"\n\n".join(chunk["text"] for chunk in current_chunk)
)
num_entries_in_layer = len(layered_history[next_layer_index])
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
emit(
"status",
status="busy",
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}",
data={"cancellable": True},
)
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
noop = False
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
next_layer.append(LayeredArchiveEntry(**{
"start": start_index,
"end": i,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
}).model_dump(exclude_none=True))
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
next_layer.append(
LayeredArchiveEntry(
**{
"start": start_index,
"end": i,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
}
).model_dump(exclude_none=True)
)
emit(
"status",
status="busy",
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer + 1} / {estimated_entries}",
)
current_chunk = []
current_tokens = 0
start_index = i
current_chunk.append(entry)
current_tokens += entry_tokens
log.debug("summarize_to_layered_history", tokens=current_tokens, threshold=token_threshold, next_layer=next_layer_index)
log.debug(
"summarize_to_layered_history",
tokens=current_tokens,
threshold=token_threshold,
next_layer=next_layer_index,
)
return not noop
# First layer (always the base layer)
has_been_updated = False
try:
if not layered_history:
layered_history.append([])
log.debug("summarize_to_layered_history", layer="base", new_layer=True)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
has_been_updated = await summarize_layer(
self.scene.archived_history, 0, 0
)
elif layered_history[0]:
# determine starting point by checking for `end` in the last entry
last_entry = layered_history[0][-1]
end = last_entry["end"]
log.debug("summarize_to_layered_history", layer="base", start=end)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
has_been_updated = await summarize_layer(
self.scene.archived_history, 0, end
)
else:
log.debug("summarize_to_layered_history", layer="base", empty=True)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
has_been_updated = await summarize_layer(
self.scene.archived_history, 0, 0
)
except SummaryLongerThanOriginalError as exc:
log.error("summarize_to_layered_history", error=exc, layer="base")
emit("status", status="error", message="Layered history update failed.")
return
except GenerationCancelled as e:
log.info("Generation cancelled, stopping rebuild of historical layered history")
emit("status", message="Rebuilding of layered history cancelled", status="info")
log.info(
"Generation cancelled, stopping rebuild of historical layered history"
)
emit(
"status",
message="Rebuilding of layered history cancelled",
status="info",
)
handle_generation_cancelled(e)
return
# process layers
async def update_layers() -> bool:
noop = True
for index in range(0, len(layered_history)):
# check against max layers
if index + 1 > max_layers:
return False
try:
# check if the next layer exists
next_layer = layered_history[index + 1]
except IndexError:
next_layer = None
end = next_layer[-1]["end"] if next_layer else 0
log.debug("summarize_to_layered_history", layer=index, start=end)
summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
summarized = await summarize_layer(
layered_history[index], index + 1, end if end else 0
)
if summarized:
noop = False
return not noop
try:
while await update_layers():
has_been_updated = True
if has_been_updated:
emit("status", status="success", message="Layered history updated.")
except SummaryLongerThanOriginalError as exc:
log.error("summarize_to_layered_history", error=exc, layer="subsequent")
emit("status", status="error", message="Layered history update failed.")
return
except GenerationCancelled as e:
log.info("Generation cancelled, stopping rebuild of historical layered history")
emit("status", message="Rebuilding of layered history cancelled", status="info")
log.info(
"Generation cancelled, stopping rebuild of historical layered history"
)
emit(
"status",
message="Rebuilding of layered history cancelled",
status="info",
)
handle_generation_cancelled(e)
return
async def summarize_entries_to_layered_history(
self,
entries: list[dict],
self,
entries: list[dict],
next_layer_index: int,
start_index: int,
end_index: int,
@@ -580,11 +651,11 @@ class LayeredHistoryMixin:
) -> list[LayeredArchiveEntry]:
"""
Summarizes a list of entries into layered history entries.
This method is used for regenerating specific history entries by processing
their source entries. It chunks the entries based on the token threshold and
summarizes each chunk into a LayeredArchiveEntry.
Args:
entries: List of dictionaries containing the text entries to summarize.
Each entry should have at least a 'text' field and optionally
@@ -597,12 +668,12 @@ class LayeredHistoryMixin:
correspond to.
generation_options: Optional generation options to pass to the summarization
process.
Returns:
List of LayeredArchiveEntry objects containing the summarized text along
with timestamp and index information. Currently returns a list with a
single entry, but the structure supports multiple entries if needed.
Notes:
- The method respects the layered_history_threshold for chunking
- Uses helper methods for timestamp extraction, context building, and
@@ -611,63 +682,73 @@ class LayeredHistoryMixin:
- The last entry is always included in the final chunk if it doesn't
exceed the token threshold
"""
token_threshold = self.layered_history_threshold
archive_entries = []
summaries = []
current_chunk = []
current_tokens = 0
ts = "PT1S"
ts_start = "PT1S"
ts_end = "PT1S"
for entry_index, entry in enumerate(entries):
is_last_entry = entry_index == len(entries) - 1
entry_tokens = util.count_tokens(entry['text'])
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
entry_tokens = util.count_tokens(entry["text"])
log.debug(
"summarize_entries_to_layered_history",
entry=entry["text"][:100] + "...",
entry_tokens=entry_tokens,
current_layer=next_layer_index - 1,
current_tokens=current_tokens,
)
if current_tokens + entry_tokens > token_threshold or is_last_entry:
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
# if we are here because this is the last entry and adding it to
# the current chunk would not exceed the token threshold, we will
# add it to the current chunk
current_chunk.append(entry)
if current_chunk:
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
text_length = util.count_tokens(
"\n\n".join(chunk["text"] for chunk in current_chunk)
)
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
archive_entry = LayeredArchiveEntry(**{
"start": start_index,
"end": end_index,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
})
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
archive_entry = LayeredArchiveEntry(
**{
"start": start_index,
"end": end_index,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
}
)
archive_entry = await self._lh_finalize_archive_entry(
archive_entry, extra_context.split("\n\n")
)
archive_entries.append(archive_entry)
current_chunk.append(entry)
current_tokens += entry_tokens
return archive_entries

View File

@@ -0,0 +1,59 @@
import structlog
from talemate.agents.base import (
set_processing,
)
from talemate.prompts import Prompt
from talemate.status import set_loading
from talemate.util.dialogue import separate_dialogue_from_exposition
log = structlog.get_logger("talemate.agents.summarize.tts_utils")
class TTSUtilsMixin:
"""
Summarizer Mixin for text-to-speech utilities.
"""
@set_loading("Preparing TTS context")
@set_processing
async def markup_context_for_tts(self, text: str) -> str:
"""
Markup the context for text-to-speech.
"""
original_text = text
log.debug("Markup context for TTS", text=text)
# if there are no quotes in the text, there is nothing to separate
if '"' not in text:
return original_text
# here we separate dialogue from exposition because into
# obvious segments. It seems to have a positive effect on some
# LLMs returning the complete text.
separate_chunks = separate_dialogue_from_exposition(text)
numbered_chunks = []
for i, chunk in enumerate(separate_chunks):
numbered_chunks.append(f"[{i + 1}] {chunk.text.strip()}")
text = "\n".join(numbered_chunks)
response = await Prompt.request(
"summarizer.markup-context-for-tts",
self.client,
"investigate_1024",
vars={
"text": text,
"max_tokens": self.client.max_token_length,
"scene": self.scene,
},
)
try:
response = response.split("<MARKUP>")[1].split("</MARKUP>")[0].strip()
return response
except IndexError:
log.error("Failed to extract markup from response", response=response)
return original_text

View File

@@ -1,673 +0,0 @@
from __future__ import annotations
import asyncio
import base64
import functools
import io
import os
import tempfile
import time
import uuid
from typing import Union
import httpx
import nltk
import pydantic
import structlog
from nltk.tokenize import sent_tokenize
from openai import AsyncOpenAI
import talemate.config as config
import talemate.emit.async_signals
import talemate.instance as instance
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import CharacterMessage, NarratorMessage
from .base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .registry import register
try:
from TTS.api import TTS
except ImportError:
TTS = None
log = structlog.get_logger("talemate.agents.tts") #
if not TTS:
# TTS installation is massive and requires a lot of dependencies
# so we don't want to require it unless the user wants to use it
log.info(
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
)
def parse_chunks(text: str) -> list[str]:
"""
Takes a string and splits it into chunks based on punctuation.
In case of an error it will return the original text as a single chunk and
the error will be logged.
"""
try:
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
chunk = chunk.replace("*", "")
if not chunk:
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
except Exception as e:
log.error("chunking error", error=e, text=text)
return [text.replace("__ellipsis__", "...").replace("*", "")]
def clean_quotes(chunk: str):
# if there is an uneven number of quotes, remove the last one if its
# at the end of the chunk. If its in the middle, add a quote to the end
if chunk.count('"') % 2 == 1:
if chunk.endswith('"'):
chunk = chunk[:-1]
else:
chunk += '"'
return chunk
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(clean_quotes(current_chunk))
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(clean_quotes(current_chunk))
return joined_chunks
class Voice(pydantic.BaseModel):
value: str
label: str
class VoiceLibrary(pydantic.BaseModel):
api: str
voices: list[Voice] = pydantic.Field(default_factory=list)
last_synced: float = None
@register()
class TTSAgent(Agent):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Voice"
requires_llm_client = False
essential = False
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if agent:
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
voice.model_dump() for voice in agent.list_voices_sync()
]
return config_options
def __init__(self, **kwargs):
self.is_enabled = False #
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
try:
nltk.download("punkt", quiet=True)
except Exception as e:
log.error("nltk download error", error=e)
except Exception as e:
log.error("nltk find error", error=e)
self.voices = {
"elevenlabs": VoiceLibrary(api="elevenlabs"),
"tts": VoiceLibrary(api="tts"),
"openai": VoiceLibrary(api="openai"),
}
self.config = config.load_config()
self.playback_done_event = asyncio.Event()
self.preselect_voice = None
self.actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"api": AgentActionConfig(
type="text",
choices=[
{"value": "tts", "label": "TTS (Local)"},
{"value": "elevenlabs", "label": "Eleven Labs"},
{"value": "openai", "label": "OpenAI"},
],
value="tts",
label="API",
description="Which TTS API to use",
onchange="emit",
),
"voice_id": AgentActionConfig(
type="text",
value="default",
label="Narrator Voice",
description="Voice ID/Name to use for TTS",
choices=[],
),
"generate_for_player": AgentActionConfig(
type="bool",
value=False,
label="Generate for player",
description="Generate audio for player messages",
),
"generate_for_npc": AgentActionConfig(
type="bool",
value=True,
label="Generate for NPCs",
description="Generate audio for NPC messages",
),
"generate_for_narration": AgentActionConfig(
type="bool",
value=True,
label="Generate for narration",
description="Generate audio for narration messages",
),
"generate_chunks": AgentActionConfig(
type="bool",
value=False,
label="Split generation",
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
),
},
),
"openai": AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
condition=AgentActionConditional(
attribute="_config.config.api", value="openai"
),
label="OpenAI",
config={
"model": AgentActionConfig(
type="text",
value="tts-1",
choices=[
{"value": "tts-1", "label": "TTS 1"},
{"value": "tts-1-hd", "label": "TTS 1 HD"},
],
label="Model",
description="TTS model to use",
),
},
),
}
self.actions["_config"].model_dump()
handlers["config_saved"].connect(self.on_config_saved)
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def not_ready_reason(self) -> str:
"""
Returns a string explaining why the agent is not ready
"""
if self.ready:
return ""
if self.api == "tts":
if not TTS:
return "TTS not installed"
elif self.requires_token and not self.token:
return "No API token"
elif not self.default_voice_id:
return "No voice selected"
@property
def agent_details(self):
details = {
"api": AgentDetail(
icon="mdi-server-outline",
value=self.api_label,
description="The backend to use for TTS",
).model_dump(),
}
if self.ready and self.enabled:
details["voice"] = AgentDetail(
icon="mdi-account-voice",
value=self.voice_id_to_label(self.default_voice_id) or "",
description="The voice to use for TTS",
color="info",
).model_dump()
elif self.enabled:
details["error"] = AgentDetail(
icon="mdi-alert",
value=self.not_ready_reason,
description=self.not_ready_reason,
color="error",
).model_dump()
return details
@property
def api(self):
return self.actions["_config"].config["api"].value
@property
def api_label(self):
choices = self.actions["_config"].config["api"].choices
api = self.api
for choice in choices:
if choice["value"] == api:
return choice["label"]
return api
@property
def token(self):
api = self.api
return self.config.get(api, {}).get("api_key")
@property
def default_voice_id(self):
return self.actions["_config"].config["voice_id"].value
@property
def requires_token(self):
return self.api != "tts"
@property
def ready(self):
if self.api == "tts":
if not TTS:
return False
return True
return (not self.requires_token or self.token) and self.default_voice_id
@property
def status(self):
if not self.enabled:
return "disabled"
if self.ready:
if getattr(self, "processing_bg", 0) > 0:
return "busy_bg" if not getattr(self, "processing", False) else "busy"
return "active" if not getattr(self, "processing", False) else "busy"
if self.requires_token and not self.token:
return "error"
if self.api == "tts":
if not TTS:
return "error"
return "uninitialized"
@property
def max_generation_length(self):
if self.api == "elevenlabs":
return 1024
elif self.api == "coqui":
return 250
return 250
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
async def apply_config(self, *args, **kwargs):
try:
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
except KeyError:
api = self.api
api_changed = api != self.api
# log.debug(
# "apply_config",
# api=api,
# api_changed=api != self.api,
# current_api=self.api,
# args=args,
# kwargs=kwargs,
# )
try:
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
"value"
]
except KeyError:
self.preselect_voice = self.default_voice_id
await super().apply_config(*args, **kwargs)
if api_changed:
try:
self.actions["_config"].config["voice_id"].value = (
self.voices[api].voices[0].value
)
except IndexError:
self.actions["_config"].config["voice_id"].value = ""
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_new_message").connect(
self.on_game_loop_new_message
)
def on_config_saved(self, event):
config = event.data
self.config = config
instance.emit_agent_status(self.__class__, self)
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if not self.enabled or not self.ready:
return
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
return
if (
isinstance(emission.message, NarratorMessage)
and not self.actions["_config"].config["generate_for_narration"].value
):
return
if isinstance(emission.message, CharacterMessage):
if (
emission.message.source == "player"
and not self.actions["_config"].config["generate_for_player"].value
):
return
elif (
emission.message.source == "ai"
and not self.actions["_config"].config["generate_for_npc"].value
):
return
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
else:
character_prefix = ""
log.info(
"reactive tts", message=emission.message, character_prefix=character_prefix
)
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
def voice(self, voice_id: str) -> Union[Voice, None]:
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice
return None
def voice_id_to_label(self, voice_id: str):
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice.label
return None
def list_voices_sync(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.list_voices())
async def list_voices(self):
if self.requires_token and not self.token:
return []
library = self.voices[self.api]
# TODO: allow re-syncing voices
if library.last_synced:
return library.voices
list_fn = getattr(self, f"_list_voices_{self.api}")
log.info("Listing voices", api=self.api)
library.voices = await list_fn()
library.last_synced = time.time()
if self.preselect_voice:
if self.voice(self.preselect_voice):
self.actions["_config"].config["voice_id"].value = self.preselect_voice
self.preselect_voice = None
# if the current voice cannot be found, reset it
if not self.voice(self.default_voice_id):
self.actions["_config"].config["voice_id"].value = ""
# set loading to false
return library.voices
@set_processing
async def generate(self, text: str):
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
generate_fn = getattr(self, f"_generate_{self.api}")
if self.actions["_config"].config["generate_chunks"].value:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks)
else:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks, chunk_size=self.max_generation_length)
# Start generating audio chunks in the background
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
await self.set_background_processing(generation_task)
# Wait for both tasks to complete
# await asyncio.gather(generation_task)
async def generate_chunks(self, generate_fn, chunks):
for chunk in chunks:
chunk = chunk.replace("*", "").strip()
log.info("Generating audio", api=self.api, chunk=chunk)
audio_data = await generate_fn(chunk)
self.play_audio(audio_data)
def play_audio(self, audio_data):
# play audio through the python audio player
# play(audio_data)
emit(
"audio_queue",
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
)
self.playback_done_event.set() # Signal that playback is finished
# LOCAL
async def _generate_tts(self, text: str) -> Union[bytes, None]:
if not TTS:
return
tts_config = self.config.get("tts", {})
model = tts_config.get("model")
device = tts_config.get("device", "cpu")
log.debug("tts local", model=model, device=device)
if not hasattr(self, "tts_instance"):
self.tts_instance = TTS(model).to(device)
tts = self.tts_instance
loop = asyncio.get_event_loop()
voice = self.voice(self.default_voice_id)
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(
None,
functools.partial(
tts.tts_to_file,
text=text,
speaker_wav=voice.value,
language="en",
file_path=file_path,
),
)
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
with open(file_path, "rb") as f:
return f.read()
async def _list_voices_tts(self) -> dict[str, str]:
return [
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
]
# ELEVENLABS
async def _generate_elevenlabs(
self, text: str, chunk_size: int = 1024
) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
async with httpx.AsyncClient() as client:
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.default_voice_id}"
headers = {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": api_key,
}
data = {
"text": text,
"model_id": self.config.get("elevenlabs", {}).get("model"),
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
}
response = await client.post(url, json=data, headers=headers, timeout=300)
if response.status_code == 200:
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()
else:
log.error(f"Error generating audio: {response.text}")
async def _list_voices_elevenlabs(self) -> dict[str, str]:
url_voices = "https://api.elevenlabs.io/v1/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Accept": "application/json",
"xi-api-key": self.token,
}
response = await client.get(
url_voices, headers=headers, params={"per_page": 1000}
)
speakers = response.json()["voices"]
voices.extend(
[
Voice(value=speaker["voice_id"], label=speaker["name"])
for speaker in speakers
]
)
# sort by name
voices.sort(key=lambda x: x.label)
return voices
# OPENAI
async def _generate_openai(self, text: str, chunk_size: int = 1024):
client = AsyncOpenAI(api_key=self.openai_api_key)
model = self.actions["openai"].config["model"].value
response = await client.audio.speech.create(
model=model, voice=self.default_voice_id, input=text
)
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()
async def _list_voices_openai(self) -> dict[str, str]:
return [
Voice(value="alloy", label="Alloy"),
Voice(value="echo", label="Echo"),
Voice(value="fable", label="Fable"),
Voice(value="onyx", label="Onyx"),
Voice(value="nova", label="Nova"),
Voice(value="shimmer", label="Shimmer"),
]

View File

@@ -0,0 +1,995 @@
from __future__ import annotations
import asyncio
import base64
import re
import traceback
from typing import TYPE_CHECKING
import uuid
from collections import deque
import structlog
from nltk.tokenize import sent_tokenize
import talemate.util.dialogue as dialogue_utils
import talemate.emit.async_signals as async_signals
import talemate.instance as instance
from talemate.ux.schema import Note
from talemate.emit import emit
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import (
CharacterMessage,
NarratorMessage,
ContextInvestigationMessage,
)
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
AgentDetail,
AgentActionNote,
set_processing,
)
from talemate.agents.registry import register
from .schema import (
APIStatus,
Voice,
VoiceLibrary,
GenerationContext,
Chunk,
VoiceGenerationEmission,
)
from .providers import provider
import talemate.agents.tts.voice_library as voice_library
from .elevenlabs import ElevenLabsMixin
from .openai import OpenAIMixin
from .google import GoogleMixin
from .kokoro import KokoroMixin
from .chatterbox import ChatterboxMixin
from .websocket_handler import TTSWebsocketHandler
from .f5tts import F5TTSMixin
import talemate.agents.tts.nodes as tts_nodes # noqa: F401
if TYPE_CHECKING:
from talemate.character import Character, VoiceChangedEvent
from talemate.agents.summarize import SummarizeAgent
from talemate.game.engine.nodes.scene import SceneLoopEvent
log = structlog.get_logger("talemate.agents.tts")
HOT_SWAP_NOTIFICATION_TIME = 60
VOICE_LIBRARY_NOTE = "Voices are not managed here, but in the voice library which can be accessed through the Talemate application bar at the top. When disabling/enabling APIS, close and open this window to refresh the choices."
async_signals.register(
"agent.tts.prepare.before",
"agent.tts.prepare.after",
"agent.tts.generate.before",
"agent.tts.generate.after",
)
def parse_chunks(text: str) -> list[str]:
"""
Takes a string and splits it into chunks based on punctuation.
In case of an error it will return the original text as a single chunk and
the error will be logged.
"""
try:
text = text.replace("*", "")
# ensure sentence terminators are before quotes
# otherwise the beginning of dialog will bleed into narration
text = re.sub(r'([^.?!]+) "', r'\1. "', text)
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
if not chunk.strip():
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
except Exception as e:
log.error("chunking error", error=e, text=text)
return [text.replace("__ellipsis__", "...").replace("*", "")]
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(current_chunk)
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(current_chunk)
return joined_chunks
@register()
class TTSAgent(
ElevenLabsMixin,
OpenAIMixin,
GoogleMixin,
KokoroMixin,
ChatterboxMixin,
F5TTSMixin,
Agent,
):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Voice"
requires_llm_client = False
essential = False
# websocket handler for frontend voice library management
websocket_handler = TTSWebsocketHandler
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if not agent:
return config_options
narrator_voice_id = config_options["actions"]["_config"]["config"][
"narrator_voice_id"
]
narrator_voice_id["choices"] = cls.narrator_voice_id_choices(agent)
return config_options
@classmethod
def narrator_voice_id_choices(cls, agent: "TTSAgent") -> list[dict[str, str]]:
choices = voice_library.voices_for_apis(agent.ready_apis, agent.voice_library)
choices.sort(key=lambda x: x.label)
return [
{
"label": f"{voice.label} ({voice.provider})",
"value": voice.id,
}
for voice in choices
]
@classmethod
def init_actions(cls) -> dict[str, AgentAction]:
actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"apis": AgentActionConfig(
type="flags",
value=[
"kokoro",
],
label="Enabled APIs",
description="APIs to use for TTS",
choices=[],
),
"narrator_voice_id": AgentActionConfig(
type="autocomplete",
value="kokoro:am_adam",
label="Narrator Voice",
description="Voice to use for narration",
choices=[],
note=VOICE_LIBRARY_NOTE,
),
"speaker_separation": AgentActionConfig(
type="text",
value="simple",
label="Speaker separation",
description="How to separate speaker dialogue from exposition",
choices=[
{"label": "No separation", "value": "none"},
{"label": "Simple", "value": "simple"},
{"label": "Mixed", "value": "mixed"},
{"label": "AI assisted", "value": "ai_assisted"},
],
note_on_value={
"none": AgentActionNote(
type="primary",
text="Character messages will be voiced entirely by the character's voice with a fallback to the narrator voice if the character has no voice selecte. Narrator messages will be voiced exclusively by the narrator voice.",
),
"simple": AgentActionNote(
type="primary",
text="Exposition and dialogue will be separated in character messages. Narrator messages will be voiced exclusively by the narrator voice. This means",
),
"mixed": AgentActionNote(
type="primary",
text="A mix of `simple` and `ai_assisted`. Character messages are separated into narrator and the character's voice. Narrator messages that have dialogue are analyzed by the Summarizer agent to determine the appropriate speaker(s).",
),
"ai_assisted": AgentActionNote(
type="primary",
text="Appropriate speaker separation will be attempted based on the content of the message with help from the Summarizer agent. This sends an extra prompt to the LLM to determine the appropriate speaker(s).",
),
},
),
"generate_for_player": AgentActionConfig(
type="bool",
value=False,
label="Auto-generate for player",
description="Generate audio for player messages",
),
"generate_for_npc": AgentActionConfig(
type="bool",
value=True,
label="Auto-generate for AI characters",
description="Generate audio for NPC messages",
),
"generate_for_narration": AgentActionConfig(
type="bool",
value=True,
label="Auto-generate for narration",
description="Generate audio for narration messages",
),
"generate_for_context_investigation": AgentActionConfig(
type="bool",
value=True,
label="Auto-generate for context investigation",
description="Generate audio for context investigation messages",
),
},
),
}
KokoroMixin.add_actions(actions)
ChatterboxMixin.add_actions(actions)
GoogleMixin.add_actions(actions)
ElevenLabsMixin.add_actions(actions)
OpenAIMixin.add_actions(actions)
F5TTSMixin.add_actions(actions)
return actions
def __init__(self, **kwargs):
self.is_enabled = False # tts agent is disabled by default
self.actions = TTSAgent.init_actions()
self.playback_done_event = asyncio.Event()
# Queue management for voice generation
# Each queue instance gets a unique id so it can later be referenced
# (e.g. for cancellation of all remaining items).
# Only one queue can be active at a time. New generation requests that
# arrive while a queue is processing will be appended to the same
# queue. Once the queue is fully processed it is discarded and a new
# one will be created for subsequent generation requests.
# Queue now holds individual (context, chunk) pairs so interruption can
# happen between chunks even when a single context produced many.
self._generation_queue: deque[tuple[GenerationContext, Chunk]] = deque()
self._queue_id: str | None = None
self._queue_task: asyncio.Task | None = None
self._queue_lock = asyncio.Lock()
# general helpers
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def voice_library(self) -> VoiceLibrary:
return voice_library.get_instance()
# config helpers
@property
def narrator_voice_id(self) -> str:
return self.actions["_config"].config["narrator_voice_id"].value
@property
def generate_for_player(self) -> bool:
return self.actions["_config"].config["generate_for_player"].value
@property
def generate_for_npc(self) -> bool:
return self.actions["_config"].config["generate_for_npc"].value
@property
def generate_for_narration(self) -> bool:
return self.actions["_config"].config["generate_for_narration"].value
@property
def generate_for_context_investigation(self) -> bool:
return (
self.actions["_config"].config["generate_for_context_investigation"].value
)
@property
def speaker_separation(self) -> str:
return self.actions["_config"].config["speaker_separation"].value
@property
def apis(self) -> list[str]:
return self.actions["_config"].config["apis"].value
@property
def all_apis(self) -> list[str]:
return [api["value"] for api in self.actions["_config"].config["apis"].choices]
@property
def agent_details(self):
details = {}
self.actions["_config"].config[
"narrator_voice_id"
].choices = self.narrator_voice_id_choices(self)
if not self.enabled:
return details
used_apis: set[str] = set()
used_disabled_apis: set[str] = set()
if self.narrator_voice:
#
label = self.narrator_voice.label
color = "primary"
used_apis.add(self.narrator_voice.provider)
if not self.api_enabled(self.narrator_voice.provider):
used_disabled_apis.add(self.narrator_voice.provider)
if not self.api_ready(self.narrator_voice.provider):
color = "error"
details["narrator_voice"] = AgentDetail(
icon="mdi-script-text",
value=label,
description="Default voice",
color=color,
).model_dump()
scene = getattr(self, "scene", None)
if scene:
for character in scene.characters:
if character.voice:
label = character.voice.label
color = "primary"
used_apis.add(character.voice.provider)
if not self.api_enabled(character.voice.provider):
used_disabled_apis.add(character.voice.provider)
if not self.api_ready(character.voice.provider):
color = "error"
details[f"{character.name}_voice"] = AgentDetail(
icon="mdi-account-voice",
value=f"{character.name}",
description=f"{character.name}'s voice: {label} ({character.voice.provider})",
color=color,
).model_dump()
for api in used_disabled_apis:
details[f"{api}_disabled"] = AgentDetail(
icon="mdi-alert-circle",
value=f"{api} disabled",
description=f"{api} disabled - at least one voice is attempting to use this api but is not enabled",
color="error",
).model_dump()
for api in used_apis:
fn = getattr(self, f"{api}_agent_details", None)
if fn:
details.update(fn)
return details
@property
def status(self):
if not self.enabled:
return "disabled"
if self.ready:
if getattr(self, "processing_bg", 0) > 0:
return "busy_bg" if not getattr(self, "processing", False) else "busy"
return "idle" if not getattr(self, "processing", False) else "busy"
return "uninitialized"
@property
def narrator_voice(self) -> Voice | None:
return self.voice_library.get_voice(self.narrator_voice_id)
@property
def api_status(self) -> list[APIStatus]:
api_status: list[APIStatus] = []
for api in self.all_apis:
not_configured_reason = getattr(self, f"{api}_not_configured_reason", None)
not_configured_action = getattr(self, f"{api}_not_configured_action", None)
api_info: str | None = getattr(self, f"{api}_info", None)
messages: list[Note] = []
if not_configured_reason:
messages.append(
Note(
text=not_configured_reason,
color="error",
icon="mdi-alert-circle-outline",
actions=[not_configured_action]
if not_configured_action
else None,
)
)
if api_info:
messages.append(
Note(
text=api_info.strip(),
color="muted",
icon="mdi-information-outline",
)
)
_status = APIStatus(
api=api,
enabled=self.api_enabled(api),
ready=self.api_ready(api),
configured=self.api_configured(api),
messages=messages,
supports_mixing=getattr(self, f"{api}_supports_mixing", False),
provider=provider(api),
default_model=getattr(self, f"{api}_model", None),
model_choices=getattr(self, f"{api}_model_choices", []),
)
api_status.append(_status)
# order by api
api_status.sort(key=lambda x: x.api)
return api_status
# events
def connect(self, scene):
super().connect(scene)
async_signals.get("game_loop_new_message").connect(
self.on_game_loop_new_message
)
async_signals.get("voice_library.update.after").connect(
self.on_voice_library_update
)
async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init)
async_signals.get("character.voice_changed").connect(
self.on_character_voice_changed
)
async def on_scene_loop_init(self, event: "SceneLoopEvent"):
if not self.enabled or not self.ready or not self.generate_for_narration:
return
if self.scene.environment == "creative":
return
content_messages = self.scene.last_message_of_type(
["character", "narrator", "context_investigation"]
)
if content_messages:
# we already have a history, so we don't need to generate TTS for the intro
return
await self.generate(self.scene.get_intro(), character=None)
async def on_voice_library_update(self, voice_library: VoiceLibrary):
log.debug("Voice library updated - refreshing narrator voice choices")
self.actions["_config"].config[
"narrator_voice_id"
].choices = self.narrator_voice_id_choices(self)
await self.emit_status()
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if self.scene.environment == "creative":
return
character: Character | None = None
if not self.enabled or not self.ready:
return
if not isinstance(
emission.message,
(CharacterMessage, NarratorMessage, ContextInvestigationMessage),
):
return
if (
isinstance(emission.message, NarratorMessage)
and not self.generate_for_narration
):
return
if (
isinstance(emission.message, ContextInvestigationMessage)
and not self.generate_for_context_investigation
):
return
if isinstance(emission.message, CharacterMessage):
if emission.message.source == "player" and not self.generate_for_player:
return
elif emission.message.source == "ai" and not self.generate_for_npc:
return
character = self.scene.get_character(emission.message.character_name)
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
text_to_generate = str(emission.message).replace(
character_prefix + ": ", ""
)
elif isinstance(emission.message, ContextInvestigationMessage):
character_prefix = ""
text_to_generate = (
emission.message.message
) # Use just the message content, not the title prefix
else:
character_prefix = ""
text_to_generate = str(emission.message)
log.info(
"reactive tts", message=emission.message, character_prefix=character_prefix
)
await self.generate(
text_to_generate,
character=character,
message=emission.message,
)
async def on_character_voice_changed(self, event: "VoiceChangedEvent"):
log.debug(
"Character voice changed", character=event.character, voice=event.voice
)
await self.emit_status()
# voice helpers
@property
def ready_apis(self) -> list[str]:
"""
Returns a list of apis that are ready
"""
return [api for api in self.apis if self.api_ready(api)]
@property
def used_apis(self) -> list[str]:
"""
Returns a list of apis that are in use
The api is in use if it is the narrator voice or if any of the active characters in the scene use a voice from the api.
"""
return [api for api in self.apis if self.api_used(api)]
def api_enabled(self, api: str) -> bool:
"""
Returns whether the api is currently in the .apis list, which means it is enabled.
"""
return api in self.apis
def api_ready(self, api: str) -> bool:
"""
Returns whether the api is ready.
The api must be enabled and configured.
"""
if not self.api_enabled(api):
return False
return self.api_configured(api)
def api_configured(self, api: str) -> bool:
return getattr(self, f"{api}_configured", True)
def api_used(self, api: str) -> bool:
"""
Returns whether the narrator or any of the active characters in the scene
use a voice from the given api
Args:
api (str): The api to check
Returns:
bool: Whether the api is in use
"""
if self.narrator_voice and self.narrator_voice.provider == api:
return True
if not getattr(self, "scene", None):
return False
for character in self.scene.characters:
if not character.voice:
continue
voice = self.voice_library.get_voice(character.voice.id)
if voice and voice.provider == api:
return True
return False
def use_ai_assisted_speaker_separation(
self,
text: str,
message: CharacterMessage
| NarratorMessage
| ContextInvestigationMessage
| None,
) -> bool:
"""
Returns whether the ai assisted speaker separation should be used for the given text.
"""
try:
if not message and '"' not in text:
return False
if not message and '"' in text:
return self.speaker_separation in ["ai_assisted", "mixed"]
if message.source == "player":
return False
if self.speaker_separation == "ai_assisted":
return True
if (
isinstance(message, NarratorMessage)
and self.speaker_separation == "mixed"
):
return True
return False
except Exception as e:
log.error(
"Error using ai assisted speaker separation",
error=e,
traceback=traceback.format_exc(),
)
return False
# tts markup cache
async def get_tts_markup_cache(self, text: str) -> str | None:
"""
Returns the cached tts markup for the given text.
"""
fp = hash(text)
cached_markup = self.get_scene_state("tts_markup_cache")
if cached_markup and cached_markup.get("fp") == fp:
return cached_markup.get("markup")
return None
async def set_tts_markup_cache(self, text: str, markup: str):
fp = hash(text)
self.set_scene_states(
tts_markup_cache={
"fp": fp,
"markup": markup,
}
)
# generation
@set_processing
async def generate(
self,
text: str,
character: Character | None = None,
force_voice: Voice | None = None,
message: CharacterMessage | NarratorMessage | None = None,
):
"""
Public entry-point for voice generation.
The actual audio generation happens sequentially inside a single
background queue. If a queue is currently active, we simply append the
new request to it; if not, we create a new queue (with its own unique
id) and start processing.
"""
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
summarizer: "SummarizeAgent" = instance.get_agent("summarizer")
context = GenerationContext(voice_id=self.narrator_voice_id)
character_voice: Voice = force_voice or self.narrator_voice
if character and character.voice:
voice = character.voice
if voice and self.api_ready(voice.provider):
character_voice = voice
else:
log.warning(
"Character voice not available",
character=character.name,
voice=character.voice,
)
log.debug("Voice routing", character=character, voice=character_voice)
# initial chunking by separating dialogue from exposition
chunks: list[Chunk] = []
if self.speaker_separation != "none":
if self.use_ai_assisted_speaker_separation(text, message):
markup = await self.get_tts_markup_cache(text)
if not markup:
log.debug("No markup cache found, generating markup")
markup = await summarizer.markup_context_for_tts(text)
await self.set_tts_markup_cache(text, markup)
else:
log.debug("Using markup cache")
# Use the new markup parser for AI-assisted format
dlg_chunks = dialogue_utils.parse_tts_markup(markup)
else:
# Use the original parser for non-AI-assisted format
dlg_chunks = dialogue_utils.separate_dialogue_from_exposition(text)
for _dlg_chunk in dlg_chunks:
_voice = (
character_voice
if _dlg_chunk.type == "dialogue"
else self.narrator_voice
)
if _dlg_chunk.speaker is not None:
# speaker name has been identified
_character = self.scene.get_character(_dlg_chunk.speaker)
log.debug(
"Identified speaker",
speaker=_dlg_chunk.speaker,
character=_character,
)
if (
_character
and _character.voice
and self.api_ready(_character.voice.provider)
):
log.debug(
"Using character voice",
character=_character.name,
voice=_character.voice,
)
_voice = _character.voice
_api: str = _voice.provider if _voice else self.api
chunk = Chunk(
api=_api,
voice=Voice(**_voice.model_dump()),
model=_voice.provider_model,
generate_fn=getattr(self, f"{_api}_generate"),
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
character_name=character.name if character else None,
text=[_dlg_chunk.text],
type=_dlg_chunk.type,
message_id=message.id if message else None,
)
chunks.append(chunk)
else:
_voice = character_voice if character else self.narrator_voice
_api: str = _voice.provider if _voice else self.api
chunks = [
Chunk(
api=_api,
voice=Voice(**_voice.model_dump()),
model=_voice.provider_model,
generate_fn=getattr(self, f"{_api}_generate"),
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
character_name=character.name if character else None,
text=[text],
type="dialogue" if character else "exposition",
message_id=message.id if message else None,
)
]
# second chunking by splitting into chunks of max_generation_length
for chunk in chunks:
api_chunk_size = getattr(self, f"{chunk.api}_chunk_size", 0)
log.debug("chunking", api=chunk.api, api_chunk_size=api_chunk_size)
_text = []
max_generation_length = getattr(self, f"{chunk.api}_max_generation_length")
if api_chunk_size > 0:
max_generation_length = min(max_generation_length, api_chunk_size)
for _chunk_text in chunk.text:
if len(_chunk_text) <= max_generation_length:
_text.append(_chunk_text)
continue
_parsed = parse_chunks(_chunk_text)
_joined = rejoin_chunks(_parsed, chunk_size=max_generation_length)
_text.extend(_joined)
log.debug("chunked for size", before=chunk.text, after=_text)
chunk.text = _text
context.chunks = chunks
# Enqueue each chunk individually for fine-grained interruptibility
async with self._queue_lock:
if self._queue_id is None:
self._queue_id = str(uuid.uuid4())
for chunk in context.chunks:
self._generation_queue.append((context, chunk))
# Start processing task if needed
if self._queue_task is None or self._queue_task.done():
self._queue_task = asyncio.create_task(
self._process_queue(self._queue_id)
)
log.debug(
"tts queue enqueue",
queue_id=self._queue_id,
total_items=len(self._generation_queue),
)
# The caller doesn't need to wait for the queue to finish; it runs in
# the background. We still register the task with Talemate's
# background-processing tracking so that UI can reflect activity.
await self.set_background_processing(self._queue_task)
# ---------------------------------------------------------------------
# Queue helpers
# ---------------------------------------------------------------------
async def _process_queue(self, queue_id: str):
"""Sequentially processes all GenerationContext objects in the queue.
Once the last context has been processed the queue state is reset so a
future generation call will create a new queue (and therefore a new
id). The *queue_id* argument allows us to later add cancellation logic
that can target a specific queue instance.
"""
try:
while True:
async with self._queue_lock:
if not self._generation_queue:
break
context, chunk = self._generation_queue.popleft()
log.debug(
"tts queue dequeue",
queue_id=queue_id,
total_items=len(self._generation_queue),
chunk_type=chunk.type,
)
# Process outside lock so other coroutines can enqueue
await self._generate_chunk(chunk, context)
except Exception as e:
log.error(
"Error processing queue", error=e, traceback=traceback.format_exc()
)
finally:
# Clean up queue state after finishing (or on cancellation)
async with self._queue_lock:
if queue_id == self._queue_id:
self._queue_id = None
self._queue_task = None
self._generation_queue.clear()
# Public helper so external code (e.g. later cancellation UI) can find the current queue id
def current_queue_id(self) -> str | None:
return self._queue_id
async def _generate_chunk(self, chunk: Chunk, context: GenerationContext):
"""Generate audio for a single chunk (all its sub-chunks)."""
for _chunk in chunk.sub_chunks:
if not _chunk.cleaned_text.strip():
continue
emission: VoiceGenerationEmission = VoiceGenerationEmission(
chunk=_chunk, context=context
)
if _chunk.prepare_fn:
await async_signals.get("agent.tts.prepare.before").send(emission)
await _chunk.prepare_fn(_chunk)
await async_signals.get("agent.tts.prepare.after").send(emission)
log.info(
"Generating audio",
api=chunk.api,
text=_chunk.cleaned_text,
parameters=_chunk.voice.parameters,
prepare_fn=_chunk.prepare_fn,
)
await async_signals.get("agent.tts.generate.before").send(emission)
try:
emission.wav_bytes = await _chunk.generate_fn(_chunk, context)
except Exception as e:
log.error("Error generating audio", error=e, chunk=_chunk)
continue
await async_signals.get("agent.tts.generate.after").send(emission)
self.play_audio(emission.wav_bytes, chunk.message_id)
await asyncio.sleep(0.1)
# Deprecated: kept for backward compatibility but no longer used.
async def generate_chunks(self, context: GenerationContext):
for chunk in context.chunks:
await self._generate_chunk(chunk, context)
def play_audio(self, audio_data, message_id: int | None = None):
# play audio through the websocket (browser)
audio_data_encoded: str = base64.b64encode(audio_data).decode("utf-8")
emit(
"audio_queue",
data={"audio_data": audio_data_encoded, "message_id": message_id},
)
self.playback_done_event.set() # Signal that playback is finished
async def stop_and_clear_queue(self):
"""Cancel any ongoing generation and clear the pending queue.
This is triggered by UI actions that request immediate stop of TTS
synthesis and playback. It cancels the background task (if still
running) and clears all queued items in a thread-safe manner.
"""
async with self._queue_lock:
# Clear all queued items
self._generation_queue.clear()
# Cancel the background task if it is still running
if self._queue_task and not self._queue_task.done():
self._queue_task.cancel()
# Reset queue identifiers/state
self._queue_id = None
self._queue_task = None
# Ensure downstream components know playback is finished
self.playback_done_event.set()

View File

@@ -0,0 +1,317 @@
import os
import functools
import tempfile
import uuid
import asyncio
import structlog
import pydantic
import torch
# Lazy imports for heavy dependencies
def _import_heavy_deps():
global ta, ChatterboxTTS
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
CUDA_AVAILABLE = torch.cuda.is_available()
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
AgentDetail,
)
from talemate.ux.schema import Field
from .schema import Voice, Chunk, GenerationContext, VoiceProvider, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
from .providers import register, provider
from .util import voice_is_talemate_asset
log = structlog.get_logger("talemate.agents.tts.chatterbox")
add_default_voices(
[
Voice(
label="Eva",
provider="chatterbox",
provider_id="tts/voice/chatterbox/eva.wav",
tags=["female", "calm", "mature", "thoughtful"],
),
Voice(
label="Lisa",
provider="chatterbox",
provider_id="tts/voice/chatterbox/lisa.wav",
tags=["female", "energetic", "young"],
),
Voice(
label="Adam",
provider="chatterbox",
provider_id="tts/voice/chatterbox/adam.wav",
tags=["male", "calm", "mature", "thoughtful", "deep"],
),
Voice(
label="Bradford",
provider="chatterbox",
provider_id="tts/voice/chatterbox/bradford.wav",
tags=["male", "calm", "mature", "thoughtful", "deep"],
),
Voice(
label="Julia",
provider="chatterbox",
provider_id="tts/voice/chatterbox/julia.wav",
tags=["female", "calm", "mature"],
),
Voice(
label="Zoe",
provider="chatterbox",
provider_id="tts/voice/chatterbox/zoe.wav",
tags=["female"],
),
Voice(
label="William",
provider="chatterbox",
provider_id="tts/voice/chatterbox/william.wav",
tags=["male", "young"],
),
]
)
CHATTERBOX_INFO = """
Chatterbox is a local text to speech model.
The voice id is the path to the .wav file for the voice.
The path can be relative to the talemate root directory, and you can put new *.wav samples
in the `tts/voice/chatterbox` directory. It is also ok if you want to load the files from somewhere else as long as the filepath is available to the talemate backend.
First generation will download the models (2.13GB + 1.06GB).
Uses about 4GB of VRAM.
"""
@register()
class ChatterboxProvider(VoiceProvider):
name: str = "chatterbox"
allow_model_override: bool = False
allow_file_upload: bool = True
upload_file_types: list[str] = ["audio/wav"]
voice_parameters: list[Field] = [
Field(
name="exaggeration",
type="number",
label="Exaggeration level",
value=0.5,
min=0.25,
max=2.0,
step=0.05,
),
Field(
name="cfg_weight",
type="number",
label="CFG/Pace",
value=0.5,
min=0.2,
max=1.0,
step=0.1,
),
Field(
name="temperature",
type="number",
label="Temperature",
value=0.8,
min=0.05,
max=5.0,
step=0.05,
),
]
class ChatterboxInstance(pydantic.BaseModel):
model: "ChatterboxTTS"
device: str
class Config:
arbitrary_types_allowed = True
class ChatterboxMixin:
"""
Chatterbox agent mixin for local text to speech.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "chatterbox",
"label": "Chatterbox (Local)",
"help": "Chatterbox is a local text to speech model.",
}
)
actions["chatterbox"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="Chatterbox",
description="Chatterbox is a local text to speech model.",
config={
"device": AgentActionConfig(
type="text",
value="cuda" if CUDA_AVAILABLE else "cpu",
label="Device",
choices=[
{"value": "cpu", "label": "CPU"},
{"value": "cuda", "label": "CUDA"},
],
description="Device to use for TTS",
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=256,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@property
def chatterbox_configured(self) -> bool:
return True
@property
def chatterbox_max_generation_length(self) -> int:
return 512
@property
def chatterbox_device(self) -> str:
return self.actions["chatterbox"].config["device"].value
@property
def chatterbox_chunk_size(self) -> int:
return self.actions["chatterbox"].config["chunk_size"].value
@property
def chatterbox_info(self) -> str:
return CHATTERBOX_INFO
@property
def chatterbox_agent_details(self) -> dict:
if not self.chatterbox_configured:
return {}
details = {}
details["chatterbox_device"] = AgentDetail(
icon="mdi-memory",
value=f"Chatterbox: {self.chatterbox_device}",
description="The device to use for Chatterbox",
).model_dump()
return details
def chatterbox_delete_voice(self, voice: Voice):
"""
Remove the voice from the file system.
Only do this if the path is within TALEMATE_ROOT.
"""
is_talemate_asset, resolved = voice_is_talemate_asset(
voice, provider(voice.provider)
)
log.debug(
"chatterbox_delete_voice",
voice_id=voice.provider_id,
is_talemate_asset=is_talemate_asset,
resolved=resolved,
)
if not is_talemate_asset:
return
try:
if resolved.exists() and resolved.is_file():
resolved.unlink()
log.debug("Deleted chatterbox voice file", path=str(resolved))
except Exception as e:
log.error(
"Failed to delete chatterbox voice file", error=e, path=str(resolved)
)
def _chatterbox_generate_file(
self,
model: "ChatterboxTTS",
text: str,
audio_prompt_path: str,
output_path: str,
**kwargs,
):
wav = model.generate(text=text, audio_prompt_path=audio_prompt_path, **kwargs)
ta.save(output_path, wav, model.sr)
return output_path
async def chatterbox_generate(
self, chunk: Chunk, context: GenerationContext
) -> bytes | None:
chatterbox_instance: ChatterboxInstance | None = getattr(
self, "chatterbox_instance", None
)
reload: bool = False
if not chatterbox_instance:
reload = True
elif chatterbox_instance.device != self.chatterbox_device:
reload = True
if reload:
log.debug(
"chatterbox - reinitializing tts instance",
device=self.chatterbox_device,
)
# Lazy import heavy dependencies only when needed
_import_heavy_deps()
self.chatterbox_instance = ChatterboxInstance(
model=ChatterboxTTS.from_pretrained(device=self.chatterbox_device),
device=self.chatterbox_device,
)
model: "ChatterboxTTS" = self.chatterbox_instance.model
loop = asyncio.get_event_loop()
voice = chunk.voice
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(
None,
functools.partial(
self._chatterbox_generate_file,
model=model,
text=chunk.cleaned_text,
audio_prompt_path=voice.provider_id,
output_path=file_path,
**voice.parameters,
),
)
with open(file_path, "rb") as f:
return f.read()
async def chatterbox_prepare_chunk(self, chunk: Chunk):
voice = chunk.voice
P = provider(voice.provider)
exaggeration = P.voice_parameter(voice, "exaggeration")
voice.parameters["exaggeration"] = exaggeration

View File

@@ -0,0 +1,248 @@
import io
from typing import Union
import structlog
# Lazy imports for heavy dependencies
def _import_heavy_deps():
global AsyncElevenLabs, ApiError
from elevenlabs.client import AsyncElevenLabs
# Added explicit ApiError import for clearer error handling
from elevenlabs.core.api_error import ApiError
from talemate.ux.schema import Action
from talemate.agents.base import (
AgentAction,
AgentActionConfig,
AgentDetail,
)
from .schema import Voice, VoiceLibrary, GenerationContext, Chunk, INFO_CHUNK_SIZE
from .voice_library import add_default_voices
# emit helper to propagate status messages to the UX
from talemate.emit import emit
log = structlog.get_logger("talemate.agents.tts.elevenlabs")
add_default_voices(
[
Voice(
label="Adam",
provider="elevenlabs",
provider_id="wBXNqKUATyqu0RtYt25i",
tags=["male", "deep"],
),
Voice(
label="Amy",
provider="elevenlabs",
provider_id="oGn4Ha2pe2vSJkmIJgLQ",
tags=["female"],
),
]
)
ELEVENLABS_INFO = """
ElevenLabs is a cloud-based text to speech API.
To add new voices, head to their voice library at [https://elevenlabs.io/app/voice-library](https://elevenlabs.io/app/voice-library) and note the voice id of the voice you want to use. (Click 'More Actions -> Copy Voice ID')
**About elevenlabs voices**
Your elevenlabs subscription allows you to maintain a set number of voices (10 for cheapest plan).
Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button above. It is recommend testing via their voice library instead.
"""
class ElevenLabsMixin:
"""
ElevenLabs TTS agent mixin for cloud-based text to speech.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_config"].config["apis"].choices.append(
{
"value": "elevenlabs",
"label": "ElevenLabs",
"help": "ElevenLabs is a cloud-based text to speech model that uses the ElevenLabs API. (API key required)",
}
)
actions["elevenlabs"] = AgentAction(
enabled=True,
container=True,
icon="mdi-server-outline",
label="ElevenLabs",
description="ElevenLabs is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> ElevenLabs)",
config={
"model": AgentActionConfig(
type="text",
value="eleven_flash_v2_5",
label="Model",
description="Model to use for TTS",
choices=[
{
"value": "eleven_multilingual_v2",
"label": "Eleven Multilingual V2",
},
{"value": "eleven_flash_v2_5", "label": "Eleven Flash V2.5"},
{"value": "eleven_turbo_v2_5", "label": "Eleven Turbo V2.5"},
],
),
"chunk_size": AgentActionConfig(
type="number",
min=0,
step=64,
max=2048,
value=0,
label="Chunk size",
note=INFO_CHUNK_SIZE,
),
},
)
return actions
@classmethod
def add_voices(cls, voices: dict[str, VoiceLibrary]):
voices["elevenlabs"] = VoiceLibrary(api="elevenlabs", local=True)
@property
def elevenlabs_chunk_size(self) -> int:
return self.actions["elevenlabs"].config["chunk_size"].value
@property
def elevenlabs_configured(self) -> bool:
api_key_set = bool(self.elevenlabs_api_key)
model_set = bool(self.elevenlabs_model)
return api_key_set and model_set
@property
def elevenlabs_not_configured_reason(self) -> str | None:
if not self.elevenlabs_api_key:
return "ElevenLabs API key not set"
if not self.elevenlabs_model:
return "ElevenLabs model not set"
return None
@property
def elevenlabs_not_configured_action(self) -> Action | None:
if not self.elevenlabs_api_key:
return Action(
action_name="openAppConfig",
arguments=["application", "elevenlabs_api"],
label="Set API Key",
icon="mdi-key",
)
if not self.elevenlabs_model:
return Action(
action_name="openAgentSettings",
arguments=["tts", "elevenlabs"],
label="Set Model",
icon="mdi-brain",
)
return None
@property
def elevenlabs_max_generation_length(self) -> int:
return 1024
@property
def elevenlabs_model(self) -> str:
return self.actions["elevenlabs"].config["model"].value
@property
def elevenlabs_model_choices(self) -> list[str]:
return [
{"label": choice["label"], "value": choice["value"]}
for choice in self.actions["elevenlabs"].config["model"].choices
]
@property
def elevenlabs_info(self) -> str:
return ELEVENLABS_INFO
@property
def elevenlabs_agent_details(self) -> dict:
details = {}
if not self.elevenlabs_configured:
details["elevenlabs_api_key"] = AgentDetail(
icon="mdi-key",
value="ElevenLabs API key not set",
description="ElevenLabs API key not set. You can set it in the Talemate Settings -> Application -> ElevenLabs",
color="error",
).model_dump()
else:
details["elevenlabs_model"] = AgentDetail(
icon="mdi-brain",
value=self.elevenlabs_model,
description="The model to use for ElevenLabs",
).model_dump()
return details
@property
def elevenlabs_api_key(self) -> str:
return self.config.elevenlabs.api_key
async def elevenlabs_generate(
self, chunk: Chunk, context: GenerationContext, chunk_size: int = 1024
) -> Union[bytes, None]:
api_key = self.elevenlabs_api_key
if not api_key:
return
# Lazy import heavy dependencies only when needed
_import_heavy_deps()
client = AsyncElevenLabs(api_key=api_key)
try:
response_async_iter = client.text_to_speech.convert(
text=chunk.cleaned_text,
voice_id=chunk.voice.provider_id,
model_id=chunk.model or self.elevenlabs_model,
)
bytes_io = io.BytesIO()
async for _chunk_bytes in response_async_iter:
if _chunk_bytes:
bytes_io.write(_chunk_bytes)
return bytes_io.getvalue()
except ApiError as e:
# Emit detailed status message to the frontend UI
error_message = "ElevenLabs API Error"
try:
# The ElevenLabs ApiError often contains a JSON body with details
detail = e.body.get("detail", {}) if hasattr(e, "body") else {}
error_message = detail.get("message", str(e)) or str(e)
except Exception:
error_message = str(e)
log.error("ElevenLabs API error", error=str(e))
emit(
"status",
message=f"ElevenLabs TTS: {error_message}",
status="error",
)
raise e
except Exception as e:
# Catch-all to ensure the app does not crash on unexpected errors
log.error("ElevenLabs TTS generation error", error=str(e))
emit(
"status",
message=f"ElevenLabs TTS: Unexpected error {str(e)}",
status="error",
)
raise e

Some files were not shown because too many files have changed in this diff Show More