Compare commits
19 Commits
0.31.0
...
prep-0.32.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
028ca360a4 | ||
|
|
bbb66d63fd | ||
|
|
0dc696301a | ||
|
|
23998628ad | ||
|
|
35f872f94d | ||
|
|
6c082cd4e3 | ||
|
|
30ec3038c3 | ||
|
|
62e1c4f653 | ||
|
|
daadf6f1d0 | ||
|
|
ec512512e6 | ||
|
|
990ad4c285 | ||
|
|
0f72a7ab86 | ||
|
|
e19ad23f1d | ||
|
|
8733d54735 | ||
|
|
ce4c302d73 | ||
|
|
685ca994f9 | ||
|
|
285b0699ab | ||
|
|
7825489cfc | ||
|
|
fb2fa31f13 |
53
.github/workflows/ci.yml
vendored
@@ -1,30 +1,57 @@
|
||||
name: ci
|
||||
name: ci
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- main
|
||||
- prep-0.26.0
|
||||
- master
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
container-build:
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Configure Git Credentials
|
||||
run: |
|
||||
git config user.name github-actions[bot]
|
||||
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
|
||||
- uses: actions/setup-python@v5
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build & push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}:latest
|
||||
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
|
||||
|
||||
deploy-docs:
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Configure Git credentials
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
- uses: actions/setup-python@v5
|
||||
with: { python-version: '3.x' }
|
||||
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
restore-keys: mkdocs-material-
|
||||
- run: pip install mkdocs-material mkdocs-awesome-pages-plugin mkdocs-glightbox
|
||||
- run: mkdocs gh-deploy --force
|
||||
32
.github/workflows/test-container-build.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
name: test-container-build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ 'prep-*' ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
container-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build & push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
# Tag with prep suffix to avoid conflicts with production
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}:${{ github.ref_name }}
|
||||
5
.github/workflows/test.yml
vendored
@@ -42,6 +42,11 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Run linting
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv run pre-commit run --all-files
|
||||
|
||||
- name: Setup configuration file
|
||||
run: |
|
||||
cp config.example.yaml config.yaml
|
||||
|
||||
12
.gitignore
vendored
@@ -8,11 +8,20 @@
|
||||
talemate_env
|
||||
chroma
|
||||
config.yaml
|
||||
.cursor
|
||||
.claude
|
||||
|
||||
# uv
|
||||
.venv/
|
||||
templates/llm-prompt/user/*.jinja2
|
||||
templates/world-state/*.yaml
|
||||
tts/voice/piper/*.onnx
|
||||
tts/voice/piper/*.json
|
||||
tts/voice/kokoro/*.pt
|
||||
tts/voice/xtts2/*.wav
|
||||
tts/voice/chatterbox/*.wav
|
||||
tts/voice/f5tts/*.wav
|
||||
tts/voice/voice-library.json
|
||||
scenes/
|
||||
!scenes/infinity-quest-dynamic-scenario/
|
||||
!scenes/infinity-quest-dynamic-scenario/assets/
|
||||
@@ -21,4 +30,5 @@ scenes/
|
||||
!scenes/infinity-quest/assets/
|
||||
!scenes/infinity-quest/infinity-quest.json
|
||||
tts_voice_samples/*.wav
|
||||
third-party-docs/
|
||||
third-party-docs/
|
||||
legacy-state-reinforcements.yaml
|
||||
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
fail_fast: false
|
||||
exclude: |
|
||||
(?x)^(
|
||||
tests/data/.*
|
||||
|install-utils/.*
|
||||
)$
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.12.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
64
CONTRIBUTING.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Contributing to Talemate
|
||||
|
||||
## About This Project
|
||||
|
||||
Talemate is a **personal hobbyist project** that I maintain in my spare time. While I appreciate the community's interest and contributions, please understand that:
|
||||
|
||||
- This is primarily a passion project that I enjoy working on myself
|
||||
- I have limited time for code reviews and prefer to spend that time developing fixes or new features myself
|
||||
- Large contributions require significant review and testing time that takes away from my own development
|
||||
|
||||
For these reasons, I've established contribution guidelines that balance community involvement with my desire to actively develop the project myself.
|
||||
|
||||
## Contribution Policy
|
||||
|
||||
**I welcome small bugfix and small feature pull requests!** If you've found a bug and have a fix, or have a small feature improvement, I'd love to review it.
|
||||
|
||||
However, please note that **I am not accepting large refactors or major feature additions** at this time. This includes:
|
||||
- Major architectural changes
|
||||
- Large new features or significant functionality additions
|
||||
- Large-scale code reorganization
|
||||
- Breaking API changes
|
||||
- Features that would require significant maintenance
|
||||
|
||||
## What is accepted
|
||||
|
||||
✅ **Small bugfixes** - Fixes for specific, isolated bugs
|
||||
|
||||
✅ **Small features** - Minor improvements that don't break existing functionality
|
||||
|
||||
✅ **Documentation fixes** - Typo corrections, clarifications in existing docs
|
||||
|
||||
✅ **Minor dependency updates** - Security patches or minor version bumps
|
||||
|
||||
## What is not accepted
|
||||
|
||||
❌ **Major features** - Large new functionality or systems
|
||||
|
||||
❌ **Large refactors** - Code reorganization or architectural changes
|
||||
|
||||
❌ **Breaking changes** - Any changes that break existing functionality
|
||||
|
||||
❌ **Major dependency changes** - Framework upgrades or replacements
|
||||
|
||||
## Submitting a PR
|
||||
|
||||
If you'd like to submit a bugfix or small feature:
|
||||
|
||||
1. **Open an issue first** - Describe the bug you've found or feature you'd like to add
|
||||
2. **Keep it small** - Focus on one specific issue or small improvement
|
||||
3. **Follow existing code style** - Match the project's current patterns
|
||||
4. **Don't break existing functionality** - Ensure all existing tests pass
|
||||
5. **Include tests** - Add or update tests that verify your fix or feature
|
||||
6. **Update documentation** - If your changes affect behavior, update relevant docs
|
||||
|
||||
## Testing
|
||||
|
||||
Ensure all tests pass by running:
|
||||
```bash
|
||||
uv run pytest tests/ -p no:warnings
|
||||
```
|
||||
|
||||
## Questions?
|
||||
|
||||
If you're unsure whether your contribution would be welcome, please open an issue to discuss it first. This saves everyone time and ensures alignment with the project's direction.
|
||||
11
Dockerfile
@@ -35,18 +35,9 @@ COPY pyproject.toml uv.lock /app/
|
||||
# Copy the Python source code (needed for editable install)
|
||||
COPY ./src /app/src
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
# Create virtual environment and install dependencies (includes CUDA support via pyproject.toml)
|
||||
RUN uv sync
|
||||
|
||||
# Conditional PyTorch+CUDA install
|
||||
ARG CUDA_AVAILABLE=false
|
||||
RUN . /app/.venv/bin/activate && \
|
||||
if [ "$CUDA_AVAILABLE" = "true" ]; then \
|
||||
echo "Installing PyTorch with CUDA support..." && \
|
||||
uv pip uninstall torch torchaudio && \
|
||||
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
|
||||
fi
|
||||
|
||||
# Stage 3: Final image
|
||||
FROM python:3.11-slim
|
||||
|
||||
|
||||
20
docker-compose.manual.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
talemate:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-8080}:8080"
|
||||
- "${BACKEND_PORT:-5050}:5050"
|
||||
volumes:
|
||||
- ./config.yaml:/app/config.yaml
|
||||
- ./scenes:/app/scenes
|
||||
- ./templates:/app/templates
|
||||
- ./chroma:/app/chroma
|
||||
- ./tts:/app/tts
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONPATH=/app/src:$PYTHONPATH
|
||||
command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]
|
||||
@@ -2,11 +2,7 @@ version: '3.8'
|
||||
|
||||
services:
|
||||
talemate:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- CUDA_AVAILABLE=${CUDA_AVAILABLE:-false}
|
||||
image: ghcr.io/vegu-ai/talemate:latest
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-8080}:8080"
|
||||
- "${BACKEND_PORT:-5050}:5050"
|
||||
@@ -15,6 +11,7 @@ services:
|
||||
- ./scenes:/app/scenes
|
||||
- ./templates:/app/templates
|
||||
- ./chroma:/app/chroma
|
||||
- ./tts:/app/tts
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONPATH=/app/src:$PYTHONPATH
|
||||
|
||||
@@ -1,60 +1,63 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
def find_image_references(md_file):
|
||||
"""Find all image references in a markdown file."""
|
||||
with open(md_file, 'r', encoding='utf-8') as f:
|
||||
with open(md_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
pattern = r'!\[.*?\]\((.*?)\)'
|
||||
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
matches = re.findall(pattern, content)
|
||||
|
||||
|
||||
cleaned_paths = []
|
||||
for match in matches:
|
||||
path = match.lstrip('/')
|
||||
if 'img/' in path:
|
||||
path = path[path.index('img/') + 4:]
|
||||
path = match.lstrip("/")
|
||||
if "img/" in path:
|
||||
path = path[path.index("img/") + 4 :]
|
||||
# Only keep references to versioned images
|
||||
parts = os.path.normpath(path).split(os.sep)
|
||||
if len(parts) >= 2 and parts[0].replace('.', '').isdigit():
|
||||
if len(parts) >= 2 and parts[0].replace(".", "").isdigit():
|
||||
cleaned_paths.append(path)
|
||||
|
||||
|
||||
return cleaned_paths
|
||||
|
||||
|
||||
def scan_markdown_files(docs_dir):
|
||||
"""Recursively scan all markdown files in the docs directory."""
|
||||
md_files = []
|
||||
for root, _, files in os.walk(docs_dir):
|
||||
for file in files:
|
||||
if file.endswith('.md'):
|
||||
if file.endswith(".md"):
|
||||
md_files.append(os.path.join(root, file))
|
||||
return md_files
|
||||
|
||||
|
||||
def find_all_images(img_dir):
|
||||
"""Find all image files in version subdirectories."""
|
||||
image_files = []
|
||||
for root, _, files in os.walk(img_dir):
|
||||
# Get the relative path from img_dir to current directory
|
||||
rel_dir = os.path.relpath(root, img_dir)
|
||||
|
||||
|
||||
# Skip if we're in the root img directory
|
||||
if rel_dir == '.':
|
||||
if rel_dir == ".":
|
||||
continue
|
||||
|
||||
|
||||
# Check if the immediate parent directory is a version number
|
||||
parent_dir = rel_dir.split(os.sep)[0]
|
||||
if not parent_dir.replace('.', '').isdigit():
|
||||
if not parent_dir.replace(".", "").isdigit():
|
||||
continue
|
||||
|
||||
|
||||
for file in files:
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.svg')):
|
||||
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".svg")):
|
||||
rel_path = os.path.relpath(os.path.join(root, file), img_dir)
|
||||
image_files.append(rel_path)
|
||||
return image_files
|
||||
|
||||
|
||||
def grep_check_image(docs_dir, image_path):
|
||||
"""
|
||||
Check if versioned image is referenced anywhere using grep.
|
||||
@@ -65,33 +68,46 @@ def grep_check_image(docs_dir, image_path):
|
||||
parts = os.path.normpath(image_path).split(os.sep)
|
||||
version = parts[0] # e.g., "0.29.0"
|
||||
filename = parts[-1] # e.g., "world-state-suggestions-2.png"
|
||||
|
||||
|
||||
# For versioned images, require both version and filename to match
|
||||
version_pattern = f"{version}.*{filename}"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['grep', '-r', '-l', version_pattern, docs_dir],
|
||||
["grep", "-r", "-l", version_pattern, docs_dir],
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
print(f"Found reference to {image_path} with version pattern: {version_pattern}")
|
||||
print(
|
||||
f"Found reference to {image_path} with version pattern: {version_pattern}"
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during grep check for {image_path}: {e}")
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Find and optionally delete unused versioned images in MkDocs project')
|
||||
parser.add_argument('--docs-dir', type=str, required=True, help='Path to the docs directory')
|
||||
parser.add_argument('--img-dir', type=str, required=True, help='Path to the images directory')
|
||||
parser.add_argument('--delete', action='store_true', help='Delete unused images')
|
||||
parser.add_argument('--verbose', action='store_true', help='Show all found references and files')
|
||||
parser.add_argument('--skip-grep', action='store_true', help='Skip the additional grep validation')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Find and optionally delete unused versioned images in MkDocs project"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--docs-dir", type=str, required=True, help="Path to the docs directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img-dir", type=str, required=True, help="Path to the images directory"
|
||||
)
|
||||
parser.add_argument("--delete", action="store_true", help="Delete unused images")
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Show all found references and files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-grep", action="store_true", help="Skip the additional grep validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert paths to absolute paths
|
||||
@@ -118,7 +134,7 @@ def main():
|
||||
print("\nAll versioned image references found in markdown:")
|
||||
for img in sorted(used_images):
|
||||
print(f"- {img}")
|
||||
|
||||
|
||||
print("\nAll versioned images in directory:")
|
||||
for img in sorted(all_images):
|
||||
print(f"- {img}")
|
||||
@@ -133,9 +149,11 @@ def main():
|
||||
for img in unused_images:
|
||||
if not grep_check_image(docs_dir, img):
|
||||
actually_unused.add(img)
|
||||
|
||||
|
||||
if len(actually_unused) != len(unused_images):
|
||||
print(f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!")
|
||||
print(
|
||||
f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!"
|
||||
)
|
||||
unused_images = actually_unused
|
||||
|
||||
# Report findings
|
||||
@@ -148,7 +166,7 @@ def main():
|
||||
print("\nUnused versioned images:")
|
||||
for img in sorted(unused_images):
|
||||
print(f"- {img}")
|
||||
|
||||
|
||||
if args.delete:
|
||||
print("\nDeleting unused versioned images...")
|
||||
for img in unused_images:
|
||||
@@ -162,5 +180,6 @@ def main():
|
||||
else:
|
||||
print("\nNo unused versioned images found!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -4,12 +4,12 @@ from talemate.events import GameLoopEvent
|
||||
import talemate.emit.async_signals
|
||||
from talemate.emit import emit
|
||||
|
||||
|
||||
@register()
|
||||
class TestAgent(Agent):
|
||||
|
||||
agent_type = "test"
|
||||
verbose_name = "Test"
|
||||
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
@@ -20,7 +20,7 @@ class TestAgent(Agent):
|
||||
description="Test",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
@@ -36,7 +36,7 @@ class TestAgent(Agent):
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called on the beginning of every game loop
|
||||
@@ -45,4 +45,8 @@ class TestAgent(Agent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
emit("status", status="info", message="Annoying you with a test message every game loop.")
|
||||
emit(
|
||||
"status",
|
||||
status="info",
|
||||
message="Annoying you with a test message every game loop.",
|
||||
)
|
||||
|
||||
@@ -19,14 +19,17 @@ from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.runpod_vllm")
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
runpod_id: str = ""
|
||||
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
runpod_id: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class RunPodVLLMClient(ClientBase):
|
||||
client_type = "runpod_vllm"
|
||||
@@ -49,7 +52,6 @@ class RunPodVLLMClient(ClientBase):
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, model=None, runpod_id=None, **kwargs):
|
||||
self.model_name = model
|
||||
self.runpod_id = runpod_id
|
||||
@@ -59,12 +61,10 @@ class RunPodVLLMClient(ClientBase):
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
|
||||
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
|
||||
|
||||
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
@@ -88,32 +88,37 @@ class RunPodVLLMClient(ClientBase):
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
|
||||
|
||||
run_request = await endpoint.run({
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
|
||||
run_request = await endpoint.run(
|
||||
{
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
}
|
||||
# "parameters": parameters
|
||||
}
|
||||
#"parameters": parameters
|
||||
})
|
||||
|
||||
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
|
||||
)
|
||||
|
||||
while (await run_request.status()) not in [
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
]:
|
||||
status = await run_request.status()
|
||||
log.debug("generate", status=status)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
status = await run_request.status()
|
||||
|
||||
|
||||
log.debug("generate", status=status)
|
||||
|
||||
|
||||
response = await run_request.output()
|
||||
|
||||
|
||||
log.debug("generate", response=response)
|
||||
|
||||
|
||||
return response["choices"][0]["tokens"][0]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
|
||||
@@ -9,6 +9,7 @@ class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
|
||||
@register()
|
||||
class TestClient(ClientBase):
|
||||
client_type = "test"
|
||||
@@ -22,14 +23,13 @@ class TestClient(ClientBase):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
|
||||
"""
|
||||
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
|
||||
|
||||
|
||||
This method is called before the prompt is sent to the client, and it allows the client to remove
|
||||
any parameters that it doesn't support.
|
||||
"""
|
||||
|
||||
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
@@ -41,11 +41,10 @@ class TestClient(ClientBase):
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
"""
|
||||
This should return the name of the model that is being used.
|
||||
"""
|
||||
|
||||
|
||||
return "Mock test model"
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
|
||||
@@ -27,10 +27,10 @@ uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
|
||||
|
||||
### Letting the frontend know about the new host and port
|
||||
|
||||
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL`.
|
||||
Copy `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.production.local` and edit the `VITE_TALEMATE_BACKEND_WEBSOCKET_URL`.
|
||||
|
||||
```env
|
||||
VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
|
||||
VITE_TALEMATE_BACKEND_WEBSOCKET_URL=ws://localhost:1234
|
||||
```
|
||||
|
||||
Next rebuild the frontend.
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
!!! example "Experimental"
|
||||
Talemate through docker has not received a lot of testing from me, so please let me know if you encounter any issues.
|
||||
|
||||
You can do so by creating an issue on the [:material-github: GitHub repository](https://github.com/vegu-ai/talemate)
|
||||
|
||||
## Quick install instructions
|
||||
|
||||
1. `git clone https://github.com/vegu-ai/talemate.git`
|
||||
1. `cd talemate`
|
||||
1. copy config file
|
||||
1. linux: `cp config.example.yaml config.yaml`
|
||||
1. windows: `copy config.example.yaml config.yaml`
|
||||
1. If your host has a CUDA compatible Nvidia GPU
|
||||
1. Windows (via PowerShell): `$env:CUDA_AVAILABLE="true"; docker compose up`
|
||||
1. Linux: `CUDA_AVAILABLE=true docker compose up`
|
||||
1. If your host does **NOT** have a CUDA compatible Nvidia GPU
|
||||
1. Windows: `docker compose up`
|
||||
1. Linux: `docker compose up`
|
||||
1. windows: `copy config.example.yaml config.yaml` (or just copy the file and rename it via the file explorer)
|
||||
1. `docker compose up`
|
||||
1. Navigate your browser to http://localhost:8080
|
||||
|
||||
!!! info "Pre-built Images"
|
||||
The default setup uses pre-built images from GitHub Container Registry that include CUDA support by default. To manually build the container instead, use `docker compose -f docker-compose.manual.yml up --build`.
|
||||
|
||||
!!! note
|
||||
When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.
|
||||
|
||||
BIN
docs/img/0.32.0/add-chatterbox-voice.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/img/0.32.0/add-elevenlabs-voice.png
Normal file
|
After Width: | Height: | Size: 29 KiB |
BIN
docs/img/0.32.0/add-f5tts-voice.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.32.0/character-voice-assignment.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
BIN
docs/img/0.32.0/chatterbox-api-settings.png
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
docs/img/0.32.0/chatterbox-parameters.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
docs/img/0.32.0/client-reasoning-2.png
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
docs/img/0.32.0/client-reasoning.png
Normal file
|
After Width: | Height: | Size: 75 KiB |
BIN
docs/img/0.32.0/elevenlabs-api-settings.png
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
docs/img/0.32.0/elevenlabs-copy-voice-id.png
Normal file
|
After Width: | Height: | Size: 9.6 KiB |
BIN
docs/img/0.32.0/f5tts-api-settings.png
Normal file
|
After Width: | Height: | Size: 72 KiB |
BIN
docs/img/0.32.0/f5tts-parameters.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
docs/img/0.32.0/google-tts-api-settings.png
Normal file
|
After Width: | Height: | Size: 63 KiB |
BIN
docs/img/0.32.0/kokoro-mixer.png
Normal file
|
After Width: | Height: | Size: 33 KiB |
BIN
docs/img/0.32.0/openai-tts-api-settings.png
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
docs/img/0.32.0/voice-agent-settings.png
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
docs/img/0.32.0/voice-agent-status-characters.png
Normal file
|
After Width: | Height: | Size: 3.0 KiB |
BIN
docs/img/0.32.0/voice-library-access.png
Normal file
|
After Width: | Height: | Size: 9.3 KiB |
BIN
docs/img/0.32.0/voice-library-api-status.png
Normal file
|
After Width: | Height: | Size: 6.6 KiB |
BIN
docs/img/0.32.0/voice-library-interface.png
Normal file
|
After Width: | Height: | Size: 142 KiB |
58
docs/user-guide/agents/voice/chatterbox.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Chatterbox
|
||||
|
||||
Local zero shot voice cloning from .wav files.
|
||||
|
||||

|
||||
|
||||
##### Device
|
||||
|
||||
Auto-detects best available option
|
||||
|
||||
##### Model
|
||||
|
||||
Default Chatterbox model optimized for speed
|
||||
|
||||
##### Chunk size
|
||||
|
||||
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
|
||||
|
||||
## Adding Chatterbox Voices
|
||||
|
||||
### Voice Requirements
|
||||
|
||||
Chatterbox voices require:
|
||||
|
||||
- Reference audio file (.wav format, 5-15 seconds optimal)
|
||||
- Clear speech with minimal background noise
|
||||
- Single speaker throughout the sample
|
||||
|
||||
### Creating a Voice
|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click **:material-plus: New**
|
||||
3. Select "Chatterbox" as the provider
|
||||
4. Configure the voice:
|
||||
|
||||

|
||||
|
||||
**Label:** Descriptive name (e.g., "Marcus - Deep Male")
|
||||
|
||||
**Voice ID / Upload File** Upload a .wav file containing the voice sample. The uploaded reference audio will also be the voice ID.
|
||||
|
||||
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
|
||||
|
||||
**Tags:** Add descriptive tags for organization
|
||||
|
||||
**Extra voice parameters**
|
||||
|
||||
There exist some optional parameters that can be set here on a per voice level.
|
||||
|
||||

|
||||
|
||||
##### Exaggeration Level
|
||||
|
||||
Exaggeration (Neutral = 0.5, extreme values can be unstable). Higher exaggeration tends to speed up speech; reducing cfg helps compensate with slower, more deliberate pacing.
|
||||
|
||||
##### CFG / Pace
|
||||
|
||||
If the reference speaker has a fast speaking style, lowering cfg to around 0.3 can improve pacing.
|
||||
@@ -1,7 +1,41 @@
|
||||
# ElevenLabs
|
||||
|
||||
If you have not configured the ElevenLabs TTS API, the voice agent will show that the API key is missing.
|
||||
Professional voice synthesis with voice cloning capabilities using ElevenLabs API.
|
||||
|
||||

|
||||

|
||||
|
||||
See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on how to set up the API key.
|
||||
## API Setup
|
||||
|
||||
ElevenLabs requires an API key. See the [ElevenLabs API setup](/talemate/user-guide/apis/elevenlabs/) for instructions on obtaining and setting an API key.
|
||||
|
||||
## Configuration
|
||||
|
||||
**Model:** Select from available ElevenLabs models
|
||||
|
||||
!!! warning "Voice Limits"
|
||||
Your ElevenLabs subscription allows you to maintain a set number of voices (10 for the cheapest plan). Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button. It is recommended to test voices via their voice library instead.
|
||||
|
||||
## Adding ElevenLabs Voices
|
||||
|
||||
### Getting Voice IDs
|
||||
|
||||
1. Go to [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab) to view your voices
|
||||
2. Find or create the voice you want to use
|
||||
3. Click "More Actions" -> "Copy Voice ID" for the desired voice
|
||||
|
||||

|
||||
|
||||
### Creating a Voice in Talemate
|
||||
|
||||

|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click "Add Voice"
|
||||
3. Select "ElevenLabs" as the provider
|
||||
4. Configure the voice:
|
||||
|
||||
**Label:** Descriptive name for the voice
|
||||
|
||||
**Provider ID:** Paste the ElevenLabs voice ID you copied
|
||||
|
||||
**Tags:** Add descriptive tags for organization
|
||||
78
docs/user-guide/agents/voice/f5tts.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# F5-TTS
|
||||
|
||||
Local zero shot voice cloning from .wav files.
|
||||
|
||||

|
||||
|
||||
##### Device
|
||||
Auto-detects best available option (GPU preferred)
|
||||
|
||||
##### Model
|
||||
|
||||
- F5TTS_v1_Base (default, most recent model)
|
||||
- F5TTS_Base
|
||||
- E2TTS_Base
|
||||
|
||||
##### NFE Step
|
||||
|
||||
Number of steps to generate the voice. Higher values result in more detailed voices.
|
||||
|
||||
##### Chunk size
|
||||
|
||||
Split text into chunks of this size. Smaller values will increase responsiveness at the cost of lost context between chunks. (Stuff like appropriate inflection, etc.). 0 = no chunking
|
||||
|
||||
##### Replace exclamation marks
|
||||
|
||||
If checked, exclamation marks will be replaced with periods. This is recommended for `F5TTS_v1_Base` since it seems to over exaggerate exclamation marks.
|
||||
|
||||
## Adding F5-TTS Voices
|
||||
|
||||
### Voice Requirements
|
||||
|
||||
F5-TTS voices require:
|
||||
|
||||
- Reference audio file (.wav format, 10-30 seconds)
|
||||
- Clear speech with minimal background noise
|
||||
- Single speaker throughout the sample
|
||||
- Reference text (optional but recommended)
|
||||
|
||||
### Creating a Voice
|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click "Add Voice"
|
||||
3. Select "F5-TTS" as the provider
|
||||
4. Configure the voice:
|
||||
|
||||

|
||||
|
||||
**Label:** Descriptive name (e.g., "Emma - Calm Female")
|
||||
|
||||
**Voice ID / Upload File** Upload a .wav file containing the **reference audio** voice sample. The uploaded reference audio will also be the voice ID.
|
||||
|
||||
- Use 6-10 second samples (longer doesn't improve quality)
|
||||
- Ensure clear speech with minimal background noise
|
||||
- Record at natural speaking pace
|
||||
|
||||
**Reference Text:** Enter the exact text spoken in the reference audio for improved quality
|
||||
|
||||
- Enter exactly what is spoken in the reference audio
|
||||
- Include proper punctuation and capitalization
|
||||
- Improves voice cloning accuracy significantly
|
||||
|
||||
**Speed:** Adjust playback speed (0.5 to 2.0, default 1.0)
|
||||
|
||||
**Tags:** Add descriptive tags (gender, age, style) for organization
|
||||
|
||||
**Extra voice parameters**
|
||||
|
||||
There exist some optional parameters that can be set here on a per voice level.
|
||||
|
||||

|
||||
|
||||
##### Speed
|
||||
|
||||
Allows you to adjust the speed of the voice.
|
||||
|
||||
##### CFG Strength
|
||||
|
||||
A higher CFG strength generally leads to more faithful reproduction of the input text, while a lower CFG strength can result in more varied or creative speech output, potentially at the cost of text-to-speech accuracy.
|
||||
15
docs/user-guide/agents/voice/google.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Google Gemini-TTS
|
||||
|
||||
Google Gemini-TTS provides access to Google's text-to-speech service.
|
||||
|
||||
## API Setup
|
||||
|
||||
Google Gemini-TTS requires a Google Cloud API key.
|
||||
|
||||
See the [Google Cloud API setup](/talemate/user-guide/apis/google/) for instructions on obtaining an API key.
|
||||
|
||||
## Configuration
|
||||
|
||||

|
||||
|
||||
**Model:** Select from available Google TTS models
|
||||
@@ -1,6 +1,26 @@
|
||||
# Overview
|
||||
|
||||
Talemate supports Text-to-Speech (TTS) functionality, allowing users to convert text into spoken audio. This document outlines the steps required to configure TTS for Talemate using different providers, including ElevenLabs and a local TTS API.
|
||||
In 0.32.0 Talemate's TTS (Text-to-Speech) agent has been completely refactored to provide advanced voice capabilities including per-character voice assignment, speaker separation, and support for multiple local and remote APIs. The voice system now includes a comprehensive voice library for managing and organizing voices across all supported providers.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Per-character voice assignment** - Each character can have their own unique voice
|
||||
- **Speaker separation** - Automatic detection and separation of dialogue from narration
|
||||
- **Voice library management** - Centralized management of all voices across providers
|
||||
- **Multiple API support** - Support for both local and remote TTS providers
|
||||
- **Director integration** - Automatic voice assignment for new characters
|
||||
|
||||
## Supported APIs
|
||||
|
||||
### Local APIs
|
||||
- **Kokoro** - Fastest generation with predefined voice models and mixing
|
||||
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
|
||||
- **Chatterbox** - High-quality voice cloning (slower generation)
|
||||
|
||||
### Remote APIs
|
||||
- **ElevenLabs** - Professional voice synthesis with voice cloning
|
||||
- **Google Gemini-TTS** - Google's text-to-speech service
|
||||
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
|
||||
|
||||
## Enable the Voice agent
|
||||
|
||||
@@ -12,28 +32,30 @@ If your voice agent is disabled - indicated by the grey dot next to the agent -
|
||||
|
||||
 
|
||||
|
||||
!!! note "Ctrl click to toggle agent"
|
||||
You can use Ctrl click to toggle the agent on and off.
|
||||
|
||||
!!! abstract "Next: Connect to a TTS api"
|
||||
Next you need to decide which service / api to use for audio generation and configure the voice agent accordingly.
|
||||
## Voice Library Management
|
||||
|
||||
- [OpenAI](openai.md)
|
||||
- [ElevenLabs](elevenlabs.md)
|
||||
- [Local TTS](local_tts.md)
|
||||
Voices are managed through the Voice Library, accessible from the main application bar. The Voice Library allows you to:
|
||||
|
||||
You can also find more information about the various settings [here](settings.md).
|
||||
- Add and organize voices from all supported providers
|
||||
- Assign voices to specific characters
|
||||
- Create mixed voices (Kokoro)
|
||||
- Manage both global and scene-specific voice libraries
|
||||
|
||||
## Select a voice
|
||||
See the [Voice Library Guide](voice-library.md) for detailed instructions.
|
||||
|
||||

|
||||
## Character Voice Assignment
|
||||
|
||||
Click on the agent to open the agent settings.
|
||||

|
||||
|
||||
Then click on the `Narrator Voice` dropdown and select a voice.
|
||||
Characters can have individual voices assigned through the Voice Library. When a character has a voice assigned:
|
||||
|
||||

|
||||
1. Their dialogue will use their specific voice
|
||||
2. The narrator voice is used for exposition in their messages (with speaker separation enabled)
|
||||
3. If their assigned voice's API is not available, it falls back to the narrator voice
|
||||
|
||||
The selection is saved automatically, click anywhere outside the agent window to close it.
|
||||
The Voice agent status will show all assigned character voices and their current status.
|
||||
|
||||
The Voice agent should now show that the voice is selected and be ready to use.
|
||||
|
||||

|
||||

|
||||
55
docs/user-guide/agents/voice/kokoro.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Kokoro
|
||||
|
||||
Kokoro provides predefined voice models and voice mixing capabilities for creating custom voices.
|
||||
|
||||
## Using Predefined Voices
|
||||
|
||||
Kokoro comes with built-in voice models that are ready to use immediately
|
||||
|
||||
Available predefined voices include various male and female voices with different characteristics.
|
||||
|
||||
## Creating Mixed Voices
|
||||
|
||||
Kokor allows you to mix voices together to create a new voice.
|
||||
|
||||
### Voice Mixing Interface
|
||||
|
||||
|
||||
To create a mixed voice:
|
||||
|
||||
1. Open the Voice Library
|
||||
2. Click ":material-plus: New"
|
||||
3. Select "Kokoro" as the provider
|
||||
4. Choose ":material-tune:Mixer" option
|
||||
5. Configure the mixed voice:
|
||||
|
||||

|
||||
|
||||
|
||||
**Label:** Descriptive name for the mixed voice
|
||||
|
||||
**Base Voices:** Select 2-4 existing Kokoro voices to combine
|
||||
|
||||
**Weights:** Set the influence of each voice (0.1 to 1.0)
|
||||
|
||||
**Tags:** Descriptive tags for organization
|
||||
|
||||
### Weight Configuration
|
||||
|
||||
Each selected voice can have its weight adjusted:
|
||||
|
||||
- Higher weights make that voice more prominent in the mix
|
||||
- Lower weights make that voice more subtle
|
||||
- Total weights need to sum to 1.0
|
||||
- Experiment with different combinations to achieve desired results
|
||||
|
||||
### Saving Mixed Voices
|
||||
|
||||
Once configured click "Add Voice", mixed voices are saved to your voice library and can be:
|
||||
|
||||
- Assigned to characters
|
||||
- Used as narrator voices
|
||||
|
||||
just like any other voice.
|
||||
|
||||
Saving a mixed cvoice may take a moment to complete.
|
||||
@@ -1,53 +0,0 @@
|
||||
# Local TTS
|
||||
|
||||
!!! warning
|
||||
This has not been tested in a while and may not work as expected. It will likely be replaced with something different in the future. If this approach is currently broken its likely to remain so until it is replaced.
|
||||
|
||||
For running a local TTS API, Talemate requires specific dependencies to be installed.
|
||||
|
||||
### Windows Installation
|
||||
|
||||
Run `install-local-tts.bat` to install the necessary requirements.
|
||||
|
||||
### Linux Installation
|
||||
|
||||
Execute the following command:
|
||||
|
||||
```bash
|
||||
pip install TTS
|
||||
```
|
||||
|
||||
### Model and Device Configuration
|
||||
|
||||
1. Choose a TTS model from the [Coqui TTS model list](https://github.com/coqui-ai/TTS).
|
||||
2. Decide whether to use `cuda` or `cpu` for the device setting.
|
||||
3. The first time you run TTS through the local API, it will download the specified model. Please note that this may take some time, and the download progress will be visible in the Talemate backend output.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
tts:
|
||||
device: cuda # or 'cpu'
|
||||
model: tts_models/multilingual/multi-dataset/xtts_v2
|
||||
```
|
||||
|
||||
### Voice Samples Configuration
|
||||
|
||||
Configure voice samples by setting the `value` field to the path of a .wav file voice sample. Official samples can be downloaded from [Coqui XTTS-v2 samples](https://huggingface.co/coqui/XTTS-v2/tree/main/samples).
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
tts:
|
||||
voices:
|
||||
- label: English Male
|
||||
value: path/to/english_male.wav
|
||||
- label: English Female
|
||||
value: path/to/english_female.wav
|
||||
```
|
||||
|
||||
## Saving the Configuration
|
||||
|
||||
After configuring the `config.yaml` file, save your changes. Talemate will use the updated settings the next time it starts.
|
||||
|
||||
For more detailed information on configuring Talemate, refer to the `config.py` file in the Talemate source code and the `config.example.yaml` file for a barebone configuration example.
|
||||
@@ -8,16 +8,12 @@ See the [OpenAI API setup](/apis/openai.md) for instructions on how to set up th
|
||||
|
||||
## Settings
|
||||
|
||||

|
||||

|
||||
|
||||
##### Model
|
||||
|
||||
Which model to use for generation.
|
||||
|
||||
- GPT-4o Mini TTS
|
||||
- TTS-1
|
||||
- TTS-1 HD
|
||||
|
||||
!!! quote "OpenAI API documentation on quality"
|
||||
For real-time applications, the standard tts-1 model provides the lowest latency but at a lower quality than the tts-1-hd model. Due to the way the audio is generated, tts-1 is likely to generate content that has more static in certain situations than tts-1-hd. In some cases, the audio may not have noticeable differences depending on your listening device and the individual person.
|
||||
|
||||
Generally i have found that HD is fast enough for talemate, so this is the default.
|
||||
- TTS-1 HD
|
||||
@@ -1,36 +1,65 @@
|
||||
# Settings
|
||||
|
||||

|
||||

|
||||
|
||||
##### API
|
||||
##### Enabled APIs
|
||||
|
||||
The TTS API to use for voice generation.
|
||||
Select which TTS APIs to enable. You can enable multiple APIs simultaneously:
|
||||
|
||||
- OpenAI
|
||||
- ElevenLabs
|
||||
- Local TTS
|
||||
- **Kokoro** - Fastest generation with predefined voice models and mixing
|
||||
- **F5-TTS** - Fast voice cloning with occasional mispronunciations
|
||||
- **Chatterbox** - High-quality voice cloning (slower generation)
|
||||
- **ElevenLabs** - Professional voice synthesis with voice cloning
|
||||
- **Google Gemini-TTS** - Google's text-to-speech service
|
||||
- **OpenAI** - OpenAI's TTS-1 and TTS-1-HD models
|
||||
|
||||
!!! note "Multi-API Support"
|
||||
You can enable multiple APIs and assign different voices from different providers to different characters. The system will automatically route voice generation to the appropriate API based on the voice assignment.
|
||||
|
||||
##### Narrator Voice
|
||||
|
||||
The voice to use for narration. Each API will come with its own set of voices.
|
||||
The default voice used for narration and as a fallback for characters without assigned voices.
|
||||
|
||||

|
||||
The dropdown shows all available voices from all enabled APIs, with the format: "Voice Name (Provider)"
|
||||
|
||||
!!! note "Local TTS"
|
||||
For local TTS, you will have to provide voice samples yourself. See [Local TTS Instructions](local_tts.md) for more information.
|
||||
!!! info "Voice Management"
|
||||
Voices are managed through the Voice Library, accessible from the main application bar. Adding, removing, or modifying voices should be done through the Voice Library interface.
|
||||
|
||||
##### Generate for player
|
||||
##### Speaker Separation
|
||||
|
||||
Whether to generate voice for the player. If enabled, whenever the player speaks, the voice agent will generate audio for them.
|
||||
Controls how dialogue is separated from exposition in messages:
|
||||
|
||||
##### Generate for NPCs
|
||||
- **No separation** - Character messages use character voice entirely, narrator messages use narrator voice
|
||||
- **Simple** - Basic separation of dialogue from exposition using punctuation analysis, with exposition being read by the narrator voice
|
||||
- **Mixed** - Enables AI assisted separation for narrator messages and simple separation for character messages
|
||||
- **AI assisted** - AI assisted separation for both narrator and character messages
|
||||
|
||||
Whether to generate voice for NPCs. If enabled, whenever a non player character speaks, the voice agent will generate audio for them.
|
||||
!!! warning "AI Assisted Performance"
|
||||
AI-assisted speaker separation sends additional prompts to your LLM, which may impact response time and API costs.
|
||||
|
||||
##### Generate for narration
|
||||
##### Auto-generate for player
|
||||
|
||||
Whether to generate voice for narration. If enabled, whenever the narrator speaks, the voice agent will generate audio for them.
|
||||
Generate voice automatically for player messages
|
||||
|
||||
##### Split generation
|
||||
##### Auto-generate for AI characters
|
||||
|
||||
If enabled, the voice agent will generate audio in chunks, allowing for faster generation. This does however cause it lose context between chunks, and inflection may not be as good.
|
||||
Generate voice automatically for NPC/AI character messages
|
||||
|
||||
##### Auto-generate for narration
|
||||
|
||||
Generate voice automatically for narrator messages
|
||||
|
||||
##### Auto-generate for context investigation
|
||||
|
||||
Generate voice automatically for context investigation messages
|
||||
|
||||
## Advanced Settings
|
||||
|
||||
Advanced settings are configured per-API and can be found in the respective API configuration sections:
|
||||
|
||||
- **Chunk size** - Maximum text length per generation request
|
||||
- **Model selection** - Choose specific models for each API
|
||||
- **Voice parameters** - Provider-specific voice settings
|
||||
|
||||
!!! tip "Performance Optimization"
|
||||
Each API has different optimal chunk sizes and parameters. The system automatically handles chunking and queuing for optimal performance across all enabled APIs.
|
||||
156
docs/user-guide/agents/voice/voice-library.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Voice Library
|
||||
|
||||
The Voice Library is the central hub for managing all voices across all TTS providers in Talemate. It provides a unified interface for organizing, creating, and assigning voices to characters.
|
||||
|
||||
## Accessing the Voice Library
|
||||
|
||||
The Voice Library can be accessed from the main application bar at the top of the Talemate interface.
|
||||
|
||||

|
||||
|
||||
Click the voice icon to open the Voice Library dialog.
|
||||
|
||||
!!! note "Voice agent needs to be enabled"
|
||||
The Voice agent needs to be enabled for the voice library to be available.
|
||||
|
||||
## Voice Library Interface
|
||||
|
||||

|
||||
|
||||
The Voice Library interface consists of:
|
||||
|
||||
### Scope Tabs
|
||||
|
||||
- **Global** - Voices available across all scenes
|
||||
- **Scene** - Voices specific to the current scene (only visible when a scene is loaded)
|
||||
- **Characters** - Character voice assignments for the current scene (only visible when a scene is loaded)
|
||||
|
||||
### API Status
|
||||
|
||||
The toolbar shows the status of all TTS APIs:
|
||||
|
||||
- **Green** - API is enabled and ready
|
||||
- **Orange** - API is enabled but not configured
|
||||
- **Red** - API has configuration issues
|
||||
- **Gray** - API is disabled
|
||||
|
||||

|
||||
|
||||
## Managing Voices
|
||||
|
||||
### Global Voice Library
|
||||
|
||||
The global voice library contains voices that are available across all scenes. These include:
|
||||
|
||||
- Default voices provided by each TTS provider
|
||||
- Custom voices you've added
|
||||
|
||||
#### Adding New Voices
|
||||
|
||||
To add a new voice:
|
||||
|
||||
1. Click the "+ New" button
|
||||
2. Select the TTS provider
|
||||
3. Configure the voice parameters:
|
||||
- **Label** - Display name for the voice
|
||||
- **Provider ID** - Provider-specific identifier
|
||||
- **Tags** - Free-form descriptive tags you define (gender, age, style, etc.)
|
||||
- **Parameters** - Provider-specific settings
|
||||
|
||||
Check the provider specific documentation for more information on how to configure the voice.
|
||||
|
||||
#### Voice Types by Provider
|
||||
|
||||
**F5-TTS & Chatterbox:**
|
||||
|
||||
- Upload .wav reference files for voice cloning
|
||||
- Specify reference text for better quality
|
||||
- Adjust speed and other parameters
|
||||
|
||||
**Kokoro:**
|
||||
|
||||
- Select from predefined voice models
|
||||
- Create mixed voices by combining multiple models
|
||||
- Adjust voice mixing weights
|
||||
|
||||
**ElevenLabs:**
|
||||
|
||||
- Select from available ElevenLabs voices
|
||||
- Configure voice settings and stability
|
||||
- Use custom cloned voices from your ElevenLabs account
|
||||
|
||||
**OpenAI:**
|
||||
|
||||
- Choose from available OpenAI voice models
|
||||
- Configure model (GPT-4o Mini TTS, TTS-1, TTS-1-HD)
|
||||
|
||||
**Google Gemini-TTS:**
|
||||
|
||||
- Select from Google's voice models
|
||||
- Configure language and gender settings
|
||||
|
||||
### Scene Voice Library
|
||||
|
||||
Scene-specific voices are only available within the current scene. This is useful for:
|
||||
|
||||
- Scene-specific characters
|
||||
- Temporary voice experiments
|
||||
- Custom voices for specific scenarios
|
||||
|
||||
Scene voices are saved with the scene and will be available when the scene is loaded.
|
||||
|
||||
## Character Voice Assignment
|
||||
|
||||
### Automatic Assignment
|
||||
|
||||
The Director agent can automatically assign voices to new characters based on:
|
||||
|
||||
- Character tags and attributes
|
||||
- Voice tags matching character personality
|
||||
- Available voices in the voice library
|
||||
|
||||
This feature can be enabled in the Director agent settings.
|
||||
|
||||
### Manual Assignment
|
||||
|
||||

|
||||
|
||||
To manually assign a voice to a character:
|
||||
|
||||
1. Go to the "Characters" tab in the Voice Library
|
||||
2. Find the character in the list
|
||||
3. Click the voice dropdown for that character
|
||||
4. Select a voice from the available options
|
||||
5. The assignment is saved automatically
|
||||
|
||||
### Character Voice Status
|
||||
|
||||
The character list shows:
|
||||
|
||||
- **Character name**
|
||||
- **Currently assigned voice** (if any)
|
||||
- **Voice status** - whether the voice's API is available
|
||||
- **Quick assignment controls**
|
||||
|
||||
## Voice Tags and Organization
|
||||
|
||||
### Tagging System
|
||||
|
||||
Voices can be tagged with any descriptive attributes you choose. Tags are completely free-form and user-defined. Common examples include:
|
||||
|
||||
- **Gender**: male, female, neutral
|
||||
- **Age**: young, mature, elderly
|
||||
- **Style**: calm, energetic, dramatic, mysterious
|
||||
- **Quality**: deep, high, raspy, smooth
|
||||
- **Character types**: narrator, villain, hero, comic relief
|
||||
- **Custom tags**: You can create any tags that help you organize your voices
|
||||
|
||||
### Filtering and Search
|
||||
|
||||
Use the search bar to filter voices by:
|
||||
- Voice label/name
|
||||
- Provider
|
||||
- Tags
|
||||
- Character assignments
|
||||
|
||||
This makes it easy to find the right voice for specific characters or situations.
|
||||
82
docs/user-guide/clients/reasoning.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# Reasoning Model Support
|
||||
|
||||
Talemate supports reasoning models that can perform step-by-step thinking before generating their final response. This feature allows models to work through complex problems internally before providing an answer.
|
||||
|
||||
## Enabling Reasoning Support
|
||||
|
||||
To enable reasoning support for a client:
|
||||
|
||||
1. Open the **Clients** dialog from the main toolbar
|
||||
2. Select the client you want to configure
|
||||
3. Navigate to the **Reasoning** tab in the client configuration
|
||||
|
||||

|
||||
|
||||
4. Check the **Enable Reasoning** checkbox
|
||||
|
||||
## Configuring Reasoning Tokens
|
||||
|
||||
Once reasoning is enabled, you can configure the **Reasoning Tokens** setting using the slider:
|
||||
|
||||

|
||||
|
||||
### Recommended Token Amounts
|
||||
|
||||
**For local reasoning models:** Use a high token allocation (recommended: 4096 tokens) to give the model sufficient space for complex reasoning.
|
||||
|
||||
**For remote APIs:** Start with lower amounts (512-1024 tokens) and adjust based on your needs and token costs.
|
||||
|
||||
### Token Allocation Behavior
|
||||
|
||||
The behavior of the reasoning tokens setting depends on your API provider:
|
||||
|
||||
**For APIs that support direct reasoning token specification:**
|
||||
|
||||
- The specified tokens will be allocated specifically for reasoning
|
||||
- The model will use these tokens for internal thinking before generating the response
|
||||
|
||||
**For APIs that do NOT support reasoning token specification:**
|
||||
|
||||
- The tokens are added as extra allowance to the response token limit for ALL requests
|
||||
- This may lead to more verbose responses than usual since Talemate normally uses response token limits to control verbosity
|
||||
|
||||
!!! warning "Increased Verbosity"
|
||||
For providers without direct reasoning token support, enabling reasoning may result in more verbose responses since the extra tokens are added to all requests.
|
||||
|
||||
## Response Pattern Configuration
|
||||
|
||||
When reasoning is enabled, you may need to configure a **Pattern to strip from the response** to remove the thinking process from the final output.
|
||||
|
||||
### Default Patterns
|
||||
|
||||
Talemate provides quick-access buttons for common reasoning patterns:
|
||||
|
||||
- **Default** - Uses the built-in pattern: `.*?</think>`
|
||||
- **`.*?◁/think▷`** - For models using arrow-style thinking delimiters
|
||||
- **`.*?</think>`** - For models using XML-style think tags
|
||||
|
||||
### Custom Patterns
|
||||
|
||||
You can also specify a custom regular expression pattern that matches your model's reasoning format. This pattern will be used to strip the thinking tokens from the response before displaying it to the user.
|
||||
|
||||
## Model Compatibility
|
||||
|
||||
Not all models support reasoning. This feature works best with:
|
||||
|
||||
- Models specifically trained for chain-of-thought reasoning
|
||||
- Models that support structured thinking patterns
|
||||
- APIs that provide reasoning token specification
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Coercion Disabled**: When reasoning is enabled, LLM coercion (pre-filling responses) is automatically disabled since reasoning models need to generate their complete thought process
|
||||
- **Response Time**: Reasoning models may take longer to respond as they work through their thinking process
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Pattern Not Working
|
||||
If the reasoning pattern isn't properly stripping the thinking process:
|
||||
|
||||
1. Check your model's actual reasoning output format
|
||||
2. Adjust the regular expression pattern to match your model's specific format
|
||||
3. Test with the default pattern first to see if it works
|
||||
@@ -35,4 +35,19 @@ A unique name for the client that makes sense to you.
|
||||
Which model to use. Currently defaults to `gpt-4o`.
|
||||
|
||||
!!! note "Talemate lags behind OpenAI"
|
||||
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
|
||||
When OpenAI adds a new model, it currently requires a Talemate update to add it to the list of available models. We are working on making this more dynamic.
|
||||
|
||||
##### Reasoning models (o1, o3, gpt-5)
|
||||
|
||||
!!! important "Enable reasoning and allocate tokens"
|
||||
The `o1`, `o3`, and `gpt-5` families are reasoning models. They always perform internal thinking before producing the final answer. To use them effectively in Talemate:
|
||||
|
||||
- Enable the **Reasoning** option in the client configuration.
|
||||
- Set **Reasoning Tokens** to a sufficiently high value to make room for the model's thinking process.
|
||||
|
||||
A good starting range is 512–1024 tokens. Increase if your tasks are complex. Without enabling reasoning and allocating tokens, these models may return minimal or empty visible content because the token budget is consumed by internal reasoning.
|
||||
|
||||
See the detailed guide: [Reasoning Model Support](/talemate/user-guide/clients/reasoning/).
|
||||
|
||||
!!! tip "Getting empty responses?"
|
||||
If these models return empty or very short answers, it usually means the reasoning budget was exhausted. Increase **Reasoning Tokens** and try again.
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
@@ -14,6 +14,7 @@ app = FastAPI()
|
||||
# Serve static files, but exclude the root path
|
||||
app.mount("/", StaticFiles(directory=dist_dir, html=True), name="static")
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def serve_root():
|
||||
index_path = os.path.join(dist_dir, "index.html")
|
||||
@@ -24,5 +25,6 @@ async def serve_root():
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="index.html not found")
|
||||
|
||||
|
||||
# This is the ASGI application
|
||||
application = app
|
||||
application = app
|
||||
|
||||
@@ -4,4 +4,4 @@
|
||||
uv pip uninstall torch torchaudio
|
||||
|
||||
# install torch and torchaudio with CUDA support
|
||||
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
uv pip install torch~=2.7.1 torchaudio~=2.7.1 --index-url https://download.pytorch.org/whl/cu128
|
||||
@@ -1,11 +1,12 @@
|
||||
[project]
|
||||
name = "talemate"
|
||||
version = "0.31.0"
|
||||
version = "0.32.1"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = [{name = "VeguAITools"}]
|
||||
license = {text = "GNU Affero General Public License v3.0"}
|
||||
requires-python = ">=3.10,<3.14"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
dependencies = [
|
||||
"pip",
|
||||
"astroid>=2.8",
|
||||
"jedi>=0.18",
|
||||
"black",
|
||||
@@ -50,11 +51,20 @@ dependencies = [
|
||||
# ChromaDB
|
||||
"chromadb>=1.0.12",
|
||||
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
|
||||
"torch>=2.7.0",
|
||||
"torchaudio>=2.7.0",
|
||||
# locked for instructor embeddings
|
||||
#sentence-transformers==2.2.2
|
||||
"torch>=2.7.1",
|
||||
"torchaudio>=2.7.1",
|
||||
"sentence_transformers>=2.7.0",
|
||||
# TTS
|
||||
"elevenlabs>=2.7.1",
|
||||
# Local TTS
|
||||
# Chatterbox TTS
|
||||
#"chatterbox-tts @ https://github.com/rsxdalv/chatterbox/archive/refs/heads/fast.zip",
|
||||
"chatterbox-tts==0.1.2",
|
||||
# kokoro TTS
|
||||
"kokoro>=0.9.4",
|
||||
"soundfile>=0.13.1",
|
||||
# F5-TTS
|
||||
"f5-tts>=1.1.7",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -65,6 +75,7 @@ dev = [
|
||||
"mkdocs-material>=9.5.27",
|
||||
"mkdocs-awesome-pages-plugin>=2.9.2",
|
||||
"mkdocs-glightbox>=0.4.0",
|
||||
"pre-commit>=2.13",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -103,4 +114,27 @@ include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
line_length = 88
|
||||
|
||||
[tool.uv]
|
||||
override-dependencies = [
|
||||
# chatterbox wants torch 2.6.0, but is confirmed working with 2.7.1
|
||||
"torchaudio>=2.7.1",
|
||||
"torch>=2.7.1",
|
||||
# numba needs numpy < 2.3
|
||||
"numpy>=2,<2.3",
|
||||
"pydantic>=2.11",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
5
ruff.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[lint]
|
||||
# Disable automatic fix for unused imports (`F401`). We check these manually.
|
||||
unfixable = ["F401"]
|
||||
# Ignore E402
|
||||
extend-ignore = ["E402"]
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"description": "Captain Elmer Farstield and his trusty first officer, Kaira, embark upon a daring mission into uncharted space. Their small but mighty exploration vessel, the Starlight Nomad, is equipped with state-of-the-art technology and crewed by an elite team of scientists, engineers, and pilots. Together they brave the vast cosmos seeking answers to humanity's most pressing questions about life beyond our solar system.",
|
||||
"intro": "You awaken aboard your ship, the Starlight Nomad, surrounded by darkness. A soft hum resonates throughout the vessel indicating its systems are online. Your mind struggles to recall what brought you here - where 'here' actually is. You remember nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there was a bright light followed by this endless void.\n\nGingerly, you make your way through the dimly lit corridors of the ship. It seems smaller than you expected given the magnitude of the mission ahead. However, each room reveals intricate technology designed specifically for long-term space travel and exploration. There appears to be no other living soul besides yourself. An eerie silence fills every corner.",
|
||||
"intro": "Elmer awoke aboard his ship, the Starlight Nomad, surrounded by darkness. A soft hum resonated throughout the vessel indicating its systems were online. His mind struggled to recall what had brought him here - where here actually was. He remembered nothing more than flashes of images; swirling nebulae, foreign constellations, alien life forms... Then there had been a bright light followed by this endless void.\n\nGingerly, he made his way through the dimly lit corridors of the ship. It seemed smaller than he had expected given the magnitude of the mission ahead. However, each room revealed intricate technology designed specifically for long-term space travel and exploration. There appeared to be no other living soul besides himself. An eerie silence filled every corner.",
|
||||
"name": "Infinity Quest",
|
||||
"history": [],
|
||||
"environment": "scene",
|
||||
@@ -90,11 +90,11 @@
|
||||
"gender": "female",
|
||||
"color": "red",
|
||||
"example_dialogue": [
|
||||
"Kaira: \"Yes Captain, I believe that is the best course of action\" She nods slightly, as if to punctuate her approval of the decision*",
|
||||
"Kaira: \"This device appears to have multiple functions, Captain. Allow me to analyze its capabilities and determine if it could be useful in our exploration efforts.\"",
|
||||
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" Excitement bubbles beneath her calm exterior as she shares the news",
|
||||
"Kaira: \"Captain, I understand why you would want us to pursue this course of action based on our current data, but I cannot shake the feeling that there might be unforeseen consequences if we proceed without further investigation into potential hazards.\"",
|
||||
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world... But then again, perhaps it was fate that led me here, onto this ship bound for destinations unknown...\""
|
||||
"Kaira: \"Yes Captain, I believe that is the best course of action.\" Kaira glanced at the navigation display, then back at him with a slight nod. \"The numbers check out on my end too. If we adjust our trajectory by 2.7 degrees and increase thrust by fifteen percent, we should reach the nebula's outer edge within six hours.\" Her violet fingers moved efficiently across the controls as she pulled up the gravitational readings.",
|
||||
"Kaira: The scanner hummed as it analyzed the alien artifact. Kaira knelt beside the strange object, frowning in concentration at the shifting symbols on its warm metal surface. \"This device appears to have multiple functions, Captain,\" she said, adjusting her scanner's settings. \"Give me a few minutes to run a full analysis and I'll know what we're dealing with. The material composition is fascinating - it's responding to our ship's systems in ways that shouldn't be possible.\"",
|
||||
"Kaira: \"Captain, it appears that this newly discovered planet harbors an ancient civilization whose technological advancements rival those found back home on Altrusia!\" The excitement in her voice was unmistakable as Kaira looked up from her console. \"These readings are incredible - I've never seen anything like this before. There are structures beneath the surface that predate our oldest sites by millions of years.\" She paused, processing the implications. \"If these readings are accurate, we may have found something truly significant.\"",
|
||||
"Kaira: Something felt off about the proposed course of action. Kaira moved from her station to stand beside the Captain, organizing her thoughts carefully. \"Captain, I understand why you would want us to pursue this based on our current data,\" she began respectfully, clasping her hands behind her back. \"But something feels wrong about this. The quantum signatures have subtle variations that remind me of Hegemony cloaking technology. Maybe we should run a few more scans before we commit to a full approach.\"",
|
||||
"Kaira: \"I often find myself wondering what it would have been like if I had never left my home world,\" she said softly, not turning from the observation deck viewport as footsteps approached. The stars wheeled slowly past in their eternal dance. \"Sometimes I dream of Altrusia's crystal gardens, the way our twin suns would set over the mountains.\" Kaira finally turned, her expression thoughtful. \"Then again, I suppose I was always meant to end up out here somehow. Perhaps this journey is exactly where I'm supposed to be.\""
|
||||
],
|
||||
"history_events": [],
|
||||
"is_player": false,
|
||||
|
||||
@@ -494,34 +494,6 @@
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942": {
|
||||
"title": "TRUE",
|
||||
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
|
||||
"properties": {
|
||||
"value": true
|
||||
},
|
||||
"x": 460,
|
||||
"y": 670,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/MakeBool",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
|
||||
"title": "Create Character",
|
||||
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
|
||||
"properties": {},
|
||||
"x": 580,
|
||||
"y": 550,
|
||||
"width": 245,
|
||||
"height": 186,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "agents/creator/CreateCharacter",
|
||||
"base_type": "core/Graph"
|
||||
},
|
||||
"970f12d0-330e-41b3-b025-9a53bcf2fc6f": {
|
||||
"title": "SET - created_character",
|
||||
"id": "970f12d0-330e-41b3-b025-9a53bcf2fc6f",
|
||||
@@ -628,6 +600,34 @@
|
||||
"inherited": false,
|
||||
"registry": "data/string/MakeText",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c": {
|
||||
"title": "Create Character",
|
||||
"id": "72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c",
|
||||
"properties": {},
|
||||
"x": 580,
|
||||
"y": 550,
|
||||
"width": 245,
|
||||
"height": 206,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "agents/creator/CreateCharacter",
|
||||
"base_type": "core/Graph"
|
||||
},
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942": {
|
||||
"title": "TRUE",
|
||||
"id": "8a050403-5c69-46f7-abe2-f65db4553942",
|
||||
"properties": {
|
||||
"value": true
|
||||
},
|
||||
"x": 400,
|
||||
"y": 720,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/MakeBool",
|
||||
"base_type": "core/Node"
|
||||
}
|
||||
},
|
||||
"edges": {
|
||||
@@ -714,17 +714,6 @@
|
||||
"fd4cd318-121b-47de-84a0-b1ab62c5601b.value": [
|
||||
"41c0c2cd-d39b-4528-a5fe-459094393ba3.list"
|
||||
],
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
|
||||
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
|
||||
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
|
||||
],
|
||||
"a2457116-35cb-4571-ba1a-cbf63851544e.value": [
|
||||
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.state",
|
||||
"a9f17ddc-fc8f-4257-8e32-45ac111fd50d.character",
|
||||
@@ -738,6 +727,18 @@
|
||||
],
|
||||
"5041f12d-507f-4f5d-a26f-048625974602.value": [
|
||||
"b50e23d4-b456-4385-b80e-c2b6884c7855.template"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.state": [
|
||||
"90610e90-3d00-4b4b-96de-1f6aa3f4f795.state"
|
||||
],
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.character": [
|
||||
"970f12d0-330e-41b3-b025-9a53bcf2fc6f.value"
|
||||
],
|
||||
"8a050403-5c69-46f7-abe2-f65db4553942.value": [
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate_attributes",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.is_active",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.generate",
|
||||
"72943c1c-d2a1-40b4-bb28-8bcc5f02aa5c.assign_voice"
|
||||
]
|
||||
},
|
||||
"groups": [
|
||||
@@ -808,13 +809,14 @@
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "f78fa84b-8b6f-4c8a-83c0-754537bb9060",
|
||||
"id": "92423e0a-89d8-4ad8-a42d-fdcba75b31d1",
|
||||
"name": "fn",
|
||||
"optional": false,
|
||||
"group": null,
|
||||
"socket_type": "function"
|
||||
}
|
||||
],
|
||||
"module_properties": {},
|
||||
"style": {
|
||||
"title_color": "#573a2e",
|
||||
"node_color": "#392f2c",
|
||||
|
||||
@@ -116,8 +116,8 @@
|
||||
"context_aware": true,
|
||||
"history_aware": true
|
||||
},
|
||||
"x": 568,
|
||||
"y": 861,
|
||||
"x": 372,
|
||||
"y": 857,
|
||||
"width": 249,
|
||||
"height": 406,
|
||||
"collapsed": false,
|
||||
@@ -130,7 +130,7 @@
|
||||
"id": "a8110d74-0fb5-4601-b883-6c63ceaa9d31",
|
||||
"properties": {},
|
||||
"x": 24,
|
||||
"y": 2084,
|
||||
"y": 2088,
|
||||
"width": 140,
|
||||
"height": 106,
|
||||
"collapsed": false,
|
||||
@@ -145,7 +145,7 @@
|
||||
"attribute": "title"
|
||||
},
|
||||
"x": 213,
|
||||
"y": 2094,
|
||||
"y": 2098,
|
||||
"width": 210,
|
||||
"height": 98,
|
||||
"collapsed": false,
|
||||
@@ -160,7 +160,7 @@
|
||||
"value": "The Simulation Suite"
|
||||
},
|
||||
"x": 213,
|
||||
"y": 2284,
|
||||
"y": 2288,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": false,
|
||||
@@ -177,7 +177,7 @@
|
||||
"case_sensitive": true
|
||||
},
|
||||
"x": 523,
|
||||
"y": 2184,
|
||||
"y": 2188,
|
||||
"width": 210,
|
||||
"height": 126,
|
||||
"collapsed": false,
|
||||
@@ -192,7 +192,7 @@
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 774,
|
||||
"y": 2185,
|
||||
"y": 2189,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -207,7 +207,7 @@
|
||||
"attribute": "0"
|
||||
},
|
||||
"x": 1629,
|
||||
"y": 2411,
|
||||
"y": 2415,
|
||||
"width": 210,
|
||||
"height": 98,
|
||||
"collapsed": false,
|
||||
@@ -222,7 +222,7 @@
|
||||
"attribute": "title"
|
||||
},
|
||||
"x": 2189,
|
||||
"y": 2321,
|
||||
"y": 2325,
|
||||
"width": 210,
|
||||
"height": 98,
|
||||
"collapsed": false,
|
||||
@@ -237,7 +237,7 @@
|
||||
"stage": 5
|
||||
},
|
||||
"x": 2459,
|
||||
"y": 2331,
|
||||
"y": 2335,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -250,7 +250,7 @@
|
||||
"id": "8208d05c-1822-4f4a-ba75-cfd18d2de8ca",
|
||||
"properties": {},
|
||||
"x": 1989,
|
||||
"y": 2331,
|
||||
"y": 2335,
|
||||
"width": 140,
|
||||
"height": 106,
|
||||
"collapsed": true,
|
||||
@@ -266,7 +266,7 @@
|
||||
"chars": "\"'*"
|
||||
},
|
||||
"x": 1899,
|
||||
"y": 2411,
|
||||
"y": 2415,
|
||||
"width": 210,
|
||||
"height": 102,
|
||||
"collapsed": false,
|
||||
@@ -282,7 +282,7 @@
|
||||
"max_splits": 1
|
||||
},
|
||||
"x": 1359,
|
||||
"y": 2411,
|
||||
"y": 2415,
|
||||
"width": 210,
|
||||
"height": 102,
|
||||
"collapsed": false,
|
||||
@@ -304,7 +304,7 @@
|
||||
"history_aware": true
|
||||
},
|
||||
"x": 1014,
|
||||
"y": 2185,
|
||||
"y": 2189,
|
||||
"width": 276,
|
||||
"height": 406,
|
||||
"collapsed": false,
|
||||
@@ -376,8 +376,8 @@
|
||||
"name": "arg_goal",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 228,
|
||||
"y": 1041,
|
||||
"x": 32,
|
||||
"y": 1037,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -393,7 +393,7 @@
|
||||
"max_scene_types": 2
|
||||
},
|
||||
"x": 671,
|
||||
"y": 1490,
|
||||
"y": 1494,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -409,7 +409,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 30,
|
||||
"y": 1551,
|
||||
"y": 1555,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -448,37 +448,6 @@
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"59d31050-f61d-4798-9790-e22d34ecbd4b": {
|
||||
"title": "GET local.auto_direct_enabled",
|
||||
"id": "59d31050-f61d-4798-9790-e22d34ecbd4b",
|
||||
"properties": {
|
||||
"name": "auto_direct_enabled",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 20,
|
||||
"y": 850,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8": {
|
||||
"title": "Switch",
|
||||
"id": "e8f19a05-43fe-4e4a-9cc4-bec0a29779d8",
|
||||
"properties": {
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 310,
|
||||
"y": 870,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Switch",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"b03fa942-c48e-4c04-b9ae-a009a7e0f947": {
|
||||
"title": "GET local.auto_direct_enabled",
|
||||
"id": "b03fa942-c48e-4c04-b9ae-a009a7e0f947",
|
||||
@@ -487,7 +456,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 24,
|
||||
"y": 1367,
|
||||
"y": 1371,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -502,7 +471,7 @@
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 360,
|
||||
"y": 1390,
|
||||
"y": 1394,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -518,7 +487,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 30,
|
||||
"y": 1821,
|
||||
"y": 1826,
|
||||
"width": 256,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -533,7 +502,7 @@
|
||||
"stage": 4
|
||||
},
|
||||
"x": 1100,
|
||||
"y": 1870,
|
||||
"y": 1875,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -546,7 +515,7 @@
|
||||
"id": "9db37d1e-3cf8-49bd-bdc5-8663494e5657",
|
||||
"properties": {},
|
||||
"x": 670,
|
||||
"y": 1840,
|
||||
"y": 1845,
|
||||
"width": 226,
|
||||
"height": 62,
|
||||
"collapsed": false,
|
||||
@@ -561,7 +530,7 @@
|
||||
"stage": 3
|
||||
},
|
||||
"x": 1080,
|
||||
"y": 1520,
|
||||
"y": 1524,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -576,7 +545,7 @@
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 370,
|
||||
"y": 1840,
|
||||
"y": 1845,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -590,8 +559,8 @@
|
||||
"properties": {
|
||||
"stage": 2
|
||||
},
|
||||
"x": 1280,
|
||||
"y": 890,
|
||||
"x": 1084,
|
||||
"y": 886,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -599,14 +568,29 @@
|
||||
"registry": "core/Stage",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"5559196c-f6b1-4223-8e13-2bf64e3cfef0": {
|
||||
"title": "true",
|
||||
"id": "5559196c-f6b1-4223-8e13-2bf64e3cfef0",
|
||||
"properties": {
|
||||
"value": true
|
||||
},
|
||||
"x": 24,
|
||||
"y": 876,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/MakeBool",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"6ef94917-f9b1-4c18-af15-617430e50cfe": {
|
||||
"title": "Set Scene Intent",
|
||||
"id": "6ef94917-f9b1-4c18-af15-617430e50cfe",
|
||||
"properties": {
|
||||
"intent": ""
|
||||
},
|
||||
"x": 930,
|
||||
"y": 860,
|
||||
"x": 731,
|
||||
"y": 853,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
@@ -693,14 +677,8 @@
|
||||
"e4cd1391-daed-4951-a6c6-438d993c07a9.state"
|
||||
],
|
||||
"c66bdaeb-4166-4835-9415-943af547c926.value": [
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description",
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state"
|
||||
],
|
||||
"59d31050-f61d-4798-9790-e22d34ecbd4b.value": [
|
||||
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.value"
|
||||
],
|
||||
"e8f19a05-43fe-4e4a-9cc4-bec0a29779d8.yes": [
|
||||
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.state",
|
||||
"24ac670b-4648-4915-9dbb-b6bf35ee6d80.description"
|
||||
],
|
||||
"b03fa942-c48e-4c04-b9ae-a009a7e0f947.value": [
|
||||
"8ad7c42c-110e-46ae-b649-4a1e6d055e25.value"
|
||||
@@ -717,6 +695,9 @@
|
||||
"6a8762c4-16cf-4e8c-9d10-8af7597c4097.yes": [
|
||||
"9db37d1e-3cf8-49bd-bdc5-8663494e5657.state"
|
||||
],
|
||||
"5559196c-f6b1-4223-8e13-2bf64e3cfef0.value": [
|
||||
"bb43a68e-bdf6-4b02-9cc0-102742b14f5d.state"
|
||||
],
|
||||
"6ef94917-f9b1-4c18-af15-617430e50cfe.state": [
|
||||
"f4cd34d9-0628-4145-a3da-ec1215cd356c.state"
|
||||
]
|
||||
@@ -745,7 +726,7 @@
|
||||
{
|
||||
"title": "Generate Scene Types",
|
||||
"x": -1,
|
||||
"y": 1287,
|
||||
"y": 1290,
|
||||
"width": 1298,
|
||||
"height": 408,
|
||||
"color": "#8AA",
|
||||
@@ -756,8 +737,8 @@
|
||||
"title": "Set story intention",
|
||||
"x": -1,
|
||||
"y": 773,
|
||||
"width": 1539,
|
||||
"height": 512,
|
||||
"width": 1320,
|
||||
"height": 514,
|
||||
"color": "#8AA",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
@@ -765,7 +746,7 @@
|
||||
{
|
||||
"title": "Evaluate Scene Intent",
|
||||
"x": -1,
|
||||
"y": 1697,
|
||||
"y": 1701,
|
||||
"width": 1293,
|
||||
"height": 302,
|
||||
"color": "#8AA",
|
||||
@@ -775,7 +756,7 @@
|
||||
{
|
||||
"title": "Set title",
|
||||
"x": -1,
|
||||
"y": 2003,
|
||||
"y": 2006,
|
||||
"width": 2618,
|
||||
"height": 637,
|
||||
"color": "#8AA",
|
||||
@@ -794,7 +775,7 @@
|
||||
{
|
||||
"text": "Some times the AI will produce more text after the title, we only care about the title on the first line.",
|
||||
"x": 1359,
|
||||
"y": 2311,
|
||||
"y": 2315,
|
||||
"width": 471,
|
||||
"inherited": false
|
||||
}
|
||||
@@ -804,13 +785,14 @@
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "dede1a38-2107-4475-9db5-358c09cb0d12",
|
||||
"id": "5c8dee64-5832-40ba-b1e2-2a411d913cc7",
|
||||
"name": "fn",
|
||||
"optional": false,
|
||||
"group": null,
|
||||
"socket_type": "function"
|
||||
}
|
||||
],
|
||||
"module_properties": {},
|
||||
"style": {
|
||||
"title_color": "#573a2e",
|
||||
"node_color": "#392f2c",
|
||||
|
||||
@@ -666,23 +666,6 @@
|
||||
"registry": "data/ListAppend",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66": {
|
||||
"title": "AI Function Calling",
|
||||
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
|
||||
"properties": {
|
||||
"template": "computer",
|
||||
"max_calls": 5,
|
||||
"retries": 1
|
||||
},
|
||||
"x": 1000,
|
||||
"y": 2581,
|
||||
"width": 212,
|
||||
"height": 206,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "focal/Focal",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"3da498c0-55de-4ec3-9943-6486279b9826": {
|
||||
"title": "GET scene loop.user_message",
|
||||
"id": "3da498c0-55de-4ec3-9943-6486279b9826",
|
||||
@@ -1167,6 +1150,24 @@
|
||||
"inherited": false,
|
||||
"registry": "data/MakeList",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66": {
|
||||
"title": "AI Function Calling",
|
||||
"id": "4eb36f21-1020-4609-85e5-d16b42019c66",
|
||||
"properties": {
|
||||
"template": "computer",
|
||||
"max_calls": 5,
|
||||
"retries": 1,
|
||||
"response_length": 1408
|
||||
},
|
||||
"x": 1000,
|
||||
"y": 2581,
|
||||
"width": 210,
|
||||
"height": 230,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "focal/Focal",
|
||||
"base_type": "core/Node"
|
||||
}
|
||||
},
|
||||
"edges": {
|
||||
@@ -1296,9 +1297,6 @@
|
||||
"d45f07ff-c3ce-4aac-bae6-b9c77089cb69.list": [
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66.callbacks"
|
||||
],
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
|
||||
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
|
||||
],
|
||||
"3da498c0-55de-4ec3-9943-6486279b9826.value": [
|
||||
"6d707f1c-af55-481a-970a-eb7a9f9c45dd.message"
|
||||
],
|
||||
@@ -1380,6 +1378,9 @@
|
||||
],
|
||||
"632d4a0e-3327-409b-aaa4-ed38b932286b.list": [
|
||||
"06175a92-abbe-4483-9ab2-abaee2104728.value"
|
||||
],
|
||||
"4eb36f21-1020-4609-85e5-d16b42019c66.state": [
|
||||
"3008c0f4-105d-444a-8fde-0ac11a21f40c.state"
|
||||
]
|
||||
},
|
||||
"groups": [
|
||||
|
||||
@@ -19,7 +19,7 @@ You have access to the following functions you must call to fulfill the user's r
|
||||
focal.callbacks.set_simulated_environment.render(
|
||||
"Create or change the simulated environment. This means the location, time, specific conditions, or any other aspect of the simulation that is not directly related to the characters.",
|
||||
instructions="Instructions on how to change the simulated environment. These will be given to the simulation computer to setup the new environment. REQUIRED.",
|
||||
reset="If true, the environment should be reset and all simulated characters are removed. If false, the environment should be changed but the characters should remain. REQUIRED.",
|
||||
reset="If true, the environment should be reset and ALL simulated characters are removed. IMPORTANT: If you set reset=true, this function MUST be the FIRST call in your stack; otherwise, set reset=false to avoid deactivating characters added earlier. REQUIRED.",
|
||||
examples=[
|
||||
{"instructions": "Change the location to a lush forest, with a river running through it.", "reset":true},
|
||||
{"instructions": "The simulation suite flickers and changes to a bustling city street.", "reset":true},
|
||||
@@ -123,7 +123,7 @@ You have access to the following functions you must call to fulfill the user's r
|
||||
|
||||
{{
|
||||
focal.callbacks.set_simulation_goal.render(
|
||||
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives.",
|
||||
"Briefly describe the overall goal of the simulation. What is the user looking to experience? What needs to happen for the simulation to be considered complete? This function is used to provide context and direction for the simulation. It should be clear, specific and detailed, and focused on the user's objectives. You MUST call this on new simulations or if the user has requested a change in the simulation's goal.",
|
||||
goal="The overall goal of the simulation. This should be a clear and concise statement that outlines the user's objective. REQUIRED.",
|
||||
examples=[
|
||||
{"goal": "The user is exploring a mysterious alien planet to uncover the secrets of an ancient civilization."},
|
||||
|
||||
@@ -1,111 +1,112 @@
|
||||
|
||||
def game(TM):
|
||||
|
||||
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
|
||||
|
||||
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
|
||||
|
||||
|
||||
MSG_HELP = 'Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating "Computer," followed by an instruction. For example ... "Computer, i want to experience being on a derelict spaceship."'
|
||||
|
||||
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
|
||||
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice. Remind the user that this is an old version of the simulation suite and they should check out version two for a more advanced experience."
|
||||
|
||||
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
|
||||
CTX_PIN_UNAWARE = (
|
||||
"Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
|
||||
)
|
||||
|
||||
AUTO_NARRATE_INTERVAL = 10
|
||||
|
||||
def parse_sim_call_arguments(call:str) -> str:
|
||||
|
||||
def parse_sim_call_arguments(call: str) -> str:
|
||||
"""
|
||||
Returns the value between the parentheses of a simulation call
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
call = 'change_environment("a house")'
|
||||
|
||||
|
||||
parse_sim_call_arguments(call) -> "a house"
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
return call.split("(", 1)[1].split(")")[0]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class SimulationSuite:
|
||||
def __init__(self):
|
||||
|
||||
"""
|
||||
This is initialized at the beginning of each round of the simulation suite
|
||||
"""
|
||||
|
||||
|
||||
# do we update the world state at the end of the round
|
||||
self.update_world_state = False
|
||||
self.simulation_reset = False
|
||||
|
||||
|
||||
# will keep track of any npcs added during the current round
|
||||
self.added_npcs = []
|
||||
|
||||
|
||||
TM.log.debug("SIMULATION SUITE INIT!", scene=TM.scene)
|
||||
self.player_message = TM.scene.last_player_message
|
||||
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
|
||||
|
||||
self.last_processed_call = TM.game_state.get_var(
|
||||
"instr.lastprocessed_call", -1
|
||||
)
|
||||
|
||||
# determine whether the player / user input is an instruction
|
||||
# to the simulation computer
|
||||
#
|
||||
#
|
||||
# we do this by checking if the message starts with "Computer,"
|
||||
self.player_message_is_instruction = (
|
||||
self.player_message and
|
||||
self.player_message.raw.lower().startswith("computer") and
|
||||
not self.player_message.hidden and
|
||||
not self.last_processed_call > self.player_message.id
|
||||
self.player_message
|
||||
and self.player_message.raw.lower().startswith("computer")
|
||||
and not self.player_message.hidden
|
||||
and not self.last_processed_call > self.player_message.id
|
||||
)
|
||||
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Main entry point for the simulation suite
|
||||
"""
|
||||
|
||||
|
||||
if not TM.game_state.has_var("instr.simulation_stopped"):
|
||||
# simulation is still running
|
||||
self.simulation()
|
||||
|
||||
|
||||
self.finalize_round()
|
||||
|
||||
|
||||
def simulation(self):
|
||||
"""
|
||||
Simulation suite logic
|
||||
"""
|
||||
|
||||
|
||||
if not TM.game_state.has_var("instr.simulation_started"):
|
||||
self.startup()
|
||||
else:
|
||||
self.simulation_calls()
|
||||
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state(force=True)
|
||||
|
||||
|
||||
def startup(self):
|
||||
|
||||
"""
|
||||
Scene startup logic
|
||||
"""
|
||||
|
||||
|
||||
# we are at the beginning of the simulation
|
||||
TM.signals.status("busy", "Simulation suite powering up.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite powering up.", as_scene_message=True
|
||||
)
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
|
||||
|
||||
# add narration for the introduction
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_STARTUP,
|
||||
emit_message=False
|
||||
emit_message=False,
|
||||
)
|
||||
|
||||
|
||||
# add narration for the instructions on how to interact with the simulation
|
||||
# this is a passthrough since we don't want the AI to paraphrase this
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="passthrough",
|
||||
narration=MSG_HELP
|
||||
action_name="passthrough", narration=MSG_HELP
|
||||
)
|
||||
|
||||
|
||||
# create a world state entry letting the AI know that characters
|
||||
# interacting in the simulation are not aware of the computer or the simulation
|
||||
TM.agents.world_state.save_world_entry(
|
||||
@@ -113,37 +114,43 @@ def game(TM):
|
||||
text=CTX_PIN_UNAWARE,
|
||||
meta={},
|
||||
# this should always be pinned
|
||||
pin=True
|
||||
pin=True,
|
||||
)
|
||||
|
||||
|
||||
# set flag that we have started the simulation
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
|
||||
|
||||
# signal to the UX that the simulation suite is ready
|
||||
TM.signals.status("success", "Simulation suite ready", as_scene_message=True)
|
||||
|
||||
TM.signals.status(
|
||||
"success", "Simulation suite ready", as_scene_message=True
|
||||
)
|
||||
|
||||
# we want to update the world state at the end of the round
|
||||
self.update_world_state = True
|
||||
|
||||
|
||||
def simulation_calls(self):
|
||||
"""
|
||||
Calls the simulation suite main prompt to determine the appropriate
|
||||
simulation calls
|
||||
"""
|
||||
|
||||
|
||||
# we only process instructions that are not hidden and are not the last processed call
|
||||
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
|
||||
if (
|
||||
not self.player_message_is_instruction
|
||||
or self.player_message.id == self.last_processed_call
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
# First instruction?
|
||||
if not TM.game_state.has_var("instr.has_issued_instructions"):
|
||||
|
||||
# determine the context of the simulation
|
||||
context_context = TM.agents.creator.determine_content_context_for_description(
|
||||
description=self.player_message.raw,
|
||||
context_context = (
|
||||
TM.agents.creator.determine_content_context_for_description(
|
||||
description=self.player_message.raw,
|
||||
)
|
||||
)
|
||||
TM.scene.set_content_context(context_context)
|
||||
|
||||
|
||||
# Render the `computer` template and send it to the LLM for processing
|
||||
# The LLM will return a list of calls that the simulation suite will process
|
||||
# The calls are pseudo code that the simulation suite will interpret and execute
|
||||
@@ -153,90 +160,98 @@ def game(TM):
|
||||
player_instruction=self.player_message.raw,
|
||||
scene=TM.scene,
|
||||
)
|
||||
|
||||
|
||||
self.calls = calls = calls.split("\n")
|
||||
|
||||
|
||||
calls = self.prepare_calls(calls)
|
||||
|
||||
|
||||
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
|
||||
|
||||
|
||||
# calls that are processed
|
||||
processed = []
|
||||
|
||||
|
||||
for call in calls:
|
||||
processed_call = self.process_call(call)
|
||||
if processed_call:
|
||||
processed.append(processed_call)
|
||||
|
||||
|
||||
if processed:
|
||||
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
|
||||
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
|
||||
|
||||
TM.signals.status("busy", "Simulation suite altering environment.", as_scene_message=True)
|
||||
TM.game_state.set_var(
|
||||
"instr.has_issued_instructions", "yes", commit=False
|
||||
)
|
||||
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite altering environment.", as_scene_message=True
|
||||
)
|
||||
compiled = "\n".join(processed)
|
||||
|
||||
|
||||
if not self.simulation_reset and compiled:
|
||||
|
||||
# send the compiled calls to the narrator to generate a narrative based
|
||||
# on them
|
||||
narration = TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
# on the first narration we update the scene description and remove any mention of the computer
|
||||
# or the simulation from the previous narration
|
||||
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
|
||||
is_initial_narration = TM.game_state.get_var(
|
||||
"instr.intro_narration", False
|
||||
)
|
||||
if not is_initial_narration:
|
||||
TM.scene.set_description(narration.raw)
|
||||
TM.scene.set_intro(narration.raw)
|
||||
TM.log.debug("SIMULATION SUITE: initial narration", intro=narration.raw)
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: initial narration", intro=narration.raw
|
||||
)
|
||||
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
|
||||
TM.scene.pop_history(typ="director", all=True, reverse=True)
|
||||
TM.game_state.set_var("instr.intro_narration", True, commit=False)
|
||||
|
||||
|
||||
self.update_world_state = True
|
||||
|
||||
|
||||
self.set_simulation_title(compiled)
|
||||
|
||||
|
||||
def set_simulation_title(self, compiled_calls):
|
||||
|
||||
"""
|
||||
Generates a fitting title for the simulation based on the user's instructions
|
||||
"""
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
|
||||
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: set simulation title",
|
||||
name=TM.scene.title,
|
||||
compiled_calls=compiled_calls,
|
||||
)
|
||||
|
||||
if not compiled_calls:
|
||||
return
|
||||
|
||||
|
||||
if TM.scene.title != "Simulation Suite":
|
||||
# name already changed, no need to do it again
|
||||
return
|
||||
|
||||
|
||||
title = TM.agents.creator.contextual_generate_from_args(
|
||||
"scene:simulation title",
|
||||
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
|
||||
length=75
|
||||
length=75,
|
||||
)
|
||||
|
||||
|
||||
title = title.strip('"').strip()
|
||||
|
||||
|
||||
TM.scene.set_title(title)
|
||||
|
||||
|
||||
def prepare_calls(self, calls):
|
||||
"""
|
||||
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
|
||||
found, ensure that the `set_player_name` call is processed first by moving it in front of the
|
||||
`set_player_persona` call.
|
||||
"""
|
||||
|
||||
|
||||
set_player_name_call_exists = -1
|
||||
set_player_persona_call_exists = -1
|
||||
|
||||
|
||||
i = 0
|
||||
for call in calls:
|
||||
if "set_player_name" in call:
|
||||
@@ -244,351 +259,445 @@ def game(TM):
|
||||
elif "set_player_persona" in call:
|
||||
set_player_persona_call_exists = i
|
||||
i = i + 1
|
||||
|
||||
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
|
||||
|
||||
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
|
||||
if set_player_name_call_exists > set_player_persona_call_exists:
|
||||
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
|
||||
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
|
||||
|
||||
calls.insert(
|
||||
set_player_persona_call_exists,
|
||||
calls.pop(set_player_name_call_exists),
|
||||
)
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: prepare calls - moved set_player_persona call",
|
||||
calls=calls,
|
||||
)
|
||||
|
||||
return calls
|
||||
|
||||
def process_call(self, call:str) -> str:
|
||||
def process_call(self, call: str) -> str:
|
||||
"""
|
||||
Processes a simulation call
|
||||
|
||||
|
||||
Simulation alls are pseudo functions that are called by the simulation suite
|
||||
|
||||
|
||||
We grab the function name by splitting against ( and taking the first element
|
||||
if the SimulationSuite has a method with the name _call_{function_name} then we call it
|
||||
|
||||
|
||||
if a function name could be found but we do not have a method to call we dont do anything
|
||||
but we still return it as procssed as the AI can still interpret it as something later on
|
||||
"""
|
||||
|
||||
|
||||
if "(" not in call:
|
||||
return None
|
||||
|
||||
|
||||
function_name = call.split("(")[0]
|
||||
|
||||
|
||||
if hasattr(self, f"call_{function_name}"):
|
||||
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE CALL", call=call, function_name=function_name
|
||||
)
|
||||
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
|
||||
return getattr(self, f"call_{function_name}")(call, inject)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_set_simulation_goal(self, call:str, inject:str) -> str:
|
||||
def call_set_simulation_goal(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
Set's the simulation goal as a permanent pin
|
||||
"""
|
||||
TM.signals.status("busy", "Simulation suite setting goal.", as_scene_message=True)
|
||||
TM.agents.world_state.save_world_entry(
|
||||
entry_id="sim.goal",
|
||||
text=self.player_message.raw,
|
||||
meta={},
|
||||
pin=True
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite setting goal.", as_scene_message=True
|
||||
)
|
||||
|
||||
TM.agents.world_state.save_world_entry(
|
||||
entry_id="sim.goal", text=self.player_message.raw, meta={}, pin=True
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer sets the goal for the simulation.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_change_environment(self, call:str, inject:str) -> str:
|
||||
|
||||
def call_change_environment(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
Simulation changes the environment, this is entirely interpreted by the AI
|
||||
and we dont need to do any logic on our end, so we just return the call
|
||||
"""
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer changes the environment of the simulation."
|
||||
action_description="The computer changes the environment of the simulation.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_answer_question(self, call:str, inject:str) -> str:
|
||||
|
||||
def call_answer_question(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
The player asked the simulation a query, we need to process this and have
|
||||
the AI produce an answer
|
||||
"""
|
||||
|
||||
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
def call_set_player_persona(self, call:str, inject:str) -> str:
|
||||
|
||||
|
||||
def call_set_player_persona(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
The simulation suite is altering the player persona
|
||||
"""
|
||||
|
||||
|
||||
player_character = TM.scene.get_player_character()
|
||||
|
||||
TM.signals.status("busy", "Simulation suite altering user persona.", as_scene_message=True)
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=player_character.name, text=inject, alteration_instructions=self.player_message.raw
|
||||
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite altering user persona.", as_scene_message=True
|
||||
)
|
||||
TM.scene.set_character_attributes(player_character.name, character_attributes)
|
||||
|
||||
character_description = TM.agents.creator.determine_character_description(player_character.name)
|
||||
|
||||
TM.scene.set_character_description(player_character.name, character_description)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
|
||||
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=player_character.name,
|
||||
text=inject,
|
||||
alteration_instructions=self.player_message.raw,
|
||||
)
|
||||
TM.scene.set_character_attributes(
|
||||
player_character.name, character_attributes
|
||||
)
|
||||
|
||||
character_description = TM.agents.creator.determine_character_description(
|
||||
player_character.name
|
||||
)
|
||||
|
||||
TM.scene.set_character_description(
|
||||
player_character.name, character_description
|
||||
)
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: transform player",
|
||||
attributes=character_attributes,
|
||||
description=character_description,
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer transforms the player persona."
|
||||
action_description="The computer transforms the player persona.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_set_player_name(self, call:str, inject:str) -> str:
|
||||
|
||||
|
||||
def call_set_player_name(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
The simulation suite is altering the player name
|
||||
"""
|
||||
player_character = TM.scene.get_player_character()
|
||||
|
||||
TM.signals.status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
|
||||
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
"Simulation suite adjusting user identity.",
|
||||
as_scene_message=True,
|
||||
)
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits."
|
||||
)
|
||||
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
|
||||
if character_name != player_character.name:
|
||||
TM.scene.set_character_name(player_character.name, character_name)
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer changes the player's identity to {character_name}."
|
||||
action_description=f"The computer changes the player's identity to {character_name}.",
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_add_ai_character(self, call:str, inject:str) -> str:
|
||||
|
||||
return call
|
||||
|
||||
def call_add_ai_character(self, call: str, inject: str) -> str:
|
||||
# sometimes the AI will call this function an pass an inanimate object as the parameter
|
||||
# we need to determine if this is the case and just ignore it
|
||||
is_inanimate = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
|
||||
|
||||
is_inanimate = TM.agents.world_state.answer_query_true_or_false(
|
||||
f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)",
|
||||
call,
|
||||
)
|
||||
|
||||
if is_inanimate:
|
||||
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped",
|
||||
call=call,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# sometimes the AI will ask if the function adds a group of characters, we need to
|
||||
# determine if this is the case
|
||||
adds_group = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add MULTIPLE ai characters?", call)
|
||||
|
||||
adds_group = TM.agents.world_state.answer_query_true_or_false(
|
||||
f"does the function `{call}` add MULTIPLE ai characters?", call
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
|
||||
|
||||
TM.signals.status("busy", "Simulation suite adding character.", as_scene_message=True)
|
||||
|
||||
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite adding character.", as_scene_message=True
|
||||
)
|
||||
|
||||
if not adds_group:
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name."
|
||||
)
|
||||
else:
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.",
|
||||
group=True,
|
||||
)
|
||||
|
||||
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
|
||||
# the same character, if this happens we need to combine into a single add_ai_character call
|
||||
|
||||
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
|
||||
|
||||
|
||||
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(
|
||||
f"Are there any calls to `change_ai_character` in the instruction for {character_name}?",
|
||||
"\n".join(self.calls),
|
||||
)
|
||||
|
||||
if has_change_ai_character_call:
|
||||
|
||||
combined_arg = TM.prompt.request(
|
||||
"combine-add-and-alter-ai-character",
|
||||
dedupe_enabled=False,
|
||||
calls="\n".join(self.calls),
|
||||
character_name=character_name,
|
||||
scene=TM.scene,
|
||||
).replace("COMBINED ARGUMENT:", "").strip()
|
||||
|
||||
combined_arg = (
|
||||
TM.prompt.request(
|
||||
"combine-add-and-alter-ai-character",
|
||||
dedupe_enabled=False,
|
||||
calls="\n".join(self.calls),
|
||||
character_name=character_name,
|
||||
scene=TM.scene,
|
||||
)
|
||||
.replace("COMBINED ARGUMENT:", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
call = f"add_ai_character({combined_arg})"
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
|
||||
TM.signals.status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
|
||||
|
||||
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
f"Simulation suite adding character: {character_name}",
|
||||
as_scene_message=True,
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
|
||||
|
||||
npc = TM.agents.director.persist_character(character_name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
|
||||
|
||||
|
||||
npc = TM.agents.director.persist_character(
|
||||
character_name=character_name,
|
||||
content=self.player_message.raw + f"\n\n{inject}",
|
||||
determine_name=False,
|
||||
)
|
||||
|
||||
self.added_npcs.append(npc.name)
|
||||
|
||||
|
||||
TM.agents.world_state.add_detail_reinforcement(
|
||||
character_name=npc.name,
|
||||
detail="Goal",
|
||||
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
|
||||
interval=25,
|
||||
run_immediately=True
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
|
||||
|
||||
|
||||
TM.agents.visual.generate_character_portrait(character_name=npc.name)
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer adds {npc.name} to the simulation."
|
||||
action_description=f"The computer adds {npc.name} to the simulation.",
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
return call
|
||||
|
||||
####
|
||||
|
||||
def call_remove_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.signals.status("busy", "Simulation suite removing character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names)
|
||||
|
||||
def call_remove_ai_character(self, call: str, inject: str) -> str:
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite removing character.", as_scene_message=True
|
||||
)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character being removed?",
|
||||
allowed_names=TM.scene.npc_character_names,
|
||||
)
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
|
||||
if npc:
|
||||
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
|
||||
TM.agents.world_state.deactivate_character(action_name="deactivate_character", character_name=npc.name)
|
||||
|
||||
TM.agents.world_state.deactivate_character(
|
||||
action_name="deactivate_character", character_name=npc.name
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer removes {npc.name} from the simulation."
|
||||
action_description=f"The computer removes {npc.name} from the simulation.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_change_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.signals.status("busy", "Simulation suite altering character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names)
|
||||
|
||||
def call_change_ai_character(self, call: str, inject: str) -> str:
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite altering character.", as_scene_message=True
|
||||
)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?",
|
||||
allowed_names=TM.scene.npc_character_names,
|
||||
)
|
||||
|
||||
if character_name in self.added_npcs:
|
||||
# we dont want to change the character if it was just added
|
||||
return
|
||||
|
||||
character_name_after = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
|
||||
|
||||
|
||||
character_name_after = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?"
|
||||
)
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
|
||||
if npc:
|
||||
TM.signals.status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
|
||||
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
f"Changing {character_name} -> {character_name_after}",
|
||||
as_scene_message=True,
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
|
||||
|
||||
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=npc.name,
|
||||
text=inject,
|
||||
alteration_instructions=self.player_message.raw
|
||||
alteration_instructions=self.player_message.raw,
|
||||
)
|
||||
|
||||
|
||||
TM.scene.set_character_attributes(npc.name, character_attributes)
|
||||
character_description = TM.agents.creator.determine_character_description(npc.name)
|
||||
|
||||
character_description = (
|
||||
TM.agents.creator.determine_character_description(npc.name)
|
||||
)
|
||||
|
||||
TM.scene.set_character_description(npc.name, character_description)
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: transform npc",
|
||||
attributes=character_attributes,
|
||||
description=character_description,
|
||||
)
|
||||
|
||||
if character_name_after != character_name:
|
||||
TM.scene.set_character_name(npc.name, character_name_after)
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer transforms {npc.name}."
|
||||
action_description=f"The computer transforms {npc.name}.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_end_simulation(self, call:str, inject:str) -> str:
|
||||
|
||||
|
||||
def call_end_simulation(self, call: str, inject: str) -> str:
|
||||
player_character = TM.scene.get_player_character()
|
||||
|
||||
explicit_command = TM.agents.world_state.answer_query_true_or_false("has the player explicitly asked to end the simulation?", self.player_message.raw)
|
||||
|
||||
|
||||
explicit_command = TM.agents.world_state.answer_query_true_or_false(
|
||||
"has the player explicitly asked to end the simulation?",
|
||||
self.player_message.raw,
|
||||
)
|
||||
|
||||
if explicit_command:
|
||||
TM.signals.status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
"Simulation suite ending current simulation.",
|
||||
as_scene_message=True,
|
||||
)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names)}. The player is also transformed back to their normal, non-descript persona as the form of {player_character.name} ceases to exist.",
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
TM.scene.restore()
|
||||
|
||||
self.simulation_reset = True
|
||||
|
||||
|
||||
TM.game_state.unset_var("instr.has_issued_instructions")
|
||||
TM.game_state.unset_var("instr.lastprocessed_call")
|
||||
TM.game_state.unset_var("instr.simulation_started")
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer ends the simulation."
|
||||
action_description="The computer ends the simulation.",
|
||||
)
|
||||
|
||||
|
||||
def finalize_round(self):
|
||||
|
||||
# track rounds
|
||||
rounds = TM.game_state.get_var("instr.rounds", 0)
|
||||
|
||||
|
||||
# increase rounds
|
||||
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
|
||||
|
||||
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
|
||||
|
||||
|
||||
has_issued_instructions = TM.game_state.has_var(
|
||||
"instr.has_issued_instructions"
|
||||
)
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state()
|
||||
|
||||
|
||||
if self.player_message_is_instruction:
|
||||
TM.scene.hide_message(self.player_message.id)
|
||||
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
|
||||
TM.signals.status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
|
||||
|
||||
TM.game_state.set_var(
|
||||
"instr.lastprocessed_call", self.player_message.id, commit=False
|
||||
)
|
||||
TM.signals.status(
|
||||
"success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True
|
||||
)
|
||||
|
||||
elif self.player_message and not has_issued_instructions:
|
||||
# simulation started, player message is NOT an instruction, and player has not given
|
||||
# any instructions
|
||||
self.guide_player()
|
||||
|
||||
elif self.player_message and not TM.scene.npc_character_names:
|
||||
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
|
||||
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
|
||||
self.narrate_round()
|
||||
|
||||
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names and has_issued_instructions:
|
||||
|
||||
elif (
|
||||
rounds % AUTO_NARRATE_INTERVAL == 0
|
||||
and rounds
|
||||
and TM.scene.npc_character_names
|
||||
and has_issued_instructions
|
||||
):
|
||||
# every N rounds, narrate the round
|
||||
self.narrate_round()
|
||||
|
||||
|
||||
def guide_player(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="paraphrase",
|
||||
narration=MSG_HELP,
|
||||
emit_message=True
|
||||
action_name="paraphrase", narration=MSG_HELP, emit_message=True
|
||||
)
|
||||
|
||||
|
||||
def narrate_round(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_NARRATE_ROUND,
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
def run_update_world_state(self, force=False):
|
||||
TM.log.debug("SIMULATION SUITE: update world state", force=force)
|
||||
TM.signals.status("busy", "Simulation suite updating world state.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite updating world state.", as_scene_message=True
|
||||
)
|
||||
TM.agents.world_state.update_world_state(force=force)
|
||||
TM.signals.status("success", "Simulation suite updated world state.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"success",
|
||||
"Simulation suite updated world state.",
|
||||
as_scene_message=True,
|
||||
)
|
||||
|
||||
SimulationSuite().run()
|
||||
|
||||
|
||||
|
||||
def on_generation_cancelled(TM, exc):
|
||||
|
||||
"""
|
||||
Called when user pressed the cancel button during the simulation suite
|
||||
loop.
|
||||
"""
|
||||
|
||||
TM.signals.status("success", "Simulation suite instructions cancelled", as_scene_message=True)
|
||||
|
||||
TM.signals.status(
|
||||
"success", "Simulation suite instructions cancelled", as_scene_message=True
|
||||
)
|
||||
rounds = TM.game_state.get_var("instr.rounds", 0)
|
||||
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)
|
||||
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .tale_mate import *
|
||||
from .tale_mate import * # noqa: F401, F403
|
||||
|
||||
from .version import VERSION
|
||||
|
||||
__version__ = VERSION
|
||||
__version__ = VERSION
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from .base import Agent
|
||||
from .conversation import ConversationAgent
|
||||
from .creator import CreatorAgent
|
||||
from .director import DirectorAgent
|
||||
from .editor import EditorAgent
|
||||
from .memory import ChromaDBMemoryAgent, MemoryAgent
|
||||
from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .tts import TTSAgent
|
||||
from .visual import VisualAgent
|
||||
from .world_state import WorldStateAgent
|
||||
from .base import Agent # noqa: F401
|
||||
from .conversation import ConversationAgent # noqa: F401
|
||||
from .creator import CreatorAgent # noqa: F401
|
||||
from .director import DirectorAgent # noqa: F401
|
||||
from .editor import EditorAgent # noqa: F401
|
||||
from .memory import ChromaDBMemoryAgent, MemoryAgent # noqa: F401
|
||||
from .narrator import NarratorAgent # noqa: F401
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register # noqa: F401
|
||||
from .summarize import SummarizeAgent # noqa: F401
|
||||
from .tts import TTSAgent # noqa: F401
|
||||
from .visual import VisualAgent # noqa: F401
|
||||
from .world_state import WorldStateAgent # noqa: F401
|
||||
|
||||
@@ -6,11 +6,10 @@ from inspect import signature
|
||||
import re
|
||||
from abc import ABC
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
from typing import Callable, Union
|
||||
import uuid
|
||||
import pydantic
|
||||
import structlog
|
||||
from blinker import signal
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
@@ -19,10 +18,11 @@ from talemate.agents.context import ActiveAgent, active_agent
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopStartEvent
|
||||
from talemate.context import active_scene
|
||||
import talemate.config as config
|
||||
from talemate.ux.schema import Column
|
||||
from talemate.config import get_config, Config
|
||||
import talemate.config.schema as config_schema
|
||||
from talemate.client.context import (
|
||||
ClientContext,
|
||||
set_client_context_attribute,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -39,6 +39,7 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.agents.base")
|
||||
|
||||
|
||||
class AgentActionConditional(pydantic.BaseModel):
|
||||
attribute: str
|
||||
value: int | float | str | bool | list[int | float | str | bool] | None = None
|
||||
@@ -48,24 +49,26 @@ class AgentActionNote(pydantic.BaseModel):
|
||||
type: str
|
||||
text: str
|
||||
|
||||
|
||||
class AgentActionConfig(pydantic.BaseModel):
|
||||
type: str
|
||||
label: str
|
||||
description: str = ""
|
||||
value: Union[int, float, str, bool, list, None] = None
|
||||
default_value: Union[int, float, str, bool] = None
|
||||
max: Union[int, float, None] = None
|
||||
min: Union[int, float, None] = None
|
||||
step: Union[int, float, None] = None
|
||||
value: int | float | str | bool | list | None = None
|
||||
default_value: int | float | str | bool | None = None
|
||||
max: int | float | None = None
|
||||
min: int | float | None = None
|
||||
step: int | float | None = None
|
||||
scope: str = "global"
|
||||
choices: Union[list[dict[str, str]], None] = None
|
||||
choices: list[dict[str, str | int | float | bool]] | None = None
|
||||
note: Union[str, None] = None
|
||||
expensive: bool = False
|
||||
quick_toggle: bool = False
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
title: Union[str, None] = None
|
||||
value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True)
|
||||
|
||||
condition: AgentActionConditional | None = None
|
||||
title: str | None = None
|
||||
value_migration: Callable | None = pydantic.Field(default=None, exclude=True)
|
||||
columns: list[Column] | None = None
|
||||
|
||||
note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
@@ -77,45 +80,46 @@ class AgentAction(pydantic.BaseModel):
|
||||
label: str
|
||||
description: str = ""
|
||||
warning: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
config: dict[str, AgentActionConfig] | None = None
|
||||
condition: AgentActionConditional | None = None
|
||||
container: bool = False
|
||||
icon: Union[str, None] = None
|
||||
icon: str | None = None
|
||||
can_be_disabled: bool = False
|
||||
quick_toggle: bool = False
|
||||
experimental: bool = False
|
||||
|
||||
|
||||
class AgentDetail(pydantic.BaseModel):
|
||||
value: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
icon: Union[str, None] = None
|
||||
value: str | None = None
|
||||
description: str | None = None
|
||||
icon: str | None = None
|
||||
color: str = "grey"
|
||||
hidden: bool = False
|
||||
|
||||
|
||||
class DynamicInstruction(pydantic.BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
f"<|SECTION:{self.title}|>",
|
||||
self.content,
|
||||
"<|CLOSE_SECTION|>"
|
||||
]
|
||||
[f"<|SECTION:{self.title}|>", self.content, "<|CLOSE_SECTION|>"]
|
||||
)
|
||||
|
||||
|
||||
def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = None) -> dict:
|
||||
|
||||
def args_and_kwargs_to_dict(
|
||||
fn, args: list, kwargs: dict, filter: list[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Takes a list of arguments and a dict of keyword arguments and returns
|
||||
a dict mapping parameter names to their values.
|
||||
|
||||
|
||||
Args:
|
||||
fn: The function whose parameters we want to map
|
||||
args: List of positional arguments
|
||||
kwargs: Dictionary of keyword arguments
|
||||
filter: List of parameter names to include in the result, if None all parameters are included
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping parameter names to their values
|
||||
"""
|
||||
@@ -124,34 +128,36 @@ def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = Non
|
||||
bound_args.apply_defaults()
|
||||
rv = dict(bound_args.arguments)
|
||||
rv.pop("self", None)
|
||||
|
||||
|
||||
if filter:
|
||||
for key in list(rv.keys()):
|
||||
if key not in filter:
|
||||
rv.pop(key)
|
||||
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
class store_context_state:
|
||||
"""
|
||||
Flag to store a function's arguments in the agent's context state.
|
||||
|
||||
|
||||
Any arguments passed to the function will be stored in the agent's context
|
||||
|
||||
|
||||
If no arguments are passed, all arguments will be stored.
|
||||
|
||||
|
||||
Keyword arguments can be passed to store additional values in the context state.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def __call__(self, fn):
|
||||
fn.store_context_state = self.args
|
||||
fn.store_context_state_kwargs = self.kwargs
|
||||
return fn
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
decorator that emits the agent status as processing while the function
|
||||
@@ -165,31 +171,33 @@ def set_processing(fn):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
with ClientContext():
|
||||
scene = active_scene.get()
|
||||
|
||||
|
||||
if scene:
|
||||
scene.continue_actions()
|
||||
|
||||
if getattr(scene, "config", None):
|
||||
set_client_context_attribute("app_config_system_prompts", scene.config.get("system_prompts", {}))
|
||||
|
||||
|
||||
with ActiveAgent(self, fn, args, kwargs) as active_agent_context:
|
||||
try:
|
||||
await self.emit_status(processing=True)
|
||||
|
||||
|
||||
# Now pass the complete args list
|
||||
if getattr(fn, "store_context_state", None) is not None:
|
||||
all_args = args_and_kwargs_to_dict(
|
||||
fn, [self] + list(args), kwargs, getattr(fn, "store_context_state", [])
|
||||
fn,
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
getattr(fn, "store_context_state", []),
|
||||
)
|
||||
if getattr(fn, "store_context_state_kwargs", None) is not None:
|
||||
all_args.update(getattr(fn, "store_context_state_kwargs", {}))
|
||||
|
||||
all_args.update(
|
||||
getattr(fn, "store_context_state_kwargs", {})
|
||||
)
|
||||
|
||||
all_args[f"fn_{fn.__name__}"] = True
|
||||
|
||||
|
||||
active_agent_context.state_params = all_args
|
||||
|
||||
|
||||
self.set_context_states(**all_args)
|
||||
|
||||
|
||||
return await fn(self, *args, **kwargs)
|
||||
finally:
|
||||
try:
|
||||
@@ -211,18 +219,23 @@ class Agent(ABC):
|
||||
verbose_name = None
|
||||
set_processing = set_processing
|
||||
requires_llm_client = True
|
||||
auto_break_repetition = False
|
||||
websocket_handler = None
|
||||
essential = True
|
||||
ready_check_error = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls, actions: dict[str, AgentAction] | None = None) -> dict[str, AgentAction]:
|
||||
def init_actions(
|
||||
cls, actions: dict[str, AgentAction] | None = None
|
||||
) -> dict[str, AgentAction]:
|
||||
if actions is None:
|
||||
actions = {}
|
||||
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return get_config()
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
if hasattr(self, "client"):
|
||||
@@ -230,12 +243,14 @@ class Agent(ABC):
|
||||
return self.client.name
|
||||
return None
|
||||
|
||||
@property
|
||||
def verbose_name(self):
|
||||
return self.agent_type.capitalize()
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if not self.requires_llm_client:
|
||||
return True
|
||||
|
||||
if not hasattr(self, "client"):
|
||||
return False
|
||||
|
||||
if not getattr(self.client, "enabled", True):
|
||||
return False
|
||||
|
||||
@@ -315,67 +330,69 @@ class Agent(ABC):
|
||||
return {}
|
||||
|
||||
return {k: v.model_dump() for k, v in self.actions.items()}
|
||||
|
||||
|
||||
# scene state
|
||||
|
||||
|
||||
def context_fingerpint(self, extra: list[str] = []) -> str | None:
|
||||
|
||||
def context_fingerprint(self, extra: list[str] = None) -> str | None:
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
|
||||
if not active_agent_context:
|
||||
return None
|
||||
|
||||
|
||||
if self.scene.history:
|
||||
fingerprint = f"{self.scene.history[-1].fingerprint}-{active_agent_context.first.fingerprint}"
|
||||
else:
|
||||
fingerprint = f"START-{active_agent_context.first.fingerprint}"
|
||||
|
||||
for extra_key in extra:
|
||||
fingerprint += f"-{hash(extra_key)}"
|
||||
|
||||
|
||||
if extra:
|
||||
for extra_key in extra:
|
||||
fingerprint += f"-{hash(extra_key)}"
|
||||
|
||||
return fingerprint
|
||||
|
||||
|
||||
def get_scene_state(self, key:str, default=None):
|
||||
|
||||
def get_scene_state(self, key: str, default=None):
|
||||
agent_state = self.scene.agent_state.get(self.agent_type, {})
|
||||
return agent_state.get(key, default)
|
||||
|
||||
|
||||
def set_scene_states(self, **kwargs):
|
||||
agent_state = self.scene.agent_state.get(self.agent_type, {})
|
||||
for key, value in kwargs.items():
|
||||
agent_state[key] = value
|
||||
self.scene.agent_state[self.agent_type] = agent_state
|
||||
|
||||
|
||||
def dump_scene_state(self):
|
||||
return self.scene.agent_state.get(self.agent_type, {})
|
||||
|
||||
|
||||
# active agent context state
|
||||
|
||||
def get_context_state(self, key:str, default=None):
|
||||
|
||||
def get_context_state(self, key: str, default=None):
|
||||
key = f"{self.agent_type}__{key}"
|
||||
try:
|
||||
return active_agent.get().state.get(key, default)
|
||||
except AttributeError:
|
||||
log.warning("get_context_state error", agent=self.agent_type, key=key)
|
||||
return default
|
||||
|
||||
|
||||
def set_context_states(self, **kwargs):
|
||||
try:
|
||||
|
||||
items = {f"{self.agent_type}__{k}": v for k, v in kwargs.items()}
|
||||
active_agent.get().state.update(items)
|
||||
log.debug("set_context_states", agent=self.agent_type, state=active_agent.get().state)
|
||||
log.debug(
|
||||
"set_context_states",
|
||||
agent=self.agent_type,
|
||||
state=active_agent.get().state,
|
||||
)
|
||||
except AttributeError:
|
||||
log.error("set_context_states error", agent=self.agent_type, kwargs=kwargs)
|
||||
|
||||
|
||||
def dump_context_state(self):
|
||||
try:
|
||||
return active_agent.get().state
|
||||
except AttributeError:
|
||||
return {}
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
async def _handle_ready_check(self, fut: asyncio.Future):
|
||||
callback_failure = getattr(self, "on_ready_check_failure", None)
|
||||
if fut.cancelled():
|
||||
@@ -425,42 +442,53 @@ class Agent(ABC):
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
for config_key, config in action.config.items():
|
||||
for config_key, _config in action.config.items():
|
||||
try:
|
||||
config.value = (
|
||||
_config.value = (
|
||||
kwargs.get("actions", {})
|
||||
.get(action_key, {})
|
||||
.get("config", {})
|
||||
.get(config_key, {})
|
||||
.get("value", config.value)
|
||||
.get("value", _config.value)
|
||||
)
|
||||
if config.value_migration and callable(config.value_migration):
|
||||
config.value = config.value_migration(config.value)
|
||||
if _config.value_migration and callable(_config.value_migration):
|
||||
_config.value = _config.value_migration(_config.value)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
async def save_config(self, app_config: config.Config | None = None):
|
||||
async def save_config(self):
|
||||
"""
|
||||
Saves the agent config to the config file.
|
||||
|
||||
|
||||
If no config object is provided, the config is loaded from the config file.
|
||||
"""
|
||||
|
||||
if not app_config:
|
||||
app_config:config.Config = config.load_config(as_model=True)
|
||||
|
||||
app_config.agents[self.agent_type] = config.Agent(
|
||||
|
||||
app_config: Config = get_config()
|
||||
|
||||
app_config.agents[self.agent_type] = config_schema.Agent(
|
||||
name=self.agent_type,
|
||||
client=self.client.name if self.client else None,
|
||||
client=self.client.name if getattr(self, "client", None) else None,
|
||||
enabled=self.enabled,
|
||||
actions={action_key: config.AgentAction(
|
||||
enabled=action.enabled,
|
||||
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
|
||||
) for action_key, action in self.actions.items()}
|
||||
actions={
|
||||
action_key: config_schema.AgentAction(
|
||||
enabled=action.enabled,
|
||||
config={
|
||||
config_key: config_schema.AgentActionConfig(
|
||||
value=config_obj.value
|
||||
)
|
||||
for config_key, config_obj in action.config.items()
|
||||
},
|
||||
)
|
||||
for action_key, action in self.actions.items()
|
||||
},
|
||||
)
|
||||
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
|
||||
config.save_config(app_config)
|
||||
log.debug(
|
||||
"saving agent config",
|
||||
agent=self.agent_type,
|
||||
config=app_config.agents[self.agent_type],
|
||||
)
|
||||
|
||||
app_config.dirty = True
|
||||
|
||||
async def on_game_loop_start(self, event: GameLoopStartEvent):
|
||||
"""
|
||||
@@ -474,19 +502,19 @@ class Agent(ABC):
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
for _, config in action.config.items():
|
||||
if config.scope == "scene":
|
||||
for _, _config in action.config.items():
|
||||
if _config.scope == "scene":
|
||||
# if default_value is None, just use the `type` of the current
|
||||
# value
|
||||
if config.default_value is None:
|
||||
default_value = type(config.value)()
|
||||
if _config.default_value is None:
|
||||
default_value = type(_config.value)()
|
||||
else:
|
||||
default_value = config.default_value
|
||||
default_value = _config.default_value
|
||||
|
||||
log.debug(
|
||||
"resetting config", config=config, default_value=default_value
|
||||
"resetting config", config=_config, default_value=default_value
|
||||
)
|
||||
config.value = default_value
|
||||
_config.value = default_value
|
||||
|
||||
await self.emit_status()
|
||||
|
||||
@@ -518,7 +546,9 @@ class Agent(ABC):
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _handle_background_processing(self, fut: asyncio.Future, error_handler = None):
|
||||
async def _handle_background_processing(
|
||||
self, fut: asyncio.Future, error_handler=None
|
||||
):
|
||||
try:
|
||||
if fut.cancelled():
|
||||
return
|
||||
@@ -541,7 +571,7 @@ class Agent(ABC):
|
||||
self.processing_bg -= 1
|
||||
await self.emit_status()
|
||||
|
||||
async def set_background_processing(self, task: asyncio.Task, error_handler = None):
|
||||
async def set_background_processing(self, task: asyncio.Task, error_handler=None):
|
||||
log.info("set_background_processing", agent=self.agent_type)
|
||||
if not hasattr(self, "processing_bg"):
|
||||
self.processing_bg = 0
|
||||
@@ -550,7 +580,9 @@ class Agent(ABC):
|
||||
|
||||
await self.emit_status()
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_background_processing(fut, error_handler))
|
||||
lambda fut: asyncio.create_task(
|
||||
self._handle_background_processing(fut, error_handler)
|
||||
)
|
||||
)
|
||||
|
||||
def connect(self, scene):
|
||||
@@ -580,7 +612,7 @@ class Agent(ABC):
|
||||
exclude_fn: Callable = None,
|
||||
):
|
||||
current_memory_context = []
|
||||
memory_helper = self.scene.get_helper("memory")
|
||||
memory_helper = instance.get_agent("memory")
|
||||
if memory_helper:
|
||||
history_messages = "\n".join(
|
||||
self.scene.recent_history(memory_history_context_max)
|
||||
@@ -623,7 +655,6 @@ class Agent(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@set_processing
|
||||
async def delegate(self, fn: Callable, *args, **kwargs):
|
||||
"""
|
||||
@@ -631,20 +662,22 @@ class Agent(ABC):
|
||||
by the agent.
|
||||
"""
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
async def emit_message(self, header:str, message:str | list[dict], meta: dict = None, **data):
|
||||
|
||||
async def emit_message(
|
||||
self, header: str, message: str | list[dict], meta: dict = None, **data
|
||||
):
|
||||
if not data:
|
||||
data = {}
|
||||
|
||||
|
||||
if not meta:
|
||||
meta = {}
|
||||
|
||||
|
||||
if "uuid" not in data:
|
||||
data["uuid"] = str(uuid.uuid4())
|
||||
|
||||
|
||||
if "agent" not in data:
|
||||
data["agent"] = self.agent_type
|
||||
|
||||
|
||||
data["header"] = header
|
||||
emit(
|
||||
"agent_message",
|
||||
@@ -653,17 +686,22 @@ class Agent(ABC):
|
||||
meta=meta,
|
||||
websocket_passthrough=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentEmission:
|
||||
agent: Agent
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentTemplateEmission(AgentEmission):
|
||||
template_vars: dict = dataclasses.field(default_factory=dict)
|
||||
response: str = None
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagBuildSubInstructionEmission(AgentEmission):
|
||||
sub_instruction: str | None = None
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import contextvars
|
||||
import uuid
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import Callable
|
||||
|
||||
import pydantic
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
__all__ = [
|
||||
"active_agent",
|
||||
]
|
||||
@@ -25,7 +21,6 @@ class ActiveAgentContext(pydantic.BaseModel):
|
||||
state: dict = pydantic.Field(default_factory=dict)
|
||||
state_params: dict = pydantic.Field(default_factory=dict)
|
||||
previous: "ActiveAgentContext" = None
|
||||
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -35,29 +30,30 @@ class ActiveAgentContext(pydantic.BaseModel):
|
||||
return self.previous.first if self.previous else self
|
||||
|
||||
@property
|
||||
def action(self):
|
||||
def action(self):
|
||||
name = self.fn.__name__
|
||||
if name == "delegate":
|
||||
return self.fn_args[0].__name__
|
||||
return name
|
||||
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> int:
|
||||
if hasattr(self, "_fingerprint"):
|
||||
return self._fingerprint
|
||||
self._fingerprint = hash(frozenset(self.state_params.items()))
|
||||
return self._fingerprint
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.agent.verbose_name}.{self.action}"
|
||||
|
||||
|
||||
|
||||
class ActiveAgent:
|
||||
def __init__(self, agent, fn, args=None, kwargs=None):
|
||||
self.agent = ActiveAgentContext(agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {})
|
||||
self.agent = ActiveAgentContext(
|
||||
agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {}
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
previous_agent = active_agent.get()
|
||||
|
||||
if previous_agent:
|
||||
@@ -70,7 +66,7 @@ class ActiveAgent:
|
||||
self.agent.agent_stack_uid = str(uuid.uuid4())
|
||||
|
||||
self.token = active_agent.set(self.agent)
|
||||
|
||||
|
||||
return self.agent
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@@ -10,14 +9,12 @@ import structlog
|
||||
|
||||
import talemate.client as client
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.client.context import (
|
||||
client_context_attribute,
|
||||
set_client_context_attribute,
|
||||
set_conversation_context_attribute,
|
||||
)
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
@@ -26,6 +23,7 @@ from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentActionNote,
|
||||
AgentDetail,
|
||||
AgentEmission,
|
||||
DynamicInstruction,
|
||||
@@ -37,7 +35,7 @@ from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
from talemate.agents.context import active_agent
|
||||
|
||||
from .websocket_handler import ConversationWebsocketHandler
|
||||
import talemate.agents.conversation.nodes
|
||||
import talemate.agents.conversation.nodes # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Character
|
||||
@@ -50,21 +48,20 @@ class ConversationAgentEmission(AgentEmission):
|
||||
actor: Actor
|
||||
character: Character
|
||||
response: str
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.conversation.before_generate",
|
||||
"agent.conversation.before_generate",
|
||||
"agent.conversation.inject_instructions",
|
||||
"agent.conversation.generated"
|
||||
"agent.conversation.generated",
|
||||
)
|
||||
|
||||
|
||||
@register()
|
||||
class ConversationAgent(
|
||||
MemoryRAGMixin,
|
||||
Agent
|
||||
):
|
||||
class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
"""
|
||||
An agent that can be used to have a conversation with the AI
|
||||
|
||||
@@ -89,12 +86,22 @@ class ConversationAgent(
|
||||
"format": AgentActionConfig(
|
||||
type="text",
|
||||
label="Format",
|
||||
description="The generation format of the scene context, as seen by the AI.",
|
||||
description="The generation format of the scene progression, as seen by the AI. Has no direct effect on your view of the scene, but will affect the way the AI perceives the scene and its characters, leading to changes in the response, for better or worse.",
|
||||
choices=[
|
||||
{"label": "Screenplay", "value": "movie_script"},
|
||||
{"label": "Chat (legacy)", "value": "chat"},
|
||||
{
|
||||
"label": "Narrative (NEW, experimental)",
|
||||
"value": "narrative",
|
||||
},
|
||||
],
|
||||
value="movie_script",
|
||||
note_on_value={
|
||||
"narrative": AgentActionNote(
|
||||
type="primary",
|
||||
text="Will attempt to generate flowing, novel-like prose with scene intent awareness and character goal consideration. A reasoning model is STRONGLY recommended. Experimental and more prone to generate out of turn character actions and dialogue.",
|
||||
)
|
||||
},
|
||||
),
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
@@ -135,16 +142,8 @@ class ConversationAgent(
|
||||
max=20,
|
||||
step=1,
|
||||
),
|
||||
|
||||
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled=True,
|
||||
can_be_disabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"content": AgentAction(
|
||||
enabled=True,
|
||||
can_be_disabled=False,
|
||||
@@ -159,7 +158,7 @@ class ConversationAgent(
|
||||
description="Use the writing style selected in the scene settings",
|
||||
value=True,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
}
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
@@ -167,7 +166,7 @@ class ConversationAgent(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.TaleMateClient,
|
||||
client: client.ClientBase | None = None,
|
||||
kind: Optional[str] = "pygmalion",
|
||||
logging_enabled: Optional[bool] = True,
|
||||
**kwargs,
|
||||
@@ -205,7 +204,6 @@ class ConversationAgent(
|
||||
|
||||
@property
|
||||
def agent_details(self) -> dict:
|
||||
|
||||
details = {
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
@@ -231,22 +229,24 @@ class ConversationAgent(
|
||||
|
||||
@property
|
||||
def generation_settings_actor_instructions_offset(self):
|
||||
return self.actions["generation_override"].config["actor_instructions_offset"].value
|
||||
|
||||
return (
|
||||
self.actions["generation_override"]
|
||||
.config["actor_instructions_offset"]
|
||||
.value
|
||||
)
|
||||
|
||||
@property
|
||||
def generation_settings_response_length(self):
|
||||
return self.actions["generation_override"].config["length"].value
|
||||
|
||||
|
||||
@property
|
||||
def generation_settings_override_enabled(self):
|
||||
return self.actions["generation_override"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def content_use_writing_style(self) -> bool:
|
||||
return self.actions["content"].config["use_writing_style"].value
|
||||
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
|
||||
@@ -276,7 +276,7 @@ class ConversationAgent(
|
||||
main_character = scene.main_character.character
|
||||
|
||||
character_names = [c.name for c in scene.characters]
|
||||
|
||||
|
||||
if main_character:
|
||||
try:
|
||||
character_names.remove(main_character.name)
|
||||
@@ -296,21 +296,22 @@ class ConversationAgent(
|
||||
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
|
||||
except IndexError:
|
||||
director_message = False
|
||||
|
||||
|
||||
|
||||
inject_instructions_emission = ConversationAgentEmission(
|
||||
agent=self,
|
||||
response="",
|
||||
actor=None,
|
||||
character=character,
|
||||
response="",
|
||||
actor=None,
|
||||
character=character,
|
||||
)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.conversation.inject_instructions"
|
||||
).send(inject_instructions_emission)
|
||||
|
||||
|
||||
agent_context = active_agent.get()
|
||||
agent_context.state["dynamic_instructions"] = inject_instructions_emission.dynamic_instructions
|
||||
|
||||
agent_context.state["dynamic_instructions"] = (
|
||||
inject_instructions_emission.dynamic_instructions
|
||||
)
|
||||
|
||||
conversation_format = self.conversation_format
|
||||
prompt = Prompt.get(
|
||||
f"conversation.dialogue-{conversation_format}",
|
||||
@@ -319,26 +320,30 @@ class ConversationAgent(
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene_and_dialogue_budget": scene_and_dialogue_budget,
|
||||
"scene_and_dialogue": scene_and_dialogue,
|
||||
"memory": None, # DEPRECATED VARIABLE
|
||||
"memory": None, # DEPRECATED VARIABLE
|
||||
"characters": list(scene.get_characters()),
|
||||
"main_character": main_character,
|
||||
"formatted_names": formatted_names,
|
||||
"talking_character": character,
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": self.generation_settings_task_instructions, #backward compatibility
|
||||
"extra_instructions": self.generation_settings_task_instructions, # backward compatibility
|
||||
"task_instructions": self.generation_settings_task_instructions,
|
||||
"actor_instructions": self.generation_settings_actor_instructions,
|
||||
"actor_instructions_offset": self.generation_settings_actor_instructions_offset,
|
||||
"direct_instruction": instruction,
|
||||
"decensor": self.client.decensor_enabled,
|
||||
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
|
||||
"response_length": self.generation_settings_response_length
|
||||
if self.generation_settings_override_enabled
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
return str(prompt)
|
||||
|
||||
async def build_prompt(self, character, char_message: str = "", instruction:str = None):
|
||||
async def build_prompt(
|
||||
self, character, char_message: str = "", instruction: str = None
|
||||
):
|
||||
fn = self.build_prompt_default
|
||||
|
||||
return await fn(character, char_message=char_message, instruction=instruction)
|
||||
@@ -376,12 +381,12 @@ class ConversationAgent(
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
|
||||
@set_processing
|
||||
@store_context_state('instruction')
|
||||
@store_context_state("instruction")
|
||||
async def converse(
|
||||
self,
|
||||
self,
|
||||
actor,
|
||||
instruction:str = None,
|
||||
emit_signals:bool = True,
|
||||
instruction: str = None,
|
||||
emit_signals: bool = True,
|
||||
) -> list[CharacterMessage]:
|
||||
"""
|
||||
Have a conversation with the AI
|
||||
@@ -398,7 +403,9 @@ class ConversationAgent(
|
||||
|
||||
self.set_generation_overrides()
|
||||
|
||||
result = await self.client.send_prompt(await self.build_prompt(character, instruction=instruction))
|
||||
result = await self.client.send_prompt(
|
||||
await self.build_prompt(character, instruction=instruction)
|
||||
)
|
||||
|
||||
result = self.clean_result(result, character)
|
||||
|
||||
@@ -451,21 +458,31 @@ class ConversationAgent(
|
||||
if total_result.startswith(":\n") or total_result.startswith(": "):
|
||||
total_result = total_result[2:]
|
||||
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
|
||||
conversation_format = self.conversation_format
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
if conversation_format == "narrative":
|
||||
# For narrative format, the LLM generates pure prose without character name prefixes
|
||||
# We need to store it internally in the standard {name}: {text} format
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
# Only add character name if it's not already there
|
||||
if not total_result.startswith(character.name + ":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
else:
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", "")
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name+":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name + ":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
|
||||
total_result = total_result.strip()
|
||||
|
||||
@@ -481,11 +498,11 @@ class ConversationAgent(
|
||||
|
||||
log.debug("conversation agent", response=response)
|
||||
emission = ConversationAgentEmission(
|
||||
agent=self,
|
||||
actor=actor,
|
||||
character=character,
|
||||
agent=self,
|
||||
actor=actor,
|
||||
character=character,
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
if emit_signals:
|
||||
await talemate.emit.async_signals.get("agent.conversation.generated").send(
|
||||
emission
|
||||
@@ -497,9 +514,6 @@ class ConversationAgent(
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
return agent_function_name == "converse"
|
||||
|
||||
def inject_prompt_paramters(
|
||||
|
||||
@@ -1,74 +1,75 @@
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, UNRESOLVED
|
||||
from talemate.game.engine.nodes.core import GraphState, UNRESOLVED
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentNode, AgentSettingsNode
|
||||
from talemate.context import active_scene
|
||||
from talemate.client.context import ConversationContext, ClientContext
|
||||
import talemate.events as events
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene, Character
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.conversation")
|
||||
|
||||
|
||||
@register("agents/conversation/Settings")
|
||||
class ConversationSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render conversation agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "conversation"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "conversation"
|
||||
|
||||
def __init__(self, title="Conversation Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
@register("agents/conversation/Generate")
|
||||
class GenerateConversation(AgentNode):
|
||||
"""
|
||||
Generate a conversation between two characters
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "conversation"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "conversation"
|
||||
|
||||
def __init__(self, title="Generate Conversation", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("instruction", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("trigger_conversation_generated", True)
|
||||
|
||||
|
||||
self.add_output("generated", socket_type="str")
|
||||
self.add_output("message", socket_type="message_object")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character:"Character" = self.get_input_value("character")
|
||||
scene:"Scene" = active_scene.get()
|
||||
character: "Character" = self.get_input_value("character")
|
||||
scene: "Scene" = active_scene.get()
|
||||
instruction = self.get_input_value("instruction")
|
||||
trigger_conversation_generated = self.get_property("trigger_conversation_generated")
|
||||
|
||||
trigger_conversation_generated = self.get_property(
|
||||
"trigger_conversation_generated"
|
||||
)
|
||||
|
||||
other_characters = [c.name for c in scene.characters if c != character]
|
||||
|
||||
|
||||
conversation_context = ConversationContext(
|
||||
talking_character=character.name,
|
||||
other_characters=other_characters,
|
||||
)
|
||||
|
||||
|
||||
if instruction == UNRESOLVED:
|
||||
instruction = None
|
||||
|
||||
|
||||
with ClientContext(conversation=conversation_context):
|
||||
messages = await self.agent.converse(
|
||||
character.actor,
|
||||
character.actor,
|
||||
instruction=instruction,
|
||||
emit_signals=trigger_conversation_generated,
|
||||
)
|
||||
|
||||
|
||||
message = messages[0]
|
||||
|
||||
self.set_output_values({
|
||||
"generated": message.message,
|
||||
"message": message
|
||||
})
|
||||
|
||||
self.set_output_values({"generated": message.message, "message": message})
|
||||
|
||||
@@ -18,23 +18,25 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.conversation")
|
||||
|
||||
|
||||
class RequestActorActionPayload(pydantic.BaseModel):
|
||||
character:str = ""
|
||||
instructions:str = ""
|
||||
emit_signals:bool = True
|
||||
instructions_through_director:bool = True
|
||||
character: str = ""
|
||||
instructions: str = ""
|
||||
emit_signals: bool = True
|
||||
instructions_through_director: bool = True
|
||||
|
||||
|
||||
class ConversationWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles narrator actions
|
||||
"""
|
||||
|
||||
|
||||
router = "conversation"
|
||||
|
||||
|
||||
@property
|
||||
def agent(self) -> "ConversationAgent":
|
||||
return get_agent("conversation")
|
||||
|
||||
|
||||
@set_loading("Generating actor action", cancellable=True, as_async=True)
|
||||
async def handle_request_actor_action(self, data: dict):
|
||||
"""
|
||||
@@ -43,38 +45,37 @@ class ConversationWebsocketHandler(Plugin):
|
||||
payload = RequestActorActionPayload(**data)
|
||||
character = None
|
||||
actor = None
|
||||
|
||||
|
||||
if payload.character:
|
||||
character = self.scene.get_character(payload.character)
|
||||
actor = character.actor
|
||||
else:
|
||||
actor = random.choice(list(self.scene.get_npc_characters())).actor
|
||||
|
||||
|
||||
if not actor:
|
||||
log.error("handle_request_actor_action: No actor found")
|
||||
return
|
||||
|
||||
|
||||
character = actor.character
|
||||
|
||||
|
||||
if payload.instructions_through_director:
|
||||
director_message = DirectorMessage(
|
||||
payload.instructions,
|
||||
source="player",
|
||||
meta={"character": character.name}
|
||||
meta={"character": character.name},
|
||||
)
|
||||
emit("director", message=director_message, character=character)
|
||||
self.scene.push_history(director_message)
|
||||
generated_messages = await self.agent.converse(
|
||||
actor,
|
||||
emit_signals=payload.emit_signals
|
||||
actor, emit_signals=payload.emit_signals
|
||||
)
|
||||
else:
|
||||
generated_messages = await self.agent.converse(
|
||||
actor,
|
||||
actor,
|
||||
instruction=payload.instructions,
|
||||
emit_signals=payload.emit_signals
|
||||
emit_signals=payload.emit_signals,
|
||||
)
|
||||
|
||||
|
||||
for message in generated_messages:
|
||||
self.scene.push_history(message)
|
||||
emit("character", message=message, character=character)
|
||||
emit("character", message=message, character=character)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import talemate.client as client
|
||||
from talemate.agents.base import Agent, set_processing
|
||||
from talemate.agents.registry import register
|
||||
from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
from .assistant import AssistantMixin
|
||||
@@ -16,7 +12,8 @@ from .scenario import ScenarioCreatorMixin
|
||||
|
||||
from talemate.agents.base import AgentAction
|
||||
|
||||
import talemate.agents.creator.nodes
|
||||
import talemate.agents.creator.nodes # noqa: F401
|
||||
|
||||
|
||||
@register()
|
||||
class CreatorAgent(
|
||||
@@ -42,7 +39,7 @@ class CreatorAgent(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.ClientBase,
|
||||
client: client.ClientBase | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
@@ -51,7 +48,7 @@ class CreatorAgent(
|
||||
@set_processing
|
||||
async def generate_title(self, text: str):
|
||||
title = await Prompt.request(
|
||||
f"creator.generate-title",
|
||||
"creator.generate-title",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
|
||||
@@ -40,6 +40,7 @@ async_signals.register(
|
||||
"agent.creator.autocomplete.after",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ContextualGenerateEmission(AgentTemplateEmission):
|
||||
"""
|
||||
@@ -48,15 +49,16 @@ class ContextualGenerateEmission(AgentTemplateEmission):
|
||||
|
||||
content_generation_context: "ContentGenerationContext | None" = None
|
||||
character: "Character | None" = None
|
||||
|
||||
|
||||
@property
|
||||
def context_type(self) -> str:
|
||||
return self.content_generation_context.computed_context[0]
|
||||
|
||||
|
||||
@property
|
||||
def context_name(self) -> str:
|
||||
return self.content_generation_context.computed_context[1]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutocompleteEmission(AgentTemplateEmission):
|
||||
"""
|
||||
@@ -67,6 +69,7 @@ class AutocompleteEmission(AgentTemplateEmission):
|
||||
type: str = ""
|
||||
character: "Character | None" = None
|
||||
|
||||
|
||||
class ContentGenerationContext(pydantic.BaseModel):
|
||||
"""
|
||||
A context for generating content.
|
||||
@@ -104,7 +107,6 @@ class ContentGenerationContext(pydantic.BaseModel):
|
||||
|
||||
@property
|
||||
def spice(self) -> str:
|
||||
|
||||
spice_level = self.generation_options.spice_level
|
||||
|
||||
if self.template and not getattr(self.template, "supports_spice", False):
|
||||
@@ -148,7 +150,6 @@ class ContentGenerationContext(pydantic.BaseModel):
|
||||
|
||||
@property
|
||||
def style(self):
|
||||
|
||||
if self.template and not getattr(self.template, "supports_style", False):
|
||||
# template supplied that doesn't support style
|
||||
return ""
|
||||
@@ -165,11 +166,12 @@ class ContentGenerationContext(pydantic.BaseModel):
|
||||
def get_state(self, key: str) -> str | int | float | bool | None:
|
||||
return self.state.get(key)
|
||||
|
||||
|
||||
class AssistantMixin:
|
||||
"""
|
||||
Creator mixin that allows quick contextual generation of content.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["autocomplete"] = AgentAction(
|
||||
@@ -198,15 +200,15 @@ class AssistantMixin:
|
||||
max=256,
|
||||
step=16,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# property helpers
|
||||
|
||||
|
||||
@property
|
||||
def autocomplete_dialogue_suggestion_length(self):
|
||||
return self.actions["autocomplete"].config["dialogue_suggestion_length"].value
|
||||
|
||||
|
||||
@property
|
||||
def autocomplete_narrative_suggestion_length(self):
|
||||
return self.actions["autocomplete"].config["narrative_suggestion_length"].value
|
||||
@@ -255,7 +257,7 @@ class AssistantMixin:
|
||||
history_aware=history_aware,
|
||||
information=information,
|
||||
)
|
||||
|
||||
|
||||
for key, value in kwargs.items():
|
||||
generation_context.set_state(key, value)
|
||||
|
||||
@@ -279,9 +281,13 @@ class AssistantMixin:
|
||||
f"Contextual generate: {context_typ} - {context_name}",
|
||||
generation_context=generation_context,
|
||||
)
|
||||
|
||||
character = self.scene.get_character(generation_context.character) if generation_context.character else None
|
||||
|
||||
|
||||
character = (
|
||||
self.scene.get_character(generation_context.character)
|
||||
if generation_context.character
|
||||
else None
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -295,27 +301,29 @@ class AssistantMixin:
|
||||
"character": character,
|
||||
"template": generation_context.template,
|
||||
}
|
||||
|
||||
|
||||
emission = ContextualGenerateEmission(
|
||||
agent=self,
|
||||
content_generation_context=generation_context,
|
||||
character=character,
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.before").send(emission)
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.before").send(
|
||||
emission
|
||||
)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
content = await Prompt.request(
|
||||
f"creator.contextual-generate",
|
||||
"creator.contextual-generate",
|
||||
self.client,
|
||||
kind,
|
||||
vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
emission.response = content
|
||||
|
||||
|
||||
if not generation_context.partial:
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
@@ -329,22 +337,29 @@ class AssistantMixin:
|
||||
if not content.startswith(generation_context.character + ":"):
|
||||
content = generation_context.character + ": " + content
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
character = self.scene.get_character(generation_context.character)
|
||||
|
||||
if not character:
|
||||
log.warning("Character not found", character=generation_context.character)
|
||||
return content
|
||||
|
||||
emission.response = await editor.cleanup_character_message(content, character)
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
|
||||
return emission.response
|
||||
|
||||
emission.response = content.strip().strip("*").strip()
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
|
||||
return emission.response
|
||||
|
||||
character = self.scene.get_character(generation_context.character)
|
||||
|
||||
if not character:
|
||||
log.warning(
|
||||
"Character not found", character=generation_context.character
|
||||
)
|
||||
return content
|
||||
|
||||
emission.response = await editor.cleanup_character_message(
|
||||
content, character
|
||||
)
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(
|
||||
emission
|
||||
)
|
||||
return emission.response
|
||||
|
||||
emission.response = content.strip().strip("*").strip()
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(
|
||||
emission
|
||||
)
|
||||
return emission.response
|
||||
|
||||
@set_processing
|
||||
async def generate_character_attribute(
|
||||
@@ -357,10 +372,10 @@ class AssistantMixin:
|
||||
) -> str:
|
||||
"""
|
||||
Wrapper for contextual_generate that generates a character attribute.
|
||||
"""
|
||||
"""
|
||||
if not generation_options:
|
||||
generation_options = GenerationOptions()
|
||||
|
||||
|
||||
return await self.contextual_generate_from_args(
|
||||
context=f"character attribute:{attribute_name}",
|
||||
character=character.name,
|
||||
@@ -368,7 +383,7 @@ class AssistantMixin:
|
||||
original=original,
|
||||
**generation_options.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_character_detail(
|
||||
self,
|
||||
@@ -381,11 +396,11 @@ class AssistantMixin:
|
||||
) -> str:
|
||||
"""
|
||||
Wrapper for contextual_generate that generates a character detail.
|
||||
"""
|
||||
"""
|
||||
|
||||
if not generation_options:
|
||||
generation_options = GenerationOptions()
|
||||
|
||||
|
||||
return await self.contextual_generate_from_args(
|
||||
context=f"character detail:{detail_name}",
|
||||
character=character.name,
|
||||
@@ -394,7 +409,7 @@ class AssistantMixin:
|
||||
length=length,
|
||||
**generation_options.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_thematic_list(
|
||||
self,
|
||||
@@ -408,11 +423,11 @@ class AssistantMixin:
|
||||
"""
|
||||
if not generation_options:
|
||||
generation_options = GenerationOptions()
|
||||
|
||||
|
||||
i = 0
|
||||
|
||||
|
||||
result = []
|
||||
|
||||
|
||||
while i < iterations:
|
||||
i += 1
|
||||
_result = await self.contextual_generate_from_args(
|
||||
@@ -420,14 +435,14 @@ class AssistantMixin:
|
||||
instructions=instructions,
|
||||
length=length,
|
||||
original="\n".join(result) if result else None,
|
||||
extend=i>1,
|
||||
extend=i > 1,
|
||||
**generation_options.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
_result = json.loads(_result)
|
||||
|
||||
|
||||
result = list(set(result + _result))
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@set_processing
|
||||
@@ -443,34 +458,34 @@ class AssistantMixin:
|
||||
"""
|
||||
if not response_length:
|
||||
response_length = self.autocomplete_dialogue_suggestion_length
|
||||
|
||||
|
||||
# continuing recent character message
|
||||
non_anchor, anchor = util.split_anchor_text(input, 10)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_anchor",
|
||||
anchor=anchor,
|
||||
non_anchor=non_anchor,
|
||||
input=input
|
||||
"autocomplete_anchor", anchor=anchor, non_anchor=non_anchor, input=input
|
||||
)
|
||||
|
||||
continuing_message = False
|
||||
message = None
|
||||
|
||||
|
||||
try:
|
||||
message = self.scene.history[-1]
|
||||
if isinstance(message, CharacterMessage) and message.character_name == character.name:
|
||||
if (
|
||||
isinstance(message, CharacterMessage)
|
||||
and message.character_name == character.name
|
||||
):
|
||||
continuing_message = input.strip() == message.without_name.strip()
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if input.strip().endswith('"'):
|
||||
prefix = ' *'
|
||||
elif input.strip().endswith('*'):
|
||||
prefix = " *"
|
||||
elif input.strip().endswith("*"):
|
||||
prefix = ' "'
|
||||
else:
|
||||
prefix = ''
|
||||
|
||||
prefix = ""
|
||||
|
||||
template_vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -484,7 +499,7 @@ class AssistantMixin:
|
||||
"non_anchor": non_anchor,
|
||||
"prefix": prefix,
|
||||
}
|
||||
|
||||
|
||||
emission = AutocompleteEmission(
|
||||
agent=self,
|
||||
input=input,
|
||||
@@ -492,13 +507,13 @@ class AssistantMixin:
|
||||
character=character,
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.before").send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-dialogue",
|
||||
"creator.autocomplete-dialogue",
|
||||
self.client,
|
||||
f"create_{response_length}",
|
||||
vars=template_vars,
|
||||
@@ -506,21 +521,22 @@ class AssistantMixin:
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
response = response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
|
||||
|
||||
|
||||
response = (
|
||||
response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
|
||||
)
|
||||
|
||||
if prefix:
|
||||
response = prefix + response
|
||||
|
||||
|
||||
emission.response = response
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.after").send(emission)
|
||||
|
||||
if not response:
|
||||
if emit_signal:
|
||||
emit("autocomplete_suggestion", "")
|
||||
return ""
|
||||
|
||||
|
||||
response = util.strip_partial_sentences(response).replace("*", "")
|
||||
|
||||
if response.startswith(input):
|
||||
@@ -550,14 +566,14 @@ class AssistantMixin:
|
||||
|
||||
# Split the input text into non-anchor and anchor parts
|
||||
non_anchor, anchor = util.split_anchor_text(input, 10)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_narrative_anchor",
|
||||
"autocomplete_narrative_anchor",
|
||||
anchor=anchor,
|
||||
non_anchor=non_anchor,
|
||||
input=input
|
||||
input=input,
|
||||
)
|
||||
|
||||
|
||||
template_vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -567,20 +583,20 @@ class AssistantMixin:
|
||||
"anchor": anchor,
|
||||
"non_anchor": non_anchor,
|
||||
}
|
||||
|
||||
|
||||
emission = AutocompleteEmission(
|
||||
agent=self,
|
||||
input=input,
|
||||
type="narrative",
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.before").send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-narrative",
|
||||
"creator.autocomplete-narrative",
|
||||
self.client,
|
||||
f"create_{response_length}",
|
||||
vars=template_vars,
|
||||
@@ -593,7 +609,7 @@ class AssistantMixin:
|
||||
response = response[len(input) :]
|
||||
|
||||
emission.response = response
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.after").send(emission)
|
||||
|
||||
self.scene.log.debug(
|
||||
@@ -614,78 +630,75 @@ class AssistantMixin:
|
||||
"""
|
||||
Allows to fork a new scene from a specific message
|
||||
in the current scene.
|
||||
|
||||
|
||||
All content after the message will be removed and the
|
||||
context database will be re imported ensuring a clean state.
|
||||
|
||||
|
||||
All state reinforcements will be reset to their most recent
|
||||
state before the message.
|
||||
"""
|
||||
|
||||
|
||||
emit("status", "Creating scene fork ...", status="busy")
|
||||
try:
|
||||
if not save_name:
|
||||
# build a save name
|
||||
uuid_str = str(uuid.uuid4())[:8]
|
||||
save_name = f"{uuid_str}-forked"
|
||||
|
||||
log.info(f"Forking scene", message_id=message_id, save_name=save_name)
|
||||
|
||||
|
||||
log.info("Forking scene", message_id=message_id, save_name=save_name)
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
|
||||
# does a message with the given id exist?
|
||||
index = self.scene.message_index(message_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Message with id {message_id} not found.")
|
||||
|
||||
|
||||
# truncate scene.history keeping index as the last element
|
||||
self.scene.history = self.scene.history[:index + 1]
|
||||
|
||||
self.scene.history = self.scene.history[: index + 1]
|
||||
|
||||
# truncate scene.archived_history keeping the element where `end` is < `index`
|
||||
# as the last element
|
||||
self.scene.archived_history = [
|
||||
x for x in self.scene.archived_history if "end" not in x or x["end"] < index
|
||||
x
|
||||
for x in self.scene.archived_history
|
||||
if "end" not in x or x["end"] < index
|
||||
]
|
||||
|
||||
|
||||
# the same needs to be done for layered history
|
||||
# where each layer is truncated based on what's left in the previous layer
|
||||
# using similar logic as above (checking `end` vs `index`)
|
||||
# layer 0 checks archived_history
|
||||
|
||||
|
||||
new_layered_history = []
|
||||
for layer_number, layer in enumerate(self.scene.layered_history):
|
||||
|
||||
if layer_number == 0:
|
||||
index = len(self.scene.archived_history) - 1
|
||||
else:
|
||||
index = len(new_layered_history[layer_number - 1]) - 1
|
||||
|
||||
new_layer = [
|
||||
x for x in layer if x["end"] < index
|
||||
]
|
||||
|
||||
new_layer = [x for x in layer if x["end"] < index]
|
||||
new_layered_history.append(new_layer)
|
||||
|
||||
|
||||
self.scene.layered_history = new_layered_history
|
||||
|
||||
# save the scene
|
||||
await self.scene.save(copy_name=save_name)
|
||||
|
||||
log.info(f"Scene forked", save_name=save_name)
|
||||
|
||||
|
||||
log.info("Scene forked", save_name=save_name)
|
||||
|
||||
# re-emit history
|
||||
await self.scene.emit_history()
|
||||
|
||||
emit("status", f"Updating world state ...", status="busy")
|
||||
|
||||
emit("status", "Updating world state ...", status="busy")
|
||||
|
||||
# reset state reinforcements
|
||||
await world_state.update_reinforcements(force = True, reset= True)
|
||||
|
||||
await world_state.update_reinforcements(force=True, reset=True)
|
||||
|
||||
# update world state
|
||||
await self.scene.world_state.request_update()
|
||||
|
||||
emit("status", f"Scene forked", status="success")
|
||||
except Exception as e:
|
||||
|
||||
emit("status", "Scene forked", status="success")
|
||||
except Exception:
|
||||
log.error("Scene fork failed", exc=traceback.format_exc())
|
||||
emit("status", "Scene fork failed", status="error")
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ import structlog
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
import talemate.game.focal as focal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
@@ -18,14 +16,13 @@ DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audien
|
||||
|
||||
|
||||
class CharacterCreatorMixin:
|
||||
|
||||
@set_processing
|
||||
async def determine_content_context_for_character(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
"creator.determine-content-context",
|
||||
self.client,
|
||||
"create_192",
|
||||
vars={
|
||||
@@ -42,7 +39,7 @@ class CharacterCreatorMixin:
|
||||
information: str = "",
|
||||
):
|
||||
instructions = await Prompt.request(
|
||||
f"creator.determine-character-dialogue-instructions",
|
||||
"creator.determine-character-dialogue-instructions",
|
||||
self.client,
|
||||
"create_concise",
|
||||
vars={
|
||||
@@ -63,7 +60,7 @@ class CharacterCreatorMixin:
|
||||
character: Character,
|
||||
):
|
||||
attributes = await Prompt.request(
|
||||
f"creator.determine-character-attributes",
|
||||
"creator.determine-character-attributes",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
@@ -81,7 +78,7 @@ class CharacterCreatorMixin:
|
||||
instructions: str = "",
|
||||
) -> str:
|
||||
name = await Prompt.request(
|
||||
f"creator.determine-character-name",
|
||||
"creator.determine-character-name",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars={
|
||||
@@ -97,14 +94,14 @@ class CharacterCreatorMixin:
|
||||
|
||||
@set_processing
|
||||
async def determine_character_description(
|
||||
self,
|
||||
self,
|
||||
character: Character,
|
||||
text: str = "",
|
||||
instructions: str = "",
|
||||
information: str = "",
|
||||
):
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-character-description",
|
||||
"creator.determine-character-description",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
@@ -125,7 +122,7 @@ class CharacterCreatorMixin:
|
||||
goal_instructions: str,
|
||||
):
|
||||
goals = await Prompt.request(
|
||||
f"creator.determine-character-goals",
|
||||
"creator.determine-character-goals",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
@@ -141,4 +138,4 @@ class CharacterCreatorMixin:
|
||||
log.debug("determine_character_goals", goals=goals, character=character)
|
||||
await character.set_detail("goals", goals.strip())
|
||||
|
||||
return goals.strip()
|
||||
return goals.strip()
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
"num": 3
|
||||
},
|
||||
"x": 39,
|
||||
"y": 507,
|
||||
"y": 236,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -40,7 +40,7 @@
|
||||
"num": 1
|
||||
},
|
||||
"x": 38,
|
||||
"y": 107,
|
||||
"y": -164,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -59,7 +59,7 @@
|
||||
"num": 0
|
||||
},
|
||||
"x": 38,
|
||||
"y": -93,
|
||||
"y": -364,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -76,7 +76,7 @@
|
||||
"num": 0
|
||||
},
|
||||
"x": 288,
|
||||
"y": -63,
|
||||
"y": -334,
|
||||
"width": 210,
|
||||
"height": 106,
|
||||
"collapsed": true,
|
||||
@@ -89,7 +89,7 @@
|
||||
"id": "553125be-2c2b-4404-98b5-d6333a4f9655",
|
||||
"properties": {},
|
||||
"x": 348,
|
||||
"y": 507,
|
||||
"y": 236,
|
||||
"width": 140,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -102,7 +102,7 @@
|
||||
"id": "207e357e-5e83-4d40-a331-d0041b9dfa49",
|
||||
"properties": {},
|
||||
"x": 348,
|
||||
"y": 307,
|
||||
"y": 36,
|
||||
"width": 140,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -115,7 +115,7 @@
|
||||
"id": "bad7d2ed-f6fa-452b-8b47-d753ae7a45e0",
|
||||
"properties": {},
|
||||
"x": 348,
|
||||
"y": 107,
|
||||
"y": -164,
|
||||
"width": 140,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -131,7 +131,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 598,
|
||||
"y": 107,
|
||||
"y": -164,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -147,7 +147,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 598,
|
||||
"y": 307,
|
||||
"y": 36,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -163,7 +163,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 598,
|
||||
"y": 507,
|
||||
"y": 236,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -176,7 +176,7 @@
|
||||
"id": "1765a51c-82ac-4d96-9c60-d0adc7faaa68",
|
||||
"properties": {},
|
||||
"x": 498,
|
||||
"y": 707,
|
||||
"y": 436,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -192,7 +192,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 798,
|
||||
"y": 706,
|
||||
"y": 435,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -205,7 +205,7 @@
|
||||
"id": "f3dc37df-1748-4235-b575-60e55b8bec73",
|
||||
"properties": {},
|
||||
"x": 498,
|
||||
"y": 906,
|
||||
"y": 635,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -220,7 +220,7 @@
|
||||
"stage": 0
|
||||
},
|
||||
"x": 1208,
|
||||
"y": 366,
|
||||
"y": 95,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -238,7 +238,7 @@
|
||||
"icon": "F1719"
|
||||
},
|
||||
"x": 1178,
|
||||
"y": -144,
|
||||
"y": -415,
|
||||
"width": 210,
|
||||
"height": 130,
|
||||
"collapsed": false,
|
||||
@@ -253,7 +253,7 @@
|
||||
"default": true
|
||||
},
|
||||
"x": 298,
|
||||
"y": 736,
|
||||
"y": 465,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -268,7 +268,7 @@
|
||||
"default": false
|
||||
},
|
||||
"x": 298,
|
||||
"y": 936,
|
||||
"y": 665,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -287,7 +287,7 @@
|
||||
"num": 2
|
||||
},
|
||||
"x": 38,
|
||||
"y": 306,
|
||||
"y": 35,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -303,7 +303,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 798,
|
||||
"y": 906,
|
||||
"y": 635,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -322,7 +322,7 @@
|
||||
"num": 5
|
||||
},
|
||||
"x": 38,
|
||||
"y": 706,
|
||||
"y": 435,
|
||||
"width": 237,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -341,7 +341,7 @@
|
||||
"num": 7
|
||||
},
|
||||
"x": 38,
|
||||
"y": 906,
|
||||
"y": 635,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -360,7 +360,7 @@
|
||||
"num": 4
|
||||
},
|
||||
"x": 48,
|
||||
"y": 1326,
|
||||
"y": 1055,
|
||||
"width": 237,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -375,7 +375,7 @@
|
||||
"default": true
|
||||
},
|
||||
"x": 318,
|
||||
"y": 1356,
|
||||
"y": 1085,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -388,7 +388,7 @@
|
||||
"id": "6d387c67-6b32-4435-b984-4760f0f1f8d2",
|
||||
"properties": {},
|
||||
"x": 498,
|
||||
"y": 1336,
|
||||
"y": 1065,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -404,7 +404,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 798,
|
||||
"y": 1316,
|
||||
"y": 1045,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -455,7 +455,7 @@
|
||||
"num": 6
|
||||
},
|
||||
"x": 38,
|
||||
"y": 1106,
|
||||
"y": 835,
|
||||
"width": 252,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -471,7 +471,7 @@
|
||||
"apply_on_unresolved": true
|
||||
},
|
||||
"x": 518,
|
||||
"y": 1136,
|
||||
"y": 865,
|
||||
"width": 210,
|
||||
"height": 102,
|
||||
"collapsed": true,
|
||||
@@ -1101,7 +1101,7 @@
|
||||
"num": 8
|
||||
},
|
||||
"x": 49,
|
||||
"y": 1546,
|
||||
"y": 1275,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
@@ -1116,7 +1116,7 @@
|
||||
"default": false
|
||||
},
|
||||
"x": 329,
|
||||
"y": 1586,
|
||||
"y": 1315,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
@@ -1129,7 +1129,7 @@
|
||||
"id": "8acfe789-fbb5-4e29-8fd8-2217b987c086",
|
||||
"properties": {},
|
||||
"x": 499,
|
||||
"y": 1566,
|
||||
"y": 1295,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
@@ -1145,7 +1145,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 799,
|
||||
"y": 1526,
|
||||
"y": 1255,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1258,8 +1258,8 @@
|
||||
"output_name": "character",
|
||||
"num": 0
|
||||
},
|
||||
"x": 331,
|
||||
"y": 5739,
|
||||
"x": 332,
|
||||
"y": 6178,
|
||||
"width": 210,
|
||||
"height": 106,
|
||||
"collapsed": false,
|
||||
@@ -1274,8 +1274,8 @@
|
||||
"name": "character",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 51,
|
||||
"y": 5729,
|
||||
"x": 52,
|
||||
"y": 6168,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1291,8 +1291,8 @@
|
||||
"output_name": "actor",
|
||||
"num": 0
|
||||
},
|
||||
"x": 331,
|
||||
"y": 5949,
|
||||
"x": 332,
|
||||
"y": 6388,
|
||||
"width": 210,
|
||||
"height": 106,
|
||||
"collapsed": false,
|
||||
@@ -1307,8 +1307,8 @@
|
||||
"name": "actor",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 51,
|
||||
"y": 5949,
|
||||
"x": 52,
|
||||
"y": 6388,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1323,7 +1323,7 @@
|
||||
"stage": 0
|
||||
},
|
||||
"x": 1320,
|
||||
"y": 1160,
|
||||
"y": 889,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
@@ -1414,7 +1414,7 @@
|
||||
"writing_style": null
|
||||
},
|
||||
"x": 320,
|
||||
"y": 1200,
|
||||
"y": 929,
|
||||
"width": 270,
|
||||
"height": 122,
|
||||
"collapsed": true,
|
||||
@@ -1430,7 +1430,7 @@
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 790,
|
||||
"y": 1110,
|
||||
"y": 839,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
@@ -1475,6 +1475,159 @@
|
||||
"inherited": false,
|
||||
"registry": "agents/creator/ContextualGenerate",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1": {
|
||||
"title": "assign_voice",
|
||||
"id": "d61de1ad-6f2a-447f-918a-dce7e76ea3a1",
|
||||
"properties": {},
|
||||
"x": 509,
|
||||
"y": 1544,
|
||||
"width": 171,
|
||||
"height": 26,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Watch",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"3d655827-b66b-4355-910d-96097e7f2f13": {
|
||||
"title": "SET local.assign_voice",
|
||||
"id": "3d655827-b66b-4355-910d-96097e7f2f13",
|
||||
"properties": {
|
||||
"name": "assign_voice",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 810,
|
||||
"y": 1506,
|
||||
"width": 244,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/SetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1": {
|
||||
"title": "Stage 0",
|
||||
"id": "6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1",
|
||||
"properties": {
|
||||
"stage": 0
|
||||
},
|
||||
"x": 1170,
|
||||
"y": 1546,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/Stage",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f": {
|
||||
"title": "IN assign_voice",
|
||||
"id": "9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f",
|
||||
"properties": {
|
||||
"input_type": "bool",
|
||||
"input_name": "assign_voice",
|
||||
"input_optional": true,
|
||||
"input_group": "",
|
||||
"num": 9
|
||||
},
|
||||
"x": 60,
|
||||
"y": 1527,
|
||||
"width": 210,
|
||||
"height": 154,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Input",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"de11206d-13db-44a5-befd-de559fb68d09": {
|
||||
"title": "GET local.assign_voice",
|
||||
"id": "de11206d-13db-44a5-befd-de559fb68d09",
|
||||
"properties": {
|
||||
"name": "assign_voice",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 25,
|
||||
"y": 5720,
|
||||
"width": 240,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"97b196a3-e7a6-4cfa-905e-686a744890b7": {
|
||||
"title": "Switch",
|
||||
"id": "97b196a3-e7a6-4cfa-905e-686a744890b7",
|
||||
"properties": {
|
||||
"pass_through": true
|
||||
},
|
||||
"x": 355,
|
||||
"y": 5740,
|
||||
"width": 210,
|
||||
"height": 78,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "core/Switch",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"3956711a-a3df-4213-b739-104cbe704964": {
|
||||
"title": "GET local.character",
|
||||
"id": "3956711a-a3df-4213-b739-104cbe704964",
|
||||
"properties": {
|
||||
"name": "character",
|
||||
"scope": "local"
|
||||
},
|
||||
"x": 25,
|
||||
"y": 5930,
|
||||
"width": 210,
|
||||
"height": 122,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "state/GetState",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"47b2b492-4178-4254-b468-3877a5341f66": {
|
||||
"title": "Assign Voice",
|
||||
"id": "47b2b492-4178-4254-b468-3877a5341f66",
|
||||
"properties": {},
|
||||
"x": 665,
|
||||
"y": 5830,
|
||||
"width": 161,
|
||||
"height": 66,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "agents/director/AssignVoice",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"44d78795-2d41-468b-94d5-399b9b655888": {
|
||||
"title": "Stage 6",
|
||||
"id": "44d78795-2d41-468b-94d5-399b9b655888",
|
||||
"properties": {
|
||||
"stage": 6
|
||||
},
|
||||
"x": 885,
|
||||
"y": 5860,
|
||||
"width": 210,
|
||||
"height": 118,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/Stage",
|
||||
"base_type": "core/Node"
|
||||
},
|
||||
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85": {
|
||||
"title": "As Bool",
|
||||
"id": "f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85",
|
||||
"properties": {
|
||||
"default": true
|
||||
},
|
||||
"x": 330,
|
||||
"y": 1560,
|
||||
"width": 210,
|
||||
"height": 58,
|
||||
"collapsed": true,
|
||||
"inherited": false,
|
||||
"registry": "core/AsBool",
|
||||
"base_type": "core/Node"
|
||||
}
|
||||
},
|
||||
"edges": {
|
||||
@@ -1726,15 +1879,39 @@
|
||||
],
|
||||
"4acb67ea-68ee-43ae-a8c6-98a2b0e0f053.text": [
|
||||
"d41f0d98-14d5-49dd-8e57-7812fb9fee94.value"
|
||||
],
|
||||
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value": [
|
||||
"3d655827-b66b-4355-910d-96097e7f2f13.value"
|
||||
],
|
||||
"3d655827-b66b-4355-910d-96097e7f2f13.value": [
|
||||
"6aa5c32a-8dfb-48a9-96ec-5ad9ed6aa5d1.state"
|
||||
],
|
||||
"9ac9b12d-1b97-4f42-92d8-0d4f883ffb2f.value": [
|
||||
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value"
|
||||
],
|
||||
"de11206d-13db-44a5-befd-de559fb68d09.value": [
|
||||
"97b196a3-e7a6-4cfa-905e-686a744890b7.value"
|
||||
],
|
||||
"97b196a3-e7a6-4cfa-905e-686a744890b7.yes": [
|
||||
"47b2b492-4178-4254-b468-3877a5341f66.state"
|
||||
],
|
||||
"3956711a-a3df-4213-b739-104cbe704964.value": [
|
||||
"47b2b492-4178-4254-b468-3877a5341f66.character"
|
||||
],
|
||||
"47b2b492-4178-4254-b468-3877a5341f66.state": [
|
||||
"44d78795-2d41-468b-94d5-399b9b655888.state"
|
||||
],
|
||||
"f5e5ec03-cc12-4a45-aa15-6ca3e5e4bc85.value": [
|
||||
"d61de1ad-6f2a-447f-918a-dce7e76ea3a1.value"
|
||||
]
|
||||
},
|
||||
"groups": [
|
||||
{
|
||||
"title": "Process Arguments - Stage 0",
|
||||
"x": 1,
|
||||
"y": -218,
|
||||
"width": 1446,
|
||||
"height": 1948,
|
||||
"y": -490,
|
||||
"width": 1432,
|
||||
"height": 2216,
|
||||
"color": "#3f789e",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
@@ -1792,12 +1969,22 @@
|
||||
{
|
||||
"title": "Outputs",
|
||||
"x": 0,
|
||||
"y": 5642,
|
||||
"y": 6080,
|
||||
"width": 595,
|
||||
"height": 472,
|
||||
"color": "#8A8",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
},
|
||||
{
|
||||
"title": "Assign Voice - Stage 6",
|
||||
"x": 0,
|
||||
"y": 5640,
|
||||
"width": 1120,
|
||||
"height": 437,
|
||||
"color": "#3f789e",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
}
|
||||
],
|
||||
"comments": [],
|
||||
@@ -1805,6 +1992,7 @@
|
||||
"base_type": "core/Graph",
|
||||
"inputs": [],
|
||||
"outputs": [],
|
||||
"module_properties": {},
|
||||
"style": {
|
||||
"title_color": "#572e44",
|
||||
"node_color": "#392c34",
|
||||
|
||||
@@ -1,145 +1,153 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from typing import ClassVar
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
UNRESOLVED,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.creator")
|
||||
|
||||
|
||||
@register("agents/creator/Settings")
|
||||
class CreatorSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render creator agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
def __init__(self, title="Creator Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
@register("agents/creator/DetermineContentContext")
|
||||
class DetermineContentContext(AgentNode):
|
||||
"""
|
||||
Determines the context for the content creation.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
description = PropertyField(
|
||||
name="description",
|
||||
description="Description of the context",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Determine Content Context", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("description", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("description", UNRESOLVED)
|
||||
|
||||
|
||||
self.add_output("content_context", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
context = await self.agent.determine_content_context_for_description(
|
||||
self.require_input("description")
|
||||
)
|
||||
self.set_output_values({
|
||||
"content_context": context
|
||||
})
|
||||
|
||||
self.set_output_values({"content_context": context})
|
||||
|
||||
|
||||
@register("agents/creator/DetermineCharacterDescription")
|
||||
class DetermineCharacterDescription(AgentNode):
|
||||
"""
|
||||
Determines the description for a character.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- character: The character to determine the description for
|
||||
- extra_context: Extra context to use in determining the
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- description: The determined description
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
def __init__(self, title="Determine Character Description", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("extra_context", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.add_output("description", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
|
||||
character = self.require_input("character")
|
||||
extra_context = self.get_input_value("extra_context")
|
||||
|
||||
|
||||
if extra_context is UNRESOLVED:
|
||||
extra_context = ""
|
||||
|
||||
description = await self.agent.determine_character_description(character, extra_context)
|
||||
|
||||
self.set_output_values({
|
||||
"description": description
|
||||
})
|
||||
|
||||
|
||||
description = await self.agent.determine_character_description(
|
||||
character, extra_context
|
||||
)
|
||||
|
||||
self.set_output_values({"description": description})
|
||||
|
||||
|
||||
@register("agents/creator/DetermineCharacterDialogueInstructions")
|
||||
class DetermineCharacterDialogueInstructions(AgentNode):
|
||||
"""
|
||||
Determines the dialogue instructions for a character.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="Any additional instructions to use in determining the dialogue instructions",
|
||||
type="text",
|
||||
default=""
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Determine Character Dialogue Instructions", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("instructions", "")
|
||||
|
||||
|
||||
self.add_output("dialogue_instructions", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character = self.require_input("character")
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
|
||||
dialogue_instructions = await self.agent.determine_character_dialogue_instructions(character, instructions)
|
||||
|
||||
self.set_output_values({
|
||||
"dialogue_instructions": dialogue_instructions
|
||||
})
|
||||
|
||||
dialogue_instructions = (
|
||||
await self.agent.determine_character_dialogue_instructions(
|
||||
character, instructions
|
||||
)
|
||||
)
|
||||
|
||||
self.set_output_values({"dialogue_instructions": dialogue_instructions})
|
||||
|
||||
|
||||
@register("agents/creator/ContextualGenerate")
|
||||
class ContextualGenerate(AgentNode):
|
||||
"""
|
||||
Generates text based on the given context.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- context_type: The type of context to use in generating the text
|
||||
- context_name: The name of the context to use in generating the text
|
||||
@@ -150,9 +158,9 @@ class ContextualGenerate(AgentNode):
|
||||
- partial: The partial text to use in generating the text
|
||||
- uid: The uid to use in generating the text
|
||||
- generation_options: The generation options to use in generating the text
|
||||
|
||||
|
||||
Properties:
|
||||
|
||||
|
||||
- context_type: The type of context to use in generating the text
|
||||
- context_name: The name of the context to use in generating the text
|
||||
- instructions: The instructions to use in generating the text
|
||||
@@ -161,82 +169,82 @@ class ContextualGenerate(AgentNode):
|
||||
- uid: The uid to use in generating the text
|
||||
- context_aware: Whether to use the context in generating the text
|
||||
- history_aware: Whether to use the history in generating the text
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- state: The updated state of the graph
|
||||
- text: The generated text
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
context_type = PropertyField(
|
||||
name="context_type",
|
||||
description="The type of context to use in generating the text",
|
||||
type="str",
|
||||
choices=[
|
||||
"character attribute",
|
||||
"character detail",
|
||||
"character dialogue",
|
||||
"scene intro",
|
||||
"scene intent",
|
||||
"scene phase intent",
|
||||
"character attribute",
|
||||
"character detail",
|
||||
"character dialogue",
|
||||
"scene intro",
|
||||
"scene intent",
|
||||
"scene phase intent",
|
||||
"scene type description",
|
||||
"scene type instructions",
|
||||
"general",
|
||||
"list",
|
||||
"scene",
|
||||
"scene type instructions",
|
||||
"general",
|
||||
"list",
|
||||
"scene",
|
||||
"world context",
|
||||
],
|
||||
default="general"
|
||||
default="general",
|
||||
)
|
||||
context_name = PropertyField(
|
||||
name="context_name",
|
||||
description="The name of the context to use in generating the text",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="The instructions to use in generating the text",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
length = PropertyField(
|
||||
name="length",
|
||||
description="The length of the text to generate",
|
||||
type="int",
|
||||
default=100
|
||||
default=100,
|
||||
)
|
||||
character = PropertyField(
|
||||
name="character",
|
||||
description="The character to generate the text for",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
uid = PropertyField(
|
||||
name="uid",
|
||||
description="The uid to use in generating the text",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
context_aware = PropertyField(
|
||||
name="context_aware",
|
||||
description="Whether to use the context in generating the text",
|
||||
type="bool",
|
||||
default=True
|
||||
default=True,
|
||||
)
|
||||
history_aware = PropertyField(
|
||||
name="history_aware",
|
||||
description="Whether to use the history in generating the text",
|
||||
type="bool",
|
||||
default=True
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __init__(self, title="Contextual Generate", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("context_type", socket_type="str", optional=True)
|
||||
@@ -247,8 +255,10 @@ class ContextualGenerate(AgentNode):
|
||||
self.add_input("original", socket_type="str", optional=True)
|
||||
self.add_input("partial", socket_type="str", optional=True)
|
||||
self.add_input("uid", socket_type="str", optional=True)
|
||||
self.add_input("generation_options", socket_type="generation_options", optional=True)
|
||||
|
||||
self.add_input(
|
||||
"generation_options", socket_type="generation_options", optional=True
|
||||
)
|
||||
|
||||
self.set_property("context_type", "general")
|
||||
self.set_property("context_name", UNRESOLVED)
|
||||
self.set_property("instructions", UNRESOLVED)
|
||||
@@ -257,10 +267,10 @@ class ContextualGenerate(AgentNode):
|
||||
self.set_property("uid", UNRESOLVED)
|
||||
self.set_property("context_aware", True)
|
||||
self.set_property("history_aware", True)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("text", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
scene = active_scene.get()
|
||||
context_type = self.require_input("context_type")
|
||||
@@ -274,12 +284,12 @@ class ContextualGenerate(AgentNode):
|
||||
generation_options = self.normalized_input_value("generation_options")
|
||||
context_aware = self.normalized_input_value("context_aware")
|
||||
history_aware = self.normalized_input_value("history_aware")
|
||||
|
||||
|
||||
context = f"{context_type}:{context_name}" if context_name else context_type
|
||||
|
||||
|
||||
if isinstance(character, scene.Character):
|
||||
character = character.name
|
||||
|
||||
|
||||
text = await self.agent.contextual_generate_from_args(
|
||||
context=context,
|
||||
instructions=instructions,
|
||||
@@ -288,31 +298,32 @@ class ContextualGenerate(AgentNode):
|
||||
original=original,
|
||||
partial=partial or "",
|
||||
uid=uid,
|
||||
writing_style=generation_options.writing_style if generation_options else None,
|
||||
writing_style=generation_options.writing_style
|
||||
if generation_options
|
||||
else None,
|
||||
spices=generation_options.spices if generation_options else None,
|
||||
spice_level=generation_options.spice_level if generation_options else 0.0,
|
||||
context_aware=context_aware,
|
||||
history_aware=history_aware
|
||||
history_aware=history_aware,
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"text": text
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"state": state, "text": text})
|
||||
|
||||
|
||||
@register("agents/creator/GenerateThematicList")
|
||||
class GenerateThematicList(AgentNode):
|
||||
"""
|
||||
Generates a list of thematic items based on the instructions.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="The instructions to use in generating the list",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
iterations = PropertyField(
|
||||
name="iterations",
|
||||
@@ -323,27 +334,24 @@ class GenerateThematicList(AgentNode):
|
||||
min=1,
|
||||
max=10,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Generate Thematic List", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
|
||||
self.set_property("instructions", UNRESOLVED)
|
||||
self.set_property("iterations", 1)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("list", socket_type="list")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
iterations = self.require_number_input("iterations")
|
||||
|
||||
|
||||
list = await self.agent.generate_thematic_list(instructions, iterations)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"list": list
|
||||
})
|
||||
|
||||
self.set_output_values({"state": state, "list": list})
|
||||
|
||||
@@ -10,7 +10,7 @@ class ScenarioCreatorMixin:
|
||||
@set_processing
|
||||
async def determine_scenario_description(self, text: str):
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-scenario-description",
|
||||
"creator.determine-scenario-description",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
@@ -25,7 +25,7 @@ class ScenarioCreatorMixin:
|
||||
description: str,
|
||||
):
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
"creator.determine-content-context",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
|
||||
@@ -1,36 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import structlog
|
||||
import traceback
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage
|
||||
from talemate.util import random_color
|
||||
from talemate.character import deactivate_character
|
||||
from talemate.status import LoadingStatus
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
from talemate.scene_message import DirectorMessage, Flags
|
||||
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig
|
||||
from talemate.agents.registry import register
|
||||
from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
from talemate.client import ClientBase
|
||||
from talemate.game.focal.schema import Call
|
||||
|
||||
from .guide import GuideSceneMixin
|
||||
from .generate_choices import GenerateChoicesMixin
|
||||
from .legacy_scene_instructions import LegacySceneInstructionsMixin
|
||||
from .auto_direct import AutoDirectMixin
|
||||
from .websocket_handler import DirectorWebsocketHandler
|
||||
|
||||
import talemate.agents.director.nodes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Character, Scene
|
||||
from .character_management import CharacterManagementMixin
|
||||
import talemate.agents.director.nodes # noqa: F401
|
||||
|
||||
log = structlog.get_logger("talemate.agent.director")
|
||||
|
||||
@@ -42,7 +29,8 @@ class DirectorAgent(
|
||||
GenerateChoicesMixin,
|
||||
AutoDirectMixin,
|
||||
LegacySceneInstructionsMixin,
|
||||
Agent
|
||||
CharacterManagementMixin,
|
||||
Agent,
|
||||
):
|
||||
agent_type = "director"
|
||||
verbose_name = "Director"
|
||||
@@ -74,15 +62,16 @@ class DirectorAgent(
|
||||
],
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
}
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
GenerateChoicesMixin.add_actions(actions)
|
||||
GuideSceneMixin.add_actions(actions)
|
||||
AutoDirectMixin.add_actions(actions)
|
||||
CharacterManagementMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.is_enabled = True
|
||||
self.client = client
|
||||
self.next_direct_character = {}
|
||||
@@ -105,165 +94,24 @@ class DirectorAgent(
|
||||
def actor_direction_mode(self):
|
||||
return self.actions["direct"].config["actor_direction_mode"].value
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> List[Character]:
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
determine_name: bool = True,
|
||||
templates: list[str] = None,
|
||||
active: bool = True,
|
||||
narrate_entry: bool = False,
|
||||
narrate_entry_direction: str = "",
|
||||
augment_attributes: str = "",
|
||||
generate_attributes: bool = True,
|
||||
description: str = "",
|
||||
) -> Character:
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
narrator = instance.get_agent("narrator")
|
||||
memory = instance.get_agent("memory")
|
||||
scene: "Scene" = self.scene
|
||||
any_attribute_templates = False
|
||||
|
||||
loading_status = LoadingStatus(max_steps=None, cancellable=True)
|
||||
|
||||
# Start of character creation
|
||||
log.debug("persist_character", name=name)
|
||||
|
||||
|
||||
# Determine the character's name (or clarify if it's already set)
|
||||
if determine_name:
|
||||
loading_status("Determining character name")
|
||||
name = await creator.determine_character_name(name, instructions=content)
|
||||
log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
# Create the blank character
|
||||
character:Character = self.scene.Character(name=name)
|
||||
|
||||
# Add the character to the scene
|
||||
character.color = random_color()
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
async def log_function_call(self, call: Call):
|
||||
log.debug("director.log_function_call", call=call)
|
||||
message = DirectorMessage(
|
||||
message=f"Called {call.name}",
|
||||
action=call.name,
|
||||
flags=Flags.HIDDEN,
|
||||
subtype="function_call",
|
||||
)
|
||||
await self.scene.add_actor(actor)
|
||||
|
||||
try:
|
||||
emit("director", message, data={"function_call": call.model_dump()})
|
||||
|
||||
# Apply any character generation templates
|
||||
if templates:
|
||||
loading_status("Applying character generation templates")
|
||||
templates = scene.world_state_manager.template_collection.collect_all(templates)
|
||||
log.debug("persist_character", applying_templates=templates)
|
||||
await scene.world_state_manager.apply_templates(
|
||||
templates.values(),
|
||||
character_name=character.name,
|
||||
information=content
|
||||
)
|
||||
|
||||
# if any of the templates are attribute templates, then we no longer need to
|
||||
# generate a character sheet
|
||||
any_attribute_templates = any(template.template_type == "character_attribute" for template in templates.values())
|
||||
log.debug("persist_character", any_attribute_templates=any_attribute_templates)
|
||||
|
||||
if any_attribute_templates and augment_attributes and generate_attributes:
|
||||
log.debug("persist_character", augmenting_attributes=augment_attributes)
|
||||
loading_status("Augmenting character attributes")
|
||||
additional_attributes = await world_state.extract_character_sheet(
|
||||
name=name,
|
||||
text=content,
|
||||
augmentation_instructions=augment_attributes
|
||||
)
|
||||
character.base_attributes.update(additional_attributes)
|
||||
|
||||
# Generate a character sheet if there are no attribute templates
|
||||
if not any_attribute_templates and generate_attributes:
|
||||
loading_status("Generating character sheet")
|
||||
log.debug("persist_character", extracting_character_sheet=True)
|
||||
if not attributes:
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=content
|
||||
)
|
||||
else:
|
||||
attributes = world_state._parse_character_sheet(attributes)
|
||||
|
||||
log.debug("persist_character", attributes=attributes)
|
||||
character.base_attributes = attributes
|
||||
|
||||
# Generate a description for the character
|
||||
if not description:
|
||||
loading_status("Generating character description")
|
||||
description = await creator.determine_character_description(character, information=content)
|
||||
character.description = description
|
||||
log.debug("persist_character", description=description)
|
||||
|
||||
# Generate a dialogue instructions for the character
|
||||
loading_status("Generating acting instructions")
|
||||
dialogue_instructions = await creator.determine_character_dialogue_instructions(
|
||||
character,
|
||||
information=content
|
||||
)
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
log.debug(
|
||||
"persist_character", dialogue_instructions=dialogue_instructions
|
||||
)
|
||||
|
||||
# Narrate the character's entry if the option is selected
|
||||
if active and narrate_entry:
|
||||
loading_status("Narrating character entry")
|
||||
is_present = await world_state.is_character_present(name)
|
||||
if not is_present:
|
||||
await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
emit_message=True,
|
||||
character=character,
|
||||
narrative_direction=narrate_entry_direction
|
||||
)
|
||||
|
||||
# Deactivate the character if not active
|
||||
if not active:
|
||||
await deactivate_character(scene, character)
|
||||
|
||||
# Commit the character's details to long term memory
|
||||
await character.commit_to_memory(memory)
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
loading_status.done(message=f"{character.name} added to scene", status="success")
|
||||
return character
|
||||
except GenerationCancelled:
|
||||
loading_status.done(message="Character creation cancelled", status="idle")
|
||||
await scene.remove_actor(actor)
|
||||
except Exception as e:
|
||||
loading_status.done(message="Character creation failed", status="error")
|
||||
await scene.remove_actor(actor)
|
||||
log.error("Error persisting character", error=traceback.format_exc())
|
||||
|
||||
async def log_action(self, action: str, action_description: str):
|
||||
message = DirectorMessage(message=action_description, action=action)
|
||||
async def log_action(
|
||||
self, action: str, action_description: str, console_only: bool = False
|
||||
):
|
||||
message = DirectorMessage(
|
||||
message=action_description,
|
||||
action=action,
|
||||
flags=Flags.HIDDEN if console_only else Flags.NONE,
|
||||
)
|
||||
self.scene.push_history(message)
|
||||
emit("director", message)
|
||||
|
||||
@@ -276,4 +124,4 @@ class DirectorAgent(
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
import pydantic
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentEmission,
|
||||
AgentTemplateEmission,
|
||||
)
|
||||
from talemate.status import set_loading
|
||||
import talemate.game.focal as focal
|
||||
from talemate.prompts import Prompt
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.scene_message import CharacterMessage, TimePassageMessage, DirectorMessage, NarratorMessage
|
||||
from talemate.scene.schema import ScenePhase, SceneType, SceneIntent
|
||||
from talemate.scene_message import (
|
||||
CharacterMessage,
|
||||
TimePassageMessage,
|
||||
NarratorMessage,
|
||||
)
|
||||
from talemate.scene.schema import ScenePhase, SceneType
|
||||
from talemate.scene.intent import set_scene_phase
|
||||
from talemate.world_state.manager import WorldStateManager
|
||||
from talemate.world_state.templates.scene import SceneType as TemplateSceneType
|
||||
import talemate.agents.director.auto_direct_nodes
|
||||
import talemate.agents.director.auto_direct_nodes # noqa: F401
|
||||
from talemate.world_state.templates.base import TypedCollection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,15 +25,15 @@ if TYPE_CHECKING:
|
||||
log = structlog.get_logger("talemate.agents.conversation.direct")
|
||||
|
||||
|
||||
#talemate.emit.async_signals.register(
|
||||
#)
|
||||
# talemate.emit.async_signals.register(
|
||||
# )
|
||||
|
||||
|
||||
class AutoDirectMixin:
|
||||
|
||||
"""
|
||||
Director agent mixin for automatic scene direction.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["auto_direct"] = AgentAction(
|
||||
@@ -108,113 +107,117 @@ class AutoDirectMixin:
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_enabled(self) -> bool:
|
||||
return self.actions["auto_direct"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_max_auto_turns(self) -> int:
|
||||
return self.actions["auto_direct"].config["max_auto_turns"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_max_idle_turns(self) -> int:
|
||||
return self.actions["auto_direct"].config["max_idle_turns"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_max_repeat_turns(self) -> int:
|
||||
return self.actions["auto_direct"].config["max_repeat_turns"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_actors(self) -> bool:
|
||||
return self.actions["auto_direct"].config["instruct_actors"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_narrator(self) -> bool:
|
||||
return self.actions["auto_direct"].config["instruct_narrator"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_frequency(self) -> int:
|
||||
return self.actions["auto_direct"].config["instruct_frequency"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_evaluate_scene_intention(self) -> int:
|
||||
return self.actions["auto_direct"].config["evaluate_scene_intention"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_any(self) -> bool:
|
||||
"""
|
||||
Will check whether actor or narrator instructions are enabled.
|
||||
|
||||
|
||||
For narrator instructions to be enabled instruct_narrator needs to be enabled as well.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if either actor or narrator instructions are enabled.
|
||||
"""
|
||||
|
||||
|
||||
return self.auto_direct_instruct_actors or self.auto_direct_instruct_narrator
|
||||
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, event):
|
||||
if not self.auto_direct_enabled:
|
||||
return
|
||||
|
||||
|
||||
if self.auto_direct_evaluate_scene_intention > 0:
|
||||
evaluation_due = self.get_scene_state("evaluated_scene_intention", 0)
|
||||
if evaluation_due == 0:
|
||||
await self.auto_direct_set_scene_intent()
|
||||
self.set_scene_states(evaluated_scene_intention=self.auto_direct_evaluate_scene_intention)
|
||||
self.set_scene_states(
|
||||
evaluated_scene_intention=self.auto_direct_evaluate_scene_intention
|
||||
)
|
||||
else:
|
||||
self.set_scene_states(evaluated_scene_intention=evaluation_due - 1)
|
||||
|
||||
# helpers
|
||||
|
||||
def auto_direct_is_due_for_instruction(self, actor_name:str) -> bool:
|
||||
|
||||
def auto_direct_is_due_for_instruction(self, actor_name: str) -> bool:
|
||||
"""
|
||||
Check if the actor is due for instruction.
|
||||
"""
|
||||
|
||||
|
||||
if self.auto_direct_instruct_frequency == 1:
|
||||
return True
|
||||
|
||||
|
||||
messages_since_last_instruction = 0
|
||||
|
||||
|
||||
def count_messages(message):
|
||||
nonlocal messages_since_last_instruction
|
||||
if message.typ in ["character", "narrator"]:
|
||||
messages_since_last_instruction += 1
|
||||
|
||||
|
||||
|
||||
last_instruction = self.scene.last_message_of_type(
|
||||
"director",
|
||||
character_name=actor_name,
|
||||
max_iterations=25,
|
||||
on_iterate=count_messages,
|
||||
)
|
||||
|
||||
log.debug("auto_direct_is_due_for_instruction", messages_since_last_instruction=messages_since_last_instruction, last_instruction=last_instruction.id if last_instruction else None)
|
||||
|
||||
|
||||
log.debug(
|
||||
"auto_direct_is_due_for_instruction",
|
||||
messages_since_last_instruction=messages_since_last_instruction,
|
||||
last_instruction=last_instruction.id if last_instruction else None,
|
||||
)
|
||||
|
||||
if not last_instruction:
|
||||
return True
|
||||
|
||||
|
||||
return messages_since_last_instruction >= self.auto_direct_instruct_frequency
|
||||
|
||||
|
||||
def auto_direct_candidates(self) -> list["Character"]:
|
||||
"""
|
||||
Returns a list of characters who are valid candidates to speak next.
|
||||
based on the max_idle_turns, max_repeat_turns, and the most recent character.
|
||||
"""
|
||||
|
||||
scene:"Scene" = self.scene
|
||||
|
||||
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
most_recent_character = None
|
||||
repeat_count = 0
|
||||
last_player_turn = None
|
||||
@@ -223,89 +226,105 @@ class AutoDirectMixin:
|
||||
active_charcters = list(scene.characters)
|
||||
active_character_names = [character.name for character in active_charcters]
|
||||
instruct_narrator = self.auto_direct_instruct_narrator
|
||||
|
||||
|
||||
# if there is only one character then they are the only candidate
|
||||
if len(active_charcters) == 1:
|
||||
return active_charcters
|
||||
|
||||
|
||||
BACKLOG_LIMIT = 50
|
||||
|
||||
|
||||
player_character_active = scene.player_character_exists
|
||||
|
||||
|
||||
# check the last BACKLOG_LIMIT entries in the scene history and collect into
|
||||
# a dictionary of character names and the number of turns since they last spoke.
|
||||
|
||||
|
||||
len_history = len(scene.history) - 1
|
||||
num = 0
|
||||
for idx in range(len_history, -1, -1):
|
||||
message = scene.history[idx]
|
||||
turns = len_history - idx
|
||||
|
||||
|
||||
num += 1
|
||||
|
||||
|
||||
if num > BACKLOG_LIMIT:
|
||||
break
|
||||
|
||||
|
||||
if isinstance(message, TimePassageMessage):
|
||||
break
|
||||
|
||||
|
||||
if not isinstance(message, (CharacterMessage, NarratorMessage)):
|
||||
continue
|
||||
|
||||
|
||||
# if character message but character is not in the active characters list then skip
|
||||
if isinstance(message, CharacterMessage) and message.character_name not in active_character_names:
|
||||
if (
|
||||
isinstance(message, CharacterMessage)
|
||||
and message.character_name not in active_character_names
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
if isinstance(message, NarratorMessage):
|
||||
if not instruct_narrator:
|
||||
continue
|
||||
character = scene.narrator_character_object
|
||||
else:
|
||||
character = scene.get_character(message.character_name)
|
||||
|
||||
|
||||
if not character:
|
||||
continue
|
||||
|
||||
|
||||
if character.is_player and last_player_turn is None:
|
||||
last_player_turn = turns
|
||||
elif not character.is_player and last_player_turn is None:
|
||||
consecutive_auto_turns += 1
|
||||
|
||||
|
||||
if not most_recent_character:
|
||||
most_recent_character = character
|
||||
repeat_count += 1
|
||||
elif character == most_recent_character:
|
||||
repeat_count += 1
|
||||
|
||||
|
||||
if character.name not in candidates:
|
||||
candidates[character.name] = turns
|
||||
|
||||
|
||||
# add any characters that have not spoken yet
|
||||
for character in active_charcters:
|
||||
if character.name not in candidates:
|
||||
candidates[character.name] = 0
|
||||
|
||||
|
||||
# explicitly add narrator if enabled and not already in candidates
|
||||
if instruct_narrator and scene.narrator_character_object:
|
||||
narrator = scene.narrator_character_object
|
||||
if narrator.name not in candidates:
|
||||
candidates[narrator.name] = 0
|
||||
|
||||
log.debug(f"auto_direct_candidates: {candidates}", most_recent_character=most_recent_character, repeat_count=repeat_count, last_player_turn=last_player_turn, consecutive_auto_turns=consecutive_auto_turns)
|
||||
|
||||
|
||||
log.debug(
|
||||
f"auto_direct_candidates: {candidates}",
|
||||
most_recent_character=most_recent_character,
|
||||
repeat_count=repeat_count,
|
||||
last_player_turn=last_player_turn,
|
||||
consecutive_auto_turns=consecutive_auto_turns,
|
||||
)
|
||||
|
||||
if not most_recent_character:
|
||||
log.debug("auto_direct_candidates: No most recent character found.")
|
||||
return list(scene.characters)
|
||||
|
||||
|
||||
# if player has not spoken in a while then they are favored
|
||||
if player_character_active and consecutive_auto_turns >= self.auto_direct_max_auto_turns:
|
||||
log.debug("auto_direct_candidates: User controlled character has not spoken in a while.")
|
||||
if (
|
||||
player_character_active
|
||||
and consecutive_auto_turns >= self.auto_direct_max_auto_turns
|
||||
):
|
||||
log.debug(
|
||||
"auto_direct_candidates: User controlled character has not spoken in a while."
|
||||
)
|
||||
return [scene.get_player_character()]
|
||||
|
||||
|
||||
# check if most recent character has spoken too many times in a row
|
||||
# if so then remove from candidates
|
||||
if repeat_count >= self.auto_direct_max_repeat_turns:
|
||||
log.debug("auto_direct_candidates: Most recent character has spoken too many times in a row.", most_recent_character=most_recent_character
|
||||
log.debug(
|
||||
"auto_direct_candidates: Most recent character has spoken too many times in a row.",
|
||||
most_recent_character=most_recent_character,
|
||||
)
|
||||
candidates.pop(most_recent_character.name, None)
|
||||
|
||||
@@ -314,27 +333,34 @@ class AutoDirectMixin:
|
||||
favored_candidates = []
|
||||
for name, turns in candidates.items():
|
||||
if turns > self.auto_direct_max_idle_turns:
|
||||
log.debug("auto_direct_candidates: Character has gone too long without speaking.", character_name=name, turns=turns)
|
||||
log.debug(
|
||||
"auto_direct_candidates: Character has gone too long without speaking.",
|
||||
character_name=name,
|
||||
turns=turns,
|
||||
)
|
||||
favored_candidates.append(scene.get_character(name))
|
||||
|
||||
|
||||
if favored_candidates:
|
||||
return favored_candidates
|
||||
|
||||
return [scene.get_character(character_name) for character_name in candidates.keys()]
|
||||
|
||||
return [
|
||||
scene.get_character(character_name) for character_name in candidates.keys()
|
||||
]
|
||||
|
||||
# actions
|
||||
|
||||
|
||||
@set_processing
|
||||
async def auto_direct_set_scene_intent(self, require:bool=False) -> ScenePhase | None:
|
||||
|
||||
async def set_scene_intention(type:str, intention:str) -> ScenePhase:
|
||||
async def auto_direct_set_scene_intent(
|
||||
self, require: bool = False
|
||||
) -> ScenePhase | None:
|
||||
async def set_scene_intention(type: str, intention: str) -> ScenePhase:
|
||||
await set_scene_phase(self.scene, type, intention)
|
||||
self.scene.emit_status()
|
||||
return self.scene.intent_state.phase
|
||||
|
||||
|
||||
async def do_nothing(*args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
@@ -355,57 +381,63 @@ class AutoDirectMixin:
|
||||
],
|
||||
max_calls=1,
|
||||
scene=self.scene,
|
||||
scene_type_ids=", ".join([f'"{scene_type.id}"' for scene_type in self.scene.intent_state.scene_types.values()]),
|
||||
scene_type_ids=", ".join(
|
||||
[
|
||||
f'"{scene_type.id}"'
|
||||
for scene_type in self.scene.intent_state.scene_types.values()
|
||||
]
|
||||
),
|
||||
retries=1,
|
||||
require=require,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"director.direct-determine-scene-intent",
|
||||
)
|
||||
|
||||
|
||||
return self.scene.intent_state.phase
|
||||
|
||||
|
||||
@set_processing
|
||||
async def auto_direct_generate_scene_types(
|
||||
self,
|
||||
instructions:str,
|
||||
max_scene_types:int=1,
|
||||
self,
|
||||
instructions: str,
|
||||
max_scene_types: int = 1,
|
||||
):
|
||||
|
||||
world_state_manager:WorldStateManager = self.scene.world_state_manager
|
||||
|
||||
scene_type_templates:TypedCollection = await world_state_manager.get_templates(types=["scene_type"])
|
||||
|
||||
async def add_from_template(id:str) -> SceneType:
|
||||
template:TemplateSceneType | None = scene_type_templates.find_by_name(id)
|
||||
world_state_manager: WorldStateManager = self.scene.world_state_manager
|
||||
|
||||
scene_type_templates: TypedCollection = await world_state_manager.get_templates(
|
||||
types=["scene_type"]
|
||||
)
|
||||
|
||||
async def add_from_template(id: str) -> SceneType:
|
||||
template: TemplateSceneType | None = scene_type_templates.find_by_name(id)
|
||||
if not template:
|
||||
log.warning("auto_direct_generate_scene_types: Template not found.", name=id)
|
||||
log.warning(
|
||||
"auto_direct_generate_scene_types: Template not found.", name=id
|
||||
)
|
||||
return None
|
||||
return template.apply_to_scene(self.scene)
|
||||
|
||||
|
||||
async def generate_scene_type(
|
||||
id:str = None,
|
||||
name:str = None,
|
||||
description:str = None,
|
||||
instructions:str = None,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
instructions: str = None,
|
||||
) -> SceneType:
|
||||
|
||||
if not id or not name:
|
||||
return None
|
||||
|
||||
|
||||
scene_type = SceneType(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
|
||||
self.scene.intent_state.scene_types[id] = scene_type
|
||||
|
||||
|
||||
return scene_type
|
||||
|
||||
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
@@ -435,7 +467,7 @@ class AutoDirectMixin:
|
||||
instructions=instructions,
|
||||
scene_type_templates=scene_type_templates.templates,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"director.generate-scene-types",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import ClassVar
|
||||
from talemate.game.engine.nodes.core import GraphState, PropertyField
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentNode
|
||||
from talemate.scene.schema import ScenePhase
|
||||
from talemate.context import active_scene
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
from talemate.agents.director import DirectorAgent
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/Candidates")
|
||||
class AutoDirectCandidates(AgentNode):
|
||||
"""
|
||||
@@ -19,52 +15,51 @@ class AutoDirectCandidates(AgentNode):
|
||||
next action, based on the director's auto-direct settings and
|
||||
the recent scene history.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Auto Direct Candidates", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_output("characters", socket_type="list")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
candidates = self.agent.auto_direct_candidates()
|
||||
self.set_output_values({
|
||||
"characters": candidates
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"characters": candidates})
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/DetermineSceneIntent")
|
||||
class DetermineSceneIntent(AgentNode):
|
||||
"""
|
||||
Determines the scene intent based on the current scene state.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Determine Scene Intent", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_output("state")
|
||||
self.add_output("scene_phase", socket_type="scene_intent/scene_phase")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
phase:ScenePhase = await self.agent.auto_direct_set_scene_intent()
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"scene_phase": phase
|
||||
})
|
||||
phase: ScenePhase = await self.agent.auto_direct_set_scene_intent()
|
||||
|
||||
self.set_output_values({"state": state, "scene_phase": phase})
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/GenerateSceneTypes")
|
||||
class GenerateSceneTypes(AgentNode):
|
||||
"""
|
||||
Generates scene types based on the current scene state.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
@@ -72,17 +67,17 @@ class GenerateSceneTypes(AgentNode):
|
||||
description="The instructions for the scene types",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
max_scene_types = PropertyField(
|
||||
name="max_scene_types",
|
||||
type="int",
|
||||
description="The maximum number of scene types to generate",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Generate Scene Types", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
@@ -90,29 +85,26 @@ class GenerateSceneTypes(AgentNode):
|
||||
self.set_property("instructions", "")
|
||||
self.set_property("max_scene_types", 1)
|
||||
self.add_output("state")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
max_scene_types = self.normalized_input_value("max_scene_types")
|
||||
|
||||
|
||||
scene_types = await self.agent.auto_direct_generate_scene_types(
|
||||
instructions=instructions,
|
||||
max_scene_types=max_scene_types
|
||||
instructions=instructions, max_scene_types=max_scene_types
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"scene_types": scene_types
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"state": state, "scene_types": scene_types})
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/IsDueForInstruction")
|
||||
class IsDueForInstruction(AgentNode):
|
||||
"""
|
||||
Checks if the actor is due for instruction based on the auto-direct settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
actor_name = PropertyField(
|
||||
name="actor_name",
|
||||
@@ -120,24 +112,21 @@ class IsDueForInstruction(AgentNode):
|
||||
description="The name of the actor to check instruction timing for",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Is Due For Instruction", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("actor_name", socket_type="str")
|
||||
|
||||
|
||||
self.set_property("actor_name", "")
|
||||
|
||||
|
||||
self.add_output("is_due", socket_type="bool")
|
||||
self.add_output("actor_name", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
actor_name = self.require_input("actor_name")
|
||||
|
||||
|
||||
is_due = self.agent.auto_direct_is_due_for_instruction(actor_name)
|
||||
|
||||
self.set_output_values({
|
||||
"is_due": is_due,
|
||||
"actor_name": actor_name
|
||||
})
|
||||
|
||||
self.set_output_values({"is_due": is_due, "actor_name": actor_name})
|
||||
|
||||
336
src/talemate/agents/director/character_management.py
Normal file
@@ -0,0 +1,336 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import traceback
|
||||
import structlog
|
||||
import talemate.instance as instance
|
||||
import talemate.agents.tts.voice_library as voice_library
|
||||
from talemate.agents.tts.schema import Voice
|
||||
from talemate.util import random_color
|
||||
from talemate.character import deactivate_character, set_voice
|
||||
from talemate.status import LoadingStatus
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
from talemate.agents.base import AgentAction, AgentActionConfig, set_processing
|
||||
import talemate.game.focal as focal
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CharacterManagementMixin",
|
||||
]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Character, Scene
|
||||
from talemate.agents.tts import TTSAgent
|
||||
|
||||
|
||||
class VoiceCandidate(Voice):
|
||||
used: bool = False
|
||||
|
||||
|
||||
class CharacterManagementMixin:
|
||||
"""
|
||||
Director agent mixin that provides functionality for automatically guiding
|
||||
the actors or the narrator during the scene progression.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["character_management"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
can_be_disabled=False,
|
||||
label="Character Management",
|
||||
icon="mdi-account",
|
||||
description="Configure how the director manages characters.",
|
||||
config={
|
||||
"assign_voice": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Assign Voice (TTS)",
|
||||
description="If enabled, the director is allowed to assign a text-to-speech voice when persisting a character.",
|
||||
value=True,
|
||||
title="Persisting Characters",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# config property helpers
|
||||
|
||||
@property
|
||||
def cm_assign_voice(self) -> bool:
|
||||
return self.actions["character_management"].config["assign_voice"].value
|
||||
|
||||
@property
|
||||
def cm_should_assign_voice(self) -> bool:
|
||||
if not self.cm_assign_voice:
|
||||
return False
|
||||
|
||||
tts_agent: "TTSAgent" = instance.get_agent("tts")
|
||||
if not tts_agent.enabled:
|
||||
return False
|
||||
|
||||
if not tts_agent.ready_apis:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# actions
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> list["Character"]:
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
determine_name: bool = True,
|
||||
templates: list[str] = None,
|
||||
active: bool = True,
|
||||
narrate_entry: bool = False,
|
||||
narrate_entry_direction: str = "",
|
||||
augment_attributes: str = "",
|
||||
generate_attributes: bool = True,
|
||||
description: str = "",
|
||||
assign_voice: bool = True,
|
||||
is_player: bool = False,
|
||||
) -> "Character":
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
narrator = instance.get_agent("narrator")
|
||||
memory = instance.get_agent("memory")
|
||||
scene: "Scene" = self.scene
|
||||
any_attribute_templates = False
|
||||
|
||||
loading_status = LoadingStatus(max_steps=None, cancellable=True)
|
||||
|
||||
# Start of character creation
|
||||
log.debug("persist_character", name=name)
|
||||
|
||||
# Determine the character's name (or clarify if it's already set)
|
||||
if determine_name:
|
||||
loading_status("Determining character name")
|
||||
name = await creator.determine_character_name(name, instructions=content)
|
||||
log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
if name in self.scene.all_character_names:
|
||||
raise ValueError(f'Name "{name}" already exists.')
|
||||
|
||||
# Create the blank character
|
||||
character: "Character" = self.scene.Character(name=name, is_player=is_player)
|
||||
|
||||
# Add the character to the scene
|
||||
character.color = random_color()
|
||||
|
||||
if is_player:
|
||||
actor = self.scene.Player(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
else:
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
|
||||
await self.scene.add_actor(actor)
|
||||
|
||||
try:
|
||||
# Apply any character generation templates
|
||||
if templates:
|
||||
loading_status("Applying character generation templates")
|
||||
templates = scene.world_state_manager.template_collection.collect_all(
|
||||
templates
|
||||
)
|
||||
log.debug("persist_character", applying_templates=templates)
|
||||
await scene.world_state_manager.apply_templates(
|
||||
templates.values(),
|
||||
character_name=character.name,
|
||||
information=content,
|
||||
)
|
||||
|
||||
# if any of the templates are attribute templates, then we no longer need to
|
||||
# generate a character sheet
|
||||
any_attribute_templates = any(
|
||||
template.template_type == "character_attribute"
|
||||
for template in templates.values()
|
||||
)
|
||||
log.debug(
|
||||
"persist_character", any_attribute_templates=any_attribute_templates
|
||||
)
|
||||
|
||||
if (
|
||||
any_attribute_templates
|
||||
and augment_attributes
|
||||
and generate_attributes
|
||||
):
|
||||
log.debug(
|
||||
"persist_character", augmenting_attributes=augment_attributes
|
||||
)
|
||||
loading_status("Augmenting character attributes")
|
||||
additional_attributes = await world_state.extract_character_sheet(
|
||||
name=name,
|
||||
text=content,
|
||||
augmentation_instructions=augment_attributes,
|
||||
)
|
||||
character.base_attributes.update(additional_attributes)
|
||||
|
||||
# Generate a character sheet if there are no attribute templates
|
||||
if not any_attribute_templates and generate_attributes:
|
||||
loading_status("Generating character sheet")
|
||||
log.debug("persist_character", extracting_character_sheet=True)
|
||||
if not attributes:
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=content
|
||||
)
|
||||
else:
|
||||
attributes = world_state._parse_character_sheet(attributes)
|
||||
|
||||
log.debug("persist_character", attributes=attributes)
|
||||
character.base_attributes = attributes
|
||||
|
||||
# Generate a description for the character
|
||||
if not description:
|
||||
loading_status("Generating character description")
|
||||
description = await creator.determine_character_description(
|
||||
character, information=content
|
||||
)
|
||||
character.description = description
|
||||
log.debug("persist_character", description=description)
|
||||
|
||||
# Generate a dialogue instructions for the character
|
||||
loading_status("Generating acting instructions")
|
||||
dialogue_instructions = (
|
||||
await creator.determine_character_dialogue_instructions(
|
||||
character, information=content
|
||||
)
|
||||
)
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
|
||||
|
||||
# Narrate the character's entry if the option is selected
|
||||
if active and narrate_entry:
|
||||
loading_status("Narrating character entry")
|
||||
is_present = await world_state.is_character_present(name)
|
||||
if not is_present:
|
||||
await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
emit_message=True,
|
||||
character=character,
|
||||
narrative_direction=narrate_entry_direction,
|
||||
)
|
||||
|
||||
if assign_voice:
|
||||
await self.assign_voice_to_character(character)
|
||||
|
||||
# Deactivate the character if not active
|
||||
if not active:
|
||||
await deactivate_character(scene, character)
|
||||
|
||||
# Commit the character's details to long term memory
|
||||
await character.commit_to_memory(memory)
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
loading_status.done(
|
||||
message=f"{character.name} added to scene", status="success"
|
||||
)
|
||||
return character
|
||||
except GenerationCancelled:
|
||||
loading_status.done(message="Character creation cancelled", status="idle")
|
||||
await scene.remove_actor(actor)
|
||||
except Exception:
|
||||
loading_status.done(message="Character creation failed", status="error")
|
||||
await scene.remove_actor(actor)
|
||||
log.error("Error persisting character", error=traceback.format_exc())
|
||||
|
||||
@set_processing
|
||||
async def assign_voice_to_character(
|
||||
self, character: "Character"
|
||||
) -> list[focal.Call]:
|
||||
tts_agent: "TTSAgent" = instance.get_agent("tts")
|
||||
if not self.cm_should_assign_voice:
|
||||
log.debug("assign_voice_to_character", skip=True, reason="not enabled")
|
||||
return
|
||||
|
||||
vl: voice_library.VoiceLibrary = voice_library.get_instance()
|
||||
|
||||
ready_tts_apis = tts_agent.ready_apis
|
||||
|
||||
voices_global = voice_library.voices_for_apis(ready_tts_apis, vl)
|
||||
voices_scene = voice_library.voices_for_apis(
|
||||
ready_tts_apis, self.scene.voice_library
|
||||
)
|
||||
|
||||
voices = voices_global + voices_scene
|
||||
|
||||
if not voices:
|
||||
log.debug(
|
||||
"assign_voice_to_character", skip=True, reason="no voices available"
|
||||
)
|
||||
return
|
||||
|
||||
voice_candidates = {
|
||||
voice.id: VoiceCandidate(**voice.model_dump()) for voice in voices
|
||||
}
|
||||
|
||||
for scene_character in self.scene.all_characters:
|
||||
if scene_character.voice:
|
||||
voice_candidates[scene_character.voice.id].used = True
|
||||
|
||||
async def assign_voice(voice_id: str):
|
||||
voice = vl.get_voice(voice_id) or self.scene.voice_library.get_voice(
|
||||
voice_id
|
||||
)
|
||||
if not voice:
|
||||
log.error(
|
||||
"assign_voice_to_character",
|
||||
skip=True,
|
||||
reason="voice not found",
|
||||
voice_id=voice_id,
|
||||
)
|
||||
return
|
||||
await set_voice(character, voice, auto=True)
|
||||
await self.log_action(
|
||||
f"Assigned voice `{voice.label}` to `{character.name}`",
|
||||
"Assigned voice",
|
||||
console_only=True,
|
||||
)
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name="assign_voice",
|
||||
arguments=[focal.Argument(name="voice_id", type="str")],
|
||||
fn=assign_voice,
|
||||
),
|
||||
],
|
||||
max_calls=1,
|
||||
character=character,
|
||||
voices=list(voice_candidates.values()),
|
||||
scene=self.scene,
|
||||
narrator_voice=tts_agent.narrator_voice,
|
||||
)
|
||||
|
||||
await focal_handler.request("director.cm-assign-voice")
|
||||
|
||||
log.debug("assign_voice_to_character", calls=focal_handler.state.calls)
|
||||
|
||||
return focal_handler.state.calls
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import random
|
||||
import structlog
|
||||
from functools import wraps
|
||||
import dataclasses
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentTemplateEmission,
|
||||
DynamicInstruction,
|
||||
)
|
||||
from talemate.events import GameLoopStartEvent
|
||||
from talemate.scene_message import NarratorMessage, CharacterMessage
|
||||
@@ -24,7 +22,7 @@ __all__ = [
|
||||
log = structlog.get_logger()
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.director.generate_choices.before_generate",
|
||||
"agent.director.generate_choices.before_generate",
|
||||
"agent.director.generate_choices.inject_instructions",
|
||||
"agent.director.generate_choices.generated",
|
||||
)
|
||||
@@ -32,18 +30,19 @@ talemate.emit.async_signals.register(
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerateChoicesEmission(AgentTemplateEmission):
|
||||
character: "Character | None" = None
|
||||
choices: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class GenerateChoicesMixin:
|
||||
|
||||
"""
|
||||
Director agent mixin that provides functionality for automatically guiding
|
||||
the actors or the narrator during the scene progression.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_generate_choices"] = AgentAction(
|
||||
@@ -65,7 +64,6 @@ class GenerateChoicesMixin:
|
||||
max=1,
|
||||
step=0.1,
|
||||
),
|
||||
|
||||
"num_choices": AgentActionConfig(
|
||||
type="number",
|
||||
label="Number of Actions",
|
||||
@@ -75,33 +73,31 @@ class GenerateChoicesMixin:
|
||||
max=10,
|
||||
step=1,
|
||||
),
|
||||
|
||||
"never_auto_progress": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Never Auto Progress on Action Selection",
|
||||
description="If enabled, the scene will not auto progress after you select an action.",
|
||||
value=False,
|
||||
),
|
||||
|
||||
"instructions": AgentActionConfig(
|
||||
type="blob",
|
||||
label="Instructions",
|
||||
description="Provide some instructions to the director for generating actions.",
|
||||
value="",
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def generate_choices_enabled(self):
|
||||
return self.actions["_generate_choices"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def generate_choices_chance(self):
|
||||
return self.actions["_generate_choices"].config["chance"].value
|
||||
|
||||
|
||||
@property
|
||||
def generate_choices_num_choices(self):
|
||||
return self.actions["_generate_choices"].config["num_choices"].value
|
||||
@@ -115,24 +111,25 @@ class GenerateChoicesMixin:
|
||||
return self.actions["_generate_choices"].config["instructions"].value
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("player_turn_start").connect(self.on_player_turn_start)
|
||||
|
||||
talemate.emit.async_signals.get("player_turn_start").connect(
|
||||
self.on_player_turn_start
|
||||
)
|
||||
|
||||
async def on_player_turn_start(self, event: GameLoopStartEvent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if self.generate_choices_enabled:
|
||||
|
||||
# look backwards through history and abort if we encounter
|
||||
# a character message with source "player" before either
|
||||
# a character message with a different source or a narrator message
|
||||
#
|
||||
# this is so choices aren't generated when the player message was
|
||||
# the most recent content in the scene
|
||||
|
||||
|
||||
for i in range(len(self.scene.history) - 1, -1, -1):
|
||||
message = self.scene.history[i]
|
||||
if isinstance(message, NarratorMessage):
|
||||
@@ -141,12 +138,11 @@ class GenerateChoicesMixin:
|
||||
if message.source == "player":
|
||||
return
|
||||
break
|
||||
|
||||
|
||||
if random.random() < self.generate_choices_chance:
|
||||
await self.generate_choices()
|
||||
|
||||
await self.generate_choices()
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_choices(
|
||||
@@ -154,20 +150,23 @@ class GenerateChoicesMixin:
|
||||
instructions: str = None,
|
||||
character: "Character | str | None" = None,
|
||||
):
|
||||
|
||||
emission: GenerateChoicesEmission = GenerateChoicesEmission(agent=self)
|
||||
|
||||
if isinstance(character, str):
|
||||
character = self.scene.get_character(character)
|
||||
|
||||
|
||||
if not character:
|
||||
character = self.scene.get_player_character()
|
||||
|
||||
|
||||
emission.character = character
|
||||
|
||||
await talemate.emit.async_signals.get("agent.director.generate_choices.before_generate").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.director.generate_choices.inject_instructions").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.generate_choices.before_generate"
|
||||
).send(emission)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.generate_choices.inject_instructions"
|
||||
).send(emission)
|
||||
|
||||
response = await Prompt.request(
|
||||
"director.generate-choices",
|
||||
self.client,
|
||||
@@ -178,7 +177,9 @@ class GenerateChoicesMixin:
|
||||
"character": character,
|
||||
"num_choices": self.generate_choices_num_choices,
|
||||
"instructions": instructions or self.generate_choices_instructions,
|
||||
"dynamic_instructions": emission.dynamic_instructions if emission else None,
|
||||
"dynamic_instructions": emission.dynamic_instructions
|
||||
if emission
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -187,10 +188,10 @@ class GenerateChoicesMixin:
|
||||
choices = util.extract_list(choice_text)
|
||||
# strip quotes
|
||||
choices = [choice.strip().strip('"') for choice in choices]
|
||||
|
||||
|
||||
# limit to num_choices
|
||||
choices = choices[:self.generate_choices_num_choices]
|
||||
|
||||
choices = choices[: self.generate_choices_num_choices]
|
||||
|
||||
except Exception as e:
|
||||
log.error("generate_choices failed", error=str(e), response=response)
|
||||
return
|
||||
@@ -198,15 +199,17 @@ class GenerateChoicesMixin:
|
||||
emit(
|
||||
"player_choice",
|
||||
response,
|
||||
data = {
|
||||
data={
|
||||
"choices": choices,
|
||||
"character": character.name,
|
||||
},
|
||||
websocket_passthrough=True
|
||||
websocket_passthrough=True,
|
||||
)
|
||||
|
||||
|
||||
emission.response = response
|
||||
emission.choices = choices
|
||||
await talemate.emit.async_signals.get("agent.director.generate_choices.generated").send(emission)
|
||||
|
||||
return emission.response
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.generate_choices.generated"
|
||||
).send(emission)
|
||||
|
||||
return emission.response
|
||||
|
||||
@@ -17,12 +17,11 @@ from talemate.util import strip_partial_sentences
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
from talemate.agents.summarize.analyze_scene import SceneAnalysisEmission
|
||||
from talemate.agents.editor.revision import RevisionAnalysisEmission
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.director.guide.before_generate",
|
||||
"agent.director.guide.before_generate",
|
||||
"agent.director.guide.inject_instructions",
|
||||
"agent.director.guide.generated",
|
||||
)
|
||||
@@ -32,6 +31,7 @@ talemate.emit.async_signals.register(
|
||||
class DirectorGuidanceEmission(AgentTemplateEmission):
|
||||
pass
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
@@ -42,29 +42,33 @@ def set_processing(fn):
|
||||
@wraps(fn)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
emission: DirectorGuidanceEmission = DirectorGuidanceEmission(agent=self)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.director.guide.before_generate").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.director.guide.inject_instructions").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.guide.before_generate"
|
||||
).send(emission)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.guide.inject_instructions"
|
||||
).send(emission)
|
||||
|
||||
agent_context = active_agent.get()
|
||||
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission.response = response
|
||||
await talemate.emit.async_signals.get("agent.director.guide.generated").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.director.guide.generated").send(
|
||||
emission
|
||||
)
|
||||
return emission.response
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
class GuideSceneMixin:
|
||||
|
||||
"""
|
||||
Director agent mixin that provides functionality for automatically guiding
|
||||
the actors or the narrator during the scene progression.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["guide_scene"] = AgentAction(
|
||||
@@ -81,13 +85,13 @@ class GuideSceneMixin:
|
||||
type="bool",
|
||||
label="Guide actors",
|
||||
description="Guide the actors in the scene. This happens during every actor turn.",
|
||||
value=True
|
||||
value=True,
|
||||
),
|
||||
"guide_narrator": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Guide narrator",
|
||||
description="Guide the narrator during the scene. This happens during the narrator's turn.",
|
||||
value=True
|
||||
value=True,
|
||||
),
|
||||
"guidance_length": AgentActionConfig(
|
||||
type="text",
|
||||
@@ -101,7 +105,7 @@ class GuideSceneMixin:
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Medium Long (768)", "value": "768"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
]
|
||||
],
|
||||
),
|
||||
"cache_guidance": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -109,57 +113,57 @@ class GuideSceneMixin:
|
||||
description="Will not regenerate the guidance until the scene moves forward or the analysis changes.",
|
||||
value=False,
|
||||
quick_toggle=True,
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def guide_scene(self) -> bool:
|
||||
return self.actions["guide_scene"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def guide_actors(self) -> bool:
|
||||
return self.actions["guide_scene"].config["guide_actors"].value
|
||||
|
||||
|
||||
@property
|
||||
def guide_narrator(self) -> bool:
|
||||
return self.actions["guide_scene"].config["guide_narrator"].value
|
||||
|
||||
|
||||
@property
|
||||
def guide_scene_guidance_length(self) -> int:
|
||||
return int(self.actions["guide_scene"].config["guidance_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def guide_scene_cache_guidance(self) -> bool:
|
||||
return self.actions["guide_scene"].config["cache_guidance"].value
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").connect(
|
||||
self.on_summarization_scene_analysis_after
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").connect(
|
||||
self.on_summarization_scene_analysis_after
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
|
||||
self.on_editor_revision_analysis_before
|
||||
)
|
||||
|
||||
async def on_summarization_scene_analysis_after(self, emission: "SceneAnalysisEmission"):
|
||||
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.after"
|
||||
).connect(self.on_summarization_scene_analysis_after)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.cached"
|
||||
).connect(self.on_summarization_scene_analysis_after)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.editor.revision-analysis.before"
|
||||
).connect(self.on_editor_revision_analysis_before)
|
||||
|
||||
async def on_summarization_scene_analysis_after(
|
||||
self, emission: "SceneAnalysisEmission"
|
||||
):
|
||||
if not self.guide_scene:
|
||||
return
|
||||
|
||||
|
||||
guidance = None
|
||||
|
||||
|
||||
cached_guidance = await self.get_cached_guidance(emission.response)
|
||||
|
||||
|
||||
if emission.analysis_type == "narration" and self.guide_narrator:
|
||||
|
||||
if cached_guidance:
|
||||
guidance = cached_guidance
|
||||
else:
|
||||
@@ -167,15 +171,14 @@ class GuideSceneMixin:
|
||||
emission.response,
|
||||
response_length=self.guide_scene_guidance_length,
|
||||
)
|
||||
|
||||
|
||||
if not guidance:
|
||||
log.warning("director.guide_scene.narration: Empty resonse")
|
||||
return
|
||||
|
||||
|
||||
self.set_context_states(narrator_guidance=guidance)
|
||||
|
||||
elif emission.analysis_type == "conversation" and self.guide_actors:
|
||||
|
||||
|
||||
elif emission.analysis_type == "conversation" and self.guide_actors:
|
||||
if cached_guidance:
|
||||
guidance = cached_guidance
|
||||
else:
|
||||
@@ -184,94 +187,112 @@ class GuideSceneMixin:
|
||||
emission.template_vars.get("character"),
|
||||
response_length=self.guide_scene_guidance_length,
|
||||
)
|
||||
|
||||
|
||||
if not guidance:
|
||||
log.warning("director.guide_scene.conversation: Empty resonse")
|
||||
return
|
||||
|
||||
|
||||
self.set_context_states(actor_guidance=guidance)
|
||||
|
||||
if guidance:
|
||||
await self.set_cached_guidance(
|
||||
emission.response,
|
||||
guidance,
|
||||
emission.analysis_type,
|
||||
emission.template_vars.get("character")
|
||||
emission.response,
|
||||
guidance,
|
||||
emission.analysis_type,
|
||||
emission.template_vars.get("character"),
|
||||
)
|
||||
|
||||
|
||||
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
|
||||
cached_guidance = await self.get_cached_guidance(emission.response)
|
||||
if cached_guidance:
|
||||
emission.dynamic_instructions.append(DynamicInstruction(
|
||||
title="Guidance",
|
||||
content=cached_guidance
|
||||
))
|
||||
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(title="Guidance", content=cached_guidance)
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def _cache_key(self) -> str:
|
||||
return f"cached_guidance"
|
||||
|
||||
async def get_cached_guidance(self, analysis:str | None = None) -> str | None:
|
||||
return "cached_guidance"
|
||||
|
||||
async def get_cached_guidance(self, analysis: str | None = None) -> str | None:
|
||||
"""
|
||||
Returns the cached guidance for the given analysis.
|
||||
|
||||
|
||||
If analysis is not provided, it will return the cached guidance for the last analysis regardless
|
||||
of the fingerprint.
|
||||
"""
|
||||
|
||||
|
||||
if not self.guide_scene_cache_guidance:
|
||||
return None
|
||||
|
||||
|
||||
key = self._cache_key()
|
||||
cached_guidance = self.get_scene_state(key)
|
||||
|
||||
|
||||
if cached_guidance:
|
||||
if not analysis:
|
||||
return cached_guidance.get("guidance")
|
||||
elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]):
|
||||
elif cached_guidance.get("fp") == self.context_fingerprint(
|
||||
extra=[analysis]
|
||||
):
|
||||
return cached_guidance.get("guidance")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def set_cached_guidance(self, analysis:str, guidance: str, analysis_type: str, character: "Character | None" = None):
|
||||
|
||||
async def set_cached_guidance(
|
||||
self,
|
||||
analysis: str,
|
||||
guidance: str,
|
||||
analysis_type: str,
|
||||
character: "Character | None" = None,
|
||||
):
|
||||
"""
|
||||
Sets the cached guidance for the given analysis.
|
||||
"""
|
||||
key = self._cache_key()
|
||||
self.set_scene_states(**{
|
||||
key: {
|
||||
"fp": self.context_fingerpint(extra=[analysis]),
|
||||
"guidance": guidance,
|
||||
"analysis_type": analysis_type,
|
||||
"character": character.name if character else None,
|
||||
self.set_scene_states(
|
||||
**{
|
||||
key: {
|
||||
"fp": self.context_fingerprint(extra=[analysis]),
|
||||
"guidance": guidance,
|
||||
"analysis_type": analysis_type,
|
||||
"character": character.name if character else None,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
async def get_cached_character_guidance(self, character_name: str) -> str | None:
|
||||
"""
|
||||
Returns the cached guidance for the given character.
|
||||
"""
|
||||
key = self._cache_key()
|
||||
cached_guidance = self.get_scene_state(key)
|
||||
|
||||
|
||||
if not cached_guidance:
|
||||
return None
|
||||
|
||||
if cached_guidance.get("character") == character_name and cached_guidance.get("analysis_type") == "conversation":
|
||||
|
||||
if (
|
||||
cached_guidance.get("character") == character_name
|
||||
and cached_guidance.get("analysis_type") == "conversation"
|
||||
):
|
||||
return cached_guidance.get("guidance")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def guide_actor_off_of_scene_analysis(self, analysis: str, character: "Character", response_length: int = 256):
|
||||
async def guide_actor_off_of_scene_analysis(
|
||||
self, analysis: str, character: "Character", response_length: int = 256
|
||||
):
|
||||
"""
|
||||
Guides the actor based on the scene analysis.
|
||||
"""
|
||||
|
||||
log.debug("director.guide_actor_off_of_scene_analysis", analysis=analysis, character=character)
|
||||
|
||||
log.debug(
|
||||
"director.guide_actor_off_of_scene_analysis",
|
||||
analysis=analysis,
|
||||
character=character,
|
||||
)
|
||||
response = await Prompt.request(
|
||||
"director.guide-conversation",
|
||||
self.client,
|
||||
@@ -285,19 +306,17 @@ class GuideSceneMixin:
|
||||
},
|
||||
)
|
||||
return strip_partial_sentences(response).strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def guide_narrator_off_of_scene_analysis(
|
||||
self,
|
||||
analysis: str,
|
||||
response_length: int = 256
|
||||
self, analysis: str, response_length: int = 256
|
||||
):
|
||||
"""
|
||||
Guides the narrator based on the scene analysis.
|
||||
"""
|
||||
|
||||
|
||||
log.debug("director.guide_narrator_off_of_scene_analysis", analysis=analysis)
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"director.guide-narration",
|
||||
self.client,
|
||||
@@ -309,4 +328,4 @@ class GuideSceneMixin:
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
return strip_partial_sentences(response).strip()
|
||||
return strip_partial_sentences(response).strip()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
@@ -9,24 +8,27 @@ from talemate.events import GameLoopActorIterEvent, SceneStateEvent
|
||||
|
||||
log = structlog.get_logger("talemate.agents.conversation.legacy_scene_instructions")
|
||||
|
||||
|
||||
class LegacySceneInstructionsMixin(
|
||||
GameInstructionsMixin,
|
||||
):
|
||||
"""
|
||||
Legacy support for scoped api instructions in scenes.
|
||||
|
||||
|
||||
This is being replaced by node based in structions, but kept for backwards compatibility.
|
||||
|
||||
|
||||
THIS WILL BE DEPRECATED IN THE FUTURE.
|
||||
"""
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.LSI_on_player_dialog)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
|
||||
self.LSI_on_player_dialog
|
||||
)
|
||||
talemate.emit.async_signals.get("scene_init").connect(self.LSI_on_scene_init)
|
||||
|
||||
|
||||
async def LSI_on_scene_init(self, event: SceneStateEvent):
|
||||
"""
|
||||
LEGACY: If game state instructions specify to be run at the start of the game loop
|
||||
@@ -57,8 +59,10 @@ class LegacySceneInstructionsMixin(
|
||||
|
||||
if not event.actor.character.is_player:
|
||||
return
|
||||
|
||||
log.warning(f"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
|
||||
|
||||
log.warning(
|
||||
"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
|
||||
)
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug(
|
||||
@@ -77,17 +81,17 @@ class LegacySceneInstructionsMixin(
|
||||
not self.scene.npc_character_names
|
||||
or self.scene.game_state.ops.always_direct
|
||||
)
|
||||
|
||||
log.warning(f"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.", always_direct=always_direct)
|
||||
|
||||
log.warning(
|
||||
"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.",
|
||||
always_direct=always_direct,
|
||||
)
|
||||
|
||||
next_direct = self.next_direct_scene
|
||||
|
||||
TURNS = 5
|
||||
|
||||
if (
|
||||
next_direct % TURNS != 0
|
||||
or next_direct == 0
|
||||
):
|
||||
if next_direct % TURNS != 0 or next_direct == 0:
|
||||
if not always_direct:
|
||||
log.info("direct", skip=True, next_direct=next_direct)
|
||||
self.next_direct_scene += 1
|
||||
@@ -112,8 +116,10 @@ class LegacySceneInstructionsMixin(
|
||||
async def LSI_direct_scene(self):
|
||||
"""
|
||||
LEGACY: Direct the scene based scoped api scene instructions.
|
||||
This is being replaced by node based instructions, but kept for
|
||||
This is being replaced by node based instructions, but kept for
|
||||
backwards compatibility.
|
||||
"""
|
||||
log.warning(f"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
|
||||
await self.run_scene_instructions(self.scene)
|
||||
log.warning(
|
||||
"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
|
||||
)
|
||||
await self.run_scene_instructions(self.scene)
|
||||
|
||||
@@ -1,29 +1,34 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from typing import ClassVar
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
TYPE_CHOICES,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
from talemate.character import Character
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
TYPE_CHOICES.extend([
|
||||
"director/direction",
|
||||
])
|
||||
TYPE_CHOICES.extend(
|
||||
[
|
||||
"director/direction",
|
||||
]
|
||||
)
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
|
||||
|
||||
|
||||
@register("agents/director/Settings")
|
||||
class DirectorSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render director agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Director Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
@register("agents/director/PersistCharacter")
|
||||
class PersistCharacter(AgentNode):
|
||||
@@ -31,45 +36,131 @@ class PersistCharacter(AgentNode):
|
||||
Persists a character that currently only exists as part of the given context
|
||||
as a real character that can actively participate in the scene.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
determine_name = PropertyField(
|
||||
name="determine_name",
|
||||
type="bool",
|
||||
description="Whether to determine the name of the character",
|
||||
default=True
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Persist Character", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character_name", socket_type="str")
|
||||
self.add_input("context", socket_type="str", optional=True)
|
||||
self.add_input("attributes", socket_type="dict,str", optional=True)
|
||||
|
||||
|
||||
self.set_property("determine_name", True)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("character", socket_type="character")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
|
||||
character_name = self.get_input_value("character_name")
|
||||
context = self.normalized_input_value("context")
|
||||
attributes = self.normalized_input_value("attributes")
|
||||
determine_name = self.normalized_input_value("determine_name")
|
||||
|
||||
|
||||
character = await self.agent.persist_character(
|
||||
name=character_name,
|
||||
content=context,
|
||||
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()]) if attributes else None,
|
||||
determine_name=determine_name
|
||||
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()])
|
||||
if attributes
|
||||
else None,
|
||||
determine_name=determine_name,
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"character": character
|
||||
})
|
||||
|
||||
self.set_output_values({"state": state, "character": character})
|
||||
|
||||
|
||||
@register("agents/director/AssignVoice")
|
||||
class AssignVoice(AgentNode):
|
||||
"""
|
||||
Assigns a voice to a character.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Assign Voice", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("character", socket_type="character")
|
||||
self.add_output("voice", socket_type="tts/voice")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character: "Character" = self.require_input("character")
|
||||
|
||||
await self.agent.assign_voice_to_character(character)
|
||||
|
||||
voice = character.voice
|
||||
|
||||
self.set_output_values({"state": state, "character": character, "voice": voice})
|
||||
|
||||
|
||||
@register("agents/director/LogAction")
|
||||
class LogAction(AgentNode):
|
||||
"""
|
||||
Logs an action to the console.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
action = PropertyField(
|
||||
name="action",
|
||||
type="str",
|
||||
description="The action to log",
|
||||
default="",
|
||||
)
|
||||
action_description = PropertyField(
|
||||
name="action_description",
|
||||
type="str",
|
||||
description="The description of the action",
|
||||
default="",
|
||||
)
|
||||
console_only = PropertyField(
|
||||
name="console_only",
|
||||
type="bool",
|
||||
description="Whether to log the action to the console only",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self, title="Log Director Action", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("action", socket_type="str")
|
||||
self.add_input("action_description", socket_type="str")
|
||||
self.add_input("console_only", socket_type="bool", optional=True)
|
||||
|
||||
self.set_property("action", "")
|
||||
self.set_property("action_description", "")
|
||||
self.set_property("console_only", False)
|
||||
|
||||
self.add_output("state")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
state = self.require_input("state")
|
||||
action = self.require_input("action")
|
||||
action_description = self.require_input("action_description")
|
||||
console_only = self.normalized_input_value("console_only") or False
|
||||
|
||||
await self.agent.log_action(
|
||||
action=action,
|
||||
action_description=action_description,
|
||||
console_only=console_only,
|
||||
)
|
||||
|
||||
self.set_output_values({"state": state})
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import pydantic
|
||||
import asyncio
|
||||
import structlog
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
from talemate.context import interaction
|
||||
from talemate.context import interaction, handle_generation_cancelled
|
||||
from talemate.status import set_loading
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
@@ -18,42 +18,52 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.director")
|
||||
|
||||
|
||||
class InstructionPayload(pydantic.BaseModel):
|
||||
instructions:str = ""
|
||||
instructions: str = ""
|
||||
|
||||
|
||||
class SelectChoicePayload(pydantic.BaseModel):
|
||||
choice: str
|
||||
character:str = ""
|
||||
character: str = ""
|
||||
|
||||
|
||||
class CharacterPayload(InstructionPayload):
|
||||
character:str = ""
|
||||
character: str = ""
|
||||
|
||||
|
||||
class PersistCharacterPayload(pydantic.BaseModel):
|
||||
name: str
|
||||
templates: list[str] | None = None
|
||||
narrate_entry: bool = True
|
||||
narrate_entry_direction: str = ""
|
||||
|
||||
|
||||
active: bool = True
|
||||
determine_name: bool = True
|
||||
augment_attributes: str = ""
|
||||
generate_attributes: bool = True
|
||||
|
||||
|
||||
content: str = ""
|
||||
description: str = ""
|
||||
|
||||
is_player: bool = False
|
||||
|
||||
|
||||
class AssignVoiceToCharacterPayload(pydantic.BaseModel):
|
||||
character_name: str
|
||||
|
||||
|
||||
class DirectorWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles director actions
|
||||
"""
|
||||
|
||||
|
||||
router = "director"
|
||||
|
||||
|
||||
@property
|
||||
def director(self):
|
||||
return get_agent("director")
|
||||
|
||||
|
||||
@set_loading("Generating dynamic actions", cancellable=True, as_async=True)
|
||||
async def handle_request_dynamic_choices(self, data: dict):
|
||||
"""
|
||||
@@ -61,21 +71,21 @@ class DirectorWebsocketHandler(Plugin):
|
||||
"""
|
||||
payload = CharacterPayload(**data)
|
||||
await self.director.generate_choices(**payload.model_dump())
|
||||
|
||||
|
||||
async def handle_select_choice(self, data: dict):
|
||||
payload = SelectChoicePayload(**data)
|
||||
|
||||
|
||||
log.debug("selecting choice", payload=payload)
|
||||
|
||||
|
||||
if payload.character:
|
||||
character = self.scene.get_character(payload.character)
|
||||
else:
|
||||
character = self.scene.get_player_character()
|
||||
|
||||
|
||||
if not character:
|
||||
log.error("handle_select_choice: could not find character", payload=payload)
|
||||
return
|
||||
|
||||
|
||||
# hijack the interaction state
|
||||
try:
|
||||
interaction_state = interaction.get()
|
||||
@@ -83,24 +93,33 @@ class DirectorWebsocketHandler(Plugin):
|
||||
# no interaction state
|
||||
log.error("handle_select_choice: no interaction state", payload=payload)
|
||||
return
|
||||
|
||||
|
||||
interaction_state.from_choice = payload.choice
|
||||
interaction_state.act_as = character.name if not character.is_player else None
|
||||
interaction_state.input = f"@{payload.choice}"
|
||||
|
||||
|
||||
async def handle_persist_character(self, data: dict):
|
||||
payload = PersistCharacterPayload(**data)
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
|
||||
if not payload.content:
|
||||
payload.content = scene.snapshot(lines=15)
|
||||
|
||||
|
||||
# add as asyncio task
|
||||
task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
|
||||
task = asyncio.create_task(
|
||||
self.director.persist_character(**payload.model_dump())
|
||||
)
|
||||
|
||||
async def handle_task_done(task):
|
||||
if task.exception():
|
||||
log.error("Error persisting character", error=task.exception())
|
||||
await self.signal_operation_failed("Error persisting character")
|
||||
exc = task.exception()
|
||||
log.error("Error persisting character", error=exc)
|
||||
|
||||
# Handle GenerationCancelled properly to reset cancel_requested flag
|
||||
if isinstance(exc, GenerationCancelled):
|
||||
handle_generation_cancelled(exc)
|
||||
|
||||
await self.signal_operation_failed(f"Error persisting character: {exc}")
|
||||
else:
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
@@ -113,3 +132,62 @@ class DirectorWebsocketHandler(Plugin):
|
||||
|
||||
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
|
||||
|
||||
async def handle_assign_voice_to_character(self, data: dict):
|
||||
"""
|
||||
Assign a voice to a character using the director agent
|
||||
"""
|
||||
try:
|
||||
payload = AssignVoiceToCharacterPayload(**data)
|
||||
except pydantic.ValidationError as e:
|
||||
await self.signal_operation_failed(str(e))
|
||||
return
|
||||
|
||||
scene: "Scene" = self.scene
|
||||
if not scene:
|
||||
await self.signal_operation_failed("No scene active")
|
||||
return
|
||||
|
||||
character = scene.get_character(payload.character_name)
|
||||
if not character:
|
||||
await self.signal_operation_failed(
|
||||
f"Character '{payload.character_name}' not found"
|
||||
)
|
||||
return
|
||||
|
||||
character.voice = None
|
||||
|
||||
# Add as asyncio task
|
||||
task = asyncio.create_task(self.director.assign_voice_to_character(character))
|
||||
|
||||
async def handle_task_done(task):
|
||||
if task.exception():
|
||||
exc = task.exception()
|
||||
log.error("Error assigning voice to character", error=exc)
|
||||
|
||||
# Handle GenerationCancelled properly to reset cancel_requested flag
|
||||
if isinstance(exc, GenerationCancelled):
|
||||
handle_generation_cancelled(exc)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "assign_voice_to_character_failed",
|
||||
"character_name": payload.character_name,
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
await self.signal_operation_failed(
|
||||
f"Error assigning voice to character: {exc}"
|
||||
)
|
||||
else:
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": self.router,
|
||||
"action": "assign_voice_to_character_done",
|
||||
"character_name": payload.character_name,
|
||||
}
|
||||
)
|
||||
await self.signal_operation_done()
|
||||
self.scene.emit_status()
|
||||
|
||||
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
|
||||
|
||||
@@ -6,6 +6,7 @@ import structlog
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.client import ClientBase
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
@@ -40,7 +41,7 @@ class EditorAgent(
|
||||
agent_type = "editor"
|
||||
verbose_name = "Editor"
|
||||
websocket_handler = EditorWebsocketHandler
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls) -> dict[str, AgentAction]:
|
||||
actions = {
|
||||
@@ -56,9 +57,9 @@ class EditorAgent(
|
||||
description="The formatting to use for exposition.",
|
||||
value="novel",
|
||||
choices=[
|
||||
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
|
||||
{"label": "Novel: \"Speech\" narration", "value": "novel"},
|
||||
]
|
||||
{"label": 'Chat RP: "Speech" *narration*', "value": "chat"},
|
||||
{"label": 'Novel: "Speech" narration', "value": "novel"},
|
||||
],
|
||||
),
|
||||
"narrator": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -81,16 +82,16 @@ class EditorAgent(
|
||||
description="Attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
RevisionMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = EditorAgent.init_actions()
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
@@ -102,11 +103,11 @@ class EditorAgent(
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def fix_exposition_enabled(self):
|
||||
return self.actions["fix_exposition"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def fix_exposition_formatting(self):
|
||||
return self.actions["fix_exposition"].config["formatting"].value
|
||||
@@ -114,11 +115,10 @@ class EditorAgent(
|
||||
@property
|
||||
def fix_exposition_narrator(self):
|
||||
return self.actions["fix_exposition"].config["narrator"].value
|
||||
|
||||
|
||||
@property
|
||||
def fix_exposition_user_input(self):
|
||||
return self.actions["fix_exposition"].config["user_input"].value
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
@@ -134,7 +134,7 @@ class EditorAgent(
|
||||
formatting = "md"
|
||||
else:
|
||||
formatting = None
|
||||
|
||||
|
||||
if self.fix_exposition_formatting == "chat":
|
||||
text = text.replace("**", "*")
|
||||
text = text.replace("[", "*").replace("]", "*")
|
||||
@@ -143,15 +143,14 @@ class EditorAgent(
|
||||
text = text.replace("*", "")
|
||||
text = text.replace("[", "").replace("]", "")
|
||||
text = text.replace("(", "").replace(")", "")
|
||||
|
||||
cleaned = util.ensure_dialog_format(
|
||||
text,
|
||||
talking_character=character.name if character else None,
|
||||
formatting=formatting
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
cleaned = util.ensure_dialog_format(
|
||||
text,
|
||||
talking_character=character.name if character else None,
|
||||
formatting=formatting,
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
async def on_conversation_generated(self, emission: ConversationAgentEmission):
|
||||
"""
|
||||
@@ -181,11 +180,13 @@ class EditorAgent(
|
||||
emission.response = edit
|
||||
|
||||
@set_processing
|
||||
async def cleanup_character_message(self, content: str, character: Character, force: bool = False):
|
||||
async def cleanup_character_message(
|
||||
self, content: str, character: Character, force: bool = False
|
||||
):
|
||||
"""
|
||||
Edits a text to make sure all narrative exposition and emotes is encased in *
|
||||
"""
|
||||
|
||||
|
||||
# if not content was generated, return it as is
|
||||
if not content:
|
||||
return content
|
||||
@@ -210,14 +211,14 @@ class EditorAgent(
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
|
||||
# if there are uneven quotation marks, fix them by adding a closing quote
|
||||
if '"' in content and content.count('"') % 2 != 0:
|
||||
content += '"'
|
||||
|
||||
|
||||
if not self.fix_exposition_enabled and not exposition_fixed:
|
||||
return content
|
||||
|
||||
|
||||
content = self.fix_exposition_in_text(content, character)
|
||||
|
||||
return content
|
||||
@@ -225,7 +226,7 @@ class EditorAgent(
|
||||
@set_processing
|
||||
async def clean_up_narration(self, content: str, force: bool = False):
|
||||
content = util.strip_partial_sentences(content)
|
||||
if (self.fix_exposition_enabled and self.fix_exposition_narrator or force):
|
||||
if self.fix_exposition_enabled and self.fix_exposition_narrator or force:
|
||||
content = self.fix_exposition_in_text(content, None)
|
||||
if self.fix_exposition_formatting == "chat":
|
||||
if '"' not in content and "*" not in content:
|
||||
@@ -234,25 +235,27 @@ class EditorAgent(
|
||||
return content
|
||||
|
||||
@set_processing
|
||||
async def cleanup_user_input(self, text: str, as_narration: bool = False, force: bool = False):
|
||||
|
||||
async def cleanup_user_input(
|
||||
self, text: str, as_narration: bool = False, force: bool = False
|
||||
):
|
||||
# special prefix characters - when found, never edit
|
||||
PREFIX_CHARACTERS = ("!", "@", "/")
|
||||
if text.startswith(PREFIX_CHARACTERS):
|
||||
return text
|
||||
|
||||
if (not self.fix_exposition_user_input or not self.fix_exposition_enabled) and not force:
|
||||
|
||||
if (
|
||||
not self.fix_exposition_user_input or not self.fix_exposition_enabled
|
||||
) and not force:
|
||||
return text
|
||||
|
||||
|
||||
if not as_narration:
|
||||
if self.fix_exposition_formatting == "chat":
|
||||
if '"' not in text and "*" not in text:
|
||||
text = f'"{text}"'
|
||||
else:
|
||||
return await self.clean_up_narration(text)
|
||||
|
||||
|
||||
return self.fix_exposition_in_text(text)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def add_detail(self, content: str, character: Character):
|
||||
@@ -279,4 +282,4 @@ class EditorAgent(
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
return response
|
||||
|
||||
@@ -1,70 +1,37 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
from talemate.agents.editor import EditorAgent
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.editor")
|
||||
|
||||
|
||||
@register("agents/editor/Settings")
|
||||
class EditorSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render editor agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
def __init__(self, title="Editor Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
@register("agents/editor/CleanUpUserInput")
|
||||
class CleanUpUserInput(AgentNode):
|
||||
"""
|
||||
Cleans up user input.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
name="force",
|
||||
description="Force clean up",
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Clean Up User Input", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("user_input", socket_type="str")
|
||||
self.add_input("as_narration", socket_type="bool")
|
||||
|
||||
self.set_property("force", False)
|
||||
|
||||
self.add_output("cleaned_user_input", socket_type="str")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor:"EditorAgent" = self.agent
|
||||
user_input = self.get_input_value("user_input")
|
||||
force = self.get_property("force")
|
||||
as_narration = self.get_input_value("as_narration")
|
||||
cleaned_user_input = await editor.cleanup_user_input(user_input, as_narration=as_narration, force=force)
|
||||
self.set_output_values({
|
||||
"cleaned_user_input": cleaned_user_input,
|
||||
})
|
||||
|
||||
@register("agents/editor/CleanUpNarration")
|
||||
class CleanUpNarration(AgentNode):
|
||||
"""
|
||||
Cleans up narration.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
@@ -73,31 +40,41 @@ class CleanUpNarration(AgentNode):
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self, title="Clean Up Narration", **kwargs):
|
||||
|
||||
def __init__(self, title="Clean Up User Input", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("narration", socket_type="str")
|
||||
self.add_input("user_input", socket_type="str")
|
||||
self.add_input("as_narration", socket_type="bool")
|
||||
|
||||
self.set_property("force", False)
|
||||
self.add_output("cleaned_narration", socket_type="str")
|
||||
|
||||
|
||||
self.add_output("cleaned_user_input", socket_type="str")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor:"EditorAgent" = self.agent
|
||||
narration = self.get_input_value("narration")
|
||||
editor: "EditorAgent" = self.agent
|
||||
user_input = self.get_input_value("user_input")
|
||||
force = self.get_property("force")
|
||||
cleaned_narration = await editor.cleanup_narration(narration, force=force)
|
||||
self.set_output_values({
|
||||
"cleaned_narration": cleaned_narration,
|
||||
})
|
||||
|
||||
@register("agents/editor/CleanUoCharacterMessage")
|
||||
class CleanUpCharacterMessage(AgentNode):
|
||||
as_narration = self.get_input_value("as_narration")
|
||||
cleaned_user_input = await editor.cleanup_user_input(
|
||||
user_input, as_narration=as_narration, force=force
|
||||
)
|
||||
self.set_output_values(
|
||||
{
|
||||
"cleaned_user_input": cleaned_user_input,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@register("agents/editor/CleanUpNarration")
|
||||
class CleanUpNarration(AgentNode):
|
||||
"""
|
||||
Cleans up character message.
|
||||
Cleans up narration.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
name="force",
|
||||
@@ -105,22 +82,62 @@ class CleanUpCharacterMessage(AgentNode):
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Clean Up Narration", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("narration", socket_type="str")
|
||||
self.set_property("force", False)
|
||||
self.add_output("cleaned_narration", socket_type="str")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor: "EditorAgent" = self.agent
|
||||
narration = self.get_input_value("narration")
|
||||
force = self.get_property("force")
|
||||
cleaned_narration = await editor.cleanup_narration(narration, force=force)
|
||||
self.set_output_values(
|
||||
{
|
||||
"cleaned_narration": cleaned_narration,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@register("agents/editor/CleanUoCharacterMessage")
|
||||
class CleanUpCharacterMessage(AgentNode):
|
||||
"""
|
||||
Cleans up character message.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
name="force",
|
||||
description="Force clean up",
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self, title="Clean Up Character Message", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("text", socket_type="str")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.set_property("force", False)
|
||||
self.add_output("cleaned_character_message", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor:"EditorAgent" = self.agent
|
||||
editor: "EditorAgent" = self.agent
|
||||
text = self.get_input_value("text")
|
||||
force = self.get_property("force")
|
||||
character = self.get_input_value("character")
|
||||
cleaned_character_message = await editor.cleanup_character_message(text, character, force=force)
|
||||
self.set_output_values({
|
||||
"cleaned_character_message": cleaned_character_message,
|
||||
})
|
||||
cleaned_character_message = await editor.cleanup_character_message(
|
||||
text, character, force=force
|
||||
)
|
||||
self.set_output_values(
|
||||
{
|
||||
"cleaned_character_message": cleaned_character_message,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
from talemate.status import set_loading
|
||||
from talemate.scene_message import CharacterMessage
|
||||
from talemate.agents.editor.revision import RevisionContext, RevisionInformation
|
||||
|
||||
@@ -17,47 +16,52 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.editor")
|
||||
|
||||
|
||||
class RevisionPayload(pydantic.BaseModel):
|
||||
message_id: int
|
||||
|
||||
|
||||
class EditorWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles editor actions
|
||||
"""
|
||||
|
||||
|
||||
router = "editor"
|
||||
|
||||
|
||||
@property
|
||||
def editor(self):
|
||||
return get_agent("editor")
|
||||
|
||||
|
||||
async def handle_request_revision(self, data: dict):
|
||||
"""
|
||||
Generate clickable actions for the user
|
||||
"""
|
||||
|
||||
|
||||
editor = self.editor
|
||||
scene:"Scene" = self.scene
|
||||
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
if not editor.revision_enabled:
|
||||
raise Exception("Revision is not enabled")
|
||||
|
||||
|
||||
payload = RevisionPayload(**data)
|
||||
message = scene.get_message(payload.message_id)
|
||||
|
||||
|
||||
character = None
|
||||
|
||||
|
||||
if isinstance(message, CharacterMessage):
|
||||
character = scene.get_character(message.character_name)
|
||||
|
||||
|
||||
if not message:
|
||||
raise Exception("Message not found")
|
||||
|
||||
|
||||
with RevisionContext(message.id):
|
||||
info = RevisionInformation(
|
||||
text=message.message,
|
||||
character=character,
|
||||
)
|
||||
revised = await editor.revision_revise(info)
|
||||
|
||||
scene.edit_message(message.id, revised)
|
||||
if isinstance(message, CharacterMessage):
|
||||
if not revised.startswith(character.name + ":"):
|
||||
revised = f"{character.name}: {revised}"
|
||||
|
||||
scene.edit_message(message.id, revised)
|
||||
|
||||
@@ -23,10 +23,9 @@ from talemate.agents.base import (
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.config import load_config
|
||||
from talemate.config.schema import EmbeddingFunctionPreset
|
||||
from talemate.context import scene_is_loading, active_scene
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.agents.memory.context import memory_request, MemoryRequest
|
||||
from talemate.agents.memory.exceptions import (
|
||||
@@ -75,10 +74,9 @@ class MemoryAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls, presets: list[dict] | None = None) -> dict[str, AgentAction]:
|
||||
|
||||
if presets is None:
|
||||
presets = []
|
||||
|
||||
|
||||
actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
@@ -101,23 +99,24 @@ class MemoryAgent(Agent):
|
||||
choices=[
|
||||
{"value": "cpu", "label": "CPU"},
|
||||
{"value": "cuda", "label": "CUDA"},
|
||||
]
|
||||
],
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
return actions
|
||||
|
||||
def __init__(self, scene, **kwargs):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.db = None
|
||||
self.scene = scene
|
||||
self.memory_tracker = {}
|
||||
self.config = load_config()
|
||||
self._ready_to_add = False
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
|
||||
|
||||
|
||||
async_signals.get("config.changed").connect(self.on_config_changed)
|
||||
|
||||
async_signals.get("client.embeddings_available").connect(
|
||||
self.on_client_embeddings_available
|
||||
)
|
||||
|
||||
self.actions = MemoryAgent.init_actions(presets=self.get_presets)
|
||||
|
||||
@property
|
||||
@@ -132,30 +131,33 @@ class MemoryAgent(Agent):
|
||||
@property
|
||||
def db_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@property
|
||||
def get_presets(self):
|
||||
|
||||
def _label(embedding:dict):
|
||||
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
|
||||
if embedding['model']:
|
||||
return f"{prefix}: {embedding['model']}"
|
||||
def _label(embedding: EmbeddingFunctionPreset):
|
||||
prefix = embedding.client if embedding.client else embedding.embeddings
|
||||
if embedding.model:
|
||||
return f"{prefix}: {embedding.model}"
|
||||
else:
|
||||
return f"{prefix}"
|
||||
|
||||
|
||||
return [
|
||||
{"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
|
||||
{"value": k, "label": _label(v)}
|
||||
for k, v in self.config.presets.embeddings.items()
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_config(self):
|
||||
_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {})
|
||||
|
||||
return self.config.presets.embeddings.get(_embeddings)
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
return self.embeddings_config.get("embeddings", "sentence-transformer")
|
||||
|
||||
try:
|
||||
return self.embeddings_config.embeddings
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def using_openai_embeddings(self):
|
||||
return self.embeddings == "openai"
|
||||
@@ -163,42 +165,46 @@ class MemoryAgent(Agent):
|
||||
@property
|
||||
def using_instructor_embeddings(self):
|
||||
return self.embeddings == "instructor"
|
||||
|
||||
|
||||
@property
|
||||
def using_sentence_transformer_embeddings(self):
|
||||
return self.embeddings == "default" or self.embeddings == "sentence-transformer"
|
||||
|
||||
|
||||
@property
|
||||
def using_client_api_embeddings(self):
|
||||
return self.embeddings == "client-api"
|
||||
|
||||
|
||||
@property
|
||||
def using_local_embeddings(self):
|
||||
return self.embeddings in [
|
||||
"instructor",
|
||||
"sentence-transformer",
|
||||
"default"
|
||||
]
|
||||
|
||||
|
||||
return self.embeddings in ["instructor", "sentence-transformer", "default"]
|
||||
|
||||
@property
|
||||
def embeddings_client(self):
|
||||
return self.embeddings_config.get("client")
|
||||
|
||||
try:
|
||||
return self.embeddings_config.client
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def max_distance(self) -> float:
|
||||
distance = float(self.embeddings_config.get("distance", 1.0))
|
||||
distance_mod = float(self.embeddings_config.get("distance_mod", 1.0))
|
||||
|
||||
distance = float(self.embeddings_config.distance)
|
||||
distance_mod = float(self.embeddings_config.distance_mod)
|
||||
|
||||
return distance * distance_mod
|
||||
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.embeddings_config.get("model")
|
||||
|
||||
try:
|
||||
return self.embeddings_config.model
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def distance_function(self):
|
||||
return self.embeddings_config.get("distance_function", "l2")
|
||||
try:
|
||||
return self.embeddings_config.distance_function
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
@@ -206,32 +212,38 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def trust_remote_code(self) -> bool:
|
||||
return self.embeddings_config.get("trust_remote_code", False)
|
||||
try:
|
||||
return self.embeddings_config.trust_remote_code
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> str:
|
||||
"""
|
||||
Returns a unique fingerprint for the current configuration
|
||||
"""
|
||||
|
||||
model_name = self.model.replace('/','-') if self.model else "none"
|
||||
|
||||
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
|
||||
|
||||
model_name = self.model.replace("/", "-") if self.model else "none"
|
||||
|
||||
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
_fingerprint = self.fingerprint
|
||||
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
|
||||
fingerprint_changed = _fingerprint != self.fingerprint
|
||||
|
||||
|
||||
# have embeddings or device changed?
|
||||
if fingerprint_changed:
|
||||
log.warning("memory agent", status="embedding function changed", old=_fingerprint, new=self.fingerprint)
|
||||
log.warning(
|
||||
"memory agent",
|
||||
status="embedding function changed",
|
||||
old=_fingerprint,
|
||||
new=self.fingerprint,
|
||||
)
|
||||
await self.handle_embeddings_change()
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def handle_embeddings_change(self):
|
||||
scene = active_scene.get()
|
||||
@@ -239,10 +251,10 @@ class MemoryAgent(Agent):
|
||||
# if sentence-transformer and no model-name, set embeddings to default
|
||||
if self.using_sentence_transformer_embeddings and not self.model:
|
||||
self.actions["_config"].config["embeddings"].value = "default"
|
||||
|
||||
if not scene or not scene.get_helper("memory"):
|
||||
|
||||
if not scene or not scene.active:
|
||||
return
|
||||
|
||||
|
||||
self.close_db(scene)
|
||||
emit("status", "Re-importing context database", status="busy")
|
||||
await scene.commit_to_memory()
|
||||
@@ -254,45 +266,66 @@ class MemoryAgent(Agent):
|
||||
self.actions["_config"].config["embeddings"].choices = self.get_presets
|
||||
return self.actions["_config"].config["embeddings"].choices
|
||||
|
||||
def on_config_saved(self, event):
|
||||
loop = asyncio.get_running_loop()
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
fingerprint = self.fingerprint
|
||||
|
||||
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
|
||||
|
||||
self.config = load_config()
|
||||
new_presets = self.sync_presets()
|
||||
if fingerprint != self.fingerprint:
|
||||
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
|
||||
loop.run_until_complete(self.handle_embeddings_change())
|
||||
|
||||
emit_status = False
|
||||
|
||||
if openai_key != self.openai_api_key:
|
||||
emit_status = True
|
||||
|
||||
if old_presets != new_presets:
|
||||
emit_status = True
|
||||
|
||||
if emit_status:
|
||||
loop.run_until_complete(self.emit_status())
|
||||
|
||||
|
||||
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
|
||||
current_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
|
||||
if current_embeddings == event.client.embeddings_identifier:
|
||||
return
|
||||
|
||||
if not self.using_client_api_embeddings or not self.ready:
|
||||
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
|
||||
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
|
||||
async def fix_broken_embeddings(self):
|
||||
if not self.embeddings_config:
|
||||
self.actions["_config"].config["embeddings"].value = "default"
|
||||
await self.emit_status()
|
||||
await self.handle_embeddings_change()
|
||||
await self.save_config()
|
||||
|
||||
async def on_config_changed(self, event):
|
||||
loop = asyncio.get_running_loop()
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
fingerprint = self.fingerprint
|
||||
|
||||
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
|
||||
|
||||
new_presets = self.sync_presets()
|
||||
if fingerprint != self.fingerprint:
|
||||
log.warning(
|
||||
"memory agent",
|
||||
status="embedding function changed",
|
||||
old=fingerprint,
|
||||
new=self.fingerprint,
|
||||
)
|
||||
loop.run_until_complete(self.handle_embeddings_change())
|
||||
|
||||
emit_status = False
|
||||
|
||||
if openai_key != self.openai_api_key:
|
||||
emit_status = True
|
||||
|
||||
if old_presets != new_presets:
|
||||
emit_status = True
|
||||
|
||||
if emit_status:
|
||||
loop.run_until_complete(self.emit_status())
|
||||
|
||||
await self.fix_broken_embeddings()
|
||||
|
||||
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
|
||||
current_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
|
||||
if current_embeddings == event.client.embeddings_identifier:
|
||||
event.seen = True
|
||||
return
|
||||
|
||||
if not self.using_client_api_embeddings or not self.ready:
|
||||
log.warning(
|
||||
"memory agent - client embeddings available",
|
||||
status="changing embeddings",
|
||||
old=current_embeddings,
|
||||
new=event.client.embeddings_identifier,
|
||||
)
|
||||
self.actions["_config"].config[
|
||||
"embeddings"
|
||||
].value = event.client.embeddings_identifier
|
||||
await self.emit_status()
|
||||
await self.handle_embeddings_change()
|
||||
await self.save_config()
|
||||
event.seen = True
|
||||
|
||||
@set_processing
|
||||
async def set_db(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -301,13 +334,16 @@ class MemoryAgent(Agent):
|
||||
except EmbeddingsModelLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error("memory agent", error="failed to set db", details=traceback.format_exc())
|
||||
|
||||
if "torchvision::nms does not exist" in str(e):
|
||||
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
|
||||
|
||||
raise SetDBError(str(e))
|
||||
log.error(
|
||||
"memory agent", error="failed to set db", details=traceback.format_exc()
|
||||
)
|
||||
|
||||
if "torchvision::nms does not exist" in str(e):
|
||||
raise SetDBError(
|
||||
"The embeddings you are trying to use require the `torchvision` package to be installed"
|
||||
)
|
||||
|
||||
raise SetDBError(str(e))
|
||||
|
||||
def close_db(self):
|
||||
raise NotImplementedError()
|
||||
@@ -426,9 +462,9 @@ class MemoryAgent(Agent):
|
||||
with MemoryRequest(query=text, query_params=query) as active_memory_request:
|
||||
active_memory_request.max_distance = self.max_distance
|
||||
return await asyncio.to_thread(self._get, text, character, **query)
|
||||
#return await loop.run_in_executor(
|
||||
# return await loop.run_in_executor(
|
||||
# None, functools.partial(self._get, text, character, **query)
|
||||
#)
|
||||
# )
|
||||
|
||||
def _get(self, text, character=None, **query):
|
||||
raise NotImplementedError()
|
||||
@@ -528,9 +564,7 @@ class MemoryAgent(Agent):
|
||||
continue
|
||||
|
||||
# Fetch potential memories for this query.
|
||||
raw_results = await self.get(
|
||||
formatter(query), limit=limit, **where
|
||||
)
|
||||
raw_results = await self.get(formatter(query), limit=limit, **where)
|
||||
|
||||
# Apply filter and respect the `iterate` limit for this query.
|
||||
accepted: list[str] = []
|
||||
@@ -591,9 +625,9 @@ class MemoryAgent(Agent):
|
||||
|
||||
Returns a dictionary with 'cosine_similarity' and 'euclidean_distance'.
|
||||
"""
|
||||
|
||||
|
||||
embed_fn = self.embedding_function
|
||||
|
||||
|
||||
# Embed the two strings
|
||||
vec1 = np.array(embed_fn([string1])[0])
|
||||
vec2 = np.array(embed_fn([string2])[0])
|
||||
@@ -604,17 +638,14 @@ class MemoryAgent(Agent):
|
||||
# Compute Euclidean distance
|
||||
euclidean_dist = np.linalg.norm(vec1 - vec2)
|
||||
|
||||
return {
|
||||
"cosine_similarity": cosine_sim,
|
||||
"euclidean_distance": euclidean_dist
|
||||
}
|
||||
return {"cosine_similarity": cosine_sim, "euclidean_distance": euclidean_dist}
|
||||
|
||||
async def compare_string_lists(
|
||||
self,
|
||||
list_a: list[str],
|
||||
list_b: list[str],
|
||||
similarity_threshold: float = None,
|
||||
distance_threshold: float = None
|
||||
distance_threshold: float = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Compare two lists of strings using the current embedding function without touching the database.
|
||||
@@ -625,15 +656,21 @@ class MemoryAgent(Agent):
|
||||
- 'similarity_matches': list of (i, j, score) (filtered if threshold set, otherwise all)
|
||||
- 'distance_matches': list of (i, j, distance) (filtered if threshold set, otherwise all)
|
||||
"""
|
||||
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
|
||||
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
|
||||
if (
|
||||
not self.db
|
||||
or not hasattr(self.db, "_embedding_function")
|
||||
or self.db._embedding_function is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Embedding function is not initialized. Make sure the database is set."
|
||||
)
|
||||
|
||||
if not list_a or not list_b:
|
||||
return {
|
||||
"cosine_similarity_matrix": np.array([]),
|
||||
"euclidean_distance_matrix": np.array([]),
|
||||
"similarity_matches": [],
|
||||
"distance_matches": []
|
||||
"distance_matches": [],
|
||||
}
|
||||
|
||||
embed_fn = self.db._embedding_function
|
||||
@@ -653,21 +690,29 @@ class MemoryAgent(Agent):
|
||||
cosine_similarity_matrix = np.dot(vecs_a_norm, vecs_b_norm.T)
|
||||
|
||||
# Euclidean distance matrix
|
||||
a_squared = np.sum(vecs_a ** 2, axis=1).reshape(-1, 1)
|
||||
b_squared = np.sum(vecs_b ** 2, axis=1).reshape(1, -1)
|
||||
euclidean_distance_matrix = np.sqrt(a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T))
|
||||
a_squared = np.sum(vecs_a**2, axis=1).reshape(-1, 1)
|
||||
b_squared = np.sum(vecs_b**2, axis=1).reshape(1, -1)
|
||||
euclidean_distance_matrix = np.sqrt(
|
||||
a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T)
|
||||
)
|
||||
|
||||
# Prepare matches
|
||||
similarity_matches = []
|
||||
distance_matches = []
|
||||
|
||||
# Populate similarity matches
|
||||
sim_indices = np.argwhere(cosine_similarity_matrix >= (similarity_threshold if similarity_threshold is not None else -np.inf))
|
||||
sim_indices = np.argwhere(
|
||||
cosine_similarity_matrix
|
||||
>= (similarity_threshold if similarity_threshold is not None else -np.inf)
|
||||
)
|
||||
for i, j in sim_indices:
|
||||
similarity_matches.append((i, j, cosine_similarity_matrix[i, j]))
|
||||
|
||||
# Populate distance matches
|
||||
dist_indices = np.argwhere(euclidean_distance_matrix <= (distance_threshold if distance_threshold is not None else np.inf))
|
||||
dist_indices = np.argwhere(
|
||||
euclidean_distance_matrix
|
||||
<= (distance_threshold if distance_threshold is not None else np.inf)
|
||||
)
|
||||
for i, j in dist_indices:
|
||||
distance_matches.append((i, j, euclidean_distance_matrix[i, j]))
|
||||
|
||||
@@ -675,9 +720,10 @@ class MemoryAgent(Agent):
|
||||
"cosine_similarity_matrix": cosine_similarity_matrix,
|
||||
"euclidean_distance_matrix": euclidean_distance_matrix,
|
||||
"similarity_matches": similarity_matches,
|
||||
"distance_matches": distance_matches
|
||||
"distance_matches": distance_matches,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@register(condition=lambda: chromadb is not None)
|
||||
class ChromaDBMemoryAgent(MemoryAgent):
|
||||
requires_llm_client = False
|
||||
@@ -690,32 +736,34 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
if getattr(self, "db_client", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def client_api_ready(self) -> bool:
|
||||
if self.using_client_api_embeddings:
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
embeddings_client: ClientBase | None = instance.get_client(
|
||||
self.embeddings_client
|
||||
)
|
||||
if not embeddings_client:
|
||||
return False
|
||||
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
return False
|
||||
|
||||
|
||||
if not embeddings_client.embeddings_status:
|
||||
return False
|
||||
|
||||
|
||||
if embeddings_client.current_status not in ["idle", "busy"]:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.using_client_api_embeddings and not self.client_api_ready:
|
||||
return "error"
|
||||
|
||||
|
||||
if self.ready:
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
|
||||
@@ -726,7 +774,6 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
@@ -738,23 +785,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
value=self.embeddings,
|
||||
description="The embeddings type.",
|
||||
).model_dump(),
|
||||
|
||||
}
|
||||
|
||||
|
||||
if self.model:
|
||||
details["model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.model,
|
||||
description="The embeddings model.",
|
||||
).model_dump()
|
||||
|
||||
|
||||
if self.embeddings_client:
|
||||
details["client"] = AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.embeddings_client,
|
||||
description="The client to use for embeddings.",
|
||||
).model_dump()
|
||||
|
||||
|
||||
if self.using_local_embeddings:
|
||||
details["device"] = AgentDetail(
|
||||
icon="mdi-memory",
|
||||
@@ -770,9 +816,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
|
||||
if self.using_client_api_embeddings:
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
embeddings_client: ClientBase | None = instance.get_client(
|
||||
self.embeddings_client
|
||||
)
|
||||
|
||||
if not embeddings_client:
|
||||
details["error"] = {
|
||||
@@ -784,7 +832,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
return details
|
||||
|
||||
client_name = embeddings_client.name
|
||||
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
error_message = f"{client_name} does not support embeddings"
|
||||
elif embeddings_client.current_status not in ["idle", "busy"]:
|
||||
@@ -793,7 +841,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
error_message = f"{client_name} has no embeddings model loaded"
|
||||
else:
|
||||
error_message = None
|
||||
|
||||
|
||||
if error_message:
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
@@ -810,12 +858,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
return self.config.openai.api_key
|
||||
|
||||
@property
|
||||
def embedding_function(self) -> Callable:
|
||||
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
|
||||
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
|
||||
if (
|
||||
not self.db
|
||||
or not hasattr(self.db, "_embedding_function")
|
||||
or self.db._embedding_function is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Embedding function is not initialized. Make sure the database is set."
|
||||
)
|
||||
|
||||
embed_fn = self.db._embedding_function
|
||||
return embed_fn
|
||||
@@ -823,18 +877,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def make_collection_name(self, scene) -> str:
|
||||
# generate plain text collection name
|
||||
collection_name = f"{self.fingerprint}"
|
||||
|
||||
|
||||
# chromadb collection names have the following rules:
|
||||
# Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address
|
||||
|
||||
# Step 1: Hash the input string using MD5
|
||||
md5_hash = hashlib.md5(collection_name.encode()).hexdigest()
|
||||
|
||||
|
||||
# Step 2: Ensure the result is exactly 32 characters long
|
||||
hashed_collection_name = md5_hash[:32]
|
||||
|
||||
|
||||
return f"{scene.memory_id}-tm-{hashed_collection_name}"
|
||||
|
||||
|
||||
async def count(self):
|
||||
await asyncio.sleep(0)
|
||||
return self.db.count()
|
||||
@@ -853,14 +907,17 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
self.collection_name = collection_name = self.make_collection_name(self.scene)
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
|
||||
"chromadb agent",
|
||||
status="setting up db",
|
||||
collection_name=collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
|
||||
distance_function = self.distance_function
|
||||
collection_metadata = {"hnsw:space": distance_function}
|
||||
device = self.actions["_config"].config["device"].value
|
||||
device = self.actions["_config"].config["device"].value
|
||||
model_name = self.model
|
||||
|
||||
|
||||
if self.using_openai_embeddings:
|
||||
if not openai_key:
|
||||
raise ValueError(
|
||||
@@ -878,7 +935,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
model_name=model_name,
|
||||
)
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef, metadata=collection_metadata
|
||||
collection_name,
|
||||
embedding_function=openai_ef,
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
elif self.using_client_api_embeddings:
|
||||
log.info(
|
||||
@@ -886,20 +945,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
embeddings="Client API",
|
||||
client=self.embeddings_client,
|
||||
)
|
||||
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
|
||||
embeddings_client: ClientBase | None = instance.get_client(
|
||||
self.embeddings_client
|
||||
)
|
||||
if not embeddings_client:
|
||||
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
|
||||
|
||||
raise ValueError(
|
||||
f"Client API embeddings client {self.embeddings_client} not found"
|
||||
)
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
|
||||
|
||||
raise ValueError(
|
||||
f"Client API embeddings client {self.embeddings_client} does not support embeddings"
|
||||
)
|
||||
|
||||
ef = embeddings_client.embeddings_function
|
||||
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef, metadata=collection_metadata
|
||||
)
|
||||
|
||||
|
||||
elif self.using_instructor_embeddings:
|
||||
log.info(
|
||||
"chromadb",
|
||||
@@ -909,7 +974,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
)
|
||||
|
||||
ef = embedding_functions.InstructorEmbeddingFunction(
|
||||
model_name=model_name, device=device, instruction="Represent the document for retrieval:"
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
instruction="Represent the document for retrieval:",
|
||||
)
|
||||
|
||||
log.info("chromadb", status="embedding function ready")
|
||||
@@ -919,25 +986,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
log.info("chromadb", status="instructor db ready")
|
||||
else:
|
||||
log.info(
|
||||
"chromadb",
|
||||
embeddings="SentenceTransformer",
|
||||
"chromadb",
|
||||
embeddings="SentenceTransformer",
|
||||
model=model_name,
|
||||
device=device,
|
||||
distance_function=distance_function
|
||||
distance_function=distance_function,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
device=device
|
||||
device=device,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "`trust_remote_code=True` to remove this error" in str(e):
|
||||
raise EmbeddingsModelLoadError(model_name, "Model requires `Trust remote code` to be enabled")
|
||||
raise EmbeddingsModelLoadError(
|
||||
model_name, "Model requires `Trust remote code` to be enabled"
|
||||
)
|
||||
raise EmbeddingsModelLoadError(model_name, str(e))
|
||||
|
||||
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef, metadata=collection_metadata
|
||||
)
|
||||
@@ -989,9 +1057,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
try:
|
||||
self.db_client.delete_collection(collection_name)
|
||||
except chromadb.errors.NotFoundError as exc:
|
||||
log.error(
|
||||
"chromadb agent", error="collection not found", details=exc
|
||||
)
|
||||
log.error("chromadb agent", error="collection not found", details=exc)
|
||||
except ValueError as exc:
|
||||
log.error(
|
||||
"chromadb agent", error="failed to delete collection", details=exc
|
||||
@@ -1049,16 +1115,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
if not objects:
|
||||
return
|
||||
|
||||
|
||||
# track seen documents by id
|
||||
seen_ids = set()
|
||||
|
||||
for obj in objects:
|
||||
if obj["id"] in seen_ids:
|
||||
log.warning("chromadb agent", status="duplicate id discarded", id=obj["id"])
|
||||
log.warning(
|
||||
"chromadb agent", status="duplicate id discarded", id=obj["id"]
|
||||
)
|
||||
continue
|
||||
seen_ids.add(obj["id"])
|
||||
|
||||
|
||||
documents.append(obj["text"])
|
||||
meta = obj.get("meta", {})
|
||||
source = meta.get("source", "talemate")
|
||||
@@ -1071,7 +1139,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
metadatas.append(meta)
|
||||
uid = obj.get("id", f"{character}-{self.memory_tracker[character]}")
|
||||
ids.append(uid)
|
||||
|
||||
|
||||
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def _delete(self, meta: dict):
|
||||
@@ -1081,11 +1149,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
return
|
||||
|
||||
where = {"$and": [{k: v} for k, v in meta.items()]}
|
||||
|
||||
|
||||
# if there is only one item in $and reduce it to the key value pair
|
||||
if len(where["$and"]) == 1:
|
||||
where = where["$and"][0]
|
||||
|
||||
|
||||
self.db.delete(where=where)
|
||||
log.debug("chromadb agent delete", meta=meta, where=where)
|
||||
|
||||
@@ -1122,16 +1190,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
log.error("chromadb agent", error="failed to query", details=e)
|
||||
return []
|
||||
|
||||
#import json
|
||||
#print(json.dumps(_results["ids"], indent=2))
|
||||
#print(json.dumps(_results["distances"], indent=2))
|
||||
# import json
|
||||
# print(json.dumps(_results["ids"], indent=2))
|
||||
# print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
results = []
|
||||
|
||||
max_distance = self.max_distance
|
||||
|
||||
|
||||
closest = None
|
||||
|
||||
|
||||
active_memory_request = memory_request.get()
|
||||
|
||||
for i in range(len(_results["distances"][0])):
|
||||
@@ -1139,7 +1207,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
doc = _results["documents"][0][i]
|
||||
meta = _results["metadatas"][0][i]
|
||||
|
||||
|
||||
active_memory_request.add_result(doc, distance, meta)
|
||||
|
||||
if not meta:
|
||||
@@ -1150,7 +1218,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
# skip pin_only entries
|
||||
if meta.get("pin_only", False):
|
||||
continue
|
||||
|
||||
|
||||
if closest is None:
|
||||
closest = {"distance": distance, "doc": doc}
|
||||
elif distance < closest["distance"]:
|
||||
@@ -1172,13 +1240,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
if len(results) > limit:
|
||||
break
|
||||
|
||||
|
||||
log.debug("chromadb agent get", closest=closest, max_distance=max_distance)
|
||||
self.last_query = {
|
||||
"query": text,
|
||||
"closest": closest,
|
||||
}
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def convert_ts_to_date_prefix(self, ts):
|
||||
@@ -1197,10 +1265,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
return None
|
||||
|
||||
def _get_document(self, id) -> dict:
|
||||
|
||||
if not id:
|
||||
return {}
|
||||
|
||||
|
||||
result = self.db.get(ids=[id] if isinstance(id, str) else id)
|
||||
documents = {}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Context manager that collects and tracks memory agent requests
|
||||
Context manager that collects and tracks memory agent requests
|
||||
for profiling and debugging purposes
|
||||
"""
|
||||
|
||||
@@ -11,91 +11,118 @@ import time
|
||||
from talemate.emit import emit
|
||||
from talemate.agents.context import active_agent
|
||||
|
||||
__all__ = [
|
||||
"MemoryRequest",
|
||||
"start_memory_request"
|
||||
"MemoryRequestState"
|
||||
"memory_request"
|
||||
]
|
||||
__all__ = ["MemoryRequest"]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
DEBUG_MEMORY_REQUESTS = False
|
||||
|
||||
|
||||
class MemoryRequestResult(pydantic.BaseModel):
|
||||
doc: str
|
||||
distance: float
|
||||
meta: dict = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
|
||||
class MemoryRequestState(pydantic.BaseModel):
|
||||
query:str
|
||||
query: str
|
||||
results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
|
||||
accepted_results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
|
||||
query_params: dict = pydantic.Field(default_factory=dict)
|
||||
closest_distance: float | None = None
|
||||
furthest_distance: float | None = None
|
||||
max_distance: float | None = None
|
||||
|
||||
def add_result(self, doc:str, distance:float, meta:dict):
|
||||
|
||||
|
||||
def add_result(self, doc: str, distance: float, meta: dict):
|
||||
if doc is None:
|
||||
return
|
||||
|
||||
|
||||
self.results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
|
||||
self.closest_distance = min(self.closest_distance, distance) if self.closest_distance is not None else distance
|
||||
self.furthest_distance = max(self.furthest_distance, distance) if self.furthest_distance is not None else distance
|
||||
|
||||
def accept_result(self, doc:str, distance:float, meta:dict):
|
||||
|
||||
self.closest_distance = (
|
||||
min(self.closest_distance, distance)
|
||||
if self.closest_distance is not None
|
||||
else distance
|
||||
)
|
||||
self.furthest_distance = (
|
||||
max(self.furthest_distance, distance)
|
||||
if self.furthest_distance is not None
|
||||
else distance
|
||||
)
|
||||
|
||||
def accept_result(self, doc: str, distance: float, meta: dict):
|
||||
if doc is None:
|
||||
return
|
||||
|
||||
self.accepted_results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
|
||||
|
||||
|
||||
self.accepted_results.append(
|
||||
MemoryRequestResult(doc=doc, distance=distance, meta=meta)
|
||||
)
|
||||
|
||||
@property
|
||||
def closest_text(self):
|
||||
return str(self.results[0].doc) if self.results else None
|
||||
|
||||
|
||||
|
||||
memory_request = contextvars.ContextVar("memory_request", default=None)
|
||||
|
||||
|
||||
class MemoryRequest:
|
||||
|
||||
def __init__(self, query:str, query_params:dict=None):
|
||||
def __init__(self, query: str, query_params: dict = None):
|
||||
self.query = query
|
||||
self.query_params = query_params
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.state = MemoryRequestState(query=self.query, query_params=self.query_params)
|
||||
self.state = MemoryRequestState(
|
||||
query=self.query, query_params=self.query_params
|
||||
)
|
||||
self.token = memory_request.set(self.state)
|
||||
self.time_start = time.time()
|
||||
return self.state
|
||||
|
||||
|
||||
def __exit__(self, *args):
|
||||
|
||||
self.time_end = time.time()
|
||||
|
||||
|
||||
if DEBUG_MEMORY_REQUESTS:
|
||||
max_length = 50
|
||||
query = self.state.query[:max_length]+"..." if len(self.state.query) > max_length else self.state.query
|
||||
log.debug("MemoryRequest", number_of_results=len(self.state.results), query=query)
|
||||
log.debug("MemoryRequest", number_of_accepted_results=len(self.state.accepted_results), query=query)
|
||||
|
||||
query = (
|
||||
self.state.query[:max_length] + "..."
|
||||
if len(self.state.query) > max_length
|
||||
else self.state.query
|
||||
)
|
||||
log.debug(
|
||||
"MemoryRequest", number_of_results=len(self.state.results), query=query
|
||||
)
|
||||
log.debug(
|
||||
"MemoryRequest",
|
||||
number_of_accepted_results=len(self.state.accepted_results),
|
||||
query=query,
|
||||
)
|
||||
|
||||
for result in self.state.results:
|
||||
# distance to 2 decimal places
|
||||
log.debug("MemoryRequest RESULT", distance=f"{result.distance:.2f}", doc=result.doc[:max_length]+"...")
|
||||
|
||||
log.debug(
|
||||
"MemoryRequest RESULT",
|
||||
distance=f"{result.distance:.2f}",
|
||||
doc=result.doc[:max_length] + "...",
|
||||
)
|
||||
|
||||
agent_context = active_agent.get()
|
||||
|
||||
emit("memory_request", data=self.state.model_dump(), meta={
|
||||
"agent_stack": agent_context.agent_stack if agent_context else [],
|
||||
"agent_stack_uid": agent_context.agent_stack_uid if agent_context else None,
|
||||
"duration": self.time_end - self.time_start,
|
||||
}, websocket_passthrough=True)
|
||||
|
||||
|
||||
emit(
|
||||
"memory_request",
|
||||
data=self.state.model_dump(),
|
||||
meta={
|
||||
"agent_stack": agent_context.agent_stack if agent_context else [],
|
||||
"agent_stack_uid": agent_context.agent_stack_uid
|
||||
if agent_context
|
||||
else None,
|
||||
"duration": self.time_end - self.time_start,
|
||||
},
|
||||
websocket_passthrough=True,
|
||||
)
|
||||
|
||||
memory_request.reset(self.token)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# decorator that opens a memory request context
|
||||
async def start_memory_request(query):
|
||||
@@ -103,5 +130,7 @@ async def start_memory_request(query):
|
||||
async def wrapper(*args, **kwargs):
|
||||
with MemoryRequest(query):
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
__all__ = [
|
||||
'EmbeddingsModelLoadError',
|
||||
'MemoryAgentError',
|
||||
'SetDBError'
|
||||
]
|
||||
__all__ = ["EmbeddingsModelLoadError", "MemoryAgentError", "SetDBError"]
|
||||
|
||||
|
||||
class MemoryAgentError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SetDBError(OSError, MemoryAgentError):
|
||||
|
||||
def __init__(self, details:str):
|
||||
def __init__(self, details: str):
|
||||
super().__init__(f"Memory Agent - Failed to set up the database: {details}")
|
||||
|
||||
|
||||
class EmbeddingsModelLoadError(ValueError, MemoryAgentError):
|
||||
|
||||
def __init__(self, model_name:str, details:str):
|
||||
super().__init__(f"Memory Agent - Failed to load embeddings model {model_name}: {details}")
|
||||
def __init__(self, model_name: str, details: str):
|
||||
super().__init__(
|
||||
f"Memory Agent - Failed to load embeddings model {model_name}: {details}"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
)
|
||||
from talemate.emit import emit
|
||||
import talemate.instance as instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,11 +13,10 @@ __all__ = ["MemoryRAGMixin"]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class MemoryRAGMixin:
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
|
||||
actions["use_long_term_memory"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
@@ -44,7 +42,7 @@ class MemoryRAGMixin:
|
||||
{
|
||||
"label": "AI compiled question and answers (slow)",
|
||||
"value": "questions",
|
||||
}
|
||||
},
|
||||
],
|
||||
),
|
||||
"number_of_queries": AgentActionConfig(
|
||||
@@ -65,7 +63,7 @@ class MemoryRAGMixin:
|
||||
{"label": "Short (256)", "value": "256"},
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
]
|
||||
],
|
||||
),
|
||||
"cache": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -73,16 +71,16 @@ class MemoryRAGMixin:
|
||||
description="Cache the long term memory for faster retrieval.",
|
||||
note="This is a cross-agent cache, assuming they use the same options.",
|
||||
value=True,
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_enabled(self):
|
||||
return self.actions["use_long_term_memory"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_retrieval_method(self):
|
||||
return self.actions["use_long_term_memory"].config["retrieval_method"].value
|
||||
@@ -90,60 +88,60 @@ class MemoryRAGMixin:
|
||||
@property
|
||||
def long_term_memory_number_of_queries(self):
|
||||
return self.actions["use_long_term_memory"].config["number_of_queries"].value
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_answer_length(self):
|
||||
return int(self.actions["use_long_term_memory"].config["answer_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_cache(self):
|
||||
return self.actions["use_long_term_memory"].config["cache"].value
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_cache_key(self):
|
||||
"""
|
||||
Build the key from the various options
|
||||
"""
|
||||
|
||||
|
||||
parts = [
|
||||
self.long_term_memory_retrieval_method,
|
||||
self.long_term_memory_number_of_queries,
|
||||
self.long_term_memory_answer_length
|
||||
self.long_term_memory_answer_length,
|
||||
]
|
||||
|
||||
|
||||
return "-".join(map(str, parts))
|
||||
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
|
||||
|
||||
# new scene, reset cache
|
||||
scene.rag_cache = {}
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
async def rag_set_cache(self, content:list[str]):
|
||||
|
||||
async def rag_set_cache(self, content: list[str]):
|
||||
self.scene.rag_cache[self.long_term_memory_cache_key] = {
|
||||
"content": content,
|
||||
"fingerprint": self.scene.history[-1].fingerprint if self.scene.history else 0
|
||||
"fingerprint": self.scene.history[-1].fingerprint
|
||||
if self.scene.history
|
||||
else 0,
|
||||
}
|
||||
|
||||
|
||||
async def rag_get_cache(self) -> list[str] | None:
|
||||
|
||||
if not self.long_term_memory_cache:
|
||||
return None
|
||||
|
||||
|
||||
fingerprint = self.scene.history[-1].fingerprint if self.scene.history else 0
|
||||
cache = self.scene.rag_cache.get(self.long_term_memory_cache_key)
|
||||
|
||||
|
||||
if cache and cache["fingerprint"] == fingerprint:
|
||||
return cache["content"]
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def rag_build(
|
||||
self,
|
||||
character: "Character | None" = None,
|
||||
self,
|
||||
character: "Character | None" = None,
|
||||
prompt: str = "",
|
||||
sub_instruction: str = "",
|
||||
) -> list[str]:
|
||||
@@ -153,37 +151,41 @@ class MemoryRAGMixin:
|
||||
|
||||
if not self.long_term_memory_enabled:
|
||||
return []
|
||||
|
||||
|
||||
cached = await self.rag_get_cache()
|
||||
|
||||
|
||||
if cached:
|
||||
log.debug(f"Using cached long term memory", agent=self.agent_type, key=self.long_term_memory_cache_key)
|
||||
log.debug(
|
||||
"Using cached long term memory",
|
||||
agent=self.agent_type,
|
||||
key=self.long_term_memory_cache_key,
|
||||
)
|
||||
return cached
|
||||
|
||||
memory_context = ""
|
||||
retrieval_method = self.long_term_memory_retrieval_method
|
||||
|
||||
|
||||
if not sub_instruction:
|
||||
if character:
|
||||
sub_instruction = f"continue the scene as {character.name}"
|
||||
elif hasattr(self, "rag_build_sub_instruction"):
|
||||
sub_instruction = await self.rag_build_sub_instruction()
|
||||
|
||||
|
||||
if not sub_instruction:
|
||||
sub_instruction = "continue the scene"
|
||||
|
||||
|
||||
if retrieval_method != "direct":
|
||||
world_state = instance.get_agent("world_state")
|
||||
|
||||
|
||||
if not prompt:
|
||||
prompt = self.scene.context_history(
|
||||
keep_director=False,
|
||||
budget=int(self.client.max_token_length * 0.75),
|
||||
)
|
||||
|
||||
|
||||
if isinstance(prompt, list):
|
||||
prompt = "\n".join(prompt)
|
||||
|
||||
|
||||
log.debug(
|
||||
"memory_rag_mixin.build_prompt_default_memory",
|
||||
direct=False,
|
||||
@@ -193,20 +195,21 @@ class MemoryRAGMixin:
|
||||
if retrieval_method == "questions":
|
||||
memory_context = (
|
||||
await world_state.analyze_text_and_extract_context(
|
||||
prompt, sub_instruction,
|
||||
prompt,
|
||||
sub_instruction,
|
||||
include_character_context=True,
|
||||
response_length=self.long_term_memory_answer_length,
|
||||
num_queries=self.long_term_memory_number_of_queries
|
||||
num_queries=self.long_term_memory_number_of_queries,
|
||||
)
|
||||
).split("\n")
|
||||
elif retrieval_method == "queries":
|
||||
memory_context = (
|
||||
await world_state.analyze_text_and_extract_context_via_queries(
|
||||
prompt, sub_instruction,
|
||||
prompt,
|
||||
sub_instruction,
|
||||
include_character_context=True,
|
||||
response_length=self.long_term_memory_answer_length,
|
||||
num_queries=self.long_term_memory_number_of_queries
|
||||
|
||||
num_queries=self.long_term_memory_number_of_queries,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -223,4 +226,4 @@ class MemoryRAGMixin:
|
||||
|
||||
await self.rag_set_cache(memory_context)
|
||||
|
||||
return memory_context
|
||||
return memory_context
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import random
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
@@ -50,15 +49,18 @@ log = structlog.get_logger("talemate.agents.narrator")
|
||||
class NarratorAgentEmission(AgentEmission):
|
||||
generation: list[str] = dataclasses.field(default_factory=list)
|
||||
response: str = dataclasses.field(default="")
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.narrator.before_generate",
|
||||
"agent.narrator.before_generate",
|
||||
"agent.narrator.inject_instructions",
|
||||
"agent.narrator.generated",
|
||||
)
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
@@ -74,11 +76,15 @@ def set_processing(fn):
|
||||
if self.content_use_writing_style:
|
||||
self.set_context_states(writing_style=self.scene.writing_style)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.narrator.inject_instructions").send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(
|
||||
emission
|
||||
)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.narrator.inject_instructions"
|
||||
).send(emission)
|
||||
|
||||
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission.response = response
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
@@ -88,10 +94,7 @@ def set_processing(fn):
|
||||
|
||||
|
||||
@register()
|
||||
class NarratorAgent(
|
||||
MemoryRAGMixin,
|
||||
Agent
|
||||
):
|
||||
class NarratorAgent(MemoryRAGMixin, Agent):
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
@@ -99,7 +102,7 @@ class NarratorAgent(
|
||||
agent_type = "narrator"
|
||||
verbose_name = "Narrator"
|
||||
set_processing = set_processing
|
||||
|
||||
|
||||
websocket_handler = NarratorWebsocketHandler
|
||||
|
||||
@classmethod
|
||||
@@ -117,7 +120,7 @@ class NarratorAgent(
|
||||
min=32,
|
||||
max=1024,
|
||||
step=32,
|
||||
),
|
||||
),
|
||||
"instructions": AgentActionConfig(
|
||||
type="text",
|
||||
label="Instructions",
|
||||
@@ -135,11 +138,6 @@ class NarratorAgent(
|
||||
),
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"content": AgentAction(
|
||||
enabled=True,
|
||||
can_be_disabled=False,
|
||||
@@ -154,7 +152,7 @@ class NarratorAgent(
|
||||
description="Use the writing style selected in the scene settings",
|
||||
value=True,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
"narrate_time_passage": AgentAction(
|
||||
enabled=True,
|
||||
@@ -201,13 +199,13 @@ class NarratorAgent(
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.TaleMateClient,
|
||||
client: client.ClientBase | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
@@ -232,28 +230,33 @@ class NarratorAgent(
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["length"].value
|
||||
return 128
|
||||
|
||||
|
||||
@property
|
||||
def narrate_time_passage_enabled(self) -> bool:
|
||||
return self.actions["narrate_time_passage"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def narrate_dialogue_enabled(self) -> bool:
|
||||
return self.actions["narrate_dialogue"].enabled
|
||||
|
||||
@property
|
||||
def narrate_dialogue_ai_chance(self) -> float:
|
||||
def narrate_dialogue_ai_chance(self) -> float:
|
||||
return self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
|
||||
|
||||
@property
|
||||
def narrate_dialogue_player_chance(self) -> float:
|
||||
return self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
|
||||
|
||||
@property
|
||||
def content_use_writing_style(self) -> bool:
|
||||
return self.actions["content"].config["use_writing_style"].value
|
||||
|
||||
def clean_result(self, result:str, ensure_dialog_format:bool=True, force_narrative:bool=True) -> str:
|
||||
|
||||
def clean_result(
|
||||
self,
|
||||
result: str,
|
||||
ensure_dialog_format: bool = True,
|
||||
force_narrative: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Cleans the result of a narration
|
||||
"""
|
||||
@@ -264,15 +267,14 @@ class NarratorAgent(
|
||||
|
||||
cleaned = []
|
||||
for line in result.split("\n"):
|
||||
|
||||
# skip lines that start with a #
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
|
||||
|
||||
log.debug("clean_result", line=line)
|
||||
|
||||
|
||||
character_dialogue_detected = False
|
||||
|
||||
|
||||
for character_name in character_names:
|
||||
if not character_name:
|
||||
continue
|
||||
@@ -280,25 +282,24 @@ class NarratorAgent(
|
||||
character_dialogue_detected = True
|
||||
elif line.startswith(f"{character_name.upper()}"):
|
||||
character_dialogue_detected = True
|
||||
|
||||
|
||||
if character_dialogue_detected:
|
||||
break
|
||||
|
||||
|
||||
if character_dialogue_detected:
|
||||
break
|
||||
|
||||
|
||||
cleaned.append(line)
|
||||
|
||||
result = "\n".join(cleaned)
|
||||
|
||||
|
||||
result = util.strip_partial_sentences(result)
|
||||
editor = get_agent("editor")
|
||||
|
||||
|
||||
if ensure_dialog_format or force_narrative:
|
||||
if editor.fix_exposition_enabled and editor.fix_exposition_narrator:
|
||||
result = editor.fix_exposition_in_text(result)
|
||||
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def connect(self, scene):
|
||||
@@ -324,16 +325,16 @@ class NarratorAgent(
|
||||
event.duration, event.human_duration, event.narrative
|
||||
)
|
||||
narrator_message = NarratorMessage(
|
||||
response,
|
||||
meta = {
|
||||
response,
|
||||
meta={
|
||||
"agent": "narrator",
|
||||
"function": "narrate_time_passage",
|
||||
"arguments": {
|
||||
"duration": event.duration,
|
||||
"time_passed": event.human_duration,
|
||||
"narrative_direction": event.narrative,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
@@ -345,7 +346,7 @@ class NarratorAgent(
|
||||
|
||||
if not self.narrate_dialogue_enabled:
|
||||
return
|
||||
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug(
|
||||
"narrate on dialog",
|
||||
@@ -375,14 +376,14 @@ class NarratorAgent(
|
||||
|
||||
response = await self.narrate_after_dialogue(event.actor.character)
|
||||
narrator_message = NarratorMessage(
|
||||
response,
|
||||
response,
|
||||
meta={
|
||||
"agent": "narrator",
|
||||
"function": "narrate_after_dialogue",
|
||||
"arguments": {
|
||||
"character": event.actor.character.name,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
@@ -390,7 +391,7 @@ class NarratorAgent(
|
||||
event.game_loop.had_passive_narration = True
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', visual_narration=True)
|
||||
@store_context_state("narrative_direction", visual_narration=True)
|
||||
async def narrate_scene(self, narrative_direction: str | None = None):
|
||||
"""
|
||||
Narrate the scene
|
||||
@@ -413,13 +414,13 @@ class NarratorAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction')
|
||||
@store_context_state("narrative_direction")
|
||||
async def progress_story(self, narrative_direction: str | None = None):
|
||||
"""
|
||||
Narrate scene progression, moving the plot forward.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- narrative_direction: A string describing the direction the narrative should take. If not provided, will attempt to subtly move the story forward.
|
||||
"""
|
||||
|
||||
@@ -431,10 +432,8 @@ class NarratorAgent(
|
||||
if narrative_direction is None:
|
||||
narrative_direction = "Slightly move the current scene forward."
|
||||
|
||||
log.debug(
|
||||
"narrative_direction", narrative_direction=narrative_direction
|
||||
)
|
||||
|
||||
log.debug("narrative_direction", narrative_direction=narrative_direction)
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-progress",
|
||||
self.client,
|
||||
@@ -453,13 +452,17 @@ class NarratorAgent(
|
||||
log.debug("progress_story", response=response)
|
||||
|
||||
response = self.clean_result(response.strip())
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('query', query_narration=True)
|
||||
@store_context_state("query", query_narration=True)
|
||||
async def narrate_query(
|
||||
self, query: str, at_the_end: bool = False, as_narrative: bool = True, extra_context: str = None
|
||||
self,
|
||||
query: str,
|
||||
at_the_end: bool = False,
|
||||
as_narrative: bool = True,
|
||||
extra_context: str = None,
|
||||
):
|
||||
"""
|
||||
Narrate a specific query
|
||||
@@ -479,20 +482,20 @@ class NarratorAgent(
|
||||
},
|
||||
)
|
||||
response = self.clean_result(
|
||||
response.strip(),
|
||||
ensure_dialog_format=False,
|
||||
force_narrative=as_narrative
|
||||
response.strip(), ensure_dialog_format=False, force_narrative=as_narrative
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
@store_context_state('character', 'narrative_direction', visual_narration=True)
|
||||
async def narrate_character(self, character:"Character", narrative_direction: str = None):
|
||||
@store_context_state("character", "narrative_direction", visual_narration=True)
|
||||
async def narrate_character(
|
||||
self, character: "Character", narrative_direction: str = None
|
||||
):
|
||||
"""
|
||||
Narrate a specific character
|
||||
"""
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-character",
|
||||
self.client,
|
||||
@@ -506,12 +509,14 @@ class NarratorAgent(
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip(), ensure_dialog_format=False, force_narrative=True)
|
||||
response = self.clean_result(
|
||||
response.strip(), ensure_dialog_format=False, force_narrative=True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', time_narration=True)
|
||||
@store_context_state("narrative_direction", time_narration=True)
|
||||
async def narrate_time_passage(
|
||||
self, duration: str, time_passed: str, narrative_direction: str
|
||||
):
|
||||
@@ -528,7 +533,7 @@ class NarratorAgent(
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"duration": duration,
|
||||
"time_passed": time_passed,
|
||||
"narrative": narrative_direction, # backwards compatibility
|
||||
"narrative": narrative_direction, # backwards compatibility
|
||||
"narrative_direction": narrative_direction,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
},
|
||||
@@ -541,7 +546,7 @@ class NarratorAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', sensory_narration=True)
|
||||
@store_context_state("narrative_direction", sensory_narration=True)
|
||||
async def narrate_after_dialogue(
|
||||
self,
|
||||
character: Character,
|
||||
@@ -572,16 +577,16 @@ class NarratorAgent(
|
||||
async def narrate_environment(self, narrative_direction: str = None):
|
||||
"""
|
||||
Narrate the environment
|
||||
|
||||
|
||||
Wraps narrate_after_dialogue with the player character
|
||||
as the perspective character
|
||||
"""
|
||||
|
||||
|
||||
pc = self.scene.get_player_character()
|
||||
return await self.narrate_after_dialogue(pc, narrative_direction)
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', 'character')
|
||||
@store_context_state("narrative_direction", "character")
|
||||
async def narrate_character_entry(
|
||||
self, character: Character, narrative_direction: str = None
|
||||
):
|
||||
@@ -607,11 +612,9 @@ class NarratorAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', 'character')
|
||||
@store_context_state("narrative_direction", "character")
|
||||
async def narrate_character_exit(
|
||||
self,
|
||||
character: Character,
|
||||
narrative_direction: str = None
|
||||
self, character: Character, narrative_direction: str = None
|
||||
):
|
||||
"""
|
||||
Narrate a character exiting the scene
|
||||
@@ -679,8 +682,10 @@ class NarratorAgent(
|
||||
later on
|
||||
"""
|
||||
args = parameters.copy()
|
||||
|
||||
if args.get("character") and isinstance(args["character"], self.scene.Character):
|
||||
|
||||
if args.get("character") and isinstance(
|
||||
args["character"], self.scene.Character
|
||||
):
|
||||
args["character"] = args["character"].name
|
||||
|
||||
return {
|
||||
@@ -701,7 +706,9 @@ class NarratorAgent(
|
||||
fn = getattr(self, action_name)
|
||||
narration = await fn(**kwargs)
|
||||
|
||||
narrator_message = NarratorMessage(narration, meta=self.action_to_meta(action_name, kwargs))
|
||||
narrator_message = NarratorMessage(
|
||||
narration, meta=self.action_to_meta(action_name, kwargs)
|
||||
)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
if emit_message:
|
||||
@@ -720,35 +727,36 @@ class NarratorAgent(
|
||||
kind=kind,
|
||||
agent_function_name=agent_function_name,
|
||||
)
|
||||
|
||||
|
||||
# depending on conversation format in the context, stopping strings
|
||||
# for character names may change format
|
||||
conversation_agent = get_agent("conversation")
|
||||
|
||||
|
||||
if conversation_agent.conversation_format == "movie_script":
|
||||
character_names = [f"\n{c.name.upper()}\n" for c in self.scene.get_characters()]
|
||||
else:
|
||||
character_names = [
|
||||
f"\n{c.name.upper()}\n" for c in self.scene.get_characters()
|
||||
]
|
||||
else:
|
||||
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
|
||||
|
||||
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += character_names
|
||||
|
||||
|
||||
self.set_generation_overrides(prompt_param)
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def set_generation_overrides(self, prompt_param: dict):
|
||||
if not self.actions["generation_override"].enabled:
|
||||
return
|
||||
|
||||
prompt_param["max_tokens"] = min(prompt_param.get("max_tokens", 256), self.max_generation_length)
|
||||
prompt_param["max_tokens"] = min(
|
||||
prompt_param.get("max_tokens", 256), self.max_generation_length
|
||||
)
|
||||
|
||||
if self.jiggle > 0.0:
|
||||
nuke_repetition = client_context_attribute("nuke_repetition")
|
||||
@@ -756,4 +764,4 @@ class NarratorAgent(
|
||||
# we only apply the agent override if some other mechanism isn't already
|
||||
# setting the nuke_repetition value
|
||||
nuke_repetition = self.jiggle
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
|
||||
@@ -8,15 +8,16 @@ from talemate.util import iso8601_duration_to_human
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.narrator")
|
||||
|
||||
|
||||
class GenerateNarrationBase(AgentNode):
|
||||
"""
|
||||
Generate a narration message
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "narrator"
|
||||
_action_name:ClassVar[str] = ""
|
||||
_title:ClassVar[str] = "Generate Narration"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "narrator"
|
||||
_action_name: ClassVar[str] = ""
|
||||
_title: ClassVar[str] = "Generate Narration"
|
||||
|
||||
class Fields:
|
||||
narrative_direction = PropertyField(
|
||||
name="narrative_direction",
|
||||
@@ -24,151 +25,170 @@ class GenerateNarrationBase(AgentNode):
|
||||
default="",
|
||||
type="str",
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
if "title" not in kwargs:
|
||||
kwargs["title"] = self._title
|
||||
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("narrative_direction", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.add_output("generated", socket_type="str")
|
||||
self.add_output("message", socket_type="message_object")
|
||||
|
||||
|
||||
async def prepare_input_values(self) -> dict:
|
||||
input_values = self.get_input_values()
|
||||
input_values.pop("state", None)
|
||||
return input_values
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
input_values = await self.prepare_input_values()
|
||||
try:
|
||||
agent_fn = getattr(self.agent, self._action_name)
|
||||
except AttributeError:
|
||||
raise InputValueError(self, "_action_name", f"Agent does not have a function named {self._action_name}")
|
||||
|
||||
raise InputValueError(
|
||||
self,
|
||||
"_action_name",
|
||||
f"Agent does not have a function named {self._action_name}",
|
||||
)
|
||||
|
||||
narration = await agent_fn(**input_values)
|
||||
|
||||
|
||||
message = NarratorMessage(
|
||||
message=narration,
|
||||
meta=self.agent.action_to_meta(self._action_name, input_values),
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"generated": narration,
|
||||
"message": message
|
||||
})
|
||||
|
||||
|
||||
|
||||
self.set_output_values({"generated": narration, "message": message})
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateProgress")
|
||||
class GenerateProgressNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a progress narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "progress_story"
|
||||
_title:ClassVar[str] = "Generate Progress Narration"
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "progress_story"
|
||||
_title: ClassVar[str] = "Generate Progress Narration"
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateSceneNarration")
|
||||
class GenerateSceneNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a scene narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_scene"
|
||||
_title:ClassVar[str] = "Generate Scene Narration"
|
||||
"""
|
||||
|
||||
@register("agents/narrator/GenerateAfterDialogNarration")
|
||||
_action_name: ClassVar[str] = "narrate_scene"
|
||||
_title: ClassVar[str] = "Generate Scene Narration"
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateAfterDialogNarration")
|
||||
class GenerateAfterDialogNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate an after dialog narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_after_dialogue"
|
||||
_title:ClassVar[str] = "Generate After Dialog Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_after_dialogue"
|
||||
_title: ClassVar[str] = "Generate After Dialog Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateEnvironmentNarration")
|
||||
class GenerateEnvironmentNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate an environment narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_environment"
|
||||
_title:ClassVar[str] = "Generate Environment Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_environment"
|
||||
_title: ClassVar[str] = "Generate Environment Narration"
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateQueryNarration")
|
||||
class GenerateQueryNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a query narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_query"
|
||||
_title:ClassVar[str] = "Generate Query Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_query"
|
||||
_title: ClassVar[str] = "Generate Query Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("query", socket_type="str")
|
||||
self.add_input("extra_context", socket_type="str", optional=True)
|
||||
self.remove_input("narrative_direction")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateCharacterNarration")
|
||||
class GenerateCharacterNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a character narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_character"
|
||||
_title:ClassVar[str] = "Generate Character Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_character"
|
||||
_title: ClassVar[str] = "Generate Character Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateTimeNarration")
|
||||
class GenerateTimeNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a time narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_time_passage"
|
||||
_title:ClassVar[str] = "Generate Time Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_time_passage"
|
||||
_title: ClassVar[str] = "Generate Time Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("duration", socket_type="str")
|
||||
self.set_property("duration", "P0T1S")
|
||||
|
||||
|
||||
async def prepare_input_values(self) -> dict:
|
||||
input_values = await super().prepare_input_values()
|
||||
input_values["time_passed"] = iso8601_duration_to_human(input_values["duration"])
|
||||
input_values["time_passed"] = iso8601_duration_to_human(
|
||||
input_values["duration"]
|
||||
)
|
||||
return input_values
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateCharacterEntryNarration")
|
||||
class GenerateCharacterEntryNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a character entry narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_character_entry"
|
||||
_title:ClassVar[str] = "Generate Character Entry Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_character_entry"
|
||||
_title: ClassVar[str] = "Generate Character Entry Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateCharacterExitNarration")
|
||||
class GenerateCharacterExitNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a character exit narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_character_exit"
|
||||
_title:ClassVar[str] = "Generate Character Exit Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_character_exit"
|
||||
_title: ClassVar[str] = "Generate Character Exit Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/UnpackSource")
|
||||
class UnpackSource(AgentNode):
|
||||
"""
|
||||
@@ -176,25 +196,19 @@ class UnpackSource(AgentNode):
|
||||
into action name and arguments
|
||||
DEPRECATED
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "narrator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "narrator"
|
||||
|
||||
def __init__(self, title="Unpack Source", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("source", socket_type="str")
|
||||
self.add_output("action_name", socket_type="str")
|
||||
self.add_output("arguments", socket_type="dict")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
source = self.get_input_value("source")
|
||||
action_name = ""
|
||||
arguments = {}
|
||||
|
||||
self.set_output_values({
|
||||
"action_name": action_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
|
||||
|
||||
self.set_output_values({"action_name": action_name, "arguments": arguments})
|
||||
|
||||
@@ -15,27 +15,31 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.narrator")
|
||||
|
||||
|
||||
class QueryPayload(pydantic.BaseModel):
|
||||
query:str
|
||||
at_the_end:bool=True
|
||||
|
||||
query: str
|
||||
at_the_end: bool = True
|
||||
|
||||
|
||||
class NarrativeDirectionPayload(pydantic.BaseModel):
|
||||
narrative_direction:str = ""
|
||||
narrative_direction: str = ""
|
||||
|
||||
|
||||
class CharacterPayload(NarrativeDirectionPayload):
|
||||
character:str = ""
|
||||
character: str = ""
|
||||
|
||||
|
||||
class NarratorWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles narrator actions
|
||||
"""
|
||||
|
||||
|
||||
router = "narrator"
|
||||
|
||||
|
||||
@property
|
||||
def narrator(self):
|
||||
return get_agent("narrator")
|
||||
|
||||
|
||||
@set_loading("Progressing the story", cancellable=True, as_async=True)
|
||||
async def handle_progress(self, data: dict):
|
||||
"""
|
||||
@@ -47,7 +51,7 @@ class NarratorWebsocketHandler(Plugin):
|
||||
narrative_direction=payload.narrative_direction,
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
@set_loading("Narrating the environment", cancellable=True, as_async=True)
|
||||
async def handle_narrate_environment(self, data: dict):
|
||||
"""
|
||||
@@ -59,8 +63,7 @@ class NarratorWebsocketHandler(Plugin):
|
||||
narrative_direction=payload.narrative_direction,
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@set_loading("Working on a query", cancellable=True, as_async=True)
|
||||
async def handle_query(self, data: dict):
|
||||
"""
|
||||
@@ -68,56 +71,55 @@ class NarratorWebsocketHandler(Plugin):
|
||||
message.
|
||||
"""
|
||||
payload = QueryPayload(**data)
|
||||
|
||||
|
||||
narration = await self.narrator.narrate_query(**payload.model_dump())
|
||||
message: ContextInvestigationMessage = ContextInvestigationMessage(
|
||||
narration, sub_type="query"
|
||||
narration, sub_type="query"
|
||||
)
|
||||
message.set_source("narrator", "narrate_query", **payload.model_dump())
|
||||
|
||||
|
||||
|
||||
emit("context_investigation", message=message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
@set_loading("Looking at the scene", cancellable=True, as_async=True)
|
||||
async def handle_look_at_scene(self, data: dict):
|
||||
"""
|
||||
Look at the scene (optionally to a specific direction)
|
||||
|
||||
|
||||
This will result in a context investigation message.
|
||||
"""
|
||||
payload = NarrativeDirectionPayload(**data)
|
||||
|
||||
narration = await self.narrator.narrate_scene(narrative_direction=payload.narrative_direction)
|
||||
|
||||
|
||||
narration = await self.narrator.narrate_scene(
|
||||
narrative_direction=payload.narrative_direction
|
||||
)
|
||||
|
||||
message: ContextInvestigationMessage = ContextInvestigationMessage(
|
||||
narration, sub_type="visual-scene"
|
||||
)
|
||||
message.set_source("narrator", "narrate_scene", **payload.model_dump())
|
||||
|
||||
|
||||
emit("context_investigation", message=message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
@set_loading("Looking at a character", cancellable=True, as_async=True)
|
||||
async def handle_look_at_character(self, data: dict):
|
||||
"""
|
||||
Look at a character (optionally to a specific direction)
|
||||
|
||||
|
||||
This will result in a context investigation message.
|
||||
"""
|
||||
payload = CharacterPayload(**data)
|
||||
|
||||
|
||||
|
||||
narration = await self.narrator.narrate_character(
|
||||
character = self.scene.get_character(payload.character),
|
||||
character=self.scene.get_character(payload.character),
|
||||
narrative_direction=payload.narrative_direction,
|
||||
)
|
||||
|
||||
|
||||
message: ContextInvestigationMessage = ContextInvestigationMessage(
|
||||
narration, sub_type="visual-character"
|
||||
)
|
||||
message.set_source("narrator", "narrate_character", **payload.model_dump())
|
||||
|
||||
|
||||
emit("context_investigation", message=message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
@@ -24,4 +24,4 @@ def get_agent_class(name):
|
||||
|
||||
|
||||
def get_agent_types() -> list[str]:
|
||||
return list(AGENT_CLASSES.keys())
|
||||
return list(AGENT_CLASSES.keys())
|
||||
|
||||
@@ -7,26 +7,21 @@ import structlog
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import (
|
||||
DirectorMessage,
|
||||
TimePassageMessage,
|
||||
ContextInvestigationMessage,
|
||||
DirectorMessage,
|
||||
TimePassageMessage,
|
||||
ContextInvestigationMessage,
|
||||
ReinforcementMessage,
|
||||
)
|
||||
from talemate.world_state.templates import GenerationOptions
|
||||
from talemate.instance import get_agent
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
import talemate.game.focal as focal
|
||||
import talemate.emit.async_signals
|
||||
|
||||
from talemate.client import ClientBase
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
set_processing,
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
set_processing,
|
||||
AgentEmission,
|
||||
AgentTemplateEmission,
|
||||
RagBuildSubInstructionEmission,
|
||||
@@ -39,6 +34,7 @@ from talemate.history import ArchiveEntry
|
||||
from .analyze_scene import SceneAnalyzationMixin
|
||||
from .context_investigation import ContextInvestigationMixin
|
||||
from .layered_history import LayeredHistoryMixin
|
||||
from .tts_utils import TTSUtilsMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
@@ -53,10 +49,12 @@ talemate.emit.async_signals.register(
|
||||
"agent.summarization.summarize.after",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BuildArchiveEmission(AgentEmission):
|
||||
generation_options: GenerationOptions | None = None
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SummarizeEmission(AgentTemplateEmission):
|
||||
text: str = ""
|
||||
@@ -66,6 +64,7 @@ class SummarizeEmission(AgentTemplateEmission):
|
||||
summarization_history: list[str] | None = None
|
||||
summarization_type: Literal["dialogue", "events"] = "dialogue"
|
||||
|
||||
|
||||
@register()
|
||||
class SummarizeAgent(
|
||||
MemoryRAGMixin,
|
||||
@@ -73,7 +72,8 @@ class SummarizeAgent(
|
||||
ContextInvestigationMixin,
|
||||
# Needs to be after ContextInvestigationMixin so signals are connected in the right order
|
||||
SceneAnalyzationMixin,
|
||||
Agent
|
||||
TTSUtilsMixin,
|
||||
Agent,
|
||||
):
|
||||
"""
|
||||
An agent that can be used to summarize text
|
||||
@@ -82,7 +82,7 @@ class SummarizeAgent(
|
||||
agent_type = "summarizer"
|
||||
verbose_name = "Summarizer"
|
||||
auto_squish = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls) -> dict[str, AgentAction]:
|
||||
actions = {
|
||||
@@ -131,7 +131,7 @@ class SummarizeAgent(
|
||||
ContextInvestigationMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
def __init__(self, client: ClientBase | None = None, **kwargs):
|
||||
self.client = client
|
||||
|
||||
self.actions = SummarizeAgent.init_actions()
|
||||
@@ -148,11 +148,11 @@ class SummarizeAgent(
|
||||
@property
|
||||
def archive_threshold(self):
|
||||
return self.actions["archive"].config["threshold"].value
|
||||
|
||||
|
||||
@property
|
||||
def archive_method(self):
|
||||
return self.actions["archive"].config["method"].value
|
||||
|
||||
|
||||
@property
|
||||
def archive_include_previous(self):
|
||||
return self.actions["archive"].config["include_previous"].value
|
||||
@@ -179,7 +179,7 @@ class SummarizeAgent(
|
||||
return result
|
||||
|
||||
# RAG HELPERS
|
||||
|
||||
|
||||
async def rag_build_sub_instruction(self):
|
||||
# Fire event to get the sub instruction from mixins
|
||||
emission = RagBuildSubInstructionEmission(
|
||||
@@ -188,37 +188,42 @@ class SummarizeAgent(
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.rag_build_sub_instruction"
|
||||
).send(emission)
|
||||
|
||||
|
||||
return emission.sub_instruction
|
||||
|
||||
|
||||
# SUMMARIZATION HELPERS
|
||||
|
||||
|
||||
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
|
||||
|
||||
num_previous = self.archive_include_previous
|
||||
|
||||
|
||||
# find entry by .id
|
||||
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
|
||||
entry_index = next(
|
||||
(
|
||||
i
|
||||
for i, e in enumerate(self.scene.archived_history)
|
||||
if e["id"] == entry.id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if entry_index is None:
|
||||
raise ValueError("Entry not found")
|
||||
end = entry_index - 1
|
||||
|
||||
previous_summaries = []
|
||||
|
||||
|
||||
if entry and num_previous > 0:
|
||||
if self.layered_history_available:
|
||||
previous_summaries = self.compile_layered_history(
|
||||
include_base_layer=True,
|
||||
base_layer_end_id=entry.id
|
||||
include_base_layer=True, base_layer_end_id=entry.id
|
||||
)[-num_previous:]
|
||||
else:
|
||||
previous_summaries = [
|
||||
entry.text for entry in self.scene.archived_history[end-num_previous:end]
|
||||
entry.text
|
||||
for entry in self.scene.archived_history[end - num_previous : end]
|
||||
]
|
||||
|
||||
return previous_summaries
|
||||
|
||||
|
||||
# SUMMARIZE
|
||||
|
||||
@set_processing
|
||||
@@ -226,13 +231,15 @@ class SummarizeAgent(
|
||||
self, scene, generation_options: GenerationOptions | None = None
|
||||
):
|
||||
end = None
|
||||
|
||||
|
||||
emission = BuildArchiveEmission(
|
||||
agent=self,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.before_build_archive").send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.before_build_archive"
|
||||
).send(emission)
|
||||
|
||||
if not self.actions["archive"].enabled:
|
||||
return
|
||||
@@ -260,7 +267,7 @@ class SummarizeAgent(
|
||||
extra_context = [
|
||||
entry["text"] for entry in scene.archived_history[-num_previous:]
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
extra_context = None
|
||||
|
||||
@@ -283,7 +290,10 @@ class SummarizeAgent(
|
||||
|
||||
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
if isinstance(dialogue, (DirectorMessage, ContextInvestigationMessage, ReinforcementMessage)):
|
||||
if isinstance(
|
||||
dialogue,
|
||||
(DirectorMessage, ContextInvestigationMessage, ReinforcementMessage),
|
||||
):
|
||||
# these messages are not part of the dialogue and should not be summarized
|
||||
if i == start:
|
||||
start += 1
|
||||
@@ -292,7 +302,7 @@ class SummarizeAgent(
|
||||
if isinstance(dialogue, TimePassageMessage):
|
||||
log.debug("build_archive", time_passage_message=dialogue)
|
||||
ts = util.iso8601_add(ts, dialogue.ts)
|
||||
|
||||
|
||||
if i == start:
|
||||
log.debug(
|
||||
"build_archive",
|
||||
@@ -347,23 +357,28 @@ class SummarizeAgent(
|
||||
if str(line) in terminating_line:
|
||||
break
|
||||
adjusted_dialogue.append(line)
|
||||
|
||||
|
||||
# if difference start and end is less than 4, ignore the termination
|
||||
if len(adjusted_dialogue) > 4:
|
||||
dialogue_entries = adjusted_dialogue
|
||||
end = start + len(dialogue_entries) - 1
|
||||
else:
|
||||
log.debug("build_archive", message="Ignoring termination", start=start, end=end, adjusted_dialogue=adjusted_dialogue)
|
||||
log.debug(
|
||||
"build_archive",
|
||||
message="Ignoring termination",
|
||||
start=start,
|
||||
end=end,
|
||||
adjusted_dialogue=adjusted_dialogue,
|
||||
)
|
||||
|
||||
if dialogue_entries:
|
||||
|
||||
if not extra_context:
|
||||
# prepend scene intro to dialogue
|
||||
dialogue_entries.insert(0, scene.intro)
|
||||
|
||||
|
||||
summarized = None
|
||||
retries = 5
|
||||
|
||||
|
||||
while not summarized and retries > 0:
|
||||
summarized = await self.summarize(
|
||||
"\n".join(map(str, dialogue_entries)),
|
||||
@@ -371,7 +386,7 @@ class SummarizeAgent(
|
||||
generation_options=generation_options,
|
||||
)
|
||||
retries -= 1
|
||||
|
||||
|
||||
if not summarized:
|
||||
raise IOError("Failed to summarize dialogue", dialogue=dialogue_entries)
|
||||
|
||||
@@ -382,13 +397,17 @@ class SummarizeAgent(
|
||||
|
||||
# determine the appropariate timestamp for the summarization
|
||||
|
||||
await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
|
||||
|
||||
scene.ts=ts
|
||||
await scene.push_archive(
|
||||
ArchiveEntry(text=summarized, start=start, end=end, ts=ts)
|
||||
)
|
||||
|
||||
scene.ts = ts
|
||||
scene.emit_status()
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.after_build_archive").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.after_build_archive"
|
||||
).send(emission)
|
||||
|
||||
return True
|
||||
|
||||
@set_processing
|
||||
@@ -408,23 +427,23 @@ class SummarizeAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def find_natural_scene_termination(self, event_chunks:list[str]) -> list[list[str]]:
|
||||
async def find_natural_scene_termination(
|
||||
self, event_chunks: list[str]
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Will analyze a list of events and return a list of events that
|
||||
has been separated at a natural scene termination points.
|
||||
"""
|
||||
|
||||
|
||||
# scan through event chunks and split into paragraphs
|
||||
rebuilt_chunks = []
|
||||
|
||||
|
||||
for chunk in event_chunks:
|
||||
paragraphs = [
|
||||
p.strip() for p in chunk.split("\n") if p.strip()
|
||||
]
|
||||
paragraphs = [p.strip() for p in chunk.split("\n") if p.strip()]
|
||||
rebuilt_chunks.extend(paragraphs)
|
||||
|
||||
|
||||
event_chunks = rebuilt_chunks
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"summarizer.find-natural-scene-termination-events",
|
||||
self.client,
|
||||
@@ -436,38 +455,42 @@ class SummarizeAgent(
|
||||
},
|
||||
)
|
||||
response = response.strip()
|
||||
|
||||
|
||||
items = util.extract_list(response)
|
||||
|
||||
# will be a list of
|
||||
|
||||
# will be a list of
|
||||
# ["Progress 1", "Progress 12", "Progress 323", ...]
|
||||
# convert to a list of just numbers
|
||||
|
||||
# convert to a list of just numbers
|
||||
|
||||
numbers = []
|
||||
|
||||
|
||||
for item in items:
|
||||
match = re.match(r"Progress (\d+)", item.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
|
||||
# make sure its unique and sorted
|
||||
numbers = sorted(list(set(numbers)))
|
||||
|
||||
|
||||
result = []
|
||||
prev_number = 0
|
||||
for number in numbers:
|
||||
result.append(event_chunks[prev_number:number+1])
|
||||
prev_number = number+1
|
||||
|
||||
#result = {
|
||||
result.append(event_chunks[prev_number : number + 1])
|
||||
prev_number = number + 1
|
||||
|
||||
# result = {
|
||||
# "selected": event_chunks[:number+1],
|
||||
# "remaining": event_chunks[number+1:]
|
||||
#}
|
||||
|
||||
log.debug("find_natural_scene_termination", response=response, result=result, numbers=numbers)
|
||||
|
||||
# }
|
||||
|
||||
log.debug(
|
||||
"find_natural_scene_termination",
|
||||
response=response,
|
||||
result=result,
|
||||
numbers=numbers,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize(
|
||||
@@ -481,9 +504,9 @@ class SummarizeAgent(
|
||||
"""
|
||||
Summarize the given text
|
||||
"""
|
||||
|
||||
|
||||
response_length = 1024
|
||||
|
||||
|
||||
template_vars = {
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
@@ -500,7 +523,7 @@ class SummarizeAgent(
|
||||
"analyze_chunks": self.layered_history_analyze_chunks,
|
||||
"response_length": response_length,
|
||||
}
|
||||
|
||||
|
||||
emission = SummarizeEmission(
|
||||
agent=self,
|
||||
text=text,
|
||||
@@ -511,43 +534,46 @@ class SummarizeAgent(
|
||||
summarization_history=extra_context or [],
|
||||
summarization_type="dialogue",
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.before"
|
||||
).send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
f"summarizer.summarize-dialogue",
|
||||
"summarizer.summarize-dialogue",
|
||||
self.client,
|
||||
f"summarize_{response_length}",
|
||||
vars=template_vars,
|
||||
dedupe_enabled=False
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"summarize", dialogue_length=len(text), summarized_length=len(response)
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
summary = response.split("SUMMARY:")[1].strip()
|
||||
except Exception as e:
|
||||
log.error("summarize failed", response=response, exc=e)
|
||||
return ""
|
||||
|
||||
|
||||
# capitalize first letter
|
||||
try:
|
||||
summary = summary[0].upper() + summary[1:]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
emission.response = self.clean_result(summary)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.after"
|
||||
).send(emission)
|
||||
|
||||
summary = emission.response
|
||||
|
||||
|
||||
return self.clean_result(summary)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize_events(
|
||||
@@ -563,15 +589,14 @@ class SummarizeAgent(
|
||||
"""
|
||||
Summarize the given text
|
||||
"""
|
||||
|
||||
|
||||
if not extra_context:
|
||||
extra_context = ""
|
||||
|
||||
|
||||
mentioned_characters: list["Character"] = self.scene.parse_characters_from_text(
|
||||
text + extra_context,
|
||||
exclude_active=True
|
||||
text + extra_context, exclude_active=True
|
||||
)
|
||||
|
||||
|
||||
template_vars = {
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
@@ -585,7 +610,7 @@ class SummarizeAgent(
|
||||
"response_length": response_length,
|
||||
"mentioned_characters": mentioned_characters,
|
||||
}
|
||||
|
||||
|
||||
emission = SummarizeEmission(
|
||||
agent=self,
|
||||
text=text,
|
||||
@@ -596,46 +621,62 @@ class SummarizeAgent(
|
||||
summarization_history=[extra_context] if extra_context else [],
|
||||
summarization_type="events",
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.before"
|
||||
).send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
f"summarizer.summarize-events",
|
||||
"summarizer.summarize-events",
|
||||
self.client,
|
||||
f"summarize_{response_length}",
|
||||
vars=template_vars,
|
||||
dedupe_enabled=False
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
response = response.strip()
|
||||
response = response.replace('"', "")
|
||||
|
||||
|
||||
log.debug(
|
||||
"layered_history_summarize", original_length=len(text), summarized_length=len(response)
|
||||
"layered_history_summarize",
|
||||
original_length=len(text),
|
||||
summarized_length=len(response),
|
||||
)
|
||||
|
||||
|
||||
# clean up analyzation (remove analyzation text)
|
||||
if self.layered_history_analyze_chunks:
|
||||
# remove all lines that begin with "ANALYSIS OF CHUNK \d+:"
|
||||
response = "\n".join([line for line in response.split("\n") if not line.startswith("ANALYSIS OF CHUNK")])
|
||||
|
||||
response = "\n".join(
|
||||
[
|
||||
line
|
||||
for line in response.split("\n")
|
||||
if not line.startswith("ANALYSIS OF CHUNK")
|
||||
]
|
||||
)
|
||||
|
||||
# strip all occurences of "CHUNK \d+: " from the summary
|
||||
response = re.sub(r"(CHUNK|CHAPTER) \d+:\s+", "", response)
|
||||
|
||||
|
||||
# capitalize first letter
|
||||
try:
|
||||
response = response[0].upper() + response[1:]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
emission.response = self.clean_result(response)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.after"
|
||||
).send(emission)
|
||||
|
||||
response = emission.response
|
||||
|
||||
log.debug("summarize_events", original_length=len(text), summarized_length=len(response))
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarize_events",
|
||||
original_length=len(text),
|
||||
summarized_length=len(response),
|
||||
)
|
||||
|
||||
return self.clean_result(response)
|
||||
|
||||
@@ -16,25 +16,44 @@ from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.agents.narrator import NarratorAgentEmission
|
||||
from talemate.agents.context import active_agent
|
||||
from talemate.agents.base import RagBuildSubInstructionEmission
|
||||
from contextvars import ContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
## CONTEXT
|
||||
|
||||
scene_analysis_disabled_context = ContextVar("scene_analysis_disabled", default=False)
|
||||
|
||||
|
||||
class SceneAnalysisDisabled:
|
||||
"""
|
||||
Context manager to disable scene analysis during specific agent actions.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.token = scene_analysis_disabled_context.set(True)
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
scene_analysis_disabled_context.reset(self.token)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.summarization.scene_analysis.before",
|
||||
"agent.summarization.scene_analysis.after",
|
||||
"agent.summarization.scene_analysis.cached",
|
||||
"agent.summarization.scene_analysis.before_deep_analysis",
|
||||
"agent.summarization.scene_analysis.after_deep_analysis",
|
||||
"agent.summarization.scene_analysis.after_deep_analysis",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SceneAnalysisEmission(AgentTemplateEmission):
|
||||
analysis_type: str | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SceneAnalysisDeepAnalysisEmission(AgentEmission):
|
||||
analysis: str
|
||||
@@ -42,15 +61,16 @@ class SceneAnalysisDeepAnalysisEmission(AgentEmission):
|
||||
analysis_sub_type: str | None = None
|
||||
max_content_investigations: int = 1
|
||||
character: "Character" = None
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
class SceneAnalyzationMixin:
|
||||
|
||||
"""
|
||||
Summarizer agent mixin that provides functionality for scene analyzation.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["analyze_scene"] = AgentAction(
|
||||
@@ -71,8 +91,8 @@ class SceneAnalyzationMixin:
|
||||
choices=[
|
||||
{"label": "Short (256)", "value": "256"},
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"}
|
||||
]
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
],
|
||||
),
|
||||
"for_conversation": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -109,196 +129,221 @@ class SceneAnalyzationMixin:
|
||||
value=True,
|
||||
quick_toggle=True,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def analyze_scene(self) -> bool:
|
||||
return self.actions["analyze_scene"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def analysis_length(self) -> int:
|
||||
return int(self.actions["analyze_scene"].config["analysis_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def cache_analysis(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["cache_analysis"].value
|
||||
|
||||
|
||||
@property
|
||||
def deep_analysis(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["deep_analysis"].value
|
||||
|
||||
|
||||
@property
|
||||
def deep_analysis_max_context_investigations(self) -> int:
|
||||
return self.actions["analyze_scene"].config["deep_analysis_max_context_investigations"].value
|
||||
|
||||
return (
|
||||
self.actions["analyze_scene"]
|
||||
.config["deep_analysis_max_context_investigations"]
|
||||
.value
|
||||
)
|
||||
|
||||
@property
|
||||
def analyze_scene_for_conversation(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["for_conversation"].value
|
||||
|
||||
|
||||
@property
|
||||
def analyze_scene_for_narration(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["for_narration"].value
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
|
||||
self.on_inject_instructions
|
||||
)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.conversation.inject_instructions"
|
||||
).connect(self.on_inject_instructions)
|
||||
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
|
||||
self.on_inject_instructions
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.summarization.rag_build_sub_instruction").connect(
|
||||
self.on_rag_build_sub_instruction
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
|
||||
self.on_editor_revision_analysis_before
|
||||
)
|
||||
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.rag_build_sub_instruction"
|
||||
).connect(self.on_rag_build_sub_instruction)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.editor.revision-analysis.before"
|
||||
).connect(self.on_editor_revision_analysis_before)
|
||||
|
||||
async def on_inject_instructions(
|
||||
self,
|
||||
emission:ConversationAgentEmission | NarratorAgentEmission,
|
||||
self,
|
||||
emission: ConversationAgentEmission | NarratorAgentEmission,
|
||||
):
|
||||
"""
|
||||
Injects instructions into the conversation.
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(emission, ConversationAgentEmission):
|
||||
emission_type = "conversation"
|
||||
elif isinstance(emission, NarratorAgentEmission):
|
||||
emission_type = "narration"
|
||||
else:
|
||||
raise ValueError("Invalid emission type.")
|
||||
|
||||
|
||||
if not self.analyze_scene:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
if scene_analysis_disabled_context.get():
|
||||
log.debug(
|
||||
"on_inject_instructions: scene analysis disabled through context",
|
||||
emission=emission,
|
||||
)
|
||||
return
|
||||
except LookupError:
|
||||
pass
|
||||
|
||||
analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}")
|
||||
|
||||
|
||||
if not analyze_scene_for_type:
|
||||
return
|
||||
|
||||
|
||||
analysis = None
|
||||
|
||||
|
||||
# self.set_scene_states and self.get_scene_state to store
|
||||
# cached analysis in scene states
|
||||
|
||||
|
||||
if self.cache_analysis:
|
||||
analysis = await self.get_cached_analysis(emission_type)
|
||||
if analysis:
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").send(
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.cached"
|
||||
).send(
|
||||
SceneAnalysisEmission(
|
||||
agent=self,
|
||||
analysis_type=emission_type,
|
||||
response=analysis,
|
||||
agent=self,
|
||||
analysis_type=emission_type,
|
||||
response=analysis,
|
||||
template_vars={
|
||||
"character": emission.character if hasattr(emission, "character") else None,
|
||||
"character": emission.character
|
||||
if hasattr(emission, "character")
|
||||
else None,
|
||||
},
|
||||
dynamic_instructions=emission.dynamic_instructions
|
||||
dynamic_instructions=emission.dynamic_instructions,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if not analysis and self.analyze_scene:
|
||||
# analyze the scene for the next action
|
||||
analysis = await self.analyze_scene_for_next_action(
|
||||
emission_type,
|
||||
emission.character if hasattr(emission, "character") else None,
|
||||
self.analysis_length
|
||||
self.analysis_length,
|
||||
)
|
||||
|
||||
|
||||
await self.set_cached_analysis(emission_type, analysis)
|
||||
|
||||
|
||||
if not analysis:
|
||||
return
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(
|
||||
title="Scene Analysis",
|
||||
content=analysis
|
||||
)
|
||||
DynamicInstruction(title="Scene Analysis", content=analysis)
|
||||
)
|
||||
|
||||
async def on_rag_build_sub_instruction(self, emission:"RagBuildSubInstructionEmission"):
|
||||
|
||||
async def on_rag_build_sub_instruction(
|
||||
self, emission: "RagBuildSubInstructionEmission"
|
||||
):
|
||||
"""
|
||||
Injects the sub instruction into the analysis.
|
||||
"""
|
||||
sub_instruction = await self.analyze_scene_rag_build_sub_instruction()
|
||||
|
||||
|
||||
if sub_instruction:
|
||||
emission.sub_instruction = sub_instruction
|
||||
|
||||
|
||||
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
|
||||
last_analysis = self.get_scene_state("scene_analysis")
|
||||
if last_analysis:
|
||||
emission.dynamic_instructions.append(DynamicInstruction(
|
||||
title="Scene Analysis",
|
||||
content=last_analysis
|
||||
))
|
||||
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(title="Scene Analysis", content=last_analysis)
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
async def get_cached_analysis(self, typ:str) -> str | None:
|
||||
|
||||
async def get_cached_analysis(self, typ: str) -> str | None:
|
||||
"""
|
||||
Returns the cached analysis for the given type.
|
||||
"""
|
||||
|
||||
|
||||
cached_analysis = self.get_scene_state(f"cached_analysis_{typ}")
|
||||
|
||||
|
||||
if not cached_analysis:
|
||||
return None
|
||||
|
||||
fingerprint = self.context_fingerpint()
|
||||
|
||||
|
||||
fingerprint = self.context_fingerprint()
|
||||
|
||||
log.debug(
|
||||
"get_cached_analysis",
|
||||
fingerprint=fingerprint,
|
||||
cached_analysis_fp=cached_analysis.get("fp"),
|
||||
match=cached_analysis.get("fp") == fingerprint,
|
||||
)
|
||||
|
||||
if cached_analysis.get("fp") == fingerprint:
|
||||
return cached_analysis["guidance"]
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def set_cached_analysis(self, typ:str, analysis:str):
|
||||
|
||||
async def set_cached_analysis(self, typ: str, analysis: str):
|
||||
"""
|
||||
Sets the cached analysis for the given type.
|
||||
"""
|
||||
|
||||
fingerprint = self.context_fingerpint()
|
||||
|
||||
|
||||
fingerprint = self.context_fingerprint()
|
||||
|
||||
self.set_scene_states(
|
||||
**{f"cached_analysis_{typ}": {
|
||||
"fp": fingerprint,
|
||||
"guidance": analysis,
|
||||
}}
|
||||
**{
|
||||
f"cached_analysis_{typ}": {
|
||||
"fp": fingerprint,
|
||||
"guidance": analysis,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def analyze_scene_sub_type(self, analysis_type:str) -> str:
|
||||
|
||||
async def analyze_scene_sub_type(self, analysis_type: str) -> str:
|
||||
"""
|
||||
Analyzes the active agent context to figure out the appropriate sub type
|
||||
"""
|
||||
|
||||
|
||||
fn = getattr(self, f"analyze_scene_{analysis_type}_sub_type", None)
|
||||
|
||||
|
||||
if fn:
|
||||
return await fn()
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
async def analyze_scene_narration_sub_type(self) -> str:
|
||||
"""
|
||||
Analyzes the active agent context to figure out the appropriate sub type
|
||||
for narration analysis. (progress, query etc.)
|
||||
"""
|
||||
|
||||
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
|
||||
if not active_agent_context:
|
||||
return "progress"
|
||||
|
||||
|
||||
state = active_agent_context.state
|
||||
|
||||
|
||||
if state.get("narrator__query_narration"):
|
||||
return "query"
|
||||
|
||||
|
||||
if state.get("narrator__sensory_narration"):
|
||||
return "sensory"
|
||||
|
||||
@@ -306,70 +351,85 @@ class SceneAnalyzationMixin:
|
||||
if state.get("narrator__character"):
|
||||
return "visual-character"
|
||||
return "visual"
|
||||
|
||||
|
||||
if state.get("narrator__fn_narrate_character_entry"):
|
||||
return "progress-character-entry"
|
||||
|
||||
|
||||
if state.get("narrator__fn_narrate_character_exit"):
|
||||
return "progress-character-exit"
|
||||
|
||||
|
||||
return "progress"
|
||||
|
||||
|
||||
async def analyze_scene_rag_build_sub_instruction(self):
|
||||
"""
|
||||
Analyzes the active agent context to figure out the appropriate sub type
|
||||
for rag build sub instruction.
|
||||
"""
|
||||
|
||||
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
|
||||
if not active_agent_context:
|
||||
return ""
|
||||
|
||||
|
||||
state = active_agent_context.state
|
||||
|
||||
|
||||
if state.get("narrator__query_narration"):
|
||||
query = state["narrator__query"]
|
||||
if query.endswith("?"):
|
||||
return "Answer the following question: " + query
|
||||
else:
|
||||
return query
|
||||
|
||||
|
||||
narrative_direction = state.get("narrator__narrative_direction")
|
||||
|
||||
|
||||
if state.get("narrator__sensory_narration") and narrative_direction:
|
||||
return "Collect information that aids in describing the following sensory experience: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following sensory experience: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__visual_narration") and narrative_direction:
|
||||
return "Collect information that aids in describing the following visual experience: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following visual experience: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__fn_narrate_character_entry") and narrative_direction:
|
||||
return "Collect information that aids in describing the following character entry: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following character entry: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__fn_narrate_character_exit") and narrative_direction:
|
||||
return "Collect information that aids in describing the following character exit: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following character exit: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__fn_narrate_progress"):
|
||||
return "Collect information that aids in progressing the story: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in progressing the story: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
# actions
|
||||
|
||||
@set_processing
|
||||
async def analyze_scene_for_next_action(self, typ:str, character:"Character"=None, length:int=1024) -> str:
|
||||
|
||||
async def analyze_scene_for_next_action(
|
||||
self, typ: str, character: "Character" = None, length: int = 1024
|
||||
) -> str:
|
||||
"""
|
||||
Analyzes the current scene progress and gives a suggestion for the next action.
|
||||
taken by the given actor.
|
||||
"""
|
||||
|
||||
|
||||
# deep analysis is only available if the scene has a layered history
|
||||
# and context investigation is enabled
|
||||
deep_analysis = (self.deep_analysis and self.context_investigation_available)
|
||||
deep_analysis = self.deep_analysis and self.context_investigation_available
|
||||
analysis_sub_type = await self.analyze_scene_sub_type(typ)
|
||||
|
||||
|
||||
template_vars = {
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
@@ -381,58 +441,60 @@ class SceneAnalyzationMixin:
|
||||
"analysis_type": typ,
|
||||
"analysis_sub_type": analysis_sub_type,
|
||||
}
|
||||
|
||||
emission = SceneAnalysisEmission(agent=self, template_vars=template_vars, analysis_type=typ)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before").send(
|
||||
emission
|
||||
|
||||
emission = SceneAnalysisEmission(
|
||||
agent=self, template_vars=template_vars, analysis_type=typ
|
||||
)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.before"
|
||||
).send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
f"summarizer.analyze-scene-for-next-{typ}",
|
||||
self.client,
|
||||
f"investigate_{length}",
|
||||
vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
response = strip_partial_sentences(response)
|
||||
|
||||
|
||||
if not response.strip():
|
||||
return response
|
||||
|
||||
|
||||
if deep_analysis:
|
||||
|
||||
emission = SceneAnalysisDeepAnalysisEmission(
|
||||
agent=self,
|
||||
agent=self,
|
||||
analysis=response,
|
||||
analysis_type=typ,
|
||||
analysis_sub_type=analysis_sub_type,
|
||||
character=character,
|
||||
max_content_investigations=self.deep_analysis_max_context_investigations
|
||||
max_content_investigations=self.deep_analysis_max_context_investigations,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").send(
|
||||
emission
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after_deep_analysis").send(
|
||||
emission
|
||||
)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").send(
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.before_deep_analysis"
|
||||
).send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.after_deep_analysis"
|
||||
).send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.after"
|
||||
).send(
|
||||
SceneAnalysisEmission(
|
||||
agent=self,
|
||||
template_vars=template_vars,
|
||||
response=response,
|
||||
agent=self,
|
||||
template_vars=template_vars,
|
||||
response=response,
|
||||
analysis_type=typ,
|
||||
dynamic_instructions=emission.dynamic_instructions
|
||||
dynamic_instructions=emission.dynamic_instructions,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
self.set_context_states(scene_analysis=response)
|
||||
self.set_scene_states(scene_analysis=response)
|
||||
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
@@ -22,14 +22,12 @@ if TYPE_CHECKING:
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
|
||||
class ContextInvestigationMixin:
|
||||
|
||||
"""
|
||||
Summarizer agent mixin that provides functionality for context investigation
|
||||
through the layered history of the scene.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["context_investigation"] = AgentAction(
|
||||
@@ -52,7 +50,7 @@ class ContextInvestigationMixin:
|
||||
{"label": "Short (256)", "value": "256"},
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
]
|
||||
],
|
||||
),
|
||||
"update_method": AgentActionConfig(
|
||||
type="text",
|
||||
@@ -62,57 +60,56 @@ class ContextInvestigationMixin:
|
||||
choices=[
|
||||
{"label": "Replace", "value": "replace"},
|
||||
{"label": "Smart Merge", "value": "merge"},
|
||||
]
|
||||
)
|
||||
}
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def context_investigation_enabled(self):
|
||||
return self.actions["context_investigation"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def context_investigation_available(self):
|
||||
return (
|
||||
self.context_investigation_enabled and
|
||||
self.layered_history_available
|
||||
)
|
||||
|
||||
return self.context_investigation_enabled and self.layered_history_available
|
||||
|
||||
@property
|
||||
def context_investigation_answer_length(self) -> int:
|
||||
return int(self.actions["context_investigation"].config["answer_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def context_investigation_update_method(self) -> str:
|
||||
return self.actions["context_investigation"].config["update_method"].value
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
|
||||
self.on_inject_context_investigation
|
||||
)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.conversation.inject_instructions"
|
||||
).connect(self.on_inject_context_investigation)
|
||||
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
|
||||
self.on_inject_context_investigation
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.director.guide.inject_instructions").connect(
|
||||
self.on_inject_context_investigation
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").connect(
|
||||
self.on_summarization_scene_analysis_before_deep_analysis
|
||||
)
|
||||
|
||||
async def on_summarization_scene_analysis_before_deep_analysis(self, emission:SceneAnalysisDeepAnalysisEmission):
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.director.guide.inject_instructions"
|
||||
).connect(self.on_inject_context_investigation)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.before_deep_analysis"
|
||||
).connect(self.on_summarization_scene_analysis_before_deep_analysis)
|
||||
|
||||
async def on_summarization_scene_analysis_before_deep_analysis(
|
||||
self, emission: SceneAnalysisDeepAnalysisEmission
|
||||
):
|
||||
"""
|
||||
Handles context investigation for deep scene analysis.
|
||||
"""
|
||||
|
||||
|
||||
if not self.context_investigation_enabled:
|
||||
return
|
||||
|
||||
|
||||
suggested_investigations = await self.suggest_context_investigations(
|
||||
emission.analysis,
|
||||
emission.analysis_type,
|
||||
@@ -120,67 +117,72 @@ class ContextInvestigationMixin:
|
||||
max_calls=emission.max_content_investigations,
|
||||
character=emission.character,
|
||||
)
|
||||
|
||||
|
||||
response = emission.analysis
|
||||
|
||||
ci_calls:list[focal.Call] = await self.request_context_investigations(
|
||||
suggested_investigations,
|
||||
max_calls=emission.max_content_investigations
|
||||
|
||||
ci_calls: list[focal.Call] = await self.request_context_investigations(
|
||||
suggested_investigations, max_calls=emission.max_content_investigations
|
||||
)
|
||||
|
||||
|
||||
log.debug("analyze_scene_for_next_action", ci_calls=ci_calls)
|
||||
|
||||
|
||||
# append call queries and answers to the response
|
||||
ci_text = []
|
||||
for ci_call in ci_calls:
|
||||
try:
|
||||
ci_text.append(f"{ci_call.arguments['query']}\n{ci_call.result}")
|
||||
except KeyError as e:
|
||||
log.error("analyze_scene_for_next_action", error="Missing key in call", ci_call=ci_call)
|
||||
|
||||
context_investigation="\n\n".join(ci_text if ci_text else [])
|
||||
except KeyError:
|
||||
log.error(
|
||||
"analyze_scene_for_next_action",
|
||||
error="Missing key in call",
|
||||
ci_call=ci_call,
|
||||
)
|
||||
|
||||
context_investigation = "\n\n".join(ci_text if ci_text else [])
|
||||
current_context_investigation = self.get_scene_state("context_investigation")
|
||||
if current_context_investigation and context_investigation:
|
||||
if self.context_investigation_update_method == "merge":
|
||||
context_investigation = await self.update_context_investigation(
|
||||
current_context_investigation, context_investigation, response
|
||||
)
|
||||
|
||||
|
||||
self.set_scene_states(context_investigation=context_investigation)
|
||||
self.set_context_states(context_investigation=context_investigation)
|
||||
|
||||
|
||||
|
||||
async def on_inject_context_investigation(self, emission:ConversationAgentEmission | NarratorAgentEmission):
|
||||
|
||||
async def on_inject_context_investigation(
|
||||
self, emission: ConversationAgentEmission | NarratorAgentEmission
|
||||
):
|
||||
"""
|
||||
Injects context investigation into the conversation.
|
||||
"""
|
||||
|
||||
|
||||
if not self.context_investigation_enabled:
|
||||
return
|
||||
|
||||
|
||||
context_investigation = self.get_scene_state("context_investigation")
|
||||
log.debug("summarizer.on_inject_context_investigation", context_investigation=context_investigation, emission=emission)
|
||||
log.debug(
|
||||
"summarizer.on_inject_context_investigation",
|
||||
context_investigation=context_investigation,
|
||||
emission=emission,
|
||||
)
|
||||
if context_investigation:
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(
|
||||
title="Context Investigation",
|
||||
content=context_investigation
|
||||
title="Context Investigation", content=context_investigation
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def suggest_context_investigations(
|
||||
self,
|
||||
analysis:str,
|
||||
analysis_type:str,
|
||||
analysis_sub_type:str="",
|
||||
max_calls:int=3,
|
||||
character:"Character"=None,
|
||||
analysis: str,
|
||||
analysis_type: str,
|
||||
analysis_sub_type: str = "",
|
||||
max_calls: int = 3,
|
||||
character: "Character" = None,
|
||||
) -> str:
|
||||
|
||||
template_vars = {
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
@@ -192,111 +194,119 @@ class ContextInvestigationMixin:
|
||||
"analysis_type": analysis_type,
|
||||
"analysis_sub_type": analysis_sub_type,
|
||||
}
|
||||
|
||||
|
||||
if not analysis_sub_type:
|
||||
template = f"summarizer.suggest-context-investigations-for-{analysis_type}"
|
||||
else:
|
||||
template = f"summarizer.suggest-context-investigations-for-{analysis_type}-{analysis_sub_type}"
|
||||
|
||||
log.debug("summarizer.suggest_context_investigations", template=template, template_vars=template_vars)
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarizer.suggest_context_investigations",
|
||||
template=template,
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
template,
|
||||
self.client,
|
||||
"investigate_512",
|
||||
vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def investigate_context(
|
||||
self,
|
||||
layer:int,
|
||||
index:int,
|
||||
query:str,
|
||||
analysis:str="",
|
||||
max_calls:int=3,
|
||||
pad_entries:int=5,
|
||||
self,
|
||||
layer: int,
|
||||
index: int,
|
||||
query: str,
|
||||
analysis: str = "",
|
||||
max_calls: int = 3,
|
||||
pad_entries: int = 5,
|
||||
) -> str:
|
||||
"""
|
||||
Processes a context investigation.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- layer: The layer to investigate
|
||||
- index: The index in the layer to investigate
|
||||
- query: The query to investigate
|
||||
- analysis: Scene analysis text
|
||||
- pad_entries: if > 0 will pad the entries with the given number of entries before and after the start and end index
|
||||
"""
|
||||
|
||||
log.debug("summarizer.investigate_context", layer=layer, index=index, query=query)
|
||||
|
||||
log.debug(
|
||||
"summarizer.investigate_context", layer=layer, index=index, query=query
|
||||
)
|
||||
entry = self.scene.layered_history[layer][index]
|
||||
|
||||
|
||||
layer_to_investigate = layer - 1
|
||||
|
||||
|
||||
start = max(entry["start"] - pad_entries, 0)
|
||||
end = entry["end"] + pad_entries + 1
|
||||
|
||||
|
||||
if layer_to_investigate == -1:
|
||||
entries = self.scene.archived_history[start:end]
|
||||
else:
|
||||
entries = self.scene.layered_history[layer_to_investigate][start:end]
|
||||
|
||||
async def answer(query:str, instructions:str) -> str:
|
||||
log.debug("Answering context investigation", query=query, instructions=answer)
|
||||
|
||||
|
||||
async def answer(query: str, instructions: str) -> str:
|
||||
log.debug(
|
||||
"Answering context investigation", query=query, instructions=answer
|
||||
)
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
|
||||
return await world_state.analyze_history_and_follow_instructions(
|
||||
entries,
|
||||
f"{query}\n{instructions}",
|
||||
analysis=analysis,
|
||||
response_length=self.context_investigation_answer_length
|
||||
response_length=self.context_investigation_answer_length,
|
||||
)
|
||||
|
||||
|
||||
async def investigate_context(chapter_number:str, query:str) -> str:
|
||||
async def investigate_context(chapter_number: str, query: str) -> str:
|
||||
# look for \d.\d in the chapter number, extract as layer and index
|
||||
match = re.match(r"(\d+)\.(\d+)", chapter_number)
|
||||
if not match:
|
||||
log.error("summarizer.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
|
||||
log.error(
|
||||
"summarizer.investigate_context",
|
||||
error="Invalid chapter number",
|
||||
chapter_number=chapter_number,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
layer = int(match.group(1))
|
||||
index = int(match.group(2))
|
||||
|
||||
return await self.investigate_context(layer-1, index-1, query, analysis=analysis, max_calls=max_calls)
|
||||
|
||||
|
||||
return await self.investigate_context(
|
||||
layer - 1, index - 1, query, analysis=analysis, max_calls=max_calls
|
||||
)
|
||||
|
||||
async def abort():
|
||||
log.debug("Aborting context investigation")
|
||||
|
||||
|
||||
focal_handler: focal.Focal = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name="investigate_context",
|
||||
arguments = [
|
||||
arguments=[
|
||||
focal.Argument(name="chapter_number", type="str"),
|
||||
focal.Argument(name="query", type="str")
|
||||
focal.Argument(name="query", type="str"),
|
||||
],
|
||||
fn=investigate_context
|
||||
fn=investigate_context,
|
||||
),
|
||||
focal.Callback(
|
||||
name="answer",
|
||||
arguments = [
|
||||
arguments=[
|
||||
focal.Argument(name="instructions", type="str"),
|
||||
focal.Argument(name="query", type="str")
|
||||
focal.Argument(name="query", type="str"),
|
||||
],
|
||||
fn=answer
|
||||
fn=answer,
|
||||
),
|
||||
focal.Callback(
|
||||
name="abort",
|
||||
fn=abort
|
||||
)
|
||||
focal.Callback(name="abort", fn=abort),
|
||||
],
|
||||
max_calls=max_calls,
|
||||
scene=self.scene,
|
||||
@@ -307,84 +317,86 @@ class ContextInvestigationMixin:
|
||||
entries=entries,
|
||||
analysis=analysis,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"summarizer.investigate-context",
|
||||
)
|
||||
|
||||
|
||||
log.debug("summarizer.investigate_context", calls=focal_handler.state.calls)
|
||||
|
||||
return focal_handler.state.calls
|
||||
|
||||
return focal_handler.state.calls
|
||||
|
||||
@set_processing
|
||||
async def request_context_investigations(
|
||||
self,
|
||||
analysis:str,
|
||||
max_calls:int=3,
|
||||
self,
|
||||
analysis: str,
|
||||
max_calls: int = 3,
|
||||
) -> list[focal.Call]:
|
||||
|
||||
"""
|
||||
Requests context investigations for the given analysis.
|
||||
"""
|
||||
|
||||
|
||||
async def abort():
|
||||
log.debug("Aborting context investigations")
|
||||
|
||||
async def investigate_context(chapter_number:str, query:str) -> str:
|
||||
|
||||
async def investigate_context(chapter_number: str, query: str) -> str:
|
||||
# look for \d.\d in the chapter number, extract as layer and index
|
||||
match = re.match(r"(\d+)\.(\d+)", chapter_number)
|
||||
if not match:
|
||||
log.error("summarizer.request_context_investigations.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
|
||||
log.error(
|
||||
"summarizer.request_context_investigations.investigate_context",
|
||||
error="Invalid chapter number",
|
||||
chapter_number=chapter_number,
|
||||
)
|
||||
return ""
|
||||
|
||||
layer = int(match.group(1))
|
||||
index = int(match.group(2))
|
||||
|
||||
|
||||
num_layers = len(self.scene.layered_history)
|
||||
|
||||
return await self.investigate_context(num_layers - layer, index-1, query, analysis, max_calls=max_calls)
|
||||
|
||||
|
||||
return await self.investigate_context(
|
||||
num_layers - layer, index - 1, query, analysis, max_calls=max_calls
|
||||
)
|
||||
|
||||
focal_handler: focal.Focal = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name="investigate_context",
|
||||
arguments = [
|
||||
arguments=[
|
||||
focal.Argument(name="chapter_number", type="str"),
|
||||
focal.Argument(name="query", type="str")
|
||||
focal.Argument(name="query", type="str"),
|
||||
],
|
||||
fn=investigate_context
|
||||
fn=investigate_context,
|
||||
),
|
||||
focal.Callback(
|
||||
name="abort",
|
||||
fn=abort
|
||||
)
|
||||
focal.Callback(name="abort", fn=abort),
|
||||
],
|
||||
max_calls=max_calls,
|
||||
scene=self.scene,
|
||||
text=analysis
|
||||
text=analysis,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"summarizer.request-context-investigation",
|
||||
)
|
||||
|
||||
log.debug("summarizer.request_context_investigations", calls=focal_handler.state.calls)
|
||||
|
||||
return focal.collect_calls(
|
||||
focal_handler.state.calls,
|
||||
nested=True,
|
||||
filter=lambda c: c.name == "answer"
|
||||
|
||||
log.debug(
|
||||
"summarizer.request_context_investigations", calls=focal_handler.state.calls
|
||||
)
|
||||
|
||||
# return focal_handler.state.calls
|
||||
|
||||
|
||||
return focal.collect_calls(
|
||||
focal_handler.state.calls, nested=True, filter=lambda c: c.name == "answer"
|
||||
)
|
||||
|
||||
# return focal_handler.state.calls
|
||||
|
||||
@set_processing
|
||||
async def update_context_investigation(
|
||||
self,
|
||||
current_context_investigation:str,
|
||||
new_context_investigation:str,
|
||||
analysis:str,
|
||||
current_context_investigation: str,
|
||||
new_context_investigation: str,
|
||||
analysis: str,
|
||||
):
|
||||
response = await Prompt.request(
|
||||
"summarizer.update-context-investigation",
|
||||
@@ -398,5 +410,5 @@ class ContextInvestigationMixin:
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
return response.strip()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
@@ -24,11 +24,12 @@ talemate.emit.async_signals.register(
|
||||
"agent.summarization.layered_history.finalize",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LayeredHistoryFinalizeEmission(AgentEmission):
|
||||
entry: LayeredArchiveEntry | None = None
|
||||
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@property
|
||||
def response(self) -> str | None:
|
||||
return self.entry.text if self.entry else None
|
||||
@@ -38,22 +39,23 @@ class LayeredHistoryFinalizeEmission(AgentEmission):
|
||||
if self.entry:
|
||||
self.entry.text = value
|
||||
|
||||
|
||||
class SummaryLongerThanOriginalError(ValueError):
|
||||
def __init__(self, original_length:int, summarized_length:int):
|
||||
def __init__(self, original_length: int, summarized_length: int):
|
||||
self.original_length = original_length
|
||||
self.summarized_length = summarized_length
|
||||
super().__init__(f"Summarized text is longer than original text: {summarized_length} > {original_length}")
|
||||
super().__init__(
|
||||
f"Summarized text is longer than original text: {summarized_length} > {original_length}"
|
||||
)
|
||||
|
||||
|
||||
class LayeredHistoryMixin:
|
||||
|
||||
"""
|
||||
Summarizer agent mixin that provides functionality for maintaining a layered history.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
|
||||
actions["layered_history"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
@@ -80,7 +82,7 @@ class LayeredHistoryMixin:
|
||||
max=5,
|
||||
step=1,
|
||||
value=3,
|
||||
),
|
||||
),
|
||||
"max_process_tokens": AgentActionConfig(
|
||||
type="number",
|
||||
label="Maximum tokens to process",
|
||||
@@ -116,69 +118,71 @@ class LayeredHistoryMixin:
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
{"label": "Exhaustive (2048)", "value": "2048"},
|
||||
]
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_enabled(self):
|
||||
return self.actions["layered_history"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_threshold(self):
|
||||
return self.actions["layered_history"].config["threshold"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_max_process_tokens(self):
|
||||
return self.actions["layered_history"].config["max_process_tokens"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_max_layers(self):
|
||||
return self.actions["layered_history"].config["max_layers"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_chunk_size(self) -> int:
|
||||
return self.actions["layered_history"].config["chunk_size"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_analyze_chunks(self) -> bool:
|
||||
return self.actions["layered_history"].config["analyze_chunks"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_response_length(self) -> int:
|
||||
return int(self.actions["layered_history"].config["response_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_available(self):
|
||||
return self.layered_history_enabled and self.scene.layered_history and self.scene.layered_history[0]
|
||||
|
||||
|
||||
return (
|
||||
self.layered_history_enabled
|
||||
and self.scene.layered_history
|
||||
and self.scene.layered_history[0]
|
||||
)
|
||||
|
||||
# signals
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.summarization.after_build_archive").connect(
|
||||
self.on_after_build_archive
|
||||
)
|
||||
|
||||
|
||||
async def on_after_build_archive(self, emission:"BuildArchiveEmission"):
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.after_build_archive"
|
||||
).connect(self.on_after_build_archive)
|
||||
|
||||
async def on_after_build_archive(self, emission: "BuildArchiveEmission"):
|
||||
"""
|
||||
After the archive has been built, we will update the layered history.
|
||||
"""
|
||||
|
||||
|
||||
if self.layered_history_enabled:
|
||||
await self.summarize_to_layered_history(
|
||||
generation_options=emission.generation_options
|
||||
)
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
async def _lh_split_and_summarize_chunks(
|
||||
self,
|
||||
self,
|
||||
chunks: list[dict],
|
||||
extra_context: str,
|
||||
generation_options: GenerationOptions | None = None,
|
||||
@@ -189,21 +193,29 @@ class LayeredHistoryMixin:
|
||||
"""
|
||||
summaries = []
|
||||
current_chunk = chunks.copy()
|
||||
|
||||
|
||||
while current_chunk:
|
||||
partial_chunk = []
|
||||
max_process_tokens = self.layered_history_max_process_tokens
|
||||
|
||||
|
||||
# Build partial chunk up to max_process_tokens
|
||||
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
|
||||
while (
|
||||
current_chunk
|
||||
and util.count_tokens(
|
||||
"\n\n".join(chunk["text"] for chunk in partial_chunk)
|
||||
)
|
||||
< max_process_tokens
|
||||
):
|
||||
partial_chunk.append(current_chunk.pop(0))
|
||||
|
||||
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
|
||||
|
||||
log.debug("_split_and_summarize_chunks",
|
||||
tokens_in_chunk=util.count_tokens(text_to_summarize),
|
||||
max_process_tokens=max_process_tokens)
|
||||
|
||||
|
||||
text_to_summarize = "\n\n".join(chunk["text"] for chunk in partial_chunk)
|
||||
|
||||
log.debug(
|
||||
"_split_and_summarize_chunks",
|
||||
tokens_in_chunk=util.count_tokens(text_to_summarize),
|
||||
max_process_tokens=max_process_tokens,
|
||||
)
|
||||
|
||||
summary_text = await self.summarize_events(
|
||||
text_to_summarize,
|
||||
extra_context=extra_context + "\n\n".join(summaries),
|
||||
@@ -213,9 +225,9 @@ class LayeredHistoryMixin:
|
||||
chunk_size=self.layered_history_chunk_size,
|
||||
)
|
||||
summaries.append(summary_text)
|
||||
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
|
||||
"""
|
||||
Validates that the summarized text is not longer than the original.
|
||||
@@ -224,17 +236,19 @@ class LayeredHistoryMixin:
|
||||
summarized_length = util.count_tokens(summaries)
|
||||
if summarized_length > original_length:
|
||||
raise SummaryLongerThanOriginalError(original_length, summarized_length)
|
||||
|
||||
log.debug("_validate_summary_length",
|
||||
original_length=original_length,
|
||||
summarized_length=summarized_length)
|
||||
|
||||
|
||||
log.debug(
|
||||
"_validate_summary_length",
|
||||
original_length=original_length,
|
||||
summarized_length=summarized_length,
|
||||
)
|
||||
|
||||
def _lh_build_extra_context(self, layer_index: int) -> str:
|
||||
"""
|
||||
Builds extra context from compiled layered history for the given layer.
|
||||
"""
|
||||
return "\n\n".join(self.compile_layered_history(layer_index))
|
||||
|
||||
|
||||
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
|
||||
"""
|
||||
Extracts timestamps from a chunk of entries.
|
||||
@@ -242,144 +256,156 @@ class LayeredHistoryMixin:
|
||||
"""
|
||||
if not chunk:
|
||||
return "PT1S", "PT1S", "PT1S"
|
||||
|
||||
ts = chunk[0].get('ts', 'PT1S')
|
||||
ts_start = chunk[0].get('ts_start', ts)
|
||||
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
|
||||
|
||||
|
||||
ts = chunk[0].get("ts", "PT1S")
|
||||
ts_start = chunk[0].get("ts_start", ts)
|
||||
ts_end = chunk[-1].get("ts_end", chunk[-1].get("ts", ts))
|
||||
|
||||
return ts, ts_start, ts_end
|
||||
|
||||
|
||||
async def _lh_finalize_archive_entry(
|
||||
self,
|
||||
self,
|
||||
entry: LayeredArchiveEntry,
|
||||
summarization_history: list[str] | None = None,
|
||||
) -> LayeredArchiveEntry:
|
||||
"""
|
||||
Finalizes an archive entry by summarizing it and adding it to the layered history.
|
||||
"""
|
||||
|
||||
|
||||
emission = LayeredHistoryFinalizeEmission(
|
||||
agent=self,
|
||||
entry=entry,
|
||||
summarization_history=summarization_history,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.layered_history.finalize"
|
||||
).send(emission)
|
||||
|
||||
return emission.entry
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
def compile_layered_history(
|
||||
self,
|
||||
for_layer_index:int = None,
|
||||
as_objects:bool=False,
|
||||
include_base_layer:bool=False,
|
||||
max:int = None,
|
||||
self,
|
||||
for_layer_index: int = None,
|
||||
as_objects: bool = False,
|
||||
include_base_layer: bool = False,
|
||||
max: int = None,
|
||||
base_layer_end_id: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Starts at the last layer and compiles the layered history into a single
|
||||
list of events.
|
||||
|
||||
|
||||
We are iterating backwards, so the last layer will be the most granular.
|
||||
|
||||
|
||||
Each preceeding layer starts from the end of the the next layer.
|
||||
"""
|
||||
|
||||
|
||||
layered_history = self.scene.layered_history
|
||||
compiled = []
|
||||
next_layer_start = None
|
||||
|
||||
|
||||
len_layered_history = len(layered_history)
|
||||
|
||||
|
||||
for i in range(len_layered_history - 1, -1, -1):
|
||||
|
||||
if for_layer_index is not None:
|
||||
if i < for_layer_index:
|
||||
break
|
||||
|
||||
|
||||
log.debug("compilelayered history", i=i, next_layer_start=next_layer_start)
|
||||
|
||||
|
||||
if not layered_history[i]:
|
||||
continue
|
||||
|
||||
|
||||
entry_num = 1
|
||||
|
||||
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
|
||||
|
||||
|
||||
for layered_history_entry in layered_history[i][
|
||||
next_layer_start if next_layer_start is not None else 0 :
|
||||
]:
|
||||
if base_layer_end_id:
|
||||
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
|
||||
index=0,
|
||||
layer=i+1,
|
||||
**layered_history_entry)
|
||||
contained = entry_contained(
|
||||
self.scene,
|
||||
base_layer_end_id,
|
||||
HistoryEntry(index=0, layer=i + 1, **layered_history_entry),
|
||||
)
|
||||
if contained:
|
||||
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
|
||||
log.debug(
|
||||
"compile_layered_history",
|
||||
contained=True,
|
||||
base_layer_end_id=base_layer_end_id,
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
text = f"{layered_history_entry['text']}"
|
||||
|
||||
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
|
||||
|
||||
if (
|
||||
for_layer_index == i
|
||||
and max is not None
|
||||
and max <= layered_history_entry["end"]
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
if as_objects:
|
||||
compiled.append({
|
||||
"text": text,
|
||||
"start": layered_history_entry["start"],
|
||||
"end": layered_history_entry["end"],
|
||||
"layer": i,
|
||||
"layer_r": len_layered_history - i,
|
||||
"ts_start": layered_history_entry["ts_start"],
|
||||
"index": entry_num,
|
||||
})
|
||||
compiled.append(
|
||||
{
|
||||
"text": text,
|
||||
"start": layered_history_entry["start"],
|
||||
"end": layered_history_entry["end"],
|
||||
"layer": i,
|
||||
"layer_r": len_layered_history - i,
|
||||
"ts_start": layered_history_entry["ts_start"],
|
||||
"index": entry_num,
|
||||
}
|
||||
)
|
||||
entry_num += 1
|
||||
else:
|
||||
compiled.append(text)
|
||||
|
||||
|
||||
next_layer_start = layered_history_entry["end"] + 1
|
||||
|
||||
|
||||
if i == 0 and include_base_layer:
|
||||
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested
|
||||
# so we append the base layer to the compiled list, starting from
|
||||
# index `next_layer_start`
|
||||
|
||||
|
||||
entry_num = 1
|
||||
|
||||
for ah in self.scene.archived_history[next_layer_start or 0:]:
|
||||
|
||||
|
||||
for ah in self.scene.archived_history[next_layer_start or 0 :]:
|
||||
if base_layer_end_id and ah["id"] == base_layer_end_id:
|
||||
break
|
||||
|
||||
|
||||
text = f"{ah['text']}"
|
||||
if as_objects:
|
||||
compiled.append({
|
||||
"text": text,
|
||||
"start": ah["start"],
|
||||
"end": ah["end"],
|
||||
"layer": -1,
|
||||
"layer_r": 1,
|
||||
"ts": ah["ts"],
|
||||
"index": entry_num,
|
||||
})
|
||||
compiled.append(
|
||||
{
|
||||
"text": text,
|
||||
"start": ah["start"],
|
||||
"end": ah["end"],
|
||||
"layer": -1,
|
||||
"layer_r": 1,
|
||||
"ts": ah["ts"],
|
||||
"index": entry_num,
|
||||
}
|
||||
)
|
||||
entry_num += 1
|
||||
else:
|
||||
compiled.append(text)
|
||||
|
||||
|
||||
return compiled
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize_to_layered_history(self, generation_options: GenerationOptions | None = None):
|
||||
|
||||
async def summarize_to_layered_history(
|
||||
self, generation_options: GenerationOptions | None = None
|
||||
):
|
||||
"""
|
||||
The layered history is a summarized archive with dynamic layers that
|
||||
will get less and less granular as the scene progresses.
|
||||
|
||||
|
||||
The most granular is still self.scene.archived_history, which holds
|
||||
all the base layer summarizations.
|
||||
|
||||
|
||||
self.scene.layered_history = [
|
||||
# first layer after archived_history
|
||||
[
|
||||
@@ -391,7 +417,7 @@ class LayeredHistoryMixin:
|
||||
},
|
||||
...
|
||||
],
|
||||
|
||||
|
||||
# second layer
|
||||
[
|
||||
{
|
||||
@@ -402,29 +428,29 @@ class LayeredHistoryMixin:
|
||||
},
|
||||
...
|
||||
],
|
||||
|
||||
|
||||
# additional layers
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
The same token threshold as for the base layer will be used for the
|
||||
layers.
|
||||
|
||||
|
||||
The same summarization function will be used for the layers.
|
||||
|
||||
|
||||
The next level layer will be generated automatically when the token
|
||||
threshold is reached.
|
||||
"""
|
||||
|
||||
|
||||
if not self.scene.archived_history:
|
||||
return # No base layer summaries to work with
|
||||
|
||||
|
||||
token_threshold = self.layered_history_threshold
|
||||
max_layers = self.layered_history_max_layers
|
||||
|
||||
if not hasattr(self.scene, 'layered_history'):
|
||||
if not hasattr(self.scene, "layered_history"):
|
||||
self.scene.layered_history = []
|
||||
|
||||
|
||||
layered_history = self.scene.layered_history
|
||||
|
||||
async def summarize_layer(source_layer, next_layer_index, start_from) -> bool:
|
||||
@@ -432,147 +458,192 @@ class LayeredHistoryMixin:
|
||||
current_tokens = 0
|
||||
start_index = start_from
|
||||
noop = True
|
||||
|
||||
total_tokens_in_previous_layer = util.count_tokens([
|
||||
entry['text'] for entry in source_layer
|
||||
])
|
||||
|
||||
total_tokens_in_previous_layer = util.count_tokens(
|
||||
[entry["text"] for entry in source_layer]
|
||||
)
|
||||
estimated_entries = total_tokens_in_previous_layer // token_threshold
|
||||
|
||||
for i in range(start_from, len(source_layer)):
|
||||
entry = source_layer[i]
|
||||
entry_tokens = util.count_tokens(entry['text'])
|
||||
|
||||
log.debug("summarize_to_layered_history", entry=entry["text"][:100]+"...", tokens=entry_tokens, current_layer=next_layer_index-1)
|
||||
|
||||
entry_tokens = util.count_tokens(entry["text"])
|
||||
|
||||
log.debug(
|
||||
"summarize_to_layered_history",
|
||||
entry=entry["text"][:100] + "...",
|
||||
tokens=entry_tokens,
|
||||
current_layer=next_layer_index - 1,
|
||||
)
|
||||
|
||||
if current_tokens + entry_tokens > token_threshold:
|
||||
if current_chunk:
|
||||
|
||||
try:
|
||||
# check if the next layer exists
|
||||
next_layer = layered_history[next_layer_index]
|
||||
except IndexError:
|
||||
# create the next layer
|
||||
layered_history.append([])
|
||||
log.debug("summarize_to_layered_history", created_layer=next_layer_index)
|
||||
log.debug(
|
||||
"summarize_to_layered_history",
|
||||
created_layer=next_layer_index,
|
||||
)
|
||||
next_layer = layered_history[next_layer_index]
|
||||
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
|
||||
|
||||
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(
|
||||
current_chunk
|
||||
)
|
||||
|
||||
extra_context = self._lh_build_extra_context(next_layer_index)
|
||||
|
||||
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
|
||||
text_length = util.count_tokens(
|
||||
"\n\n".join(chunk["text"] for chunk in current_chunk)
|
||||
)
|
||||
|
||||
num_entries_in_layer = len(layered_history[next_layer_index])
|
||||
|
||||
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
|
||||
|
||||
emit(
|
||||
"status",
|
||||
status="busy",
|
||||
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}",
|
||||
data={"cancellable": True},
|
||||
)
|
||||
|
||||
summaries = await self._lh_split_and_summarize_chunks(
|
||||
current_chunk,
|
||||
extra_context,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
noop = False
|
||||
|
||||
|
||||
# validate summary length
|
||||
self._lh_validate_summary_length(summaries, text_length)
|
||||
|
||||
next_layer.append(LayeredArchiveEntry(**{
|
||||
"start": start_index,
|
||||
"end": i,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
}).model_dump(exclude_none=True))
|
||||
|
||||
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
|
||||
|
||||
|
||||
next_layer.append(
|
||||
LayeredArchiveEntry(
|
||||
**{
|
||||
"start": start_index,
|
||||
"end": i,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
}
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
emit(
|
||||
"status",
|
||||
status="busy",
|
||||
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer + 1} / {estimated_entries}",
|
||||
)
|
||||
|
||||
current_chunk = []
|
||||
current_tokens = 0
|
||||
start_index = i
|
||||
|
||||
current_chunk.append(entry)
|
||||
current_tokens += entry_tokens
|
||||
|
||||
log.debug("summarize_to_layered_history", tokens=current_tokens, threshold=token_threshold, next_layer=next_layer_index)
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarize_to_layered_history",
|
||||
tokens=current_tokens,
|
||||
threshold=token_threshold,
|
||||
next_layer=next_layer_index,
|
||||
)
|
||||
|
||||
return not noop
|
||||
|
||||
|
||||
|
||||
# First layer (always the base layer)
|
||||
has_been_updated = False
|
||||
|
||||
|
||||
try:
|
||||
|
||||
if not layered_history:
|
||||
layered_history.append([])
|
||||
log.debug("summarize_to_layered_history", layer="base", new_layer=True)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
|
||||
has_been_updated = await summarize_layer(
|
||||
self.scene.archived_history, 0, 0
|
||||
)
|
||||
elif layered_history[0]:
|
||||
# determine starting point by checking for `end` in the last entry
|
||||
last_entry = layered_history[0][-1]
|
||||
end = last_entry["end"]
|
||||
log.debug("summarize_to_layered_history", layer="base", start=end)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
|
||||
has_been_updated = await summarize_layer(
|
||||
self.scene.archived_history, 0, end
|
||||
)
|
||||
else:
|
||||
log.debug("summarize_to_layered_history", layer="base", empty=True)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
|
||||
|
||||
has_been_updated = await summarize_layer(
|
||||
self.scene.archived_history, 0, 0
|
||||
)
|
||||
|
||||
except SummaryLongerThanOriginalError as exc:
|
||||
log.error("summarize_to_layered_history", error=exc, layer="base")
|
||||
emit("status", status="error", message="Layered history update failed.")
|
||||
return
|
||||
except GenerationCancelled as e:
|
||||
log.info("Generation cancelled, stopping rebuild of historical layered history")
|
||||
emit("status", message="Rebuilding of layered history cancelled", status="info")
|
||||
log.info(
|
||||
"Generation cancelled, stopping rebuild of historical layered history"
|
||||
)
|
||||
emit(
|
||||
"status",
|
||||
message="Rebuilding of layered history cancelled",
|
||||
status="info",
|
||||
)
|
||||
handle_generation_cancelled(e)
|
||||
return
|
||||
|
||||
|
||||
# process layers
|
||||
async def update_layers() -> bool:
|
||||
noop = True
|
||||
for index in range(0, len(layered_history)):
|
||||
|
||||
# check against max layers
|
||||
if index + 1 > max_layers:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# check if the next layer exists
|
||||
next_layer = layered_history[index + 1]
|
||||
except IndexError:
|
||||
next_layer = None
|
||||
|
||||
|
||||
end = next_layer[-1]["end"] if next_layer else 0
|
||||
|
||||
|
||||
log.debug("summarize_to_layered_history", layer=index, start=end)
|
||||
summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
|
||||
|
||||
summarized = await summarize_layer(
|
||||
layered_history[index], index + 1, end if end else 0
|
||||
)
|
||||
|
||||
if summarized:
|
||||
noop = False
|
||||
|
||||
|
||||
return not noop
|
||||
|
||||
|
||||
try:
|
||||
while await update_layers():
|
||||
has_been_updated = True
|
||||
if has_been_updated:
|
||||
emit("status", status="success", message="Layered history updated.")
|
||||
|
||||
|
||||
except SummaryLongerThanOriginalError as exc:
|
||||
log.error("summarize_to_layered_history", error=exc, layer="subsequent")
|
||||
emit("status", status="error", message="Layered history update failed.")
|
||||
return
|
||||
except GenerationCancelled as e:
|
||||
log.info("Generation cancelled, stopping rebuild of historical layered history")
|
||||
emit("status", message="Rebuilding of layered history cancelled", status="info")
|
||||
log.info(
|
||||
"Generation cancelled, stopping rebuild of historical layered history"
|
||||
)
|
||||
emit(
|
||||
"status",
|
||||
message="Rebuilding of layered history cancelled",
|
||||
status="info",
|
||||
)
|
||||
handle_generation_cancelled(e)
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def summarize_entries_to_layered_history(
|
||||
self,
|
||||
entries: list[dict],
|
||||
self,
|
||||
entries: list[dict],
|
||||
next_layer_index: int,
|
||||
start_index: int,
|
||||
end_index: int,
|
||||
@@ -580,11 +651,11 @@ class LayeredHistoryMixin:
|
||||
) -> list[LayeredArchiveEntry]:
|
||||
"""
|
||||
Summarizes a list of entries into layered history entries.
|
||||
|
||||
|
||||
This method is used for regenerating specific history entries by processing
|
||||
their source entries. It chunks the entries based on the token threshold and
|
||||
summarizes each chunk into a LayeredArchiveEntry.
|
||||
|
||||
|
||||
Args:
|
||||
entries: List of dictionaries containing the text entries to summarize.
|
||||
Each entry should have at least a 'text' field and optionally
|
||||
@@ -597,12 +668,12 @@ class LayeredHistoryMixin:
|
||||
correspond to.
|
||||
generation_options: Optional generation options to pass to the summarization
|
||||
process.
|
||||
|
||||
|
||||
Returns:
|
||||
List of LayeredArchiveEntry objects containing the summarized text along
|
||||
with timestamp and index information. Currently returns a list with a
|
||||
single entry, but the structure supports multiple entries if needed.
|
||||
|
||||
|
||||
Notes:
|
||||
- The method respects the layered_history_threshold for chunking
|
||||
- Uses helper methods for timestamp extraction, context building, and
|
||||
@@ -611,63 +682,73 @@ class LayeredHistoryMixin:
|
||||
- The last entry is always included in the final chunk if it doesn't
|
||||
exceed the token threshold
|
||||
"""
|
||||
|
||||
|
||||
token_threshold = self.layered_history_threshold
|
||||
|
||||
|
||||
archive_entries = []
|
||||
summaries = []
|
||||
current_chunk = []
|
||||
current_tokens = 0
|
||||
|
||||
|
||||
ts = "PT1S"
|
||||
ts_start = "PT1S"
|
||||
ts_end = "PT1S"
|
||||
|
||||
|
||||
|
||||
for entry_index, entry in enumerate(entries):
|
||||
is_last_entry = entry_index == len(entries) - 1
|
||||
entry_tokens = util.count_tokens(entry['text'])
|
||||
|
||||
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
|
||||
|
||||
entry_tokens = util.count_tokens(entry["text"])
|
||||
|
||||
log.debug(
|
||||
"summarize_entries_to_layered_history",
|
||||
entry=entry["text"][:100] + "...",
|
||||
entry_tokens=entry_tokens,
|
||||
current_layer=next_layer_index - 1,
|
||||
current_tokens=current_tokens,
|
||||
)
|
||||
|
||||
if current_tokens + entry_tokens > token_threshold or is_last_entry:
|
||||
|
||||
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
|
||||
# if we are here because this is the last entry and adding it to
|
||||
# the current chunk would not exceed the token threshold, we will
|
||||
# add it to the current chunk
|
||||
current_chunk.append(entry)
|
||||
|
||||
|
||||
if current_chunk:
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
|
||||
|
||||
extra_context = self._lh_build_extra_context(next_layer_index)
|
||||
|
||||
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
|
||||
text_length = util.count_tokens(
|
||||
"\n\n".join(chunk["text"] for chunk in current_chunk)
|
||||
)
|
||||
|
||||
summaries = await self._lh_split_and_summarize_chunks(
|
||||
current_chunk,
|
||||
extra_context,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
|
||||
# validate summary length
|
||||
self._lh_validate_summary_length(summaries, text_length)
|
||||
|
||||
archive_entry = LayeredArchiveEntry(**{
|
||||
"start": start_index,
|
||||
"end": end_index,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
})
|
||||
|
||||
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
|
||||
|
||||
|
||||
archive_entry = LayeredArchiveEntry(
|
||||
**{
|
||||
"start": start_index,
|
||||
"end": end_index,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
}
|
||||
)
|
||||
|
||||
archive_entry = await self._lh_finalize_archive_entry(
|
||||
archive_entry, extra_context.split("\n\n")
|
||||
)
|
||||
|
||||
archive_entries.append(archive_entry)
|
||||
|
||||
current_chunk.append(entry)
|
||||
current_tokens += entry_tokens
|
||||
|
||||
|
||||
return archive_entries
|
||||
|
||||
59
src/talemate/agents/summarize/tts_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import structlog
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
)
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.status import set_loading
|
||||
from talemate.util.dialogue import separate_dialogue_from_exposition
|
||||
|
||||
log = structlog.get_logger("talemate.agents.summarize.tts_utils")
|
||||
|
||||
|
||||
class TTSUtilsMixin:
|
||||
"""
|
||||
Summarizer Mixin for text-to-speech utilities.
|
||||
"""
|
||||
|
||||
@set_loading("Preparing TTS context")
|
||||
@set_processing
|
||||
async def markup_context_for_tts(self, text: str) -> str:
|
||||
"""
|
||||
Markup the context for text-to-speech.
|
||||
"""
|
||||
|
||||
original_text = text
|
||||
|
||||
log.debug("Markup context for TTS", text=text)
|
||||
|
||||
# if there are no quotes in the text, there is nothing to separate
|
||||
if '"' not in text:
|
||||
return original_text
|
||||
|
||||
# here we separate dialogue from exposition because into
|
||||
# obvious segments. It seems to have a positive effect on some
|
||||
# LLMs returning the complete text.
|
||||
separate_chunks = separate_dialogue_from_exposition(text)
|
||||
|
||||
numbered_chunks = []
|
||||
for i, chunk in enumerate(separate_chunks):
|
||||
numbered_chunks.append(f"[{i + 1}] {chunk.text.strip()}")
|
||||
|
||||
text = "\n".join(numbered_chunks)
|
||||
|
||||
response = await Prompt.request(
|
||||
"summarizer.markup-context-for-tts",
|
||||
self.client,
|
||||
"investigate_1024",
|
||||
vars={
|
||||
"text": text,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.split("<MARKUP>")[1].split("</MARKUP>")[0].strip()
|
||||
return response
|
||||
except IndexError:
|
||||
log.error("Failed to extract markup from response", response=response)
|
||||
return original_text
|
||||
@@ -1,673 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
import nltk
|
||||
import pydantic
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import talemate.config as config
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts") #
|
||||
|
||||
if not TTS:
|
||||
# TTS installation is massive and requires a lot of dependencies
|
||||
# so we don't want to require it unless the user wants to use it
|
||||
log.info(
|
||||
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
|
||||
)
|
||||
|
||||
|
||||
def parse_chunks(text: str) -> list[str]:
|
||||
|
||||
"""
|
||||
Takes a string and splits it into chunks based on punctuation.
|
||||
|
||||
In case of an error it will return the original text as a single chunk and
|
||||
the error will be logged.
|
||||
"""
|
||||
|
||||
try:
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*", "")
|
||||
if not chunk:
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
return cleaned_chunks
|
||||
except Exception as e:
|
||||
log.error("chunking error", error=e, text=text)
|
||||
return [text.replace("__ellipsis__", "...").replace("*", "")]
|
||||
|
||||
|
||||
def clean_quotes(chunk: str):
|
||||
# if there is an uneven number of quotes, remove the last one if its
|
||||
# at the end of the chunk. If its in the middle, add a quote to the end
|
||||
if chunk.count('"') % 2 == 1:
|
||||
if chunk.endswith('"'):
|
||||
chunk = chunk[:-1]
|
||||
else:
|
||||
chunk += '"'
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
for chunk in chunks:
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
return joined_chunks
|
||||
|
||||
|
||||
class Voice(pydantic.BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
|
||||
|
||||
class VoiceLibrary(pydantic.BaseModel):
|
||||
api: str
|
||||
voices: list[Voice] = pydantic.Field(default_factory=list)
|
||||
last_synced: float = None
|
||||
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
essential = False
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
if agent:
|
||||
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
|
||||
voice.model_dump() for voice in agent.list_voices_sync()
|
||||
]
|
||||
|
||||
return config_options
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.is_enabled = False #
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
try:
|
||||
nltk.download("punkt", quiet=True)
|
||||
except Exception as e:
|
||||
log.error("nltk download error", error=e)
|
||||
except Exception as e:
|
||||
log.error("nltk find error", error=e)
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"tts": VoiceLibrary(api="tts"),
|
||||
"openai": VoiceLibrary(api="openai"),
|
||||
}
|
||||
self.config = config.load_config()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.preselect_voice = None
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"api": AgentActionConfig(
|
||||
type="text",
|
||||
choices=[
|
||||
{"value": "tts", "label": "TTS (Local)"},
|
||||
{"value": "elevenlabs", "label": "Eleven Labs"},
|
||||
{"value": "openai", "label": "OpenAI"},
|
||||
],
|
||||
value="tts",
|
||||
label="API",
|
||||
description="Which TTS API to use",
|
||||
onchange="emit",
|
||||
),
|
||||
"voice_id": AgentActionConfig(
|
||||
type="text",
|
||||
value="default",
|
||||
label="Narrator Voice",
|
||||
description="Voice ID/Name to use for TTS",
|
||||
choices=[],
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Generate for player",
|
||||
description="Generate audio for player messages",
|
||||
),
|
||||
"generate_for_npc": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Generate for NPCs",
|
||||
description="Generate audio for NPC messages",
|
||||
),
|
||||
"generate_for_narration": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Generate for narration",
|
||||
description="Generate audio for narration messages",
|
||||
),
|
||||
"generate_chunks": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Split generation",
|
||||
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
|
||||
),
|
||||
},
|
||||
),
|
||||
"openai": AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.api", value="openai"
|
||||
),
|
||||
label="OpenAI",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="tts-1",
|
||||
choices=[
|
||||
{"value": "tts-1", "label": "TTS 1"},
|
||||
{"value": "tts-1-hd", "label": "TTS 1 HD"},
|
||||
],
|
||||
label="Model",
|
||||
description="TTS model to use",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def not_ready_reason(self) -> str:
|
||||
"""
|
||||
Returns a string explaining why the agent is not ready
|
||||
"""
|
||||
|
||||
if self.ready:
|
||||
return ""
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "TTS not installed"
|
||||
|
||||
elif self.requires_token and not self.token:
|
||||
return "No API token"
|
||||
|
||||
elif not self.default_voice_id:
|
||||
return "No voice selected"
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
details = {
|
||||
"api": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.api_label,
|
||||
description="The backend to use for TTS",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if self.ready and self.enabled:
|
||||
details["voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=self.voice_id_to_label(self.default_voice_id) or "",
|
||||
description="The voice to use for TTS",
|
||||
color="info",
|
||||
).model_dump()
|
||||
elif self.enabled:
|
||||
details["error"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=self.not_ready_reason,
|
||||
description=self.not_ready_reason,
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
@property
|
||||
def api_label(self):
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api = self.api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
return choice["label"]
|
||||
return api
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
return self.config.get(api, {}).get("api_key")
|
||||
|
||||
@property
|
||||
def default_voice_id(self):
|
||||
return self.actions["_config"].config["voice_id"].value
|
||||
|
||||
@property
|
||||
def requires_token(self):
|
||||
return self.api != "tts"
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return False
|
||||
return True
|
||||
|
||||
return (not self.requires_token or self.token) and self.default_voice_id
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
if self.requires_token and not self.token:
|
||||
return "error"
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "error"
|
||||
return "uninitialized"
|
||||
|
||||
@property
|
||||
def max_generation_length(self):
|
||||
if self.api == "elevenlabs":
|
||||
return 1024
|
||||
elif self.api == "coqui":
|
||||
return 250
|
||||
|
||||
return 250
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
api = self.api
|
||||
|
||||
api_changed = api != self.api
|
||||
|
||||
# log.debug(
|
||||
# "apply_config",
|
||||
# api=api,
|
||||
# api_changed=api != self.api,
|
||||
# current_api=self.api,
|
||||
# args=args,
|
||||
# kwargs=kwargs,
|
||||
# )
|
||||
|
||||
try:
|
||||
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
|
||||
"value"
|
||||
]
|
||||
except KeyError:
|
||||
self.preselect_voice = self.default_voice_id
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
self.actions["_config"].config["voice_id"].value = (
|
||||
self.voices[api].voices[0].value
|
||||
)
|
||||
except IndexError:
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(
|
||||
self.on_game_loop_new_message
|
||||
)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
instance.emit_agent_status(self.__class__, self)
|
||||
|
||||
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
if not self.enabled or not self.ready:
|
||||
return
|
||||
|
||||
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(emission.message, NarratorMessage)
|
||||
and not self.actions["_config"].config["generate_for_narration"].value
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
if (
|
||||
emission.message.source == "player"
|
||||
and not self.actions["_config"].config["generate_for_player"].value
|
||||
):
|
||||
return
|
||||
elif (
|
||||
emission.message.source == "ai"
|
||||
and not self.actions["_config"].config["generate_for_npc"].value
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
else:
|
||||
character_prefix = ""
|
||||
|
||||
log.info(
|
||||
"reactive tts", message=emission.message, character_prefix=character_prefix
|
||||
)
|
||||
|
||||
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
|
||||
|
||||
def voice(self, voice_id: str) -> Union[Voice, None]:
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice
|
||||
return None
|
||||
|
||||
def voice_id_to_label(self, voice_id: str):
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice.label
|
||||
return None
|
||||
|
||||
def list_voices_sync(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(self.list_voices())
|
||||
|
||||
async def list_voices(self):
|
||||
if self.requires_token and not self.token:
|
||||
return []
|
||||
|
||||
library = self.voices[self.api]
|
||||
|
||||
# TODO: allow re-syncing voices
|
||||
if library.last_synced:
|
||||
return library.voices
|
||||
|
||||
list_fn = getattr(self, f"_list_voices_{self.api}")
|
||||
log.info("Listing voices", api=self.api)
|
||||
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
if self.preselect_voice:
|
||||
if self.voice(self.preselect_voice):
|
||||
self.actions["_config"].config["voice_id"].value = self.preselect_voice
|
||||
self.preselect_voice = None
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
# set loading to false
|
||||
return library.voices
|
||||
|
||||
@set_processing
|
||||
async def generate(self, text: str):
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
generate_fn = getattr(self, f"_generate_{self.api}")
|
||||
|
||||
if self.actions["_config"].config["generate_chunks"].value:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks)
|
||||
else:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks, chunk_size=self.max_generation_length)
|
||||
|
||||
# Start generating audio chunks in the background
|
||||
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
|
||||
await self.set_background_processing(generation_task)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
# await asyncio.gather(generation_task)
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*", "").strip()
|
||||
log.info("Generating audio", api=self.api, chunk=chunk)
|
||||
audio_data = await generate_fn(chunk)
|
||||
self.play_audio(audio_data)
|
||||
|
||||
def play_audio(self, audio_data):
|
||||
# play audio through the python audio player
|
||||
# play(audio_data)
|
||||
|
||||
emit(
|
||||
"audio_queue",
|
||||
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
|
||||
)
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
# LOCAL
|
||||
|
||||
async def _generate_tts(self, text: str) -> Union[bytes, None]:
|
||||
if not TTS:
|
||||
return
|
||||
|
||||
tts_config = self.config.get("tts", {})
|
||||
model = tts_config.get("model")
|
||||
device = tts_config.get("device", "cpu")
|
||||
|
||||
log.debug("tts local", model=model, device=device)
|
||||
|
||||
if not hasattr(self, "tts_instance"):
|
||||
self.tts_instance = TTS(model).to(device)
|
||||
|
||||
tts = self.tts_instance
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
voice = self.voice(self.default_voice_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.tts_to_file,
|
||||
text=text,
|
||||
speaker_wav=voice.value,
|
||||
language="en",
|
||||
file_path=file_path,
|
||||
),
|
||||
)
|
||||
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
async def _list_voices_tts(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
|
||||
]
|
||||
|
||||
# ELEVENLABS
|
||||
|
||||
async def _generate_elevenlabs(
|
||||
self, text: str, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.default_voice_id}"
|
||||
headers = {
|
||||
"Accept": "audio/mpeg",
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": api_key,
|
||||
}
|
||||
data = {
|
||||
"text": text,
|
||||
"model_id": self.config.get("elevenlabs", {}).get("model"),
|
||||
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
||||
}
|
||||
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
|
||||
if response.status_code == 200:
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _list_voices_elevenlabs(self) -> dict[str, str]:
|
||||
url_voices = "https://api.elevenlabs.io/v1/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"xi-api-key": self.token,
|
||||
}
|
||||
response = await client.get(
|
||||
url_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["voices"]
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["voice_id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
|
||||
# OPENAI
|
||||
|
||||
async def _generate_openai(self, text: str, chunk_size: int = 1024):
|
||||
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
model = self.actions["openai"].config["model"].value
|
||||
|
||||
response = await client.audio.speech.create(
|
||||
model=model, voice=self.default_voice_id, input=text
|
||||
)
|
||||
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
|
||||
async def _list_voices_openai(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(value="alloy", label="Alloy"),
|
||||
Voice(value="echo", label="Echo"),
|
||||
Voice(value="fable", label="Fable"),
|
||||
Voice(value="onyx", label="Onyx"),
|
||||
Voice(value="nova", label="Nova"),
|
||||
Voice(value="shimmer", label="Shimmer"),
|
||||
]
|
||||
995
src/talemate/agents/tts/__init__.py
Normal file
@@ -0,0 +1,995 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import uuid
|
||||
from collections import deque
|
||||
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
import talemate.util.dialogue as dialogue_utils
|
||||
import talemate.emit.async_signals as async_signals
|
||||
import talemate.instance as instance
|
||||
from talemate.ux.schema import Note
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import (
|
||||
CharacterMessage,
|
||||
NarratorMessage,
|
||||
ContextInvestigationMessage,
|
||||
)
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
AgentActionNote,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
|
||||
from .schema import (
|
||||
APIStatus,
|
||||
Voice,
|
||||
VoiceLibrary,
|
||||
GenerationContext,
|
||||
Chunk,
|
||||
VoiceGenerationEmission,
|
||||
)
|
||||
from .providers import provider
|
||||
|
||||
import talemate.agents.tts.voice_library as voice_library
|
||||
|
||||
from .elevenlabs import ElevenLabsMixin
|
||||
from .openai import OpenAIMixin
|
||||
from .google import GoogleMixin
|
||||
from .kokoro import KokoroMixin
|
||||
from .chatterbox import ChatterboxMixin
|
||||
from .websocket_handler import TTSWebsocketHandler
|
||||
from .f5tts import F5TTSMixin
|
||||
|
||||
import talemate.agents.tts.nodes as tts_nodes # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.character import Character, VoiceChangedEvent
|
||||
from talemate.agents.summarize import SummarizeAgent
|
||||
from talemate.game.engine.nodes.scene import SceneLoopEvent
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts")
|
||||
|
||||
HOT_SWAP_NOTIFICATION_TIME = 60
|
||||
|
||||
VOICE_LIBRARY_NOTE = "Voices are not managed here, but in the voice library which can be accessed through the Talemate application bar at the top. When disabling/enabling APIS, close and open this window to refresh the choices."
|
||||
|
||||
async_signals.register(
|
||||
"agent.tts.prepare.before",
|
||||
"agent.tts.prepare.after",
|
||||
"agent.tts.generate.before",
|
||||
"agent.tts.generate.after",
|
||||
)
|
||||
|
||||
|
||||
def parse_chunks(text: str) -> list[str]:
|
||||
"""
|
||||
Takes a string and splits it into chunks based on punctuation.
|
||||
|
||||
In case of an error it will return the original text as a single chunk and
|
||||
the error will be logged.
|
||||
"""
|
||||
|
||||
try:
|
||||
text = text.replace("*", "")
|
||||
|
||||
# ensure sentence terminators are before quotes
|
||||
# otherwise the beginning of dialog will bleed into narration
|
||||
text = re.sub(r'([^.?!]+) "', r'\1. "', text)
|
||||
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.strip():
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
return cleaned_chunks
|
||||
except Exception as e:
|
||||
log.error("chunking error", error=e, text=text)
|
||||
return [text.replace("__ellipsis__", "...").replace("*", "")]
|
||||
|
||||
|
||||
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
for chunk in chunks:
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(current_chunk)
|
||||
return joined_chunks
|
||||
|
||||
|
||||
@register()
|
||||
class TTSAgent(
|
||||
ElevenLabsMixin,
|
||||
OpenAIMixin,
|
||||
GoogleMixin,
|
||||
KokoroMixin,
|
||||
ChatterboxMixin,
|
||||
F5TTSMixin,
|
||||
Agent,
|
||||
):
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
essential = False
|
||||
|
||||
# websocket handler for frontend voice library management
|
||||
websocket_handler = TTSWebsocketHandler
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
if not agent:
|
||||
return config_options
|
||||
|
||||
narrator_voice_id = config_options["actions"]["_config"]["config"][
|
||||
"narrator_voice_id"
|
||||
]
|
||||
|
||||
narrator_voice_id["choices"] = cls.narrator_voice_id_choices(agent)
|
||||
|
||||
return config_options
|
||||
|
||||
@classmethod
|
||||
def narrator_voice_id_choices(cls, agent: "TTSAgent") -> list[dict[str, str]]:
|
||||
choices = voice_library.voices_for_apis(agent.ready_apis, agent.voice_library)
|
||||
choices.sort(key=lambda x: x.label)
|
||||
return [
|
||||
{
|
||||
"label": f"{voice.label} ({voice.provider})",
|
||||
"value": voice.id,
|
||||
}
|
||||
for voice in choices
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls) -> dict[str, AgentAction]:
|
||||
actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"apis": AgentActionConfig(
|
||||
type="flags",
|
||||
value=[
|
||||
"kokoro",
|
||||
],
|
||||
label="Enabled APIs",
|
||||
description="APIs to use for TTS",
|
||||
choices=[],
|
||||
),
|
||||
"narrator_voice_id": AgentActionConfig(
|
||||
type="autocomplete",
|
||||
value="kokoro:am_adam",
|
||||
label="Narrator Voice",
|
||||
description="Voice to use for narration",
|
||||
choices=[],
|
||||
note=VOICE_LIBRARY_NOTE,
|
||||
),
|
||||
"speaker_separation": AgentActionConfig(
|
||||
type="text",
|
||||
value="simple",
|
||||
label="Speaker separation",
|
||||
description="How to separate speaker dialogue from exposition",
|
||||
choices=[
|
||||
{"label": "No separation", "value": "none"},
|
||||
{"label": "Simple", "value": "simple"},
|
||||
{"label": "Mixed", "value": "mixed"},
|
||||
{"label": "AI assisted", "value": "ai_assisted"},
|
||||
],
|
||||
note_on_value={
|
||||
"none": AgentActionNote(
|
||||
type="primary",
|
||||
text="Character messages will be voiced entirely by the character's voice with a fallback to the narrator voice if the character has no voice selecte. Narrator messages will be voiced exclusively by the narrator voice.",
|
||||
),
|
||||
"simple": AgentActionNote(
|
||||
type="primary",
|
||||
text="Exposition and dialogue will be separated in character messages. Narrator messages will be voiced exclusively by the narrator voice. This means",
|
||||
),
|
||||
"mixed": AgentActionNote(
|
||||
type="primary",
|
||||
text="A mix of `simple` and `ai_assisted`. Character messages are separated into narrator and the character's voice. Narrator messages that have dialogue are analyzed by the Summarizer agent to determine the appropriate speaker(s).",
|
||||
),
|
||||
"ai_assisted": AgentActionNote(
|
||||
type="primary",
|
||||
text="Appropriate speaker separation will be attempted based on the content of the message with help from the Summarizer agent. This sends an extra prompt to the LLM to determine the appropriate speaker(s).",
|
||||
),
|
||||
},
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Auto-generate for player",
|
||||
description="Generate audio for player messages",
|
||||
),
|
||||
"generate_for_npc": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Auto-generate for AI characters",
|
||||
description="Generate audio for NPC messages",
|
||||
),
|
||||
"generate_for_narration": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Auto-generate for narration",
|
||||
description="Generate audio for narration messages",
|
||||
),
|
||||
"generate_for_context_investigation": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Auto-generate for context investigation",
|
||||
description="Generate audio for context investigation messages",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
KokoroMixin.add_actions(actions)
|
||||
ChatterboxMixin.add_actions(actions)
|
||||
GoogleMixin.add_actions(actions)
|
||||
ElevenLabsMixin.add_actions(actions)
|
||||
OpenAIMixin.add_actions(actions)
|
||||
F5TTSMixin.add_actions(actions)
|
||||
|
||||
return actions
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.is_enabled = False # tts agent is disabled by default
|
||||
self.actions = TTSAgent.init_actions()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
|
||||
# Queue management for voice generation
|
||||
# Each queue instance gets a unique id so it can later be referenced
|
||||
# (e.g. for cancellation of all remaining items).
|
||||
# Only one queue can be active at a time. New generation requests that
|
||||
# arrive while a queue is processing will be appended to the same
|
||||
# queue. Once the queue is fully processed it is discarded and a new
|
||||
# one will be created for subsequent generation requests.
|
||||
# Queue now holds individual (context, chunk) pairs so interruption can
|
||||
# happen between chunks even when a single context produced many.
|
||||
self._generation_queue: deque[tuple[GenerationContext, Chunk]] = deque()
|
||||
self._queue_id: str | None = None
|
||||
self._queue_task: asyncio.Task | None = None
|
||||
self._queue_lock = asyncio.Lock()
|
||||
|
||||
# general helpers
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def voice_library(self) -> VoiceLibrary:
|
||||
return voice_library.get_instance()
|
||||
|
||||
# config helpers
|
||||
|
||||
@property
|
||||
def narrator_voice_id(self) -> str:
|
||||
return self.actions["_config"].config["narrator_voice_id"].value
|
||||
|
||||
@property
|
||||
def generate_for_player(self) -> bool:
|
||||
return self.actions["_config"].config["generate_for_player"].value
|
||||
|
||||
@property
|
||||
def generate_for_npc(self) -> bool:
|
||||
return self.actions["_config"].config["generate_for_npc"].value
|
||||
|
||||
@property
|
||||
def generate_for_narration(self) -> bool:
|
||||
return self.actions["_config"].config["generate_for_narration"].value
|
||||
|
||||
@property
|
||||
def generate_for_context_investigation(self) -> bool:
|
||||
return (
|
||||
self.actions["_config"].config["generate_for_context_investigation"].value
|
||||
)
|
||||
|
||||
@property
|
||||
def speaker_separation(self) -> str:
|
||||
return self.actions["_config"].config["speaker_separation"].value
|
||||
|
||||
@property
|
||||
def apis(self) -> list[str]:
|
||||
return self.actions["_config"].config["apis"].value
|
||||
|
||||
@property
|
||||
def all_apis(self) -> list[str]:
|
||||
return [api["value"] for api in self.actions["_config"].config["apis"].choices]
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
details = {}
|
||||
|
||||
self.actions["_config"].config[
|
||||
"narrator_voice_id"
|
||||
].choices = self.narrator_voice_id_choices(self)
|
||||
|
||||
if not self.enabled:
|
||||
return details
|
||||
|
||||
used_apis: set[str] = set()
|
||||
|
||||
used_disabled_apis: set[str] = set()
|
||||
|
||||
if self.narrator_voice:
|
||||
#
|
||||
|
||||
label = self.narrator_voice.label
|
||||
color = "primary"
|
||||
used_apis.add(self.narrator_voice.provider)
|
||||
|
||||
if not self.api_enabled(self.narrator_voice.provider):
|
||||
used_disabled_apis.add(self.narrator_voice.provider)
|
||||
|
||||
if not self.api_ready(self.narrator_voice.provider):
|
||||
color = "error"
|
||||
|
||||
details["narrator_voice"] = AgentDetail(
|
||||
icon="mdi-script-text",
|
||||
value=label,
|
||||
description="Default voice",
|
||||
color=color,
|
||||
).model_dump()
|
||||
|
||||
scene = getattr(self, "scene", None)
|
||||
if scene:
|
||||
for character in scene.characters:
|
||||
if character.voice:
|
||||
label = character.voice.label
|
||||
color = "primary"
|
||||
used_apis.add(character.voice.provider)
|
||||
if not self.api_enabled(character.voice.provider):
|
||||
used_disabled_apis.add(character.voice.provider)
|
||||
if not self.api_ready(character.voice.provider):
|
||||
color = "error"
|
||||
|
||||
details[f"{character.name}_voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=f"{character.name}",
|
||||
description=f"{character.name}'s voice: {label} ({character.voice.provider})",
|
||||
color=color,
|
||||
).model_dump()
|
||||
|
||||
for api in used_disabled_apis:
|
||||
details[f"{api}_disabled"] = AgentDetail(
|
||||
icon="mdi-alert-circle",
|
||||
value=f"{api} disabled",
|
||||
description=f"{api} disabled - at least one voice is attempting to use this api but is not enabled",
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
for api in used_apis:
|
||||
fn = getattr(self, f"{api}_agent_details", None)
|
||||
if fn:
|
||||
details.update(fn)
|
||||
return details
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "idle" if not getattr(self, "processing", False) else "busy"
|
||||
return "uninitialized"
|
||||
|
||||
@property
|
||||
def narrator_voice(self) -> Voice | None:
|
||||
return self.voice_library.get_voice(self.narrator_voice_id)
|
||||
|
||||
@property
|
||||
def api_status(self) -> list[APIStatus]:
|
||||
api_status: list[APIStatus] = []
|
||||
|
||||
for api in self.all_apis:
|
||||
not_configured_reason = getattr(self, f"{api}_not_configured_reason", None)
|
||||
not_configured_action = getattr(self, f"{api}_not_configured_action", None)
|
||||
api_info: str | None = getattr(self, f"{api}_info", None)
|
||||
messages: list[Note] = []
|
||||
if not_configured_reason:
|
||||
messages.append(
|
||||
Note(
|
||||
text=not_configured_reason,
|
||||
color="error",
|
||||
icon="mdi-alert-circle-outline",
|
||||
actions=[not_configured_action]
|
||||
if not_configured_action
|
||||
else None,
|
||||
)
|
||||
)
|
||||
if api_info:
|
||||
messages.append(
|
||||
Note(
|
||||
text=api_info.strip(),
|
||||
color="muted",
|
||||
icon="mdi-information-outline",
|
||||
)
|
||||
)
|
||||
_status = APIStatus(
|
||||
api=api,
|
||||
enabled=self.api_enabled(api),
|
||||
ready=self.api_ready(api),
|
||||
configured=self.api_configured(api),
|
||||
messages=messages,
|
||||
supports_mixing=getattr(self, f"{api}_supports_mixing", False),
|
||||
provider=provider(api),
|
||||
default_model=getattr(self, f"{api}_model", None),
|
||||
model_choices=getattr(self, f"{api}_model_choices", []),
|
||||
)
|
||||
api_status.append(_status)
|
||||
|
||||
# order by api
|
||||
api_status.sort(key=lambda x: x.api)
|
||||
|
||||
return api_status
|
||||
|
||||
# events
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
async_signals.get("game_loop_new_message").connect(
|
||||
self.on_game_loop_new_message
|
||||
)
|
||||
async_signals.get("voice_library.update.after").connect(
|
||||
self.on_voice_library_update
|
||||
)
|
||||
async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init)
|
||||
async_signals.get("character.voice_changed").connect(
|
||||
self.on_character_voice_changed
|
||||
)
|
||||
|
||||
async def on_scene_loop_init(self, event: "SceneLoopEvent"):
|
||||
if not self.enabled or not self.ready or not self.generate_for_narration:
|
||||
return
|
||||
|
||||
if self.scene.environment == "creative":
|
||||
return
|
||||
|
||||
content_messages = self.scene.last_message_of_type(
|
||||
["character", "narrator", "context_investigation"]
|
||||
)
|
||||
|
||||
if content_messages:
|
||||
# we already have a history, so we don't need to generate TTS for the intro
|
||||
return
|
||||
|
||||
await self.generate(self.scene.get_intro(), character=None)
|
||||
|
||||
async def on_voice_library_update(self, voice_library: VoiceLibrary):
|
||||
log.debug("Voice library updated - refreshing narrator voice choices")
|
||||
self.actions["_config"].config[
|
||||
"narrator_voice_id"
|
||||
].choices = self.narrator_voice_id_choices(self)
|
||||
await self.emit_status()
|
||||
|
||||
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
if self.scene.environment == "creative":
|
||||
return
|
||||
|
||||
character: Character | None = None
|
||||
|
||||
if not self.enabled or not self.ready:
|
||||
return
|
||||
|
||||
if not isinstance(
|
||||
emission.message,
|
||||
(CharacterMessage, NarratorMessage, ContextInvestigationMessage),
|
||||
):
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(emission.message, NarratorMessage)
|
||||
and not self.generate_for_narration
|
||||
):
|
||||
return
|
||||
|
||||
if (
|
||||
isinstance(emission.message, ContextInvestigationMessage)
|
||||
and not self.generate_for_context_investigation
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
if emission.message.source == "player" and not self.generate_for_player:
|
||||
return
|
||||
elif emission.message.source == "ai" and not self.generate_for_npc:
|
||||
return
|
||||
|
||||
character = self.scene.get_character(emission.message.character_name)
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
text_to_generate = str(emission.message).replace(
|
||||
character_prefix + ": ", ""
|
||||
)
|
||||
elif isinstance(emission.message, ContextInvestigationMessage):
|
||||
character_prefix = ""
|
||||
text_to_generate = (
|
||||
emission.message.message
|
||||
) # Use just the message content, not the title prefix
|
||||
else:
|
||||
character_prefix = ""
|
||||
text_to_generate = str(emission.message)
|
||||
|
||||
log.info(
|
||||
"reactive tts", message=emission.message, character_prefix=character_prefix
|
||||
)
|
||||
|
||||
await self.generate(
|
||||
text_to_generate,
|
||||
character=character,
|
||||
message=emission.message,
|
||||
)
|
||||
|
||||
async def on_character_voice_changed(self, event: "VoiceChangedEvent"):
|
||||
log.debug(
|
||||
"Character voice changed", character=event.character, voice=event.voice
|
||||
)
|
||||
await self.emit_status()
|
||||
|
||||
# voice helpers
|
||||
|
||||
@property
|
||||
def ready_apis(self) -> list[str]:
|
||||
"""
|
||||
Returns a list of apis that are ready
|
||||
"""
|
||||
return [api for api in self.apis if self.api_ready(api)]
|
||||
|
||||
@property
|
||||
def used_apis(self) -> list[str]:
|
||||
"""
|
||||
Returns a list of apis that are in use
|
||||
|
||||
The api is in use if it is the narrator voice or if any of the active characters in the scene use a voice from the api.
|
||||
"""
|
||||
return [api for api in self.apis if self.api_used(api)]
|
||||
|
||||
def api_enabled(self, api: str) -> bool:
|
||||
"""
|
||||
Returns whether the api is currently in the .apis list, which means it is enabled.
|
||||
"""
|
||||
return api in self.apis
|
||||
|
||||
def api_ready(self, api: str) -> bool:
|
||||
"""
|
||||
Returns whether the api is ready.
|
||||
|
||||
The api must be enabled and configured.
|
||||
"""
|
||||
|
||||
if not self.api_enabled(api):
|
||||
return False
|
||||
|
||||
return self.api_configured(api)
|
||||
|
||||
def api_configured(self, api: str) -> bool:
|
||||
return getattr(self, f"{api}_configured", True)
|
||||
|
||||
def api_used(self, api: str) -> bool:
|
||||
"""
|
||||
Returns whether the narrator or any of the active characters in the scene
|
||||
use a voice from the given api
|
||||
|
||||
Args:
|
||||
api (str): The api to check
|
||||
|
||||
Returns:
|
||||
bool: Whether the api is in use
|
||||
"""
|
||||
|
||||
if self.narrator_voice and self.narrator_voice.provider == api:
|
||||
return True
|
||||
|
||||
if not getattr(self, "scene", None):
|
||||
return False
|
||||
|
||||
for character in self.scene.characters:
|
||||
if not character.voice:
|
||||
continue
|
||||
voice = self.voice_library.get_voice(character.voice.id)
|
||||
if voice and voice.provider == api:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def use_ai_assisted_speaker_separation(
|
||||
self,
|
||||
text: str,
|
||||
message: CharacterMessage
|
||||
| NarratorMessage
|
||||
| ContextInvestigationMessage
|
||||
| None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether the ai assisted speaker separation should be used for the given text.
|
||||
"""
|
||||
try:
|
||||
if not message and '"' not in text:
|
||||
return False
|
||||
|
||||
if not message and '"' in text:
|
||||
return self.speaker_separation in ["ai_assisted", "mixed"]
|
||||
|
||||
if message.source == "player":
|
||||
return False
|
||||
|
||||
if self.speaker_separation == "ai_assisted":
|
||||
return True
|
||||
|
||||
if (
|
||||
isinstance(message, NarratorMessage)
|
||||
and self.speaker_separation == "mixed"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Error using ai assisted speaker separation",
|
||||
error=e,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
return False
|
||||
|
||||
# tts markup cache
|
||||
|
||||
async def get_tts_markup_cache(self, text: str) -> str | None:
|
||||
"""
|
||||
Returns the cached tts markup for the given text.
|
||||
"""
|
||||
fp = hash(text)
|
||||
cached_markup = self.get_scene_state("tts_markup_cache")
|
||||
if cached_markup and cached_markup.get("fp") == fp:
|
||||
return cached_markup.get("markup")
|
||||
return None
|
||||
|
||||
async def set_tts_markup_cache(self, text: str, markup: str):
|
||||
fp = hash(text)
|
||||
self.set_scene_states(
|
||||
tts_markup_cache={
|
||||
"fp": fp,
|
||||
"markup": markup,
|
||||
}
|
||||
)
|
||||
|
||||
# generation
|
||||
|
||||
@set_processing
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
character: Character | None = None,
|
||||
force_voice: Voice | None = None,
|
||||
message: CharacterMessage | NarratorMessage | None = None,
|
||||
):
|
||||
"""
|
||||
Public entry-point for voice generation.
|
||||
|
||||
The actual audio generation happens sequentially inside a single
|
||||
background queue. If a queue is currently active, we simply append the
|
||||
new request to it; if not, we create a new queue (with its own unique
|
||||
id) and start processing.
|
||||
"""
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
summarizer: "SummarizeAgent" = instance.get_agent("summarizer")
|
||||
|
||||
context = GenerationContext(voice_id=self.narrator_voice_id)
|
||||
character_voice: Voice = force_voice or self.narrator_voice
|
||||
|
||||
if character and character.voice:
|
||||
voice = character.voice
|
||||
if voice and self.api_ready(voice.provider):
|
||||
character_voice = voice
|
||||
else:
|
||||
log.warning(
|
||||
"Character voice not available",
|
||||
character=character.name,
|
||||
voice=character.voice,
|
||||
)
|
||||
|
||||
log.debug("Voice routing", character=character, voice=character_voice)
|
||||
|
||||
# initial chunking by separating dialogue from exposition
|
||||
chunks: list[Chunk] = []
|
||||
if self.speaker_separation != "none":
|
||||
if self.use_ai_assisted_speaker_separation(text, message):
|
||||
markup = await self.get_tts_markup_cache(text)
|
||||
if not markup:
|
||||
log.debug("No markup cache found, generating markup")
|
||||
markup = await summarizer.markup_context_for_tts(text)
|
||||
await self.set_tts_markup_cache(text, markup)
|
||||
else:
|
||||
log.debug("Using markup cache")
|
||||
# Use the new markup parser for AI-assisted format
|
||||
dlg_chunks = dialogue_utils.parse_tts_markup(markup)
|
||||
else:
|
||||
# Use the original parser for non-AI-assisted format
|
||||
dlg_chunks = dialogue_utils.separate_dialogue_from_exposition(text)
|
||||
|
||||
for _dlg_chunk in dlg_chunks:
|
||||
_voice = (
|
||||
character_voice
|
||||
if _dlg_chunk.type == "dialogue"
|
||||
else self.narrator_voice
|
||||
)
|
||||
|
||||
if _dlg_chunk.speaker is not None:
|
||||
# speaker name has been identified
|
||||
_character = self.scene.get_character(_dlg_chunk.speaker)
|
||||
log.debug(
|
||||
"Identified speaker",
|
||||
speaker=_dlg_chunk.speaker,
|
||||
character=_character,
|
||||
)
|
||||
if (
|
||||
_character
|
||||
and _character.voice
|
||||
and self.api_ready(_character.voice.provider)
|
||||
):
|
||||
log.debug(
|
||||
"Using character voice",
|
||||
character=_character.name,
|
||||
voice=_character.voice,
|
||||
)
|
||||
_voice = _character.voice
|
||||
|
||||
_api: str = _voice.provider if _voice else self.api
|
||||
chunk = Chunk(
|
||||
api=_api,
|
||||
voice=Voice(**_voice.model_dump()),
|
||||
model=_voice.provider_model,
|
||||
generate_fn=getattr(self, f"{_api}_generate"),
|
||||
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
|
||||
character_name=character.name if character else None,
|
||||
text=[_dlg_chunk.text],
|
||||
type=_dlg_chunk.type,
|
||||
message_id=message.id if message else None,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
_voice = character_voice if character else self.narrator_voice
|
||||
_api: str = _voice.provider if _voice else self.api
|
||||
chunks = [
|
||||
Chunk(
|
||||
api=_api,
|
||||
voice=Voice(**_voice.model_dump()),
|
||||
model=_voice.provider_model,
|
||||
generate_fn=getattr(self, f"{_api}_generate"),
|
||||
prepare_fn=getattr(self, f"{_api}_prepare_chunk", None),
|
||||
character_name=character.name if character else None,
|
||||
text=[text],
|
||||
type="dialogue" if character else "exposition",
|
||||
message_id=message.id if message else None,
|
||||
)
|
||||
]
|
||||
|
||||
# second chunking by splitting into chunks of max_generation_length
|
||||
|
||||
for chunk in chunks:
|
||||
api_chunk_size = getattr(self, f"{chunk.api}_chunk_size", 0)
|
||||
|
||||
log.debug("chunking", api=chunk.api, api_chunk_size=api_chunk_size)
|
||||
|
||||
_text = []
|
||||
|
||||
max_generation_length = getattr(self, f"{chunk.api}_max_generation_length")
|
||||
|
||||
if api_chunk_size > 0:
|
||||
max_generation_length = min(max_generation_length, api_chunk_size)
|
||||
|
||||
for _chunk_text in chunk.text:
|
||||
if len(_chunk_text) <= max_generation_length:
|
||||
_text.append(_chunk_text)
|
||||
continue
|
||||
|
||||
_parsed = parse_chunks(_chunk_text)
|
||||
_joined = rejoin_chunks(_parsed, chunk_size=max_generation_length)
|
||||
_text.extend(_joined)
|
||||
|
||||
log.debug("chunked for size", before=chunk.text, after=_text)
|
||||
|
||||
chunk.text = _text
|
||||
|
||||
context.chunks = chunks
|
||||
|
||||
# Enqueue each chunk individually for fine-grained interruptibility
|
||||
async with self._queue_lock:
|
||||
if self._queue_id is None:
|
||||
self._queue_id = str(uuid.uuid4())
|
||||
|
||||
for chunk in context.chunks:
|
||||
self._generation_queue.append((context, chunk))
|
||||
|
||||
# Start processing task if needed
|
||||
if self._queue_task is None or self._queue_task.done():
|
||||
self._queue_task = asyncio.create_task(
|
||||
self._process_queue(self._queue_id)
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"tts queue enqueue",
|
||||
queue_id=self._queue_id,
|
||||
total_items=len(self._generation_queue),
|
||||
)
|
||||
|
||||
# The caller doesn't need to wait for the queue to finish; it runs in
|
||||
# the background. We still register the task with Talemate's
|
||||
# background-processing tracking so that UI can reflect activity.
|
||||
await self.set_background_processing(self._queue_task)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Queue helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def _process_queue(self, queue_id: str):
|
||||
"""Sequentially processes all GenerationContext objects in the queue.
|
||||
|
||||
Once the last context has been processed the queue state is reset so a
|
||||
future generation call will create a new queue (and therefore a new
|
||||
id). The *queue_id* argument allows us to later add cancellation logic
|
||||
that can target a specific queue instance.
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
async with self._queue_lock:
|
||||
if not self._generation_queue:
|
||||
break
|
||||
|
||||
context, chunk = self._generation_queue.popleft()
|
||||
|
||||
log.debug(
|
||||
"tts queue dequeue",
|
||||
queue_id=queue_id,
|
||||
total_items=len(self._generation_queue),
|
||||
chunk_type=chunk.type,
|
||||
)
|
||||
|
||||
# Process outside lock so other coroutines can enqueue
|
||||
await self._generate_chunk(chunk, context)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Error processing queue", error=e, traceback=traceback.format_exc()
|
||||
)
|
||||
finally:
|
||||
# Clean up queue state after finishing (or on cancellation)
|
||||
async with self._queue_lock:
|
||||
if queue_id == self._queue_id:
|
||||
self._queue_id = None
|
||||
self._queue_task = None
|
||||
self._generation_queue.clear()
|
||||
|
||||
# Public helper so external code (e.g. later cancellation UI) can find the current queue id
|
||||
def current_queue_id(self) -> str | None:
|
||||
return self._queue_id
|
||||
|
||||
async def _generate_chunk(self, chunk: Chunk, context: GenerationContext):
|
||||
"""Generate audio for a single chunk (all its sub-chunks)."""
|
||||
|
||||
for _chunk in chunk.sub_chunks:
|
||||
if not _chunk.cleaned_text.strip():
|
||||
continue
|
||||
|
||||
emission: VoiceGenerationEmission = VoiceGenerationEmission(
|
||||
chunk=_chunk, context=context
|
||||
)
|
||||
|
||||
if _chunk.prepare_fn:
|
||||
await async_signals.get("agent.tts.prepare.before").send(emission)
|
||||
await _chunk.prepare_fn(_chunk)
|
||||
await async_signals.get("agent.tts.prepare.after").send(emission)
|
||||
|
||||
log.info(
|
||||
"Generating audio",
|
||||
api=chunk.api,
|
||||
text=_chunk.cleaned_text,
|
||||
parameters=_chunk.voice.parameters,
|
||||
prepare_fn=_chunk.prepare_fn,
|
||||
)
|
||||
|
||||
await async_signals.get("agent.tts.generate.before").send(emission)
|
||||
try:
|
||||
emission.wav_bytes = await _chunk.generate_fn(_chunk, context)
|
||||
except Exception as e:
|
||||
log.error("Error generating audio", error=e, chunk=_chunk)
|
||||
continue
|
||||
await async_signals.get("agent.tts.generate.after").send(emission)
|
||||
self.play_audio(emission.wav_bytes, chunk.message_id)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Deprecated: kept for backward compatibility but no longer used.
|
||||
async def generate_chunks(self, context: GenerationContext):
|
||||
for chunk in context.chunks:
|
||||
await self._generate_chunk(chunk, context)
|
||||
|
||||
def play_audio(self, audio_data, message_id: int | None = None):
|
||||
# play audio through the websocket (browser)
|
||||
|
||||
audio_data_encoded: str = base64.b64encode(audio_data).decode("utf-8")
|
||||
|
||||
emit(
|
||||
"audio_queue",
|
||||
data={"audio_data": audio_data_encoded, "message_id": message_id},
|
||||
)
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
async def stop_and_clear_queue(self):
|
||||
"""Cancel any ongoing generation and clear the pending queue.
|
||||
|
||||
This is triggered by UI actions that request immediate stop of TTS
|
||||
synthesis and playback. It cancels the background task (if still
|
||||
running) and clears all queued items in a thread-safe manner.
|
||||
"""
|
||||
async with self._queue_lock:
|
||||
# Clear all queued items
|
||||
self._generation_queue.clear()
|
||||
|
||||
# Cancel the background task if it is still running
|
||||
if self._queue_task and not self._queue_task.done():
|
||||
self._queue_task.cancel()
|
||||
|
||||
# Reset queue identifiers/state
|
||||
self._queue_id = None
|
||||
self._queue_task = None
|
||||
|
||||
# Ensure downstream components know playback is finished
|
||||
self.playback_done_event.set()
|
||||
317
src/talemate/agents/tts/chatterbox.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import os
|
||||
import functools
|
||||
import tempfile
|
||||
import uuid
|
||||
import asyncio
|
||||
import structlog
|
||||
import pydantic
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Lazy imports for heavy dependencies
|
||||
def _import_heavy_deps():
|
||||
global ta, ChatterboxTTS
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
from talemate.ux.schema import Field
|
||||
|
||||
from .schema import Voice, Chunk, GenerationContext, VoiceProvider, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
from .providers import register, provider
|
||||
from .util import voice_is_talemate_asset
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.chatterbox")
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Eva",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/eva.wav",
|
||||
tags=["female", "calm", "mature", "thoughtful"],
|
||||
),
|
||||
Voice(
|
||||
label="Lisa",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/lisa.wav",
|
||||
tags=["female", "energetic", "young"],
|
||||
),
|
||||
Voice(
|
||||
label="Adam",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/adam.wav",
|
||||
tags=["male", "calm", "mature", "thoughtful", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Bradford",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/bradford.wav",
|
||||
tags=["male", "calm", "mature", "thoughtful", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Julia",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/julia.wav",
|
||||
tags=["female", "calm", "mature"],
|
||||
),
|
||||
Voice(
|
||||
label="Zoe",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/zoe.wav",
|
||||
tags=["female"],
|
||||
),
|
||||
Voice(
|
||||
label="William",
|
||||
provider="chatterbox",
|
||||
provider_id="tts/voice/chatterbox/william.wav",
|
||||
tags=["male", "young"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
CHATTERBOX_INFO = """
|
||||
Chatterbox is a local text to speech model.
|
||||
|
||||
The voice id is the path to the .wav file for the voice.
|
||||
|
||||
The path can be relative to the talemate root directory, and you can put new *.wav samples
|
||||
in the `tts/voice/chatterbox` directory. It is also ok if you want to load the files from somewhere else as long as the filepath is available to the talemate backend.
|
||||
|
||||
First generation will download the models (2.13GB + 1.06GB).
|
||||
|
||||
Uses about 4GB of VRAM.
|
||||
"""
|
||||
|
||||
|
||||
@register()
|
||||
class ChatterboxProvider(VoiceProvider):
|
||||
name: str = "chatterbox"
|
||||
allow_model_override: bool = False
|
||||
allow_file_upload: bool = True
|
||||
upload_file_types: list[str] = ["audio/wav"]
|
||||
voice_parameters: list[Field] = [
|
||||
Field(
|
||||
name="exaggeration",
|
||||
type="number",
|
||||
label="Exaggeration level",
|
||||
value=0.5,
|
||||
min=0.25,
|
||||
max=2.0,
|
||||
step=0.05,
|
||||
),
|
||||
Field(
|
||||
name="cfg_weight",
|
||||
type="number",
|
||||
label="CFG/Pace",
|
||||
value=0.5,
|
||||
min=0.2,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
),
|
||||
Field(
|
||||
name="temperature",
|
||||
type="number",
|
||||
label="Temperature",
|
||||
value=0.8,
|
||||
min=0.05,
|
||||
max=5.0,
|
||||
step=0.05,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxInstance(pydantic.BaseModel):
|
||||
model: "ChatterboxTTS"
|
||||
device: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ChatterboxMixin:
|
||||
"""
|
||||
Chatterbox agent mixin for local text to speech.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "chatterbox",
|
||||
"label": "Chatterbox (Local)",
|
||||
"help": "Chatterbox is a local text to speech model.",
|
||||
}
|
||||
)
|
||||
|
||||
actions["chatterbox"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="Chatterbox",
|
||||
description="Chatterbox is a local text to speech model.",
|
||||
config={
|
||||
"device": AgentActionConfig(
|
||||
type="text",
|
||||
value="cuda" if CUDA_AVAILABLE else "cpu",
|
||||
label="Device",
|
||||
choices=[
|
||||
{"value": "cpu", "label": "CPU"},
|
||||
{"value": "cuda", "label": "CUDA"},
|
||||
],
|
||||
description="Device to use for TTS",
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=256,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
return actions
|
||||
|
||||
@property
|
||||
def chatterbox_configured(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def chatterbox_max_generation_length(self) -> int:
|
||||
return 512
|
||||
|
||||
@property
|
||||
def chatterbox_device(self) -> str:
|
||||
return self.actions["chatterbox"].config["device"].value
|
||||
|
||||
@property
|
||||
def chatterbox_chunk_size(self) -> int:
|
||||
return self.actions["chatterbox"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def chatterbox_info(self) -> str:
|
||||
return CHATTERBOX_INFO
|
||||
|
||||
@property
|
||||
def chatterbox_agent_details(self) -> dict:
|
||||
if not self.chatterbox_configured:
|
||||
return {}
|
||||
details = {}
|
||||
|
||||
details["chatterbox_device"] = AgentDetail(
|
||||
icon="mdi-memory",
|
||||
value=f"Chatterbox: {self.chatterbox_device}",
|
||||
description="The device to use for Chatterbox",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
def chatterbox_delete_voice(self, voice: Voice):
|
||||
"""
|
||||
Remove the voice from the file system.
|
||||
Only do this if the path is within TALEMATE_ROOT.
|
||||
"""
|
||||
|
||||
is_talemate_asset, resolved = voice_is_talemate_asset(
|
||||
voice, provider(voice.provider)
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"chatterbox_delete_voice",
|
||||
voice_id=voice.provider_id,
|
||||
is_talemate_asset=is_talemate_asset,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
if not is_talemate_asset:
|
||||
return
|
||||
|
||||
try:
|
||||
if resolved.exists() and resolved.is_file():
|
||||
resolved.unlink()
|
||||
log.debug("Deleted chatterbox voice file", path=str(resolved))
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Failed to delete chatterbox voice file", error=e, path=str(resolved)
|
||||
)
|
||||
|
||||
def _chatterbox_generate_file(
|
||||
self,
|
||||
model: "ChatterboxTTS",
|
||||
text: str,
|
||||
audio_prompt_path: str,
|
||||
output_path: str,
|
||||
**kwargs,
|
||||
):
|
||||
wav = model.generate(text=text, audio_prompt_path=audio_prompt_path, **kwargs)
|
||||
ta.save(output_path, wav, model.sr)
|
||||
return output_path
|
||||
|
||||
async def chatterbox_generate(
|
||||
self, chunk: Chunk, context: GenerationContext
|
||||
) -> bytes | None:
|
||||
chatterbox_instance: ChatterboxInstance | None = getattr(
|
||||
self, "chatterbox_instance", None
|
||||
)
|
||||
|
||||
reload: bool = False
|
||||
|
||||
if not chatterbox_instance:
|
||||
reload = True
|
||||
elif chatterbox_instance.device != self.chatterbox_device:
|
||||
reload = True
|
||||
|
||||
if reload:
|
||||
log.debug(
|
||||
"chatterbox - reinitializing tts instance",
|
||||
device=self.chatterbox_device,
|
||||
)
|
||||
# Lazy import heavy dependencies only when needed
|
||||
_import_heavy_deps()
|
||||
|
||||
self.chatterbox_instance = ChatterboxInstance(
|
||||
model=ChatterboxTTS.from_pretrained(device=self.chatterbox_device),
|
||||
device=self.chatterbox_device,
|
||||
)
|
||||
|
||||
model: "ChatterboxTTS" = self.chatterbox_instance.model
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
voice = chunk.voice
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._chatterbox_generate_file,
|
||||
model=model,
|
||||
text=chunk.cleaned_text,
|
||||
audio_prompt_path=voice.provider_id,
|
||||
output_path=file_path,
|
||||
**voice.parameters,
|
||||
),
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
async def chatterbox_prepare_chunk(self, chunk: Chunk):
|
||||
voice = chunk.voice
|
||||
P = provider(voice.provider)
|
||||
exaggeration = P.voice_parameter(voice, "exaggeration")
|
||||
|
||||
voice.parameters["exaggeration"] = exaggeration
|
||||
248
src/talemate/agents/tts/elevenlabs.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import io
|
||||
from typing import Union
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
# Lazy imports for heavy dependencies
|
||||
def _import_heavy_deps():
|
||||
global AsyncElevenLabs, ApiError
|
||||
from elevenlabs.client import AsyncElevenLabs
|
||||
|
||||
# Added explicit ApiError import for clearer error handling
|
||||
from elevenlabs.core.api_error import ApiError
|
||||
|
||||
|
||||
from talemate.ux.schema import Action
|
||||
|
||||
from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
)
|
||||
from .schema import Voice, VoiceLibrary, GenerationContext, Chunk, INFO_CHUNK_SIZE
|
||||
from .voice_library import add_default_voices
|
||||
|
||||
# emit helper to propagate status messages to the UX
|
||||
from talemate.emit import emit
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts.elevenlabs")
|
||||
|
||||
|
||||
add_default_voices(
|
||||
[
|
||||
Voice(
|
||||
label="Adam",
|
||||
provider="elevenlabs",
|
||||
provider_id="wBXNqKUATyqu0RtYt25i",
|
||||
tags=["male", "deep"],
|
||||
),
|
||||
Voice(
|
||||
label="Amy",
|
||||
provider="elevenlabs",
|
||||
provider_id="oGn4Ha2pe2vSJkmIJgLQ",
|
||||
tags=["female"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ELEVENLABS_INFO = """
|
||||
ElevenLabs is a cloud-based text to speech API.
|
||||
|
||||
To add new voices, head to their voice library at [https://elevenlabs.io/app/voice-library](https://elevenlabs.io/app/voice-library) and note the voice id of the voice you want to use. (Click 'More Actions -> Copy Voice ID')
|
||||
|
||||
**About elevenlabs voices**
|
||||
Your elevenlabs subscription allows you to maintain a set number of voices (10 for cheapest plan).
|
||||
|
||||
Any voice that you generate audio for is automatically added to your voices at [https://elevenlabs.io/app/voice-lab](https://elevenlabs.io/app/voice-lab). This also happens when you use the "Test" button above. It is recommend testing via their voice library instead.
|
||||
"""
|
||||
|
||||
|
||||
class ElevenLabsMixin:
|
||||
"""
|
||||
ElevenLabs TTS agent mixin for cloud-based text to speech.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_config"].config["apis"].choices.append(
|
||||
{
|
||||
"value": "elevenlabs",
|
||||
"label": "ElevenLabs",
|
||||
"help": "ElevenLabs is a cloud-based text to speech model that uses the ElevenLabs API. (API key required)",
|
||||
}
|
||||
)
|
||||
|
||||
actions["elevenlabs"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
icon="mdi-server-outline",
|
||||
label="ElevenLabs",
|
||||
description="ElevenLabs is a cloud-based text to speech API. (API key required and must be set in the Talemate Settings -> Application -> ElevenLabs)",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="eleven_flash_v2_5",
|
||||
label="Model",
|
||||
description="Model to use for TTS",
|
||||
choices=[
|
||||
{
|
||||
"value": "eleven_multilingual_v2",
|
||||
"label": "Eleven Multilingual V2",
|
||||
},
|
||||
{"value": "eleven_flash_v2_5", "label": "Eleven Flash V2.5"},
|
||||
{"value": "eleven_turbo_v2_5", "label": "Eleven Turbo V2.5"},
|
||||
],
|
||||
),
|
||||
"chunk_size": AgentActionConfig(
|
||||
type="number",
|
||||
min=0,
|
||||
step=64,
|
||||
max=2048,
|
||||
value=0,
|
||||
label="Chunk size",
|
||||
note=INFO_CHUNK_SIZE,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
@classmethod
|
||||
def add_voices(cls, voices: dict[str, VoiceLibrary]):
|
||||
voices["elevenlabs"] = VoiceLibrary(api="elevenlabs", local=True)
|
||||
|
||||
@property
|
||||
def elevenlabs_chunk_size(self) -> int:
|
||||
return self.actions["elevenlabs"].config["chunk_size"].value
|
||||
|
||||
@property
|
||||
def elevenlabs_configured(self) -> bool:
|
||||
api_key_set = bool(self.elevenlabs_api_key)
|
||||
model_set = bool(self.elevenlabs_model)
|
||||
return api_key_set and model_set
|
||||
|
||||
@property
|
||||
def elevenlabs_not_configured_reason(self) -> str | None:
|
||||
if not self.elevenlabs_api_key:
|
||||
return "ElevenLabs API key not set"
|
||||
if not self.elevenlabs_model:
|
||||
return "ElevenLabs model not set"
|
||||
return None
|
||||
|
||||
@property
|
||||
def elevenlabs_not_configured_action(self) -> Action | None:
|
||||
if not self.elevenlabs_api_key:
|
||||
return Action(
|
||||
action_name="openAppConfig",
|
||||
arguments=["application", "elevenlabs_api"],
|
||||
label="Set API Key",
|
||||
icon="mdi-key",
|
||||
)
|
||||
if not self.elevenlabs_model:
|
||||
return Action(
|
||||
action_name="openAgentSettings",
|
||||
arguments=["tts", "elevenlabs"],
|
||||
label="Set Model",
|
||||
icon="mdi-brain",
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def elevenlabs_max_generation_length(self) -> int:
|
||||
return 1024
|
||||
|
||||
@property
|
||||
def elevenlabs_model(self) -> str:
|
||||
return self.actions["elevenlabs"].config["model"].value
|
||||
|
||||
@property
|
||||
def elevenlabs_model_choices(self) -> list[str]:
|
||||
return [
|
||||
{"label": choice["label"], "value": choice["value"]}
|
||||
for choice in self.actions["elevenlabs"].config["model"].choices
|
||||
]
|
||||
|
||||
@property
|
||||
def elevenlabs_info(self) -> str:
|
||||
return ELEVENLABS_INFO
|
||||
|
||||
@property
|
||||
def elevenlabs_agent_details(self) -> dict:
|
||||
details = {}
|
||||
|
||||
if not self.elevenlabs_configured:
|
||||
details["elevenlabs_api_key"] = AgentDetail(
|
||||
icon="mdi-key",
|
||||
value="ElevenLabs API key not set",
|
||||
description="ElevenLabs API key not set. You can set it in the Talemate Settings -> Application -> ElevenLabs",
|
||||
color="error",
|
||||
).model_dump()
|
||||
else:
|
||||
details["elevenlabs_model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.elevenlabs_model,
|
||||
description="The model to use for ElevenLabs",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def elevenlabs_api_key(self) -> str:
|
||||
return self.config.elevenlabs.api_key
|
||||
|
||||
async def elevenlabs_generate(
|
||||
self, chunk: Chunk, context: GenerationContext, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
api_key = self.elevenlabs_api_key
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
# Lazy import heavy dependencies only when needed
|
||||
_import_heavy_deps()
|
||||
|
||||
client = AsyncElevenLabs(api_key=api_key)
|
||||
|
||||
try:
|
||||
response_async_iter = client.text_to_speech.convert(
|
||||
text=chunk.cleaned_text,
|
||||
voice_id=chunk.voice.provider_id,
|
||||
model_id=chunk.model or self.elevenlabs_model,
|
||||
)
|
||||
|
||||
bytes_io = io.BytesIO()
|
||||
|
||||
async for _chunk_bytes in response_async_iter:
|
||||
if _chunk_bytes:
|
||||
bytes_io.write(_chunk_bytes)
|
||||
|
||||
return bytes_io.getvalue()
|
||||
|
||||
except ApiError as e:
|
||||
# Emit detailed status message to the frontend UI
|
||||
error_message = "ElevenLabs API Error"
|
||||
try:
|
||||
# The ElevenLabs ApiError often contains a JSON body with details
|
||||
detail = e.body.get("detail", {}) if hasattr(e, "body") else {}
|
||||
error_message = detail.get("message", str(e)) or str(e)
|
||||
except Exception:
|
||||
error_message = str(e)
|
||||
|
||||
log.error("ElevenLabs API error", error=str(e))
|
||||
emit(
|
||||
"status",
|
||||
message=f"ElevenLabs TTS: {error_message}",
|
||||
status="error",
|
||||
)
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all to ensure the app does not crash on unexpected errors
|
||||
log.error("ElevenLabs TTS generation error", error=str(e))
|
||||
emit(
|
||||
"status",
|
||||
message=f"ElevenLabs TTS: Unexpected error – {str(e)}",
|
||||
status="error",
|
||||
)
|
||||
raise e
|
||||