* some prompt cleanup

* prompt tweaks

* prompt tweaks

* prompt tweaks

* set 0.31.0

* relock

* rag queries add brief analysis

* brief analysis before building rag questions

* rag improvements

* prompt tweaks

* address circular import issues

* set 0.30.1

* docs

* numpy to 2

* docs

* prompt tweaks

* prompt tweak

* some template cleanup

* prompt viewer increase height

* fix codemirror highlighting not working

* adjust details height

* update default

* change to log debug

* allow response to scale to max height

* template cleanup

* prompt tweaks

* first progress for installing modules to scene

* package install logic

* package install polish

* package install polish

* package install polish and fixes

* refactor initial world state update and expose setting for it

* add `num` property to ModuleProperty to control order of widgets

* dynamic storyline package info

* fix issue where deactivating player character would cause inconsistencies in the creative tools menui

* cruft

* add openrouter support

* ollama support

* refactor how model choices are loaded, so that can be done per client instance as opposed to just per client type

* set num_ctx

* remove debug messages

* ollama tweaks

* toggle for whether or not default character gets added to blank talemate scenes

* narrator prompt tweaks and template cleanup

* cleanup

* prompt tweaks and template cleanup

* prompt tweaks

* fix instructor embeddings

* add additional error handling to prevent broken world state templates from breaking the world editor side menu

* fix openrouter breaking startup if not configured

* remove debug message

* promp tweaks

* fix example dialogue generation no longer working

* prompt tweaks and better showing of dialogue examples in conversation instructions

* prompt tweak

* add initial startup message

* prompt tweaks

* fix asset error

* move complex acting instructions into the task block

* fix content input socket on DynamicInstructions node

* log.error with traceback over log.exception since that has a tendency to explode and hang everything

* fix usse with number nodes where they would try to execute even if the incoming wire wasnt active

* fix issue with editor revision events missing template_vars

* DynamicInstruction node should only run if both header and content can resolve

* removed remaining references to 90s adventure game writing style

* prompt tweaks

* support embeddings via client apis (koboldcpp)

* fix label on client-api embeddings

* fix issue where adding / removing an embedding preset would not be reflected immediately in the memory agent config

* remove debug output

* prompt tweaks

* prompt tweaks

* autocomplete passes message object

* validate group names to be filename valid

* embedded winsows env installs and up to poetry2

* version config

* get-pip

* relock

* pin runpod

* no longer needed

* remove rapidfuzz dependency

* nodejs directly into embedded_node without a versioned middleman dir - also remove defunct local-tts install script

* fix update script

* update script error handling

* update.bat error handling

* adjust wording

* support loading jinja2 templates node modules in templates/modules

* update google model list

* client t/s and business indicator - also switch all clients to async streaming

* formatting

* support coercion for anthropic / google
switch to the new google genai sdk
upgrade websockets

* more coercion fixes

* gracefully handle keyboard interrupt

* EmitSystemMessage node

* allow visual prompt generation without image generation

* relock

* chromadb to v1

* fix error handling

* fix issue where adding client model list would be empty

* supress pip install warnings

* allow overriding of base urls

* remove key from log

* add fade effect

* tweak request info ux

* further clarification of endpoint override api key

* world state manager: fix issue that caused max changes setting to disappear from character progress config

* fix issue with google safety settings off causing generation failures

* update to base url should always reset the client

* getattr

* support v3 chara card version and attempt to future proof

* client based embeddings improvements

* more fixes for client based embeddings

* use client icon

* history management tools progress

* history memory ids fixed and added validation

* regenerate summary fixes

* more history regeneration fixes

* fix layered history gen and prompt twweaks

* allow regeneration of individual layered history entries

* return list of LayeredArchiveEntry

* reorg for less code dupelication

* new scene message renderer based on marked

* add inspect functionality to history viewer

* message if no history entries yet

* allow adding of history entries manually

* allow deletion of history

* summarization unslop improvements

* fix make charcter real action from worldstate listing

* allow overriding length in all context generation isntructioon dialogs

* fix issue where extract_list could fail with an unhandled error if the llm response didnt contain a list

* update whats'new

* fix issues with the new history management tools

* fix check

* Switch dependency handling to UV (#202)

* Migrate from Poetry to uv package manager            (#200)

* migrate from poetry to uv package manager

* Update all installation and startup scripts for uv migration

* Fix pyproject.toml for uv - allow direct references for hatchling

* Fix PR feedback: Restore removed functionality

- Restored embedded Python/Node.js functionality in install.bat and update.bat
- Restored environment variable exposure in docker-compose.yml (CUDA_AVAILABLE, port configs)
- Fixed GitHub Actions branches (main, prep-* instead of main, dev)
- Restored fail-fast: false and cache configuration in test.yml

These changes preserve all the functionality that should not be removed
during the migration from Poetry to uv.

---------

Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>

* remove uv.lock from .gitignore

* add lock file

* fix install issues

* warn if unable to remove legacy poetry virt env dir

* uv needs to be explicitly installed into the .venv so its available

* third time's the charm?

* fix windows install scripts

* add .venv guard to update.bat

* call :die

* fix docker venv install

* node 21

* fix cuda install

* start.bat calls install if needed

* sync start-local to other startup scripts

* no need to activate venv

---------

Co-authored-by: Ztripez <reg@otherland.nu>
Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>

* ignore hfhub symlink warnings

* add openrouter and ollama mentions

* update windows install documentation

* docs

* docs

* fix issue with memory agent fingerprint

* removing a client that supports embeddings will also remove any embedding functions it created

* on invalid embeddings reset to default

* docs

* typo

* formatting

* docs

* docs

* install package

* adjust topic

* add more obvious way to exit creative mode

* when importing character cards immediately persist a usable save after the restoration save

---------

Co-authored-by: Ztripez <reg@otherland.nu>
Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>
This commit is contained in:
veguAI
2025-06-29 18:06:11 +03:00
committed by GitHub
parent e4d465ba42
commit 9eb4c48d79
222 changed files with 43178 additions and 9603 deletions

View File

@@ -2,9 +2,9 @@ name: Python Tests
on:
push:
branches: [ master, main, 'prep-*' ]
branches: [ main, 'prep-*' ]
pull_request:
branches: [ master, main, 'prep-*' ]
branches: [ main, 'prep-*' ]
jobs:
test:
@@ -23,25 +23,24 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install poetry
- name: Install uv
run: |
python -m pip install --upgrade pip
pip install poetry
pip install uv
- name: Cache poetry dependencies
- name: Cache uv dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry
key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
path: ~/.cache/uv
key: ${{ runner.os }}-uv-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-poetry-${{ matrix.python-version }}-
${{ runner.os }}-uv-${{ matrix.python-version }}-
- name: Install dependencies
run: |
python -m venv talemate_env
source talemate_env/bin/activate
poetry config virtualenvs.create false
poetry install
uv venv
source .venv/bin/activate
uv pip install -e ".[dev]"
- name: Setup configuration file
run: |
@@ -49,10 +48,10 @@ jobs:
- name: Download NLTK data
run: |
source talemate_env/bin/activate
source .venv/bin/activate
python -c "import nltk; nltk.download('punkt_tab')"
- name: Run tests
run: |
source talemate_env/bin/activate
source .venv/bin/activate
pytest tests/ -p no:warnings

3
.gitignore vendored
View File

@@ -8,6 +8,9 @@
talemate_env
chroma
config.yaml
# uv
.venv/
templates/llm-prompt/user/*.jinja2
templates/world-state/*.yaml
scenes/

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

View File

@@ -1,15 +1,19 @@
# Stage 1: Frontend build
FROM node:21 AS frontend-build
ENV NODE_ENV=development
FROM node:21-slim AS frontend-build
WORKDIR /app
# Copy the frontend directory contents into the container at /app
COPY ./talemate_frontend /app
# Copy frontend package files
COPY talemate_frontend/package*.json ./
# Install all dependencies and build
RUN npm install && npm run build
# Install dependencies
RUN npm ci
# Copy frontend source
COPY talemate_frontend/ ./
# Build frontend
RUN npm run build
# Stage 2: Backend build
FROM python:3.11-slim AS backend-build
@@ -22,30 +26,25 @@ RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Install poetry
RUN pip install poetry
# Install uv
RUN pip install uv
# Copy poetry files
COPY pyproject.toml poetry.lock* /app/
# Copy installation files
COPY pyproject.toml uv.lock /app/
# Create a virtual environment
RUN python -m venv /app/talemate_env
# Activate virtual environment and install dependencies
RUN . /app/talemate_env/bin/activate && \
poetry config virtualenvs.create false && \
poetry install --only main --no-root
# Copy the Python source code
# Copy the Python source code (needed for editable install)
COPY ./src /app/src
# Create virtual environment and install dependencies
RUN uv sync
# Conditional PyTorch+CUDA install
ARG CUDA_AVAILABLE=false
RUN . /app/talemate_env/bin/activate && \
RUN . /app/.venv/bin/activate && \
if [ "$CUDA_AVAILABLE" = "true" ]; then \
echo "Installing PyTorch with CUDA support..." && \
pip uninstall torch torchaudio -y && \
pip install torch~=2.4.1 torchaudio~=2.4.1 --index-url https://download.pytorch.org/whl/cu121; \
uv pip uninstall torch torchaudio && \
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
fi
# Stage 3: Final image
@@ -57,8 +56,11 @@ RUN apt-get update && apt-get install -y \
bash \
&& rm -rf /var/lib/apt/lists/*
# Install uv in the final stage
RUN pip install uv
# Copy virtual environment from backend-build stage
COPY --from=backend-build /app/talemate_env /app/talemate_env
COPY --from=backend-build /app/.venv /app/.venv
# Copy Python source code
COPY --from=backend-build /app/src /app/src
@@ -83,4 +85,4 @@ EXPOSE 5050
EXPOSE 8080
# Use bash as the shell, activate the virtual environment, and run backend server
CMD ["/bin/bash", "-c", "source /app/talemate_env/bin/activate && python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050 --frontend-host 0.0.0.0 --frontend-port 8080"]
CMD ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]

View File

@@ -39,12 +39,14 @@ Need help? Join the new [Discord community](https://discord.gg/8bGNRmFxMj)
- [Cohere](https://www.cohere.com/)
- [Groq](https://www.groq.com/)
- [Google Gemini](https://console.cloud.google.com/)
- [OpenRouter](https://openrouter.ai/)
Supported self-hosted APIs:
- [KoboldCpp](https://koboldai.org/cpp) ([Local](https://koboldai.org/cpp), [Runpod](https://koboldai.org/runpodcpp), [VastAI](https://koboldai.org/vastcpp), also includes image gen support)
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
- [LMStudio](https://lmstudio.ai/)
- [TabbyAPI](https://github.com/theroyallab/tabbyAPI/)
- [Ollama](https://ollama.com/)
Generic OpenAI api implementations (tested and confirmed working):
- [DeepInfra](https://deepinfra.com/)

View File

@@ -18,4 +18,4 @@ services:
environment:
- PYTHONUNBUFFERED=1
- PYTHONPATH=/app/src:$PYTHONPATH
command: ["/bin/bash", "-c", "source /app/talemate_env/bin/activate && python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050 --frontend-host 0.0.0.0 --frontend-port 8080"]
command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]

View File

@@ -0,0 +1,166 @@
# Adding a new world-state template
I am writing this up as I add phrase detection functionality to the `Writing Style` template, so that in the future, hopefully when new template types need to be added this document can just given to the LLM of the month, to do it.
## Introduction
World state templates are reusable components that plug in various parts of talemate.
At this point there are following types:
- Character Attribute
- Character Detail
- Writing Style
- Spice (for randomization of content during generation)
- Scene Type
- State Reinforcement
Basically whenever we want to add something reusable and customizable by the user, a world state template is likely a good solution.
## Steps to creating a new template type
### 1. Add a pydantic schema (python)
In `src/talemate/world_state/templates` create a new `.py` file with reasonable name.
In this example I am extending the `Writing Style` template to include phrase detection functionality, which will be used by the `Editor` agent to detect certain phrases and then act upon them.
There already is a `content.py` file - so it makes sense to just add this new functionality to this file.
```python
class PhraseDetection(pydantic.BaseModel):
phrase: str
instructions: str
# can be "unwanted" for now, more added later
classification: Literal["unwanted"] = "unwanted"
@register("writing_style")
class WritingStyle(Template):
description: str | None = None
phrases: list[PhraseDetection] = pydantic.Field(default_factory=list)
def render(self, scene: "Scene", character_name: str):
return self.formatted("instructions", scene, character_name)
```
If I were to create a new file I'd still want to read one of the existing files first to understand imports and style.
### 2. Add a vue component to allow management (vue, js)
Next we need to add a new vue component that exposes a UX for us to manage this new template type.
For this I am creating `talemate_frontend/src/components/WorldStateManagerTemplateWritingStyle.vue`.
## Bare Minimum Understanding for New Template Components
When adding a new component for managing a template type, you need to understand:
### Component Structure
1. **Props**: The component always receives an `immutableTemplate` prop with the template data.
2. **Data Management**: Create a local copy of the template data for editing before saving back.
3. **Emits**: Use the `update` event to send modified template data back to the parent.
### Core Implementation Requirements
1. **Template Properties**: Always include fields for `name`, `description`, and `favorite` status.
2. **Data Binding**: Implement two-way binding with `v-model` for all editable fields.
3. **Dirty State Tracking**: Track when changes are made but not yet saved.
4. **Save Method**: Implement a `save()` method that emits the updated template.
### Component Lifecycle
1. **Initialization**: Use the `created` hook to initialize the local template copy.
2. **Watching for Changes**: Set up a watcher for the `immutableTemplate` to handle external updates.
### UI Patterns
1. **Forms**: Use Vuetify form components with consistent validation.
2. **Actions**: Provide clear user actions for editing and managing template items.
3. **Feedback**: Give visual feedback when changes are being made or saved.
The WorldStateManagerTemplate components follow a consistent pattern where they:
- Display and edit general template metadata (name, description, favorite status)
- Provide specialized UI for the template's unique properties
- Handle the create, read, update, delete (CRUD) operations for template items
- Maintain data integrity by properly handling template updates
You absolutely should read an existing component like `WorldStateManagerTemplateWritingStyle.vue` first to get a good understanding of the implementation.
## Integrating with WorldStateManagerTemplates
After creating your template component, you need to integrate it with the WorldStateManagerTemplates component:
### 1. Import the Component
Edit `talemate_frontend/src/components/WorldStateManagerTemplates.vue` and add an import for your new component:
```javascript
import WorldStateManagerTemplateWritingStyle from './WorldStateManagerTemplateWritingStyle.vue'
```
### 2. Register the Component
Add your component to the components section of the WorldStateManagerTemplates:
```javascript
components: {
// ... existing components
WorldStateManagerTemplateWritingStyle
}
```
### 3. Add Conditional Rendering
In the template section, add a new conditional block to render your component when the template type matches:
```html
<WorldStateManagerTemplateWritingStyle v-else-if="template.template_type === 'writing_style'"
:immutableTemplate="template"
@update="(template) => applyAndSaveTemplate(template)"
/>
```
### 4. Add Icon and Color
Add cases for your template type in the `iconForTemplate` and `colorForTemplate` methods:
```javascript
iconForTemplate(template) {
// ... existing conditions
else if (template.template_type == 'writing_style') {
return 'mdi-script-text';
}
return 'mdi-cube-scan';
},
colorForTemplate(template) {
// ... existing conditions
else if (template.template_type == 'writing_style') {
return 'highlight5';
}
return 'grey';
}
```
### 5. Add Help Message
Add a help message for your template type in the `helpMessages` object in the data section:
```javascript
helpMessages: {
// ... existing messages
writing_style: "Writing style templates are used to define a writing style that can be applied to the generated content. They can be used to add a specific flavor or tone. A template must explicitly support writing styles to be able to use a writing style template.",
}
```
### 6. Update Template Type Selection
Add your template type to the `templateTypes` array in the data section:
```javascript
templateTypes: [
// ... existing types
{ "title": "Writing style", "value": 'writing_style'},
]
```

View File

@@ -10,20 +10,19 @@ To run the server on a different host and port, you need to change the values pa
#### :material-linux: Linux
Copy `start.sh` to `start_custom.sh` and edit the `--host` and `--port` parameters in the `uvicorn` command.
Copy `start.sh` to `start_custom.sh` and edit the `--host` and `--port` parameters.
```bash
#!/bin/sh
. talemate_env/bin/activate
python src/talemate/server/run.py runserver --host 0.0.0.0 --port 1234
uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 1234
```
#### :material-microsoft-windows: Windows
Copy `start.bat` to `start_custom.bat` and edit the `--host` and `--port` parameters in the `uvicorn` command.
Copy `start.bat` to `start_custom.bat` and edit the `--host` and `--port` parameters.
```batch
start cmd /k "cd talemate_env\Scripts && activate && cd ../../ && python src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234"
uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
```
### Letting the frontend know about the new host and port
@@ -71,8 +70,7 @@ Copy `start.sh` to `start_custom.sh` and edit the `--frontend-host` and `--front
```bash
#!/bin/sh
. talemate_env/bin/activate
python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
--frontend-host localhost --frontend-port 8082
```
@@ -81,7 +79,7 @@ python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
Copy `start.bat` to `start_custom.bat` and edit the `--frontend-host` and `--frontend-port` parameters.
```batch
start cmd /k "cd talemate_env\Scripts && activate && cd ../../ && python src\talemate\server\run.py runserver --host 0.0.0.0 --port 5055 --frontend-host localhost --frontend-port 8082"
uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 5055 --frontend-host localhost --frontend-port 8082
```
### Start the backend and frontend
@@ -98,5 +96,4 @@ Start the backend and frontend as usual.
```batch
start_custom.bat
```
```

View File

@@ -1,4 +1,3 @@
## Quick install instructions
### Dependencies
@@ -7,6 +6,7 @@
1. node.js and npm - see instructions [here](https://nodejs.org/en/download/package-manager/)
1. python- see instructions [here](https://www.python.org/downloads/)
1. uv - see instructions [here](https://github.com/astral-sh/uv#installation)
### Installation
@@ -25,19 +25,15 @@ If everything went well, you can proceed to [connect a client](../../connect-a-c
1. Open a terminal.
2. Navigate to the project directory.
3. Create a virtual environment by running `python3 -m venv talemate_env`.
4. Activate the virtual environment by running `source talemate_env/bin/activate`.
3. uv will automatically create a virtual environment when you run `uv venv`.
### Installing Dependencies
1. With the virtual environment activated, install poetry by running `pip install poetry`.
2. Use poetry to install dependencies by running `poetry install`.
1. Use uv to install dependencies by running `uv pip install -e ".[dev]"`.
### Running the Backend
1. With the virtual environment activated and dependencies installed, you can start the backend server.
2. Navigate to the `src/talemate/server` directory.
3. Run the server with `python run.py runserver --host 0.0.0.0 --port 5050`.
1. You can start the backend server using `uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
### Running the Frontend
@@ -45,4 +41,4 @@ If everything went well, you can proceed to [connect a client](../../connect-a-c
2. If you haven't already, install npm dependencies by running `npm install`.
3. Start the server with `npm run serve`.
Please note that you may need to set environment variables or modify the host and port as per your setup. You can refer to the `runserver.sh` and `frontend.sh` files for more details.
Please note that you may need to set environment variables or modify the host and port as per your setup. You can refer to the various start scripts for more details.

View File

@@ -2,16 +2,9 @@
## Windows
### Installation fails with "Microsoft Visual C++" or "ValueError: The onnxruntime python package is not installed." errors
If your installation errors with a notification to upgrade "Microsoft Visual C++" go to [https://visualstudio.microsoft.com/visual-cpp-build-tools/](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and click "Download Build Tools" and run it.
- During installation make sure you select the C++ development package (upper left corner)
- Run `reinstall.bat` inside talemate directory
### Frontend fails with errors
- ensure none of the directories have special characters in them, this can cause issues with the frontend. so no `(1)` in the directory name.
- ensure none of the directories leading to your talemate directory have special characters in them, this can cause issues with the frontend. so no `(1)` in the directory name.
## Docker

View File

@@ -1,53 +1,32 @@
## Quick install instructions
1. Download and install Python 3.10 - 3.13 from the [official Python website](https://www.python.org/downloads/windows/).
- [Click here for direct link to python 3.11.9 download](https://www.python.org/downloads/release/python-3119/)
- June 2025: people have reported issues with python 3.13 still, due to some dependencies not being available yet, if you run into issues during installation try downgrading.
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/prebuilt-installer). This will also install npm.
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases).
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine.
1. **Optional:** If you are using an nvidia graphics card with CUDA support you may want to also run `install-cuda.bat` **afterwards**, to install the cuda enabled version of torch - although this is only needed if you want to run some bigger embedding models where CUDA can be helpful.
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`.
1. Once the talemate logo shows up, navigate your browser to http://localhost:8080
1. Download the latest Talemate release ZIP from the [Releases page](https://github.com/vegu-ai/talemate/releases) and extract it anywhere on your system (for example, `C:\Talemate`).
2. Double-click **`start.bat`**.
- On the very first run Talemate will automatically:
1. Download a portable build of Python 3 and Node.js (no global installs required).
2. Create and configure a Python virtual environment.
3. Install all back-end and front-end dependencies with the included *uv* and *npm*.
4. Build the web client.
3. When the console window prints **"Talemate is now running"** and the logo appears, open your browser at **http://localhost:8080**.
!!! note "First start up may take a while"
We have seen cases where the first start of talemate will sit at a black screen for a minute or two. Just wait it out, eventually the Talemate logo should show up.
!!! note "First start can take a while"
The initial download and dependency installation may take several minutes, especially on slow internet connections. The console will keep you updated just wait until the Talemate logo shows up.
If everything went well, you can proceed to [connect a client](../../connect-a-client).
### Optional: CUDA support
## Additional Information
If you have an NVIDIA GPU and want CUDA acceleration for larger embedding models:
### How to Install Python
1. Close Talemate (if it is running).
2. Double-click **`install-cuda.bat`**. This script swaps the CPU-only Torch build for the CUDA 12.8 build.
3. Start Talemate again via **`start.bat`**.
--8<-- "docs/snippets/common.md:python-versions"
## Maintenance & advanced usage
1. Visit the official Python website's download page for Windows at [https://www.python.org/downloads/windows/](https://www.python.org/downloads/windows/).
2. Find the latest updated of Python 3.13 and click on one of the download links. (You will likely want the Windows installer (64-bit))
4. Run the installer file and follow the setup instructions. Make sure to check the box that says Add Python 3.13 to PATH before you click Install Now.
| Script | Purpose |
|--------|---------|
| **`start.bat`** | Primary entry point performs the initial install if needed and then starts Talemate. |
| **`install.bat`** | Runs the installer without launching the server. Useful for automated setups or debugging. |
| **`install-cuda.bat`** | Installs the CUDA-enabled Torch build (run after the regular install). |
| **`update.bat`** | Pulls the latest changes from GitHub, updates dependencies, rebuilds the web client. |
### How to Install npm
1. Download Node.js from the official site [https://nodejs.org/en/download/prebuilt-installer](https://nodejs.org/en/download/prebuilt-installer).
2. Run the installer (the .msi installer is recommended).
3. Follow the prompts in the installer (Accept the license agreement, click the NEXT button a bunch of times and accept the default installation settings).
### Usage of the Supplied bat Files
#### install.bat
This batch file is used to set up the project on your local machine. It creates a virtual environment, activates it, installs poetry, and uses poetry to install dependencies. It then navigates to the frontend directory and installs the necessary npm packages.
To run this file, simply double click on it or open a command prompt in the same directory and type `install.bat`.
#### update.bat
If you are inside a git checkout of talemate you can use this to pull and reinstall talemate if there have been updates.
!!! note "CUDA needs to be reinstalled manually"
Running `update.bat` will downgrade your torch install to the non-CUDA version, so if you want CUDA support you will need to run the `install-cuda.bat` script after the update is finished.
#### start.bat
This batch file is used to start the backend and frontend servers. It opens two command prompts, one for the frontend and one for the backend.
To run this file, simply double click on it or open a command prompt in the same directory and type `start.bat`.
No system-wide Python or Node.js is required Talemate uses the embedded runtimes it downloads automatically.

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

BIN
docs/img/0.31.0/history.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

View File

@@ -0,0 +1,25 @@
# KoboldCpp Embeddings
Talemate can leverage an embeddings model that is already loaded in your KoboldCpp instance.
## 1. Start KoboldCpp with an embeddings model
Launch KoboldCpp with the `--embeddingsmodel` flag so that it loads an embeddings-capable GGUF model alongside the main LLM:
```bash
koboldcpp_cu12.exe --model google_gemma-3-27b-it-Q6_K.gguf --embeddingsmodel bge-large-en-v1.5.Q8_0.gguf
```
## 2. Talemate will pick it up automatically
When Talemate starts, the **Memory** agent probes every connected client that advertises embedding support. If it detects that your KoboldCpp instance has an embeddings model loaded:
1. The Memory backend switches the current embedding to **Client API**.
2. The `client` field in the agent details shows the name of the KoboldCpp client.
3. A banner informs you that Talemate has switched to the new embedding. <!-- stub: screenshot -->
![Memory agent automatically switched to KoboldCpp embeddings](/talemate/img/0.31.0/koboldcpp-embeddings.png)
## 3. Reverting to a local embedding
Open the memory agent settings and pick a different embedding. See [Memory agent settings](/talemate/user-guide/agents/memory/settings).

View File

@@ -5,5 +5,6 @@ nav:
- Google Cloud: google.md
- Groq: groq.md
- Mistral.ai: mistral.md
- OpenRouter: openrouter.md
- OpenAI: openai.md
- ...

View File

@@ -0,0 +1,11 @@
# OpenRouter API Setup
Talemate can use any model accessible through OpenRouter.
You need an OpenRouter API key and must set it in the application config. You can create and manage keys in your OpenRouter dashboard at [https://openrouter.ai/keys](https://openrouter.ai/keys).
Once you have generated a key open the Talemate settings, switch to the `APPLICATION` tab and then select the `OPENROUTER API` category. Paste your key in the **API Key** field.
![OpenRouter API settings](/talemate/img/0.31.0/openrouter-settings.png)
Finally click **Save** to store the credentials.

View File

@@ -4,4 +4,5 @@ nav:
- Recommended Local Models: recommended-models.md
- Inference Presets: presets.md
- Client Types: types
- Endpoint Override: endpoint-override.md
- ...

View File

@@ -0,0 +1,24 @@
# Endpoint Override
Starting in version 0.31.0 it is now possible for some of the remote clients to override the endpoint used for the API.
THis is helpful wehn you want to point the client at a proxy gateway to serve the api instead (LiteLLM for example).
!!! warning "Only use trusted endpoints"
Only use endpoints that you trust and NEVER used your actual API key with them, unless you are hosting your endpoint proxy yourself.
If you need to provide an api key there is a separate field for that specifically in the endpoint override settings.
## How to use
Clients that support it will have a tab in their settings that allows you to override the endpoint.
![Endpoint Override](/talemate/img/0.31.0/client-endpoint-override.png)
##### Base URL
The base URL of the endpoint. For example, `http://localhost:4000` if you're running a local LiteLLM gateway,
##### API Key
The API key to use for the endpoint. This is only required if the endpoint requires an API key. This is **NOT** the API key you would use for the official API. For LiteLLM for example this could be the `general_settings.master_key` value.

View File

@@ -8,6 +8,8 @@ nav:
- Mistral.ai: mistral.md
- OpenAI: openai.md
- OpenAI Compatible: openai-compatible.md
- Ollama: ollama.md
- OpenRouter: openrouter.md
- TabbyAPI: tabbyapi.md
- Text-Generation-WebUI: text-generation-webui.md
- ...

View File

@@ -0,0 +1,59 @@
# Ollama Client
If you want to add an Ollama client, change the `Client Type` to `Ollama`.
![Client Ollama](/talemate/img/0.31.0/client-ollama.png)
Click `Save` to add the client.
### Ollama Server
The client should appear in the clients list. Talemate will ping the Ollama server to verify that it is running. If the server is not reachable you will see a warning.
![Client ollama offline](/talemate/img/0.31.0/client-ollama-offline.png)
Make sure that the Ollama server is running (by default at `http://localhost:11434`) and that the model you want to use has been pulled.
It may also show a yellow dot next to it, saying that there is no model loaded.
![Client ollama no model](/talemate/img/0.31.0/client-ollama-no-model.png)
Open the client settings by clicking the :material-cogs: icon, to select a model.
![Ollama settings](/talemate/img/0.31.0/client-ollama-select-model.png)
Click save and the client should have a green dot next to it, indicating that it is ready to go.
![Client ollama ready](/talemate/img/0.31.0/client-ollama-ready.png)
### Settings
##### Client Name
A unique name for the client that makes sense to you.
##### API URL
The base URL where the Ollama HTTP endpoint is running. Defaults to `http://localhost:11434`.
##### Model
Name of the Ollama model to use. Talemate will automatically fetch the list of models that are currently available in your local Ollama instance.
##### API handles prompt template
If enabled, Talemate will send the raw prompt and let Ollama apply its own built-in prompt template. If you are unsure leave this disabled Talemate's own prompt template generally produces better results.
##### Allow thinking
If enabled Talemate will allow models that support "thinking" (`assistant:thinking` messages) to deliberate before forming the final answer. At the moment Talemate has limited support for this feature when talemate is handling the prompt template. Its probably ok to turn it on if you let Ollama handle the prompt template.
!!! tip
You can quickly refresh the list of models by making sure the Ollama server is running and then hitting **Save** again in the client settings.
### Common issues
#### Generations are weird / bad
If letting talemate handle the prompt template, make sure the [correct prompt template is assigned](/talemate/user-guide/clients/prompt-templates/).

View File

@@ -0,0 +1,48 @@
# OpenRouter Client
If you want to add an OpenRouter client, change the `Client Type` to `OpenRouter`.
![Client OpenRouter](/talemate/img/0.31.0/client-openrouter.png)
Click `Save` to add the client.
### OpenRouter API Key
The client should appear in the clients list. If you haven't set up OpenRouter before, you will see a warning that the API key is missing.
![Client openrouter no api key](/talemate/img/0.31.0/client-openrouter-no-api-key.png)
Click the `SET API KEY` button. This will open the API settings window where you can add your OpenRouter API key.
For additional instructions on obtaining and setting your OpenRouter API key, see [OpenRouter API instructions](/talemate/user-guide/apis/openrouter/).
![OpenRouter settings](/talemate/img/0.31.0/openrouter-settings.png)
Click `Save` and after a moment the client should have a red dot next to it, saying that there is no model loaded.
Click the :material-cogs: icon to open the client settings and select a model.
![OpenRouter select model](/talemate/img/0.31.0/client-openrouter-select-model.png).
Click save and the client should have a green dot next to it, indicating that it is ready to go.
### Ready to use
![Client OpenRouter Ready](/talemate/img/0.31.0/client-openrouter-ready.png)
### Settings
##### Client Name
A unique name for the client that makes sense to you.
##### Model
Choose any model available via your OpenRouter account. Talemate dynamically fetches the list of models associated with your API key so new models will show up automatically.
##### Max token length
Maximum context length (in tokens) that OpenRouter should consider. If you are not sure leave the default value.
!!! note "Available models are fetched automatically"
Talemate fetches the list of available OpenRouter models when you save the configuration (if a valid API key is present). If you add or remove models to your account later, simply click **Save** in the application settings again to refresh the list.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.7 KiB

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.2 KiB

After

Width:  |  Height:  |  Size: 7.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -2,17 +2,6 @@
This tutorial will show you how to use the `Dynamic Storyline` module (added in `0.30`) to randomize the scene introduction for ANY scene.
!!! note "A more streamlined approach is coming soon"
I am aware that some people may not want to touch the node editor at all, so a more streamlined approach is planned.
For now this will lay out the simplest way to set this up while still using the node editor.
!!! learn-more "For those interested..."
There is tutorial on how the `Dynamic Storyline` module was made (or at least the beginnings of it).
If you are interested in the process, you can find it [here](/talemate/user-guide/howto/infinity-quest-dynamic).
## Save a foundation scene copy
This should be a save of your scene that has had NO progress made to it yet. We are generating a new scene introduction after all.
@@ -21,59 +10,52 @@ The introduction is only generated once. So you should maintain a save-file of t
To ensure this foundation scene save isn't overwritten you can go to the scene settings in the world editor and turn on the Locked save file flag:
![Immutable save](./img/0008.png)
![Immutable save](./img/0001.png)
Save the scene.
## Switch to the node editor
## Install the module
In your scene tools find the :material-puzzle-edit: creative menu and click on the **Node Editor** option.
Click the `Mods` tab in the world editor.
![Mods Tab](./img/0002.png)
![Node Editor](./img/0001.png)
Find the `COPY AS EDITABLE MODULE FOR ..` button beneath the node editor.
Find the `Dynamic Storyline` module and click **Install**.
![Copy as editable module](./img/0002.png)
It will say installed (not configured)
Click it.
![Installed (not configured)](./img/0003.png)
In the next window, don't even read any of the stuff, just click **Continue**.
Click **Configure** and set topic to something like `Sci-fi adventure with lovecraftian horror`.
## Find a blank area
![Configure Module](./img/0004.png)
Use the mousewheel to zoom out a bit, then click the canvas and drag it to the side so you're looking at some blank space. Literally anywhere that's grey background is fine.
!!! note "Optional settings"
Double click the empty area to bring up the module searcand type in "Dynamic Story" into th
##### Max intro text length
How many tokens to generate for the intro text.
![Dynamic Story](./img/0003.png)
##### Additional instructions for topic analysis task
If topic analysis is enabled, this will be used to augment the topic analysis task with further instructions
Select the `Dynamic Storyline` node to add it to the scene.
##### Enable topic analysis
This will enable the topic analysis task
![Dynamic Story](./img/0004.png)
**Save** the module configuration.
Click the `topic` input and type in a general genre or thematic guide for the story.
Finally click "Reload Scene" in the left sidebar.
Some examples
![Reload Scene](./img/0007.png)
- `sci-fi with cosmic horror elements`
- `dungeons and dragons campaign ideas`
- `slice of life story ideas`
If everything is configured correctly, the storyline generation will begin immediately.
Whatever you enter will be used to generate a list of story ideas, of which one will be chosen at random to bootstrap a new story, taking the scene context that exists already into account.
![Dynamic Storyline Module Configured](./img/0005.png)
This will NOT create new characters or world context.
!!! note "Switch out of edit mode"
It simply bootstraps a story premise based on the random topic and what's already there.
Once the topic is set, save the changes by clicking the node editor's **Save** button in the upper right corner.
![Save](./img/0005.png)
Exit the node editor through the same menu as before.
![Exit node editor](./img/0006.png)
Once back in the scene, if everythign was done correctly you should see it working on setting the scene introduction.
![Scene introduction](./img/0007.png)
If nothing is happening after configuration and reloading the scene, make sure you are not in edit mode.
You can leave edit mode by clicking the "Exit Node Editor" button in the creative menu.
![Exit Node Editor](./img/0006.png)

View File

@@ -0,0 +1,105 @@
# Installable Packages
It is possible to "package" your node modules so they can be installed into a scene.
This allows for easier controlled set up and makes your node module more sharable as users no longer need to use the node editor to install it.
Installable packages show up in the Mods list once a scene is loaded.
![Mods List](../img/package-0003.png)
## 1. Create a package module
To create a package - click the **:material-plus: Create Module** button in the node editor and select **Package**.
![Create Package Module](../img/package-0001.png)
The package itself is a special kind of node module that will let Talemate know that your node module is installable and how to install it.
## 2. Open the module properties
With the package module open find the module properties in the upper left corner of the node editor.
![Module Properties](../img/package-0002.png)
Fill in the fields:
##### The name of the node module
This is what the module package will be called in the Mods list.
Example:
```
Dynamic Storyline
```
##### The author of the node module
Your name or handle. This is arbitrary and just lets people know who made the package.
##### The description of the node module
A short description of the package. This is displayed in the Mods list.
Example:
```
Generate a random story premise at the beginning of the scene.
```
##### Whether the node module is installable to the scene
A checkbox to indicate if the package is installable to the scene.
Right now this should always be checked, there are no other package types currently.
##### Whether the scene loop should be restarted when the package is installed
If checked, installing this package will restart the scene loop. This is mostly important for modules that require to hook into the scene loop init event.
## 3. Install instructions
Currently there are only two nodes relevant for the node module install process.
1. `Install Node Module` - this node is used to make sure the target node module is added to the scene loop when installing the package. You can have more than one of these nodes in your package.
1. `Promote Config` - this node is used to promote your node module's properties to configurable fields in the mods list. E.g., this dictates what the user can configure when installing the package.
### Install Node Module
![Install Node Module](../img/package-0004.png)
!!! payload "Install Node Module"
| Property | Value |
|----------|-------|
| node_registry | the registry path of the node module to install |
### Promote Config
![Promote Config](../img/package-0005.png)
!!! payload "Promote Config"
| Property | Value |
|----------|-------|
| node_registry | the registry path of the node module |
| property_name | the name of the property to promote (as it is set in the node module) |
| exposed_property_name | expose as this name in the mods list, this can be the same as the property name or a different name - this is important if youre installing multiple node modules with the same property name, so you can differentiate between them |
| required | whether the property is required to be set when installing the package |
| label | a user friendly label for the property |
### Make talemate aware of the package
For talemate to be aware of the package, you need to copy it to the public node module directory, which exists as `templates/modules/`.
Create a new sub directory:
```
./templates/modules/<your-package-name>/
```
Copy the package module and your node module files into the directory.
Restart talemate and the package should now show up in the Mods list.

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@@ -1,14 +1,63 @@
# History
Will hold the archived history for the scene.
Will hold historical events for the scene.
This historical archive is extended everytime the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) summarizes the scene.
There are three types of historical entries:
Summarization happens when a certain progress treshold is reached (tokens) or when a time passage event occurs.
- **Archived history (static)** - These are entries are manually defined and dated before the starting point of the scene. Things that happened in the past that are **IMPORTANT** for the understanding of the world. For anything that is not **VITAL**, use world entries instead.
- **Archived history (from summary)** - These are historical entries generated from the progress in the scene. Whenever a certain token (length) threshold is reached, the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) will generate a summary of the progression and add it to the history.
- **Layered history (from summary)** - As summary archives are generated, they themselves will be summarized again and added to the history, leading to a natural compression of the history sent with the context while also keeping track of the most important events. (hopefully)
All archived history is a potential candidate to be included in the context sent to the AI based on relevancy. This is handled by the [Memory Agent](/talemate/user-guide/agents/memory/).
![History](/talemate/img/0.31.0/history.png)
You can use the **:material-refresh: Regenerate History** button to force a new summarization of the scene.
## Layers
!!! warning
If there has been lots of progress this will potentially take a long time to complete.
There is always the **BASE** layer, which is where the archived history (both static and from summary) is stored. For all intents and purposes, this is layer 0.
At the beginning of a scene, there won't be any additional layers, as any layer past layer 0 will come from summarization down the line.
Note that layered history is managed by the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) and can be disabled in its settings.
### Managing entries
- **All entries** can be edited by double-clicking the text.
- **Static entries** can be deleted by clicking the **:material-close-box-outline: Delete** button.
- **Summarized entries** can be regenerated by clicking the **:material-refresh: Regenerate** button. This will cause the LLM to re-summarize the entry and update the text.
- **Summarized entries** can be inspected by clicking the **:material-magnify-expand: Inspect** button. This will expand the entry and show the source entries that were used to generate the summary.
### Adding static entries
Static entries can be added by clicking the **:material-plus: Add Entry** button.
!!! note "Static entries must be older than any summary entries"
Static entries must be older than any summary entries. This is to ensure that the history is always chronological.
Trying to add a static entry that is more recent than any summary entry will result in an error.
##### Entry Text
The text of the entry. Should be at most 1 - 2 paragraphs. Less is more. Anything that needs great detail should be a world entry instead.
##### Unit
Defines the duration unit of the entry. So minutes, hours, days, weeks, months or years.
##### Amount
Defines the duration unit amount of the entry.
So if you want to define something that happened 10 months ago (from the current moment in the scene), you would set the unit to months and the amount to 10.
![Add Static Entry](/talemate/img/0.31.0/history-add-entry.png)
## Regenerate everything
It is possible to regenerate the entire history by clicking the **:material-refresh: Regenerate All History** button in the left sidebar. Static entries will remain unchanged.
![Regenerate All History](/talemate/img/0.31.0/history-regenerate-all.png)
!!! warning "This can take a long time"
This will go through the entire scene progress and regenerate all summarized entries.
If you have a lot of progress, be ready to wait for a while.

View File

@@ -1,8 +1,23 @@
REM activate the virtual environment
call talemate_env\Scripts\activate
@echo off
REM uninstall torch and torchaudio
python -m pip uninstall torch torchaudio -y
REM Check if .venv exists
IF NOT EXIST ".venv" (
echo [ERROR] .venv directory not found. Please run install.bat first.
goto :eof
)
REM install torch and torchaudio
python -m pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
REM Check if embedded Python exists
IF NOT EXIST "embedded_python\python.exe" (
echo [ERROR] embedded_python not found. Please run install.bat first.
goto :eof
)
REM uninstall torch and torchaudio using embedded Python's uv
embedded_python\python.exe -m uv pip uninstall torch torchaudio --python .venv\Scripts\python.exe
REM install torch and torchaudio with CUDA support using embedded Python's uv
embedded_python\python.exe -m uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128 --python .venv\Scripts\python.exe
echo.
echo CUDA versions of torch and torchaudio installed!
echo You may need to restart your application for changes to take effect.

View File

@@ -1,10 +1,7 @@
#!/bin/bash
# activate the virtual environment
source talemate_env/bin/activate
# uninstall torch and torchaudio
python -m pip uninstall torch torchaudio -y
uv pip uninstall torch torchaudio
# install torch and torchaudio
python -m pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
# install torch and torchaudio with CUDA support
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128

View File

@@ -1,4 +0,0 @@
REM activate the virtual environment
call talemate_env\Scripts\activate
call pip install "TTS>=0.21.1"

28579
install-utils/get-pip.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,65 +1,227 @@
@echo off
REM Check for Python version and use a supported version if available
SET PYTHON=python
python -c "import sys; sys.exit(0 if sys.version_info[:2] in [(3, 10), (3, 11), (3, 12), (3, 13)] else 1)" 2>nul
IF NOT ERRORLEVEL 1 (
echo Selected Python version: %PYTHON%
GOTO EndVersionCheck
)
REM ===============================
REM Talemate project installer
REM ===============================
REM 1. Detect CPU architecture and pick the best-fitting embedded Python build.
REM 2. Download & extract that build into .\embedded_python\
REM 3. Bootstrap pip via install-utils\get-pip.py
REM 4. Install virtualenv and create .\talemate_env\ using the embedded Python.
REM 5. Activate the venv and proceed with Poetry + frontend installation.
REM ---------------------------------------------------------------
SET PYTHON=python
FOR /F "tokens=*" %%i IN ('py --list') DO (
echo %%i | findstr /C:"-V:3.11 " >nul && SET PYTHON=py -3.11 && GOTO EndPythonCheck
echo %%i | findstr /C:"-V:3.10 " >nul && SET PYTHON=py -3.10 && GOTO EndPythonCheck
)
:EndPythonCheck
%PYTHON% -c "import sys; sys.exit(0 if sys.version_info[:2] in [(3, 10), (3, 11), (3, 12), (3, 13)] else 1)" 2>nul
IF ERRORLEVEL 1 (
echo Unsupported Python version. Please install Python 3.10 or 3.11.
exit /b 1
)
IF "%PYTHON%"=="python" (
echo Default Python version is being used: %PYTHON%
) ELSE (
echo Selected Python version: %PYTHON%
)
SETLOCAL ENABLEDELAYEDEXPANSION
:EndVersionCheck
REM Define fatal-error handler
REM Usage: CALL :die "Message explaining what failed"
goto :after_die
IF ERRORLEVEL 1 (
echo Unsupported Python version. Please install Python 3.10 or 3.11.
exit /b 1
)
REM create a virtual environment
%PYTHON% -m venv talemate_env
REM activate the virtual environment
call talemate_env\Scripts\activate
REM upgrade pip and setuptools
python -m pip install --upgrade pip setuptools
REM install poetry
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies
python -m poetry install
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
IF NOT EXIST config.yaml copy config.example.yaml config.yaml
REM navigate to the frontend directory
echo Installing frontend dependencies...
cd talemate_frontend
call npm install
echo Building frontend...
call npm run build
REM return to the root directory
cd ..
echo Installation completed successfully.
:die
echo.
echo ============================================================
echo !!! INSTALL FAILED !!!
echo %*
echo ============================================================
pause
exit 1
:after_die
REM ---------[ Check Prerequisites ]---------
ECHO Checking prerequisites...
where tar >nul 2>&1 || CALL :die "tar command not found. Please ensure Windows 10 version 1803+ or install tar manually."
where curl >nul 2>&1
IF %ERRORLEVEL% NEQ 0 (
where bitsadmin >nul 2>&1 || CALL :die "Neither curl nor bitsadmin found. Cannot download files."
)
REM ---------[ Remove legacy Poetry venv if present ]---------
IF EXIST "talemate_env" (
ECHO Detected legacy Poetry virtual environment 'talemate_env'. Removing...
RD /S /Q "talemate_env"
IF ERRORLEVEL 1 (
ECHO [WARNING] Failed to fully remove legacy 'talemate_env' directory. Continuing installation.
)
)
REM ---------[ Clean reinstall check ]---------
SET "NEED_CLEAN=0"
IF EXIST ".venv" SET "NEED_CLEAN=1"
IF EXIST "embedded_python" SET "NEED_CLEAN=1"
IF EXIST "embedded_node" SET "NEED_CLEAN=1"
IF "%NEED_CLEAN%"=="1" (
ECHO.
ECHO Detected existing Talemate environments.
REM Prompt user (empty input defaults to Y)
SET "ANSWER=Y"
SET /P "ANSWER=Perform a clean reinstall of the python and node.js environments? [Y/n] "
IF /I "!ANSWER!"=="N" (
ECHO Installation aborted by user.
GOTO :EOF
)
ECHO Removing previous installation...
IF EXIST ".venv" RD /S /Q ".venv"
IF EXIST "embedded_python" RD /S /Q "embedded_python"
IF EXIST "embedded_node" RD /S /Q "embedded_node"
ECHO Cleanup complete.
)
REM ---------[ Version configuration ]---------
SET "PYTHON_VERSION=3.11.9"
SET "NODE_VERSION=22.16.0"
REM ---------[ Detect architecture & choose download URL ]---------
REM Prefer PROCESSOR_ARCHITEW6432 when the script is run from a 32-bit shell on 64-bit Windows
IF DEFINED PROCESSOR_ARCHITEW6432 (
SET "ARCH=%PROCESSOR_ARCHITEW6432%"
) ELSE (
SET "ARCH=%PROCESSOR_ARCHITECTURE%"
)
REM Map architecture to download URL
IF /I "%ARCH%"=="AMD64" (
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x64.zip"
) ELSE IF /I "%ARCH%"=="IA64" (
REM Itanium systems are rare, but AMD64 build works with WoW64 layer
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x64.zip"
) ELSE IF /I "%ARCH%"=="ARM64" (
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-arm64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-arm64.zip"
) ELSE (
REM Fallback to 64-bit build for x86 / unknown architectures
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x86.zip"
)
ECHO Detected architecture: %ARCH%
ECHO Downloading embedded Python from: %PY_URL%
REM ---------[ Download ]---------
SET "PY_ZIP=python_embed.zip"
where curl >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Using curl to download Python...
curl -L -# -o "%PY_ZIP%" "%PY_URL%" || CALL :die "Failed to download Python embed package with curl."
) ELSE (
ECHO curl not found, falling back to bitsadmin...
bitsadmin /transfer "DownloadPython" /download /priority normal "%PY_URL%" "%CD%\%PY_ZIP%" || CALL :die "Failed to download Python embed package (curl & bitsadmin unavailable)."
)
REM ---------[ Extract ]---------
SET "PY_DIR=embedded_python"
IF EXIST "%PY_DIR%" RD /S /Q "%PY_DIR%"
mkdir "%PY_DIR%" || CALL :die "Could not create directory %PY_DIR%."
where tar >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Extracting with tar...
tar -xf "%PY_ZIP%" -C "%PY_DIR%" || CALL :die "Failed to extract Python embed package with tar."
) ELSE (
CALL :die "tar utility not found (required to unpack zip without PowerShell)."
)
DEL /F /Q "%PY_ZIP%"
SET "PYTHON=%PY_DIR%\python.exe"
ECHO Using embedded Python at %PYTHON%
REM ---------[ Enable site-packages in embedded Python ]---------
FOR %%f IN ("%PY_DIR%\python*._pth") DO (
ECHO Adding 'import site' to %%~nxf ...
echo import site>>"%%~ff"
)
REM ---------[ Ensure pip ]---------
ECHO Installing pip...
"%PYTHON%" install-utils\get-pip.py || (
CALL :die "pip installation failed."
)
REM Upgrade pip to latest
"%PYTHON%" -m pip install --no-warn-script-location --upgrade pip || CALL :die "Failed to upgrade pip in embedded Python."
REM ---------[ Install uv ]---------
ECHO Installing uv...
"%PYTHON%" -m pip install uv || (
CALL :die "uv installation failed."
)
REM ---------[ Create virtual environment with uv ]---------
ECHO Creating virtual environment with uv...
"%PYTHON%" -m uv venv || (
CALL :die "Virtual environment creation failed."
)
REM ---------[ Install dependencies using embedded Python's uv ]---------
ECHO Installing backend dependencies with uv...
"%PYTHON%" -m uv sync || CALL :die "Failed to install backend dependencies with uv."
REM Activate the venv for the remainder of the script
CALL .venv\Scripts\activate
REM echo python version
python --version
REM ---------[ Config file ]---------
IF NOT EXIST config.yaml COPY config.example.yaml config.yaml
REM ---------[ Node.js portable runtime ]---------
ECHO.
ECHO Downloading portable Node.js runtime...
REM Node download variables already set earlier based on %ARCH%.
ECHO Downloading Node.js from: %NODE_URL%
SET "NODE_ZIP=node_embed.zip"
where curl >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Using curl to download Node.js...
curl -L -# -o "%NODE_ZIP%" "%NODE_URL%" || CALL :die "Failed to download Node.js package with curl."
) ELSE (
ECHO curl not found, falling back to bitsadmin...
bitsadmin /transfer "DownloadNode" /download /priority normal "%NODE_URL%" "%CD%\%NODE_ZIP%" || CALL :die "Failed to download Node.js package (curl & bitsadmin unavailable)."
)
REM ---------[ Extract Node.js ]---------
SET "NODE_DIR=embedded_node"
IF EXIST "%NODE_DIR%" RD /S /Q "%NODE_DIR%"
mkdir "%NODE_DIR%" || CALL :die "Could not create directory %NODE_DIR%."
where tar >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Extracting Node.js...
tar -xf "%NODE_ZIP%" -C "%NODE_DIR%" --strip-components 1 || CALL :die "Failed to extract Node.js package with tar."
) ELSE (
CALL :die "tar utility not found (required to unpack zip without PowerShell)."
)
DEL /F /Q "%NODE_ZIP%"
REM Prepend Node.js folder to PATH so npm & node are available
SET "PATH=%CD%\%NODE_DIR%;%PATH%"
ECHO Using portable Node.js at %CD%\%NODE_DIR%\node.exe
ECHO Node.js version:
node -v
REM ---------[ Frontend ]---------
ECHO Installing frontend dependencies...
CD talemate_frontend
CALL npm install || CALL :die "npm install failed."
ECHO Building frontend...
CALL npm run build || CALL :die "Frontend build failed."
REM Return to repo root
CD ..
ECHO.
ECHO ==============================
ECHO Installation completed!
ECHO ==============================
PAUSE
ENDLOCAL

View File

@@ -1,20 +1,16 @@
#!/bin/bash
# create a virtual environment
echo "Creating a virtual environment..."
python3 -m venv talemate_env
# create a virtual environment with uv
echo "Creating a virtual environment with uv..."
uv venv
# activate the virtual environment
echo "Activating the virtual environment..."
source talemate_env/bin/activate
source .venv/bin/activate
# install poetry
echo "Installing poetry..."
pip install poetry
# use poetry to install dependencies
# install dependencies with uv
echo "Installing dependencies..."
poetry install
uv pip install -e ".[dev]"
# copy config.example.yaml to config.yaml only if config.yaml doesn't exist
if [ ! -f config.yaml ]; then

6554
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,77 +1,82 @@
[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
[tool.poetry]
[project]
name = "talemate"
version = "0.30.0"
version = "0.31.0"
description = "AI-backed roleplay and narrative tools"
authors = ["VeguAITools"]
license = "GNU Affero General Public License v3.0"
authors = [{name = "VeguAITools"}]
license = {text = "GNU Affero General Public License v3.0"}
requires-python = ">=3.10,<3.14"
dependencies = [
"astroid>=2.8",
"jedi>=0.18",
"black",
"rope>=0.22",
"isort>=5.10",
"jinja2>=3.0",
"openai>=1",
"mistralai>=0.1.8",
"cohere>=5.2.2",
"anthropic>=0.19.1",
"groq>=0.5.0",
"requests>=2.26",
"colorama>=0.4.6",
"Pillow>=9.5",
"httpx<1",
"piexif>=1.1",
"typing-inspect==0.8.0",
"typing_extensions>=4.5.0",
"uvicorn>=0.23",
"blinker>=1.6.2",
"pydantic<3",
"beautifulsoup4>=4.12.2",
"python-dotenv>=1.0.0",
"structlog>=23.1.0",
# 1.7.11 breaks subprocess stuff ???
"runpod==1.7.10",
"google-genai>=1.20.0",
"nest_asyncio>=1.5.7",
"isodate>=0.6.1",
"thefuzz>=0.20.0",
"tiktoken>=0.5.1",
"nltk>=3.8.1",
"huggingface-hub>=0.20.2",
"RestrictedPython>7.1",
"numpy>=2",
"aiofiles>=24.1.0",
"pyyaml>=6.0",
"limits>=5.0",
"diff-match-patch>=20241021",
"sseclient-py>=1.8.0",
"ollama>=0.5.1",
# ChromaDB
"chromadb>=1.0.12",
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
"torch>=2.7.0",
"torchaudio>=2.7.0",
# locked for instructor embeddings
#sentence-transformers==2.2.2
"sentence_transformers>=2.7.0",
]
[tool.poetry.dependencies]
python = ">=3.10,<3.14"
astroid = "^2.8"
jedi = "^0.18"
black = "*"
rope = "^0.22"
isort = "^5.10"
jinja2 = ">=3.0"
openai = ">=1"
mistralai = ">=0.1.8"
cohere = ">=5.2.2"
anthropic = ">=0.19.1"
groq = ">=0.5.0"
requests = "^2.26"
colorama = ">=0.4.6"
Pillow = ">=9.5"
httpx = "<1"
piexif = "^1.1"
typing-inspect = "0.8.0"
typing_extensions = "^4.5.0"
uvicorn = "^0.23"
blinker = "^1.6.2"
pydantic = "<3"
beautifulsoup4 = "^4.12.2"
python-dotenv = "^1.0.0"
websockets = "^11.0.3"
structlog = "^23.1.0"
runpod = "^1.2.0"
google-cloud-aiplatform = ">=1.50.0"
nest_asyncio = "^1.5.7"
isodate = ">=0.6.1"
thefuzz = ">=0.20.0"
tiktoken = ">=0.5.1"
nltk = ">=3.8.1"
huggingface-hub = ">=0.20.2"
RestrictedPython = ">7.1"
numpy = "^2"
aiofiles = ">=24.1.0"
pyyaml = ">=6.0"
limits = ">=5.0"
diff-match-patch = ">=20241021"
sseclient-py = "^1.8.0"
[project.optional-dependencies]
dev = [
"pytest>=6.2",
"pytest-asyncio>=0.25.3",
"mypy>=0.910",
"mkdocs-material>=9.5.27",
"mkdocs-awesome-pages-plugin>=2.9.2",
"mkdocs-glightbox>=0.4.0",
]
# ChromaDB
chromadb = ">=0.4.17,<1"
InstructorEmbedding = "^1.0.1"
torch = "^2.7.0"
torchaudio = "^2.7.0"
# locked for instructor embeddings
#sentence-transformers="==2.2.2"
sentence_transformers=">=2.7.0"
[tool.poetry.dev-dependencies]
pytest = ">=6.2"
pytest-asyncio = ">=0.25.3"
mypy = "^0.910"
mkdocs-material = ">=9.5.27"
mkdocs-awesome-pages-plugin = ">=2.9.2"
mkdocs-glightbox = ">=0.4.0"
[tool.poetry.scripts]
[project.scripts]
talemate = "talemate:cli.main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.black]
line-length = 88
target-version = ['py38']
@@ -87,6 +92,7 @@ exclude = '''
| buck-out
| build
| dist
| talemate_env
)/
'''
@@ -97,4 +103,4 @@ include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
line_length = 88

View File

@@ -1,18 +0,0 @@
@echo off
IF EXIST talemate_env rmdir /s /q "talemate_env"
REM create a virtual environment
python -m venv talemate_env
REM activate the virtual environment
call talemate_env\Scripts\activate
REM install poetry
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies
python -m poetry install
echo Virtual environment re-created.
pause

View File

@@ -0,0 +1,65 @@
{
"packages": [
{
"name": "Dynamic Storyline",
"author": "Talemate",
"description": "Generate a random story premise at the beginning of the scene.",
"installable": true,
"registry": "package/talemate/DynamicStoryline",
"status": "installed",
"errors": [],
"package_properties": {
"topic": {
"module": "scene/dynamicStoryline",
"name": "topic",
"label": "Topic",
"description": "The overarching topic - will be used to generate a theme that falls within this category. Example - 'Sci-fi adventure with cosmic horror'.",
"type": "str",
"default": "",
"value": "Sci-fi episodic adventures onboard of a spaceship, with focus on AI, alien contact, ancient creators and cosmic horror.",
"required": true,
"choices": []
},
"intro_length": {
"module": "scene/dynamicStoryline",
"name": "intro_length",
"label": "Max. intro text length (tokens)",
"description": "Length of the introduction",
"type": "int",
"default": 512,
"value": 512,
"required": true,
"choices": []
},
"analysis_instructions": {
"module": "scene/dynamicStoryline",
"name": "analysis_instructions",
"label": "Additional instructions for topic analysis task",
"description": "Additional instructions for topic analysis task - if topic analysis is enabled, this will be used to augment the topic analysis task with further instructions.",
"type": "text",
"default": "",
"value": "",
"required": false,
"choices": []
},
"analysis_enabled": {
"module": "scene/dynamicStoryline",
"name": "analysis_enabled",
"label": "Enable topic analysis",
"description": "Theme analysis",
"type": "bool",
"default": true,
"value": true,
"required": false,
"choices": []
}
},
"install_nodes": [
"scene/dynamicStoryline"
],
"installed_nodes": [],
"restart_scene_loop": true,
"configured": true
}
]
}

View File

@@ -1,6 +1,6 @@
{
"title": "Scene Loop",
"id": "71652a76-5db3-4836-8f00-1085977cd8e8",
"id": "af468414-b30d-4f67-b08e-5b7cfd139adc",
"properties": {
"trigger_game_loop": true
},
@@ -11,50 +11,10 @@
"collapsed": false,
"inherited": false,
"registry": "scene/SceneLoop",
"nodes": {
"ede29db4-700d-4edc-b93b-bf7c79f6a6a5": {
"title": "Dynamic Storyline",
"id": "ede29db4-700d-4edc-b93b-bf7c79f6a6a5",
"properties": {
"event_name": "scene_loop_init",
"analysis_instructions": "",
"reset": false,
"topic": "sci-fi with cosmic horror elements",
"analysis_enabled": true,
"intro_length": 512
},
"x": 32,
"y": -249,
"width": 295,
"height": 158,
"collapsed": false,
"inherited": false,
"registry": "scene/dynamicStoryline",
"base_type": "core/Event"
}
},
"nodes": {},
"edges": {},
"groups": [
{
"title": "Randomize Story",
"x": 8,
"y": -321,
"width": 619,
"height": 257,
"color": "#a1309b",
"font_size": 24,
"inherited": false
}
],
"comments": [
{
"text": "Will generate a randomized story line based on the topic given",
"x": 352,
"y": -269,
"width": 215,
"inherited": false
}
],
"groups": [],
"comments": [],
"extends": "src/talemate/game/engine/nodes/modules/scene/scene-loop.json",
"sleep": 0.001,
"base_type": "scene/SceneLoop",

View File

@@ -19,6 +19,7 @@ from talemate.agents.context import ActiveAgent, active_agent
from talemate.emit import emit
from talemate.events import GameLoopStartEvent
from talemate.context import active_scene
import talemate.config as config
from talemate.client.context import (
ClientContext,
set_client_context_attribute,
@@ -438,6 +439,29 @@ class Agent(ABC):
except AttributeError:
pass
async def save_config(self, app_config: config.Config | None = None):
"""
Saves the agent config to the config file.
If no config object is provided, the config is loaded from the config file.
"""
if not app_config:
app_config:config.Config = config.load_config(as_model=True)
app_config.agents[self.agent_type] = config.Agent(
name=self.agent_type,
client=self.client.name if self.client else None,
enabled=self.enabled,
actions={action_key: config.AgentAction(
enabled=action.enabled,
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
) for action_key, action in self.actions.items()}
)
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
config.save_config(app_config)
async def on_game_loop_start(self, event: GameLoopStartEvent):
"""
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values

View File

@@ -232,11 +232,21 @@ class ConversationAgent(
@property
def generation_settings_actor_instructions_offset(self):
return self.actions["generation_override"].config["actor_instructions_offset"].value
@property
def generation_settings_response_length(self):
return self.actions["generation_override"].config["length"].value
@property
def generation_settings_override_enabled(self):
return self.actions["generation_override"].enabled
@property
def content_use_writing_style(self) -> bool:
return self.actions["content"].config["use_writing_style"].value
def connect(self, scene):
super().connect(scene)
@@ -322,6 +332,7 @@ class ConversationAgent(
"actor_instructions_offset": self.generation_settings_actor_instructions_offset,
"direct_instruction": instruction,
"decensor": self.client.decensor_enabled,
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
},
)

View File

@@ -3,6 +3,7 @@ import random
import uuid
from typing import TYPE_CHECKING, Tuple
import dataclasses
import traceback
import pydantic
import structlog
@@ -328,7 +329,14 @@ class AssistantMixin:
if not content.startswith(generation_context.character + ":"):
content = generation_context.character + ": " + content
content = util.strip_partial_sentences(content)
emission.response = await editor.cleanup_character_message(content, generation_context.character.name)
character = self.scene.get_character(generation_context.character)
if not character:
log.warning("Character not found", character=generation_context.character)
return content
emission.response = await editor.cleanup_character_message(content, character)
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
return emission.response
@@ -447,6 +455,7 @@ class AssistantMixin:
)
continuing_message = False
message = None
try:
message = self.scene.history[-1]
@@ -470,6 +479,7 @@ class AssistantMixin:
"can_coerce": self.client.can_be_coerced,
"response_length": response_length,
"continuing_message": continuing_message,
"message": message,
"anchor": anchor,
"non_anchor": non_anchor,
"prefix": prefix,
@@ -675,7 +685,7 @@ class AssistantMixin:
emit("status", f"Scene forked", status="success")
except Exception as e:
log.exception("Scene fork failed", exc=e)
log.error("Scene fork failed", exc=traceback.format_exc())
emit("status", "Scene fork failed", status="error")

View File

@@ -4,6 +4,7 @@ import random
from typing import TYPE_CHECKING, List
import structlog
import traceback
import talemate.emit.async_signals
import talemate.instance as instance
@@ -259,7 +260,7 @@ class DirectorAgent(
except Exception as e:
loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor)
log.exception("Error persisting character", error=e)
log.error("Error persisting character", error=traceback.format_exc())
async def log_action(self, action: str, action_description: str):
message = DirectorMessage(message=action_description, action=action)

View File

@@ -1,6 +1,7 @@
import pydantic
import asyncio
import structlog
import traceback
from typing import TYPE_CHECKING
from talemate.instance import get_agent
@@ -98,7 +99,7 @@ class DirectorWebsocketHandler(Plugin):
task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
async def handle_task_done(task):
if task.exception():
log.exception("Error persisting character", error=task.exception())
log.error("Error persisting character", error=task.exception())
await self.signal_operation_failed("Error persisting character")
else:
self.websocket_handler.queue_put(

View File

@@ -54,7 +54,7 @@ class EditorAgent(
type="text",
label="Formatting",
description="The formatting to use for exposition.",
value="chat",
value="novel",
choices=[
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
{"label": "Novel: \"Speech\" narration", "value": "novel"},

View File

@@ -27,6 +27,7 @@ from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission
from talemate.agents.creator.assistant import ContextualGenerateEmission
from talemate.agents.summarize import SummarizeEmission
from talemate.agents.summarize.layered_history import LayeredHistoryFinalizeEmission
from talemate.scene_message import CharacterMessage
from talemate.util.dedupe import (
dedupe_sentences,
@@ -387,13 +388,16 @@ class RevisionMixin:
async_signals.get("agent.summarization.summarize.after").connect(
self.revision_on_generation
)
async_signals.get("agent.summarization.layered_history.finalize").connect(
self.revision_on_generation
)
# connect to the super class AFTER so these run first.
super().connect(scene)
async def revision_on_generation(
self,
emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission,
emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission | LayeredHistoryFinalizeEmission,
):
"""
Called when a conversation or narrator message is generated
@@ -411,7 +415,15 @@ class RevisionMixin:
if isinstance(emission, NarratorAgentEmission) and "narrator" not in self.revision_automatic_targets:
return
if isinstance(emission, SummarizeEmission) and "summarization" not in self.revision_automatic_targets:
if isinstance(emission, SummarizeEmission):
if emission.summarization_type == "dialogue" and "summarization" not in self.revision_automatic_targets:
return
if emission.summarization_type == "events":
# event summarization is very pragmatic and doesn't really benefit
# from revision, so we skip it
return
if isinstance(emission, LayeredHistoryFinalizeEmission) and "summarization" not in self.revision_automatic_targets:
return
try:
@@ -428,7 +440,7 @@ class RevisionMixin:
context_name = getattr(emission, "context_name", None),
)
if isinstance(emission, SummarizeEmission):
if isinstance(emission, (SummarizeEmission, LayeredHistoryFinalizeEmission)):
info.summarization_history = emission.summarization_history or []
if isinstance(emission, ContextualGenerateEmission) and info.context_type not in CONTEXTUAL_GENERATION_TYPES:
@@ -489,7 +501,8 @@ class RevisionMixin:
log.warning("revision_revise: generation cancelled", text=info.text)
return info.text
except Exception as e:
log.exception("revision_revise: error", error=e)
import traceback
log.error("revision_revise: error", error=traceback.format_exc())
return info.text
finally:
info.loading_status.done()
@@ -871,8 +884,14 @@ class RevisionMixin:
if loading_status:
loading_status("Editor - Issues identified, analyzing text...")
template_vars = {
emission = RevisionEmission(
agent=self,
info=info,
issues=issues,
)
emission.template_vars = {
"text": text,
"character": character,
"scene": self.scene,
@@ -880,14 +899,11 @@ class RevisionMixin:
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
}
emission = RevisionEmission(
agent=self,
template_vars=template_vars,
info=info,
issues=issues,
)
await async_signals.get("agent.editor.revision-revise.before").send(
emission
@@ -898,18 +914,7 @@ class RevisionMixin:
"editor.revision-analysis",
self.client,
f"edit_768",
vars={
"text": text,
"character": character,
"scene": self.scene,
"response_length": token_count,
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
},
vars=emission.template_vars,
dedupe_enabled=False,
)
@@ -1016,39 +1021,43 @@ class RevisionMixin:
log.debug("revision_unslop: issues", issues=issues, template=template)
emission = RevisionEmission(
agent=self,
info=info,
issues=issues,
)
emission.template_vars = {
"text": text,
"scene_analysis": scene_analysis,
"character": character,
"scene": self.scene,
"response_length": response_length,
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
"summarization_history": info.summarization_history,
}
await async_signals.get("agent.editor.revision-revise.before").send(emission)
response = await Prompt.request(
template,
self.client,
"edit_768",
vars={
"text": text,
"scene_analysis": scene_analysis,
"character": character,
"scene": self.scene,
"response_length": response_length,
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
"summarization_history": info.summarization_history,
},
vars=emission.template_vars,
dedupe_enabled=False,
)
# extract <FIX>...</FIX>
if "<FIX>" not in response:
log.error("revision_unslop: no <FIX> found in response", response=response)
log.debug("revision_unslop: no <FIX> found in response", response=response)
return original_text
fix = response.split("<FIX>", 1)[1]

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import asyncio
import functools
import hashlib
import uuid
import traceback
import numpy as np
from typing import Callable
@@ -12,6 +14,8 @@ from chromadb.config import Settings
import talemate.events as events
import talemate.util as util
from talemate.client import ClientBase
import talemate.instance as instance
from talemate.agents.base import (
Agent,
AgentAction,
@@ -23,6 +27,7 @@ from talemate.config import load_config
from talemate.context import scene_is_loading, active_scene
from talemate.emit import emit
from talemate.emit.signals import handlers
import talemate.emit.async_signals as async_signals
from talemate.agents.memory.context import memory_request, MemoryRequest
from talemate.agents.memory.exceptions import (
EmbeddingsModelLoadError,
@@ -31,19 +36,23 @@ from talemate.agents.memory.exceptions import (
try:
import chromadb
import chromadb.errors
from chromadb.utils import embedding_functions
except ImportError:
chromadb = None
pass
from talemate.agents.registry import register
if TYPE_CHECKING:
from talemate.client.base import ClientEmbeddingsStatus
log = structlog.get_logger("talemate.agents.memory")
if not chromadb:
log.info("ChromaDB not found, disabling Chroma agent")
from talemate.agents.registry import register
class MemoryDocument(str):
def __new__(cls, text, meta, id, raw):
inst = super().__new__(cls, text)
@@ -105,8 +114,9 @@ class MemoryAgent(Agent):
self.memory_tracker = {}
self.config = load_config()
self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved)
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
self.actions = MemoryAgent.init_actions(presets=self.get_presets)
@@ -125,8 +135,16 @@ class MemoryAgent(Agent):
@property
def get_presets(self):
def _label(embedding:dict):
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
if embedding['model']:
return f"{prefix}: {embedding['model']}"
else:
return f"{prefix}"
return [
{"value": k, "label": f"{v['embeddings']}: {v['model']}"} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
{"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
]
@property
@@ -150,6 +168,10 @@ class MemoryAgent(Agent):
def using_sentence_transformer_embeddings(self):
return self.embeddings == "default" or self.embeddings == "sentence-transformer"
@property
def using_client_api_embeddings(self):
return self.embeddings == "client-api"
@property
def using_local_embeddings(self):
return self.embeddings in [
@@ -158,6 +180,11 @@ class MemoryAgent(Agent):
"default"
]
@property
def embeddings_client(self):
return self.embeddings_config.get("client")
@property
def max_distance(self) -> float:
distance = float(self.embeddings_config.get("distance", 1.0))
@@ -186,7 +213,10 @@ class MemoryAgent(Agent):
"""
Returns a unique fingerprint for the current configuration
"""
return f"{self.embeddings}-{self.model.replace('/','-')}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
model_name = self.model.replace('/','-') if self.model else "none"
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
async def apply_config(self, *args, **kwargs):
@@ -205,7 +235,11 @@ class MemoryAgent(Agent):
@set_processing
async def handle_embeddings_change(self):
scene = active_scene.get()
# if sentence-transformer and no model-name, set embeddings to default
if self.using_sentence_transformer_embeddings and not self.model:
self.actions["_config"].config["embeddings"].value = "default"
if not scene or not scene.get_helper("memory"):
return
@@ -216,21 +250,49 @@ class MemoryAgent(Agent):
await scene.save(auto=True)
emit("status", "Context database re-imported", status="success")
def sync_presets(self) -> list[dict]:
self.actions["_config"].config["embeddings"].choices = self.get_presets
return self.actions["_config"].config["embeddings"].choices
def on_config_saved(self, event):
loop = asyncio.get_running_loop()
openai_key = self.openai_api_key
fingerprint = self.fingerprint
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
self.config = load_config()
new_presets = self.sync_presets()
if fingerprint != self.fingerprint:
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
loop.run_until_complete(self.handle_embeddings_change())
emit_status = False
if openai_key != self.openai_api_key:
emit_status = True
if old_presets != new_presets:
emit_status = True
if emit_status:
loop.run_until_complete(self.emit_status())
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
current_embeddings = self.actions["_config"].config["embeddings"].value
if current_embeddings == event.client.embeddings_identifier:
return
if not self.using_client_api_embeddings or not self.ready:
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
@set_processing
async def set_db(self):
loop = asyncio.get_running_loop()
@@ -239,7 +301,7 @@ class MemoryAgent(Agent):
except EmbeddingsModelLoadError:
raise
except Exception as e:
log.error("memory agent", error="failed to set db", details=e)
log.error("memory agent", error="failed to set db", details=traceback.format_exc())
if "torchvision::nms does not exist" in str(e):
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
@@ -379,14 +441,12 @@ class MemoryAgent(Agent):
def _get_document(self, id):
raise NotImplementedError()
def on_archive_add(self, event: events.ArchiveEvent):
asyncio.ensure_future(
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
)
async def on_archive_add(self, event: events.ArchiveEvent):
await self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
def connect(self, scene):
super().connect(scene)
scene.signals["archive_add"].connect(self.on_archive_add)
async_signals.get("archive_add").connect(self.on_archive_add)
async def memory_context(
self,
@@ -453,29 +513,72 @@ class MemoryAgent(Agent):
Get the character memory context for a given character
"""
memory_context = []
# First, collect results for each individual query (respecting the
# per-query `iterate` limit) so that we have them available before we
# start filling the final context. This prevents early queries from
# monopolising the token budget.
per_query_results: list[list[str]] = []
for query in queries:
# Skip empty queries so that we keep indexing consistent for the
# round-robin step that follows.
if not query:
per_query_results.append([])
continue
i = 0
for memory in await self.get(formatter(query), limit=limit, **where):
if memory in memory_context:
continue
# Fetch potential memories for this query.
raw_results = await self.get(
formatter(query), limit=limit, **where
)
# Apply filter and respect the `iterate` limit for this query.
accepted: list[str] = []
for memory in raw_results:
if filter and not filter(memory):
continue
accepted.append(memory)
if len(accepted) >= iterate:
break
per_query_results.append(accepted)
# Now interleave the results in a round-robin fashion so that each
# query gets a fair chance to contribute, until we hit the token
# budget.
memory_context: list[str] = []
idx = 0
while True:
added_any = False
for result_list in per_query_results:
if idx >= len(result_list):
# No more items remaining for this query at this depth.
continue
memory = result_list[idx]
# Avoid duplicates in the final context.
if memory in memory_context:
continue
memory_context.append(memory)
added_any = True
i += 1
if i >= iterate:
break
# Check token budget after each addition.
if util.count_tokens(memory_context) >= max_tokens:
break
if util.count_tokens(memory_context) >= max_tokens:
return memory_context
if not added_any:
# We iterated over all query result lists without adding
# anything. That means we have exhausted all available
# memories.
break
idx += 1
return memory_context
@property
@@ -587,9 +690,32 @@ class ChromaDBMemoryAgent(MemoryAgent):
if getattr(self, "db_client", None):
return True
return False
@property
def client_api_ready(self) -> bool:
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
if not embeddings_client:
return False
if not embeddings_client.supports_embeddings:
return False
if not embeddings_client.embeddings_status:
return False
if embeddings_client.current_status not in ["idle", "busy"]:
return False
return True
return False
@property
def status(self):
if self.using_client_api_embeddings and not self.client_api_ready:
return "error"
if self.ready:
return "active" if not getattr(self, "processing", False) else "busy"
@@ -612,12 +738,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
value=self.embeddings,
description="The embeddings type.",
).model_dump(),
"model": AgentDetail(
}
if self.model:
details["model"] = AgentDetail(
icon="mdi-brain",
value=self.model,
description="The embeddings model.",
).model_dump(),
}
).model_dump()
if self.embeddings_client:
details["client"] = AgentDetail(
icon="mdi-network-outline",
value=self.embeddings_client,
description="The client to use for embeddings.",
).model_dump()
if self.using_local_embeddings:
details["device"] = AgentDetail(
@@ -634,6 +770,37 @@ class ChromaDBMemoryAgent(MemoryAgent):
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
"color": "error",
}
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
if not embeddings_client:
details["error"] = {
"icon": "mdi-alert",
"value": f"Client {self.embeddings_client} not found",
"description": f"Client {self.embeddings_client} not found",
"color": "error",
}
return details
client_name = embeddings_client.name
if not embeddings_client.supports_embeddings:
error_message = f"{client_name} does not support embeddings"
elif embeddings_client.current_status not in ["idle", "busy"]:
error_message = f"{client_name} is not ready"
elif not embeddings_client.embeddings_status:
error_message = f"{client_name} has no embeddings model loaded"
else:
error_message = None
if error_message:
details["error"] = {
"icon": "mdi-alert",
"value": error_message,
"description": error_message,
"color": "error",
}
return details
@@ -686,7 +853,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
self.collection_name = collection_name = self.make_collection_name(self.scene)
log.info(
"chromadb agent", status="setting up db", collection_name=collection_name
"chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
)
distance_function = self.distance_function
@@ -713,6 +880,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=openai_ef, metadata=collection_metadata
)
elif self.using_client_api_embeddings:
log.info(
"chromadb",
embeddings="Client API",
client=self.embeddings_client,
)
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
if not embeddings_client:
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
if not embeddings_client.supports_embeddings:
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
ef = embeddings_client.embeddings_function
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef, metadata=collection_metadata
)
elif self.using_instructor_embeddings:
log.info(
"chromadb",
@@ -722,7 +909,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
)
ef = embedding_functions.InstructorEmbeddingFunction(
model_name=model_name, device=device
model_name=model_name, device=device, instruction="Represent the document for retrieval:"
)
log.info("chromadb", status="embedding function ready")
@@ -801,6 +988,10 @@ class ChromaDBMemoryAgent(MemoryAgent):
)
try:
self.db_client.delete_collection(collection_name)
except chromadb.errors.NotFoundError as exc:
log.error(
"chromadb agent", error="collection not found", details=exc
)
except ValueError as exc:
log.error(
"chromadb agent", error="failed to delete collection", details=exc

View File

@@ -510,53 +510,6 @@ class NarratorAgent(
return response
@set_processing
async def augment_context(self):
"""
Takes a context history generated via scene.context_history() and augments it with additional information
by asking and answering questions with help from the long term memory.
"""
memory = self.scene.get_helper("memory").agent
questions = await Prompt.request(
"narrator.context-questions",
self.client,
"narrate",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
},
)
log.debug("context_questions", questions=questions)
questions = [q for q in questions.split("\n") if q.strip()]
memory_context = await memory.multi_query(
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
)
answers = await Prompt.request(
"narrator.context-answers",
self.client,
"narrate",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"memory": memory_context,
"questions": questions,
"extra_instructions": self.extra_instructions,
},
)
log.debug("context_answers", answers=answers)
answers = [a for a in answers.split("\n") if a.strip()]
# return questions and answers
return list(zip(questions, answers))
@set_processing
@store_context_state('narrative_direction', time_narration=True)
async def narrate_time_passage(

View File

@@ -4,8 +4,7 @@ import re
import dataclasses
import structlog
from typing import TYPE_CHECKING
import talemate.data_objects as data_objects
from typing import TYPE_CHECKING, Literal
import talemate.emit.async_signals
import talemate.util as util
from talemate.emit import emit
@@ -35,6 +34,8 @@ from talemate.agents.base import (
from talemate.agents.registry import register
from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.history import ArchiveEntry
from .analyze_scene import SceneAnalyzationMixin
from .context_investigation import ContextInvestigationMixin
from .layered_history import LayeredHistoryMixin
@@ -63,6 +64,7 @@ class SummarizeEmission(AgentTemplateEmission):
extra_instructions: str | None = None
generation_options: GenerationOptions | None = None
summarization_history: list[str] | None = None
summarization_type: Literal["dialogue", "events"] = "dialogue"
@register()
class SummarizeAgent(
@@ -189,6 +191,34 @@ class SummarizeAgent(
return emission.sub_instruction
# SUMMARIZATION HELPERS
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
num_previous = self.archive_include_previous
# find entry by .id
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
if entry_index is None:
raise ValueError("Entry not found")
end = entry_index - 1
previous_summaries = []
if entry and num_previous > 0:
if self.layered_history_available:
previous_summaries = self.compile_layered_history(
include_base_layer=True,
base_layer_end_id=entry.id
)[-num_previous:]
else:
previous_summaries = [
entry.text for entry in self.scene.archived_history[end-num_previous:end]
]
return previous_summaries
# SUMMARIZE
@set_processing
@@ -352,7 +382,7 @@ class SummarizeAgent(
# determine the appropariate timestamp for the summarization
scene.push_archive(data_objects.ArchiveEntry(summarized, start, end, ts=ts))
await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
scene.ts=ts
scene.emit_status()
@@ -478,7 +508,8 @@ class SummarizeAgent(
extra_instructions=extra_instructions,
generation_options=generation_options,
template_vars=template_vars,
summarization_history=extra_context or []
summarization_history=extra_context or [],
summarization_type="dialogue",
)
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
@@ -562,7 +593,8 @@ class SummarizeAgent(
extra_instructions=extra_instructions,
generation_options=generation_options,
template_vars=template_vars,
summarization_history=[extra_context] if extra_context else []
summarization_history=[extra_context] if extra_context else [],
summarization_type="events",
)
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)

View File

@@ -1,17 +1,18 @@
import structlog
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from talemate.agents.base import (
set_processing,
AgentAction,
AgentActionConfig
AgentActionConfig,
AgentEmission,
)
from talemate.prompts import Prompt
import dataclasses
import talemate.emit.async_signals
from talemate.exceptions import GenerationCancelled
from talemate.world_state.templates import GenerationOptions
from talemate.emit import emit
from talemate.context import handle_generation_cancelled
from talemate.history import LayeredArchiveEntry, HistoryEntry, entry_contained
import talemate.util as util
if TYPE_CHECKING:
@@ -19,6 +20,24 @@ if TYPE_CHECKING:
log = structlog.get_logger()
talemate.emit.async_signals.register(
"agent.summarization.layered_history.finalize",
)
@dataclasses.dataclass
class LayeredHistoryFinalizeEmission(AgentEmission):
entry: LayeredArchiveEntry | None = None
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
@property
def response(self) -> str | None:
return self.entry.text if self.entry else None
@response.setter
def response(self, value: str):
if self.entry:
self.entry.text = value
class SummaryLongerThanOriginalError(ValueError):
def __init__(self, original_length:int, summarized_length:int):
self.original_length = original_length
@@ -155,7 +174,102 @@ class LayeredHistoryMixin:
await self.summarize_to_layered_history(
generation_options=emission.generation_options
)
# helpers
async def _lh_split_and_summarize_chunks(
self,
chunks: list[dict],
extra_context: str,
generation_options: GenerationOptions | None = None,
) -> list[str]:
"""
Split chunks based on max_process_tokens and summarize each part.
Returns a list of summary texts.
"""
summaries = []
current_chunk = chunks.copy()
while current_chunk:
partial_chunk = []
max_process_tokens = self.layered_history_max_process_tokens
# Build partial chunk up to max_process_tokens
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
partial_chunk.append(current_chunk.pop(0))
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
log.debug("_split_and_summarize_chunks",
tokens_in_chunk=util.count_tokens(text_to_summarize),
max_process_tokens=max_process_tokens)
summary_text = await self.summarize_events(
text_to_summarize,
extra_context=extra_context + "\n\n".join(summaries),
generation_options=generation_options,
response_length=self.layered_history_response_length,
analyze_chunks=self.layered_history_analyze_chunks,
chunk_size=self.layered_history_chunk_size,
)
summaries.append(summary_text)
return summaries
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
"""
Validates that the summarized text is not longer than the original.
Raises SummaryLongerThanOriginalError if validation fails.
"""
summarized_length = util.count_tokens(summaries)
if summarized_length > original_length:
raise SummaryLongerThanOriginalError(original_length, summarized_length)
log.debug("_validate_summary_length",
original_length=original_length,
summarized_length=summarized_length)
def _lh_build_extra_context(self, layer_index: int) -> str:
"""
Builds extra context from compiled layered history for the given layer.
"""
return "\n\n".join(self.compile_layered_history(layer_index))
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
"""
Extracts timestamps from a chunk of entries.
Returns (ts, ts_start, ts_end)
"""
if not chunk:
return "PT1S", "PT1S", "PT1S"
ts = chunk[0].get('ts', 'PT1S')
ts_start = chunk[0].get('ts_start', ts)
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
return ts, ts_start, ts_end
async def _lh_finalize_archive_entry(
self,
entry: LayeredArchiveEntry,
summarization_history: list[str] | None = None,
) -> LayeredArchiveEntry:
"""
Finalizes an archive entry by summarizing it and adding it to the layered history.
"""
emission = LayeredHistoryFinalizeEmission(
agent=self,
entry=entry,
summarization_history=summarization_history,
)
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
return emission.entry
# methods
def compile_layered_history(
@@ -164,6 +278,7 @@ class LayeredHistoryMixin:
as_objects:bool=False,
include_base_layer:bool=False,
max:int = None,
base_layer_end_id: str | None = None,
) -> list[str]:
"""
Starts at the last layer and compiles the layered history into a single
@@ -194,6 +309,17 @@ class LayeredHistoryMixin:
entry_num = 1
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
if base_layer_end_id:
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
index=0,
layer=i+1,
**layered_history_entry)
)
if contained:
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
break
text = f"{layered_history_entry['text']}"
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
@@ -212,8 +338,8 @@ class LayeredHistoryMixin:
entry_num += 1
else:
compiled.append(text)
next_layer_start = layered_history_entry["end"] + 1
next_layer_start = layered_history_entry["end"] + 1
if i == 0 and include_base_layer:
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested
@@ -222,7 +348,10 @@ class LayeredHistoryMixin:
entry_num = 1
for ah in self.scene.archived_history[next_layer_start:]:
for ah in self.scene.archived_history[next_layer_start or 0:]:
if base_layer_end_id and ah["id"] == base_layer_end_id:
break
text = f"{ah['text']}"
if as_objects:
@@ -291,8 +420,6 @@ class LayeredHistoryMixin:
return # No base layer summaries to work with
token_threshold = self.layered_history_threshold
method = self.actions["archive"].config["method"].value
max_process_tokens = self.layered_history_max_process_tokens
max_layers = self.layered_history_max_layers
if not hasattr(self.scene, 'layered_history'):
@@ -329,15 +456,9 @@ class LayeredHistoryMixin:
log.debug("summarize_to_layered_history", created_layer=next_layer_index)
next_layer = layered_history[next_layer_index]
ts = current_chunk[0]['ts']
ts_start = current_chunk[0]['ts_start'] if 'ts_start' in current_chunk[0] else ts
ts_end = current_chunk[-1]['ts_end'] if 'ts_end' in current_chunk[-1] else ts
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
summaries = []
extra_context = "\n\n".join(
self.compile_layered_history(next_layer_index)
)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
@@ -345,44 +466,24 @@ class LayeredHistoryMixin:
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
while current_chunk:
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
noop = False
log.debug("summarize_to_layered_history", tokens_in_chunk=util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk)), max_process_tokens=max_process_tokens)
partial_chunk = []
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
partial_chunk.append(current_chunk.pop(0))
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
summary_text = await self.summarize_events(
text_to_summarize,
extra_context=extra_context + "\n\n".join(summaries),
generation_options=generation_options,
response_length=self.layered_history_response_length,
analyze_chunks=self.layered_history_analyze_chunks,
chunk_size=self.layered_history_chunk_size,
)
noop = False
summaries.append(summary_text)
# if summarized text is longer than the original, we will
# raise an error
if util.count_tokens(summaries) > text_length:
raise SummaryLongerThanOriginalError(text_length, util.count_tokens(summaries))
log.debug("summarize_to_layered_history", original_length=text_length, summarized_length=util.count_tokens(summaries))
next_layer.append({
next_layer.append(LayeredArchiveEntry(**{
"start": start_index,
"end": i - 1,
"end": i,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries)
})
"text": "\n\n".join(summaries),
}).model_dump(exclude_none=True))
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
@@ -412,7 +513,7 @@ class LayeredHistoryMixin:
last_entry = layered_history[0][-1]
end = last_entry["end"]
log.debug("summarize_to_layered_history", layer="base", start=end)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end + 1)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
else:
log.debug("summarize_to_layered_history", layer="base", empty=True)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
@@ -445,7 +546,7 @@ class LayeredHistoryMixin:
end = next_layer[-1]["end"] if next_layer else 0
log.debug("summarize_to_layered_history", layer=index, start=end)
summarized = await summarize_layer(layered_history[index], index + 1, end + 1 if end else 0)
summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
if summarized:
noop = False
@@ -466,4 +567,107 @@ class LayeredHistoryMixin:
log.info("Generation cancelled, stopping rebuild of historical layered history")
emit("status", message="Rebuilding of layered history cancelled", status="info")
handle_generation_cancelled(e)
return
return
async def summarize_entries_to_layered_history(
self,
entries: list[dict],
next_layer_index: int,
start_index: int,
end_index: int,
generation_options: GenerationOptions | None = None,
) -> list[LayeredArchiveEntry]:
"""
Summarizes a list of entries into layered history entries.
This method is used for regenerating specific history entries by processing
their source entries. It chunks the entries based on the token threshold and
summarizes each chunk into a LayeredArchiveEntry.
Args:
entries: List of dictionaries containing the text entries to summarize.
Each entry should have at least a 'text' field and optionally
'ts', 'ts_start', and 'ts_end' fields.
next_layer_index: The index of the layer where the summarized entries
will be placed.
start_index: The starting index in the source layer that these entries
correspond to.
end_index: The ending index in the source layer that these entries
correspond to.
generation_options: Optional generation options to pass to the summarization
process.
Returns:
List of LayeredArchiveEntry objects containing the summarized text along
with timestamp and index information. Currently returns a list with a
single entry, but the structure supports multiple entries if needed.
Notes:
- The method respects the layered_history_threshold for chunking
- Uses helper methods for timestamp extraction, context building, and
chunk summarization
- Validates that summaries are not longer than the original text
- The last entry is always included in the final chunk if it doesn't
exceed the token threshold
"""
token_threshold = self.layered_history_threshold
archive_entries = []
summaries = []
current_chunk = []
current_tokens = 0
ts = "PT1S"
ts_start = "PT1S"
ts_end = "PT1S"
for entry_index, entry in enumerate(entries):
is_last_entry = entry_index == len(entries) - 1
entry_tokens = util.count_tokens(entry['text'])
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
if current_tokens + entry_tokens > token_threshold or is_last_entry:
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
# if we are here because this is the last entry and adding it to
# the current chunk would not exceed the token threshold, we will
# add it to the current chunk
current_chunk.append(entry)
if current_chunk:
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
archive_entry = LayeredArchiveEntry(**{
"start": start_index,
"end": end_index,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
})
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
archive_entries.append(archive_entry)
current_chunk.append(entry)
current_tokens += entry_tokens
return archive_entries

View File

@@ -23,7 +23,7 @@ from talemate.emit.signals import handlers as signal_handlers
from talemate.prompts.base import Prompt
from .commands import * # noqa
from .context import VIS_TYPES, VisualContext, visual_context
from .context import VIS_TYPES, VisualContext, VisualContextState, visual_context
from .handlers import HANDLERS
from .schema import RESOLUTION_MAP, RenderSettings
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
@@ -40,6 +40,14 @@ BACKENDS = [
for mixin_backend, mixin in HANDLERS.items()
]
PROMPT_OUTPUT_FORMAT = """
### Positive
{positive_prompt}
### Negative
{negative_prompt}
"""
log = structlog.get_logger("talemate.agents.visual")
@@ -284,7 +292,7 @@ class VisualBase(Agent):
try:
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
except KeyError:
except (KeyError, TypeError):
backend = self.backend
backend_changed = backend != self.backend
@@ -425,10 +433,9 @@ class VisualBase(Agent):
self, format: str = "portrait", prompt: str = None, automatic: bool = False
):
context = visual_context.get()
if not self.enabled:
return
context:VisualContextState = visual_context.get()
log.debug("visual generate", context=context)
if automatic and not self.allow_automatic_generation:
return
@@ -459,7 +466,7 @@ class VisualBase(Agent):
thematic_style = self.default_style
vis_type_styles = self.vis_type_styles(context.vis_type)
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
prompt:Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
if context.vis_type == VIS_TYPES.CHARACTER:
prompt.keywords.append("character portrait")
@@ -481,7 +488,34 @@ class VisualBase(Agent):
format = "portrait"
context.format = format
can_generate_image = self.enabled and self.backend_ready
if not context.prompt_only and not can_generate_image:
emit("status", "Visual agent is not ready for image generation, will output prompt instead.", status="warning")
# if prompt_only, we don't need to generate an image
# instead we emit a system message with the prompt
if context.prompt_only or not can_generate_image:
emit(
"system",
message=PROMPT_OUTPUT_FORMAT.format(
positive_prompt=prompt.positive_prompt,
negative_prompt=prompt.negative_prompt,
),
meta={
"icon": "mdi-image-text",
"color": "highlight7",
"title": f"Visual Prompt - {context.title}",
"display": "tonal",
"as_markdown": True,
}
)
return
if not can_generate_image:
return
# Call the backend specific generate function
backend = self.backend
@@ -541,8 +575,16 @@ class VisualBase(Agent):
return response.strip()
async def generate_environment_background(self, instructions: str = None):
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions):
async def generate_environment_background(
self,
instructions: str = None,
prompt_only: bool = False,
):
with VisualContext(
vis_type=VIS_TYPES.ENVIRONMENT,
instructions=instructions,
prompt_only=prompt_only,
):
await self.generate(format="landscape")
async def generate_character_portrait(
@@ -550,12 +592,14 @@ class VisualBase(Agent):
character_name: str,
instructions: str = None,
replace: bool = False,
prompt_only: bool = False,
):
with VisualContext(
vis_type=VIS_TYPES.CHARACTER,
character_name=character_name,
instructions=instructions,
replace=replace,
prompt_only=prompt_only,
):
await self.generate(format="portrait")

View File

@@ -29,6 +29,15 @@ class VisualContextState(pydantic.BaseModel):
prepared_prompt: Union[str, None] = None
format: Union[str, None] = None
replace: bool = False
prompt_only: bool = False
@property
def title(self) -> str:
if self.vis_type == VIS_TYPES.ENVIRONMENT:
return "Environment"
elif self.vis_type == VIS_TYPES.CHARACTER:
return f"Character: {self.character_name}"
return "Visual Context"
class VisualContext:

View File

@@ -90,12 +90,16 @@ class VisualWebsocketHandler(Plugin):
payload = GeneratePayload(**data)
visual = get_agent("visual")
await visual.generate_character_portrait(
payload.context.character_name, payload.context.instructions, replace=True
payload.context.character_name,
payload.context.instructions,
replace=True,
prompt_only=payload.context.prompt_only,
)
async def handle_visualize_environment(self, data: dict):
payload = GeneratePayload(**data)
visual = get_agent("visual")
await visual.generate_environment_background(
instructions=payload.context.instructions
instructions=payload.context.instructions,
prompt_only=payload.context.prompt_only,
)

View File

@@ -18,6 +18,7 @@ from talemate.scene_message import (
ReinforcementMessage,
TimePassageMessage,
)
from talemate.util.response import extract_list
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
@@ -76,6 +77,12 @@ class WorldStateAgent(
label="Update world state",
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
config={
"initial": AgentActionConfig(
type="bool",
label="When a new scene is started",
description="Whether to update the world state on scene start.",
value=True,
),
"turns": AgentActionConfig(
type="number",
label="Turns",
@@ -133,10 +140,15 @@ class WorldStateAgent(
@property
def experimental(self):
return True
@property
def initial_update(self):
return self.actions["update_world_state"].config["initial"].value
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
talemate.emit.async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init_after)
async def advance_time(self, duration: str, narrative: str = None):
"""
@@ -162,6 +174,22 @@ class WorldStateAgent(
)
)
async def on_scene_loop_init_after(self, emission):
"""
Called when a scene is initialized
"""
if not self.enabled:
return
if not self.initial_update:
return
if self.get_scene_state("inital_update_done"):
return
await self.scene.world_state.request_update()
self.set_scene_states(inital_update_done=True)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called when a conversation is generated
@@ -305,7 +333,7 @@ class WorldStateAgent(
},
)
queries = response.split("\n")
queries = extract_list(response)
memory_agent = get_agent("memory")

View File

@@ -10,7 +10,9 @@ from talemate.client.groq import GroqClient
from talemate.client.koboldcpp import KoboldCppClient
from talemate.client.lmstudio import LMStudioClient
from talemate.client.mistral import MistralAIClient
from talemate.client.ollama import OllamaClient
from talemate.client.openai import OpenAIClient
from talemate.client.openrouter import OpenRouterClient
from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.tabbyapi import TabbyAPIClient

View File

@@ -2,8 +2,14 @@ import pydantic
import structlog
from anthropic import AsyncAnthropic, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults, ExtraField
from talemate.client.registry import register
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
@@ -28,13 +34,17 @@ SUPPORTED_MODELS = [
]
class Defaults(CommonDefaults, pydantic.BaseModel):
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "claude-3-5-sonnet-latest"
double_coercion: str = None
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
class AnthropicClient(ClientBase):
class AnthropicClient(EndpointOverrideMixin, ClientBase):
"""
Anthropic client for generating text.
"""
@@ -44,6 +54,7 @@ class AnthropicClient(ClientBase):
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Anthropic"
@@ -52,15 +63,21 @@ class AnthropicClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="claude-3-5-sonnet-latest", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
@property
def anthropic_api_key(self):
return self.config.get("anthropic", {}).get("api_key")
@@ -103,6 +120,7 @@ class AnthropicClient(ClientBase):
data={
"error_action": error_action.model_dump() if error_action else None,
"double_coercion": self.double_coercion,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
@@ -117,7 +135,7 @@ class AnthropicClient(ClientBase):
)
def set_client(self, max_token_length: int = None):
if not self.anthropic_api_key:
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncAnthropic(api_key="sk-1111")
log.error("No anthropic API key set")
if self.api_key_status:
@@ -134,7 +152,7 @@ class AnthropicClient(ClientBase):
model = self.model_name
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
self.client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
@@ -158,7 +176,11 @@ class AnthropicClient(ClientBase):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event):
config = event.data
@@ -175,13 +197,10 @@ class AnthropicClient(ClientBase):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
"""
Anthropic handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
@@ -189,19 +208,19 @@ class AnthropicClient(ClientBase):
Generates text from the given prompt and parameters.
"""
if not self.anthropic_api_key:
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No anthropic API key set")
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
human_message = {"role": "user", "content": prompt.strip()}
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
system_message = self.get_system_message(kind)
messages = [
{"role": "user", "content": prompt.strip()}
]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
self.log.debug(
"generate",
@@ -209,28 +228,39 @@ class AnthropicClient(ClientBase):
parameters=parameters,
system_message=system_message,
)
completion_tokens = 0
prompt_tokens = 0
try:
response = await self.client.messages.create(
stream = await self.client.messages.create(
model=self.model_name,
system=system_message,
messages=[human_message],
messages=messages,
stream=True,
**parameters,
)
response = ""
async for event in stream:
if event.type == "content_block_delta":
content = event.delta.text
response += content
self.update_request_tokens(self.count_tokens(content))
elif event.type == "message_start":
prompt_tokens = event.message.usage.input_tokens
elif event.type == "message_delta":
completion_tokens += event.usage.output_tokens
self._returned_prompt_tokens = self.prompt_tokens(response)
self._returned_response_tokens = self.response_tokens(response)
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
log.debug("generated response", response=response.content)
response = response.content[0].text
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
log.debug("generated response", response=response)
return response
except PermissionDeniedError as e:

View File

@@ -6,10 +6,12 @@ import ipaddress
import logging
import random
import time
import traceback
import asyncio
from typing import Callable, Union, Literal
import pydantic
import dataclasses
import structlog
import urllib3
from openai import AsyncOpenAI, PermissionDeniedError
@@ -23,7 +25,10 @@ from talemate.client.model_prompts import model_prompt
from talemate.client.ratelimit import CounterRateLimiter
from talemate.context import active_scene
from talemate.emit import emit
from talemate.config import load_config, save_config, EmbeddingFunctionPreset
import talemate.emit.async_signals as async_signals
from talemate.exceptions import SceneInactiveError, GenerationCancelled
import talemate.ux.schema as ux_schema
from talemate.client.system_prompts import SystemPrompts
@@ -77,13 +82,20 @@ class Defaults(CommonDefaults, pydantic.BaseModel):
double_coercion: str = None
class FieldGroup(pydantic.BaseModel):
name: str
label: str
description: str
icon: str = "mdi-cog"
class ExtraField(pydantic.BaseModel):
name: str
type: str
label: str
required: bool
description: str
group: FieldGroup | None = None
note: ux_schema.Note | None = None
class ParameterReroute(pydantic.BaseModel):
talemate_parameter: str
@@ -101,6 +113,56 @@ class ParameterReroute(pydantic.BaseModel):
return str(self) == str(other)
class RequestInformation(pydantic.BaseModel):
start_time: float = pydantic.Field(default_factory=time.time)
end_time: float | None = None
tokens: int = 0
@pydantic.computed_field(description="Duration")
@property
def duration(self) -> float:
end_time = self.end_time or time.time()
return end_time - self.start_time
@pydantic.computed_field(description="Tokens per second")
@property
def rate(self) -> float:
try:
end_time = self.end_time or time.time()
return self.tokens / (end_time - self.start_time)
except:
pass
return 0
@pydantic.computed_field(description="Status")
@property
def status(self) -> str:
if self.end_time:
return "completed"
elif self.start_time:
if self.duration > 1 and self.rate == 0:
return "stopped"
return "in progress"
else:
return "pending"
@pydantic.computed_field(description="Age")
@property
def age(self) -> float:
if not self.end_time:
return -1
return time.time() - self.end_time
@dataclasses.dataclass
class ClientEmbeddingsStatus:
client: "ClientBase | None" = None
embedding_name: str | None = None
async_signals.register(
"client.embeddings_available",
)
class ClientBase:
api_url: str
model_name: str
@@ -120,6 +182,7 @@ class ClientBase:
data_format: Literal["yaml", "json"] | None = None
rate_limit: int | None = None
client_type = "base"
request_information: RequestInformation | None = None
status_request_timeout:int = 2
@@ -171,6 +234,13 @@ class ClientBase:
"""
return self.Meta().requires_prompt_template
@property
def can_think(self) -> bool:
"""
Allow reasoning models to think before responding.
"""
return False
@property
def max_tokens_param_name(self):
return "max_tokens"
@@ -182,9 +252,87 @@ class ClientBase:
"temperature",
"max_tokens",
]
@property
def supports_embeddings(self) -> bool:
return False
@property
def embeddings_function(self):
return None
@property
def embeddings_status(self) -> bool:
return getattr(self, "_embeddings_status", False)
@property
def embeddings_model_name(self) -> str | None:
return getattr(self, "_embeddings_model_name", None)
@property
def embeddings_url(self) -> str:
return None
@property
def embeddings_identifier(self) -> str:
return f"client-api/{self.name}/{self.embeddings_model_name}"
async def destroy(self, config:dict):
"""
This is called before the client is removed from talemate.instance.clients
Use this to perform any cleanup that is necessary.
If a subclass overrides this method, it should call super().destroy(config) in the
end of the method.
"""
if self.supports_embeddings:
self.remove_embeddings(config)
def reset_embeddings(self):
self._embeddings_model_name = None
self._embeddings_status = False
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def set_embeddings(self):
log.debug("setting embeddings", client=self.name, supports_embeddings=self.supports_embeddings, embeddings_status=self.embeddings_status)
if not self.supports_embeddings or not self.embeddings_status:
return
config = load_config(as_model=True)
key = self.embeddings_identifier
if key in config.presets.embeddings:
log.debug("embeddings already set", client=self.name, key=key)
return config.presets.embeddings[key]
log.debug("setting embeddings", client=self.name, key=key)
config.presets.embeddings[key] = EmbeddingFunctionPreset(
embeddings="client-api",
client=self.name,
model=self.embeddings_model_name,
distance=1,
distance_function="cosine",
local=False,
custom=True,
)
save_config(config)
def remove_embeddings(self, config:dict | None = None):
# remove all embeddings for this client
for key, value in list(config["presets"]["embeddings"].items()):
if value["client"] == self.name and value["embeddings"] == "client-api":
log.warning("!!! removing embeddings", client=self.name, key=key)
config["presets"]["embeddings"].pop(key)
def set_system_prompts(self, system_prompts: dict | SystemPrompts):
if isinstance(system_prompts, dict):
@@ -222,6 +370,19 @@ class ClientBase:
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
)
def split_prompt_for_coercion(self, prompt: str) -> tuple[str, str]:
"""
Splits the prompt and the prefill/coercion prompt.
"""
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if self.double_coercion:
right = f"{self.double_coercion}\n\n{right}"
return prompt, right
return prompt, None
def reconfigure(self, **kwargs):
"""
Reconfigures the client.
@@ -241,6 +402,8 @@ class ClientBase:
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if not self.enabled and self.supports_embeddings and self.embeddings_status:
self.reset_embeddings()
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
@@ -327,7 +490,7 @@ class ClientBase:
"""
Sets and emits the client status.
"""
if processing is not None:
self.processing = processing
@@ -388,6 +551,8 @@ class ClientBase:
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
data[field_name] = getattr(self, field_name, None)
data = self.finalize_status(data)
emit(
"client_status",
message=self.client_type,
@@ -400,13 +565,31 @@ class ClientBase:
if status_change:
instance.emit_agent_status_by_client(self)
def finalize_status(self, data: dict):
"""
Finalizes the status data for the client.
"""
return data
def _common_status_data(self):
return {
common_data = {
"can_be_coerced": self.can_be_coerced,
"preset_group": self.preset_group or "",
"rate_limit": self.rate_limit,
"data_format": self.data_format,
"manual_model_choices": getattr(self.Meta(), "manual_model_choices", []),
"supports_embeddings": self.supports_embeddings,
"embeddings_status": self.embeddings_status,
"embeddings_model_name": self.embeddings_model_name,
"request_information": self.request_information.model_dump() if self.request_information else None,
}
extra_fields = getattr(self.Meta(), "extra_fields", {})
for field_name in extra_fields.keys():
common_data[field_name] = getattr(self, field_name, None)
return common_data
def populate_extra_fields(self, data: dict):
"""
Updates data with the extra fields from the client's Meta
@@ -438,6 +621,7 @@ class ClientBase:
:return: None
"""
if self.processing:
self.emit_status()
return
if not self.enabled:
@@ -618,8 +802,29 @@ class ClientBase:
at the other side of the client.
"""
pass
def new_request(self):
"""
Creates a new request information object.
"""
self.request_information = RequestInformation()
def end_request(self):
"""
Ends the request information object.
"""
self.request_information.end_time = time.time()
def update_request_tokens(self, tokens: int, replace: bool = False):
"""
Updates the request information object with the number of tokens received.
"""
if self.request_information:
if replace:
self.request_information.tokens = tokens
else:
self.request_information.tokens += tokens
async def send_prompt(
self,
prompt: str,
@@ -690,7 +895,7 @@ class ClientBase:
except GenerationCancelled:
raise
except Exception as e:
log.exception("Error during rate limit check", e=e)
log.error("Error during rate limit check", e=traceback.format_exc())
if not active_scene.get():
@@ -736,8 +941,12 @@ class ClientBase:
)
prompt_sent = self.repetition_adjustment(finalized_prompt)
self.new_request()
response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
self.end_request()
if isinstance(response, GenerationCancelled):
# generation was cancelled
raise response
@@ -786,7 +995,7 @@ class ClientBase:
except GenerationCancelled as e:
raise
except Exception as e:
self.log.exception("send_prompt error", e=e)
self.log.error("send_prompt error", e=traceback.format_exc())
emit(
"status", message="Error during generation (check logs)", status="error"
)

View File

@@ -1,10 +1,15 @@
import pydantic
import structlog
from cohere import AsyncClient
from cohere import AsyncClientV2
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
from talemate.client.registry import register
from talemate.config import load_config
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.util import count_tokens
@@ -26,13 +31,17 @@ SUPPORTED_MODELS = [
]
class Defaults(CommonDefaults, pydantic.BaseModel):
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "command-r-plus"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
class CohereClient(ClientBase):
class CohereClient(EndpointOverrideMixin, ClientBase):
"""
Cohere client for generating text.
"""
@@ -41,18 +50,21 @@ class CohereClient(ClientBase):
conversation_retries = 0
auto_break_repetition_enabled = False
decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Cohere"
title: str = "Cohere"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
defaults: Defaults = Defaults()
def __init__(self, model="command-r-plus", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
@@ -119,8 +131,8 @@ class CohereClient(ClientBase):
)
def set_client(self, max_token_length: int = None):
if not self.cohere_api_key:
self.client = AsyncClient("sk-1111")
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncClientV2("sk-1111")
log.error("No cohere API key set")
if self.api_key_status:
self.api_key_status = False
@@ -136,7 +148,7 @@ class CohereClient(ClientBase):
model = self.model_name
self.client = AsyncClient(self.cohere_api_key)
self.client = AsyncClientV2(self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
@@ -161,6 +173,7 @@ class CohereClient(ClientBase):
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event):
config = event.data
@@ -168,7 +181,7 @@ class CohereClient(ClientBase):
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return count_tokens(response.text)
return count_tokens(response)
def prompt_tokens(self, prompt: str):
return count_tokens(prompt)
@@ -207,7 +220,7 @@ class CohereClient(ClientBase):
Generates text from the given prompt and parameters.
"""
if not self.cohere_api_key:
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No cohere API key set")
right = None
@@ -227,21 +240,43 @@ class CohereClient(ClientBase):
parameters=parameters,
system_message=system_message,
)
messages = [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": human_message,
}
]
try:
response = await self.client.chat(
# Cohere's `chat_stream` returns an **asynchronous generator** that can be
# consumed directly with `async for`. It is not an asynchronous context
# manager, so attempting to use `async with` raises a `TypeError` as seen
# in issue logs. We therefore iterate over the generator directly.
stream = self.client.chat_stream(
model=self.model_name,
preamble=system_message,
message=human_message,
messages=messages,
**parameters,
)
response = ""
async for event in stream:
if event and event.type == "content-delta":
chunk = event.delta.message.content.text
response += chunk
# Track token usage incrementally
self.update_request_tokens(self.count_tokens(chunk))
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
log.debug("generated response", response=response.text)
response = response.text
log.debug("generated response", response=response)
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):

View File

@@ -187,6 +187,14 @@ class DeepSeekClient(ClientBase):
return prompt
def response_tokens(self, response: str):
# Count tokens in a response string using the util.count_tokens helper
return self.count_tokens(response)
def prompt_tokens(self, prompt: str):
# Count tokens in a prompt string using the util.count_tokens helper
return self.count_tokens(prompt)
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -221,13 +229,30 @@ class DeepSeekClient(ClientBase):
)
try:
response = await self.client.chat.completions.create(
# Use streaming so we can update_Request_tokens incrementally
stream = await self.client.chat.completions.create(
model=self.model_name,
messages=[system_message, human_message],
stream=True,
**parameters,
)
response = response.choices[0].message.content
response = ""
# Iterate over streamed chunks
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta and getattr(delta, "content", None):
content_piece = delta.content
response += content_piece
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
# Save token accounting for whole request
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json

View File

@@ -3,19 +3,18 @@ import os
import pydantic
import structlog
import vertexai
from google.api_core.exceptions import ResourceExhausted
from vertexai.generative_models import (
ChatSession,
GenerationConfig,
GenerativeModel,
ResponseValidationError,
SafetySetting,
)
from google import genai
import google.genai.types as genai_types
from google.genai.errors import APIError
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults
from talemate.client.registry import register
from talemate.client.remote import RemoteServiceMixin
from talemate.client.remote import (
RemoteServiceMixin,
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig
from talemate.config import load_config
from talemate.emit import emit
@@ -31,23 +30,29 @@ log = structlog.get_logger("talemate")
SUPPORTED_MODELS = [
"gemini-1.0-pro",
"gemini-1.5-pro-preview-0409",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash-preview-05-20",
"gemini-2.5-pro-preview-03-25",
"gemini-2.5-pro-preview-06-05",
]
class Defaults(CommonDefaults, pydantic.BaseModel):
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gemini-1.0-pro"
model: str = "gemini-2.0-flash"
disable_safety_settings: bool = False
double_coercion: str = None
class ClientConfig(BaseClientConfig):
class ClientConfig(EndpointOverride, BaseClientConfig):
disable_safety_settings: bool = False
@register()
class GoogleClient(RemoteServiceMixin, ClientBase):
class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
"""
Google client for generating text.
"""
@@ -74,19 +79,26 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
description="Disable Google's safety settings for responses generated by the model.",
),
}
extra_fields.update(endpoint_override_extra_fields())
def __init__(self, model="gemini-1.0-pro", **kwargs):
def __init__(self, model="gemini-2.0-flash", **kwargs):
self.model_name = model
self.setup_status = None
self.model_instance = None
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
self.google_credentials_read = False
self.google_project_id = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
@property
def google_credentials(self):
path = self.google_credentials_path
@@ -102,16 +114,36 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
@property
def google_location(self):
return self.config.get("google").get("gcloud_location")
@property
def google_api_key(self):
return self.config.get("google").get("api_key")
@property
def vertexai_ready(self) -> bool:
return all([
self.google_credentials_path,
self.google_location,
])
@property
def developer_api_ready(self) -> bool:
return all([
self.google_api_key,
])
@property
def using(self) -> str:
if self.developer_api_ready:
return "API"
if self.vertexai_ready:
return "VertexAI"
return "Unknown"
@property
def ready(self):
# all google settings must be set
return all(
[
self.google_credentials_path,
self.google_location,
]
)
return self.vertexai_ready or self.developer_api_ready or self.endpoint_override_base_url_configured
@property
def safety_settings(self):
@@ -119,30 +151,39 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
return None
safety_settings = [
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
genai_types.SafetySetting(
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold="BLOCK_NONE",
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
genai_types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT",
threshold="BLOCK_NONE",
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
genai_types.SafetySetting(
category="HARM_CATEGORY_HARASSMENT",
threshold="BLOCK_NONE",
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
genai_types.SafetySetting(
category="HARM_CATEGORY_HATE_SPEECH",
threshold="BLOCK_NONE",
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
genai_types.SafetySetting(
category="HARM_CATEGORY_CIVIC_INTEGRITY",
threshold="BLOCK_NONE",
),
]
return safety_settings
@property
def http_options(self) -> genai_types.HttpOptions | None:
if not self.endpoint_override_base_url_configured:
return None
return genai_types.HttpOptions(
base_url=self.base_url
)
@property
def supported_parameters(self):
return [
@@ -184,6 +225,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
self.current_status = status
data = {
"double_coercion": self.double_coercion,
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
@@ -191,15 +233,27 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
data.update(self._common_status_data())
self.populate_extra_fields(data)
if self.using == "VertexAI":
details = f"{model_name} (VertexAI)"
else:
details = model_name
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
details=details,
status=status if self.enabled else "disabled",
data=data,
)
def set_client_base_url(self, base_url: str | None):
if getattr(self, "client", None):
try:
self.client.http_options.base_url = base_url
except Exception as e:
log.error("Error setting client base URL", error=e, client=self.client_type)
def set_client(self, max_token_length: int = None, **kwargs):
if not self.ready:
log.error("Google cloud setup incomplete")
@@ -210,7 +264,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
return
if not self.model_name:
self.model_name = "gemini-1.0-pro"
self.model_name = "gemini-2.0-flash"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
@@ -222,17 +276,14 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
self.max_token_length = max_token_length or 16384
if not self.setup_status:
if self.setup_status is False:
project_id = self.google_credentials.get("project_id")
self.google_project_id = project_id
if self.google_credentials_path:
vertexai.init(project=project_id, location=self.google_location)
emit("request_client_status")
emit("request_agent_status")
self.setup_status = True
self.model_instance = GenerativeModel(model_name=model)
if self.vertexai_ready and not self.developer_api_ready:
self.client = genai.Client(
vertexai=True,
project=self.google_project_id,
location=self.google_location,
)
else:
self.client = genai.Client(api_key=self.api_key or None, http_options=self.http_options)
log.info(
"google set client",
@@ -241,8 +292,9 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
model=model,
)
def response_tokens(self, response: str):
return count_tokens(response.text)
def response_tokens(self, response:str):
"""Return token count for a response which may be a string or SDK object."""
return count_tokens(response)
def prompt_tokens(self, prompt: str):
return count_tokens(prompt)
@@ -258,6 +310,9 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
def clean_prompt_parameters(self, parameters: dict):
@@ -267,27 +322,53 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
if "top_k" in parameters and parameters["top_k"] == 0:
del parameters["top_k"]
def prompt_template(self, system_message: str, prompt: str):
"""
Google handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.ready:
raise Exception("Google cloud setup incomplete")
raise Exception("Google setup incomplete")
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
human_message = prompt.strip()
system_message = self.get_system_message(kind)
contents = [
genai_types.Content(
role="user",
parts=[
genai_types.Part.from_text(
text=human_message,
)
]
)
]
if coercion_prompt:
contents.append(
genai_types.Content(
role="model",
parts=[
genai_types.Part.from_text(
text=coercion_prompt,
)
]
)
)
self.log.debug(
"generate",
base_url=self.base_url,
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
@@ -296,48 +377,53 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
)
try:
# Use streaming so we can update_Request_tokens incrementally
#stream = await chat.send_message_async(
# human_message,
# safety_settings=self.safety_settings,
# generation_config=parameters,
# stream=True
#)
chat = self.model_instance.start_chat()
response = await chat.send_message_async(
human_message,
safety_settings=self.safety_settings,
generation_config=parameters,
stream = await self.client.aio.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=genai_types.GenerateContentConfig(
safety_settings=self.safety_settings,
http_options=self.http_options,
**parameters
),
)
response = ""
async for chunk in stream:
# For each streamed chunk, append content and update token counts
content_piece = getattr(chunk, "text", None)
if not content_piece:
# Some SDK versions wrap text under candidates[0].text
try:
content_piece = chunk.candidates[0].text # type: ignore
except Exception:
content_piece = None
if content_piece:
response += content_piece
# Incrementally update token usage
self.update_request_tokens(count_tokens(content_piece))
# Store total token accounting for prompt/response
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
response = response.text
log.debug("generated response", response=response)
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
# except PermissionDeniedError as e:
# self.log.error("generate error", e=e)
# emit("status", message="google API: Permission Denied", status="error")
# return ""
except ResourceExhausted as e:
except APIError as e:
self.log.error("generate error", e=e)
emit("status", message="google API: Quota Limit reached", status="error")
emit("status", message="google API: API Error", status="error")
return ""
except ResponseValidationError as e:
self.log.error("generate error", e=e)
emit(
"status",
message="google API: Response Validation Error",
status="error",
)
if not self.disable_safety_settings:
return "Failed to generate response. Probably due to safety settings, you can turn them off in the client settings."
return "Failed to generate response. Please check logs."
except Exception as e:
raise

View File

@@ -2,11 +2,16 @@ import pydantic
import structlog
from groq import AsyncGroq, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, ExtraField
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
__all__ = [
"GroqClient",
@@ -23,13 +28,13 @@ SUPPORTED_MODELS = [
JSON_OBJECT_RESPONSE_MODELS = []
class Defaults(pydantic.BaseModel):
class Defaults(EndpointOverride, pydantic.BaseModel):
max_token_length: int = 8192
model: str = "llama3-70b-8192"
@register()
class GroqClient(ClientBase):
class GroqClient(EndpointOverrideMixin, ClientBase):
"""
OpenAI client for generating text.
"""
@@ -47,10 +52,13 @@ class GroqClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="llama3-70b-8192", **kwargs):
self.model_name = model
self.api_key_status = None
# Apply any endpoint override parameters provided via kwargs before creating client
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
@@ -100,21 +108,27 @@ class GroqClient(ClientBase):
self.current_status = status
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
# Include shared/common status data (rate limit, etc.)
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status if self.enabled else "disabled",
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
},
data=data,
)
def set_client(self, max_token_length: int = None):
if not self.groq_api_key:
# Determine if we should use the globally configured API key or the override key
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
# No API key and no endpoint override cannot initialize client correctly
self.client = AsyncGroq(api_key="sk-1111")
log.error("No groq.ai API key set")
if self.api_key_status:
@@ -131,7 +145,8 @@ class GroqClient(ClientBase):
model = self.model_name
self.client = AsyncGroq(api_key=self.groq_api_key)
# Use the override values (if any) when constructing the Groq client
self.client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
@@ -155,6 +170,11 @@ class GroqClient(ClientBase):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
# Allow dynamic reconfiguration of endpoint override parameters
self._reconfigure_endpoint_override(**kwargs)
# Reconfigure any common parameters (rate limit, data format, etc.)
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
@@ -184,7 +204,7 @@ class GroqClient(ClientBase):
Generates text from the given prompt and parameters.
"""
if not self.groq_api_key:
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No groq.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS

View File

@@ -1,6 +1,10 @@
import random
import re
import json
import sseclient
import asyncio
from typing import TYPE_CHECKING
import requests
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
# import urljoin
from urllib.parse import urljoin, urlparse
@@ -10,12 +14,14 @@ import structlog
import talemate.util as util
from talemate.client.base import (
STOPPING_STRINGS,
ClientBase,
Defaults,
ParameterReroute,
ClientEmbeddingsStatus
)
from talemate.client.registry import register
import talemate.emit.async_signals as async_signals
if TYPE_CHECKING:
from talemate.agents.visual import VisualBase
@@ -28,6 +34,37 @@ class KoboldCppClientDefaults(Defaults):
api_key: str = ""
class KoboldEmbeddingFunction(EmbeddingFunction):
def __init__(self, api_url: str, model_name: str = None):
"""
Initialize the embedding function with the KoboldCPP API endpoint.
"""
self.api_url = api_url
self.model_name = model_name
def __call__(self, texts: Documents) -> Embeddings:
"""
Embed a list of input texts using the KoboldCPP embeddings endpoint.
"""
log.debug("KoboldCppEmbeddingFunction", api_url=self.api_url, model_name=self.model_name)
# Prepare the request payload for KoboldCPP. Include model name if required.
payload = {"input": texts}
if self.model_name is not None:
payload["model"] = self.model_name # e.g. the model's name/ID if needed
# Send POST request to the local KoboldCPP embeddings endpoint
response = requests.post(self.api_url, json=payload)
response.raise_for_status() # Throw an error if the request failed (e.g., connection issue)
# Parse the JSON response to extract embedding vectors
data = response.json()
# The 'data' field contains a list of embeddings (one per input)
embedding_results = data.get("data", [])
embeddings = [item["embedding"] for item in embedding_results]
return embeddings
@register()
class KoboldCppClient(ClientBase):
auto_determine_prompt_template: bool = True
@@ -58,7 +95,7 @@ class KoboldCppClient(ClientBase):
kcpp has two apis
open-ai implementation at /v1
their own implenation at /api/v1
their own implementation at /api/v1
"""
return "/api/v1" not in self.api_url
@@ -77,8 +114,8 @@ class KoboldCppClient(ClientBase):
# join /v1/completions
return urljoin(self.api_url, "completions")
else:
# join /api/v1/generate
return urljoin(self.api_url, "generate")
# join /api/extra/generate/stream
return urljoin(self.api_url.replace("v1", "extra"), "generate/stream")
@property
def max_tokens_param_name(self):
@@ -132,6 +169,21 @@ class KoboldCppClient(ClientBase):
"temperature",
]
@property
def supports_embeddings(self) -> bool:
return True
@property
def embeddings_url(self) -> str:
if self.is_openai:
return urljoin(self.api_url, "embeddings")
else:
return urljoin(self.api_url, "api/extra/embeddings")
@property
def embeddings_function(self):
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
def api_endpoint_specified(self, url: str) -> bool:
return "/v1" in self.api_url
@@ -152,14 +204,62 @@ class KoboldCppClient(ClientBase):
self.api_key = kwargs.get("api_key", self.api_key)
self.ensure_api_endpoint_specified()
async def get_model_name(self):
self.ensure_api_endpoint_specified()
async def get_embeddings_model_name(self):
# if self._embeddings_model_name is set, return it
if self.embeddings_model_name:
return self.embeddings_model_name
# otherwise, get the model name by doing a request to
# the embeddings endpoint with a single character
async with httpx.AsyncClient() as client:
response = await client.get(
self.api_url_for_model,
response = await client.post(
self.embeddings_url,
json={"input": ["test"]},
timeout=2,
headers=self.request_headers,
)
response_data = response.json()
self._embeddings_model_name = response_data.get("model")
return self._embeddings_model_name
async def get_embeddings_status(self):
url_version = urljoin(self.api_url, "api/extra/version")
async with httpx.AsyncClient() as client:
response = await client.get(url_version, timeout=2)
response_data = response.json()
self._embeddings_status = response_data.get("embeddings", False)
if not self.embeddings_status or self.embeddings_model_name:
return
await self.get_embeddings_model_name()
log.debug("KoboldCpp embeddings are enabled, suggesting embeddings", model_name=self.embeddings_model_name)
self.set_embeddings()
await async_signals.get("client.embeddings_available").send(
ClientEmbeddingsStatus(
client=self,
embedding_name=self.embeddings_model_name,
)
)
async def get_model_name(self):
self.ensure_api_endpoint_specified()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
self.api_url_for_model,
timeout=2,
headers=self.request_headers,
)
except Exception:
self._embeddings_model_name = None
raise
if response.status_code == 404:
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
@@ -175,6 +275,8 @@ class KoboldCppClient(ClientBase):
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
await self.get_embeddings_status()
return model_name
@@ -223,11 +325,48 @@ class KoboldCppClient(ClientBase):
url_abort,
headers=self.request_headers,
)
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if self.is_openai:
return await self._generate_openai(prompt, parameters, kind)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._generate_kcpp_stream, prompt, parameters, kind)
def _generate_kcpp_stream(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ")
response = ""
parameters["stream"] = True
stream_response = requests.post(
self.api_url_for_generation,
json=parameters,
timeout=None,
headers=self.request_headers,
stream=True,
)
stream_response.raise_for_status()
sse = sseclient.SSEClient(stream_response)
for event in sse.events():
payload = json.loads(event.data)
chunk = payload['token']
response += chunk
self.update_request_tokens(self.count_tokens(chunk))
return response
async def _generate_openai(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ")

View File

@@ -54,18 +54,55 @@ class LMStudioClient(ClientBase):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
Generates text from the given prompt and parameters using a streaming
request so that token usage can be tracked incrementally via
`update_request_tokens`.
"""
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
# Send the request in streaming mode so we can update token counts
stream = await self.client.completions.create(
model=self.model_name,
prompt=prompt,
stream=True,
**parameters,
)
return response.choices[0].message.content
response = ""
# Iterate over streamed chunks and accumulate the response while
# incrementally updating the token counter
async for chunk in stream:
if not chunk.choices:
continue
content_piece = chunk.choices[0].text
response += content_piece
# Track token usage incrementally
self.update_request_tokens(self.count_tokens(content_piece))
# Store overall token accounting once the stream is finished
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
return response
except Exception as e:
self.log.error("generate error", e=e)
return ""
# ------------------------------------------------------------------
# Token helpers
# ------------------------------------------------------------------
def response_tokens(self, response: str):
"""Count tokens in a model response string."""
return self.count_tokens(response)
def prompt_tokens(self, prompt: str):
"""Count tokens in a prompt string."""
return self.count_tokens(prompt)

View File

@@ -4,9 +4,14 @@ from typing import Literal
from mistralai import Mistral
from mistralai.models.sdkerror import SDKError
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
from talemate.client.registry import register
from talemate.config import load_config
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
@@ -35,14 +40,15 @@ JSON_OBJECT_RESPONSE_MODELS = [
]
class Defaults(CommonDefaults, pydantic.BaseModel):
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "open-mixtral-8x22b"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
class MistralAIClient(ClientBase):
class MistralAIClient(EndpointOverrideMixin, ClientBase):
"""
OpenAI client for generating text.
"""
@@ -52,6 +58,7 @@ class MistralAIClient(ClientBase):
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "MistralAI"
@@ -60,16 +67,18 @@ class MistralAIClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="open-mixtral-8x22b", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def mistralai_api_key(self):
def mistral_api_key(self):
return self.config.get("mistralai", {}).get("api_key")
@property
@@ -85,7 +94,7 @@ class MistralAIClient(ClientBase):
if processing is not None:
self.processing = processing
if self.mistralai_api_key:
if self.mistral_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
@@ -122,7 +131,7 @@ class MistralAIClient(ClientBase):
)
def set_client(self, max_token_length: int = None):
if not self.mistralai_api_key:
if not self.mistral_api_key and not self.endpoint_override_base_url_configured:
self.client = Mistral(api_key="sk-1111")
log.error("No mistral.ai API key set")
if self.api_key_status:
@@ -139,7 +148,7 @@ class MistralAIClient(ClientBase):
model = self.model_name
self.client = Mistral(api_key=self.mistralai_api_key)
self.client = Mistral(api_key=self.api_key, server_url=self.base_url)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
@@ -160,7 +169,8 @@ class MistralAIClient(ClientBase):
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
@@ -201,7 +211,7 @@ class MistralAIClient(ClientBase):
Generates text from the given prompt and parameters.
"""
if not self.mistralai_api_key:
if not self.mistral_api_key:
raise Exception("No mistral.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
@@ -224,22 +234,36 @@ class MistralAIClient(ClientBase):
self.log.debug(
"generate",
base_url=self.base_url,
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
try:
response = await self.client.chat.complete_async(
event_stream = await self.client.chat.stream_async(
model=self.model_name,
messages=messages,
**parameters,
)
self._returned_prompt_tokens = self.prompt_tokens(response)
self._returned_response_tokens = self.response_tokens(response)
response = ""
completion_tokens = 0
prompt_tokens = 0
response = response.choices[0].message.content
async for event in event_stream:
if event.data.choices:
response += event.data.choices[0].delta.content
self.update_request_tokens(self.count_tokens(event.data.choices[0].delta.content))
if event.data.usage:
completion_tokens += event.data.usage.completion_tokens
prompt_tokens += event.data.usage.prompt_tokens
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
#response = response.choices[0].message.content
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json

View File

@@ -129,6 +129,12 @@ class ModelPrompt:
return prompt
def clean_model_name(self, model_name: str):
"""
Clean the model name to be used in the template file name.
"""
return model_name.replace("/", "__").replace(":", "_")
def get_template(self, model_name: str):
"""
Will attempt to load an LLM prompt template - this supports
@@ -137,7 +143,7 @@ class ModelPrompt:
matches = []
cleaned_model_name = model_name.replace("/", "__")
cleaned_model_name = self.clean_model_name(model_name)
# Iterate over all templates in the loader's directory
for template_name in self.env.list_templates():
@@ -166,7 +172,7 @@ class ModelPrompt:
template_name = template_name.split(".jinja2")[0]
cleaned_model_name = model_name.replace("/", "__")
cleaned_model_name = self.clean_model_name(model_name)
shutil.copyfile(
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),

View File

@@ -0,0 +1,313 @@
import asyncio
import structlog
import httpx
import ollama
import time
from typing import Union
from talemate.client.base import STOPPING_STRINGS, ClientBase, CommonDefaults, ErrorAction, ParameterReroute, ExtraField
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.ollama")
FETCH_MODELS_INTERVAL = 15
class OllamaClientDefaults(CommonDefaults):
api_url: str = "http://localhost:11434" # Default Ollama URL
model: str = "" # Allow empty default, will fetch from Ollama
api_handles_prompt_template: bool = False
allow_thinking: bool = False
class ClientConfig(BaseClientConfig):
api_handles_prompt_template: bool = False
allow_thinking: bool = False
@register()
class OllamaClient(ClientBase):
"""
Ollama client for generating text using locally hosted models.
"""
auto_determine_prompt_template: bool = True
client_type = "ollama"
conversation_retries = 0
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Ollama"
title: str = "Ollama"
enable_api_auth: bool = False
manual_model: bool = True
manual_model_choices: list[str] = [] # Will be overridden by finalize_status
defaults: OllamaClientDefaults = OllamaClientDefaults()
extra_fields: dict[str, ExtraField] = {
"api_handles_prompt_template": ExtraField(
name="api_handles_prompt_template",
type="bool",
label="API handles prompt template",
required=False,
description="Let Ollama handle the prompt template. Only do this if you don't know which prompt template to use. Letting talemate handle the prompt template will generally lead to improved responses.",
),
"allow_thinking": ExtraField(
name="allow_thinking",
type="bool",
label="Allow thinking",
required=False,
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
)
}
@property
def supported_parameters(self):
# Parameters supported by Ollama's generate endpoint
# Based on the API documentation
return [
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
ParameterReroute(
talemate_parameter="repetition_penalty",
client_parameter="repeat_penalty"
),
ParameterReroute(
talemate_parameter="max_tokens",
client_parameter="num_predict"
),
"stopping_strings",
# internal parameters that will be removed before sending
"extra_stopping_strings",
]
@property
def can_be_coerced(self):
"""
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
return not self.api_handles_prompt_template
@property
def can_think(self) -> bool:
"""
Allow reasoning models to think before responding.
"""
return self.allow_thinking
def __init__(self, model=None, api_handles_prompt_template=False, allow_thinking=False, **kwargs):
self.model_name = model
self.api_handles_prompt_template = api_handles_prompt_template
self.allow_thinking = allow_thinking
self._available_models = []
self._models_last_fetched = 0
self.client = None
super().__init__(**kwargs)
def set_client(self, **kwargs):
"""
Initialize the Ollama client with the API URL.
"""
# Update model if provided
if kwargs.get("model"):
self.model_name = kwargs["model"]
# Create async client with the configured API URL
# Ollama's AsyncClient expects just the base URL without any path
self.client = ollama.AsyncClient(host=self.api_url)
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
self.allow_thinking = kwargs.get(
"allow_thinking", self.allow_thinking
)
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
Raises an error if no model name is returned.
:return: None
"""
if self.processing:
self.emit_status()
return
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
# instead of using the client (which apparently cannot set a timeout per endpoint)
# we use httpx to check {api_url}/api/version to see if the server is running
# use a timeout of 2 seconds
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url}/api/version", timeout=2)
response.raise_for_status()
# if the server is running, fetch the available models
await self.fetch_available_models()
except Exception as e:
log.error("Failed to fetch models from Ollama", error=str(e))
self.connected = False
self.emit_status()
return
await super().status()
async def fetch_available_models(self):
"""
Fetch list of available models from Ollama.
"""
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
return self._available_models
response = await self.client.list()
models = response.get("models", [])
model_names = [model.model for model in models]
self._available_models = sorted(model_names)
self._models_last_fetched = time.time()
return self._available_models
def finalize_status(self, data: dict):
"""
Finalizes the status data for the client.
"""
data["manual_model_choices"] = self._available_models
return data
async def get_model_name(self):
return self.model_name
def prompt_template(self, system_message: str, prompt: str):
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
"""
Tune parameters for Ollama's generate endpoint.
"""
super().tune_prompt_parameters(parameters, kind)
# Build stopping strings list
parameters["stop"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
# Ollama uses num_predict instead of max_tokens
if "max_tokens" in parameters:
parameters["num_predict"] = parameters["max_tokens"]
def clean_prompt_parameters(self, parameters: dict):
"""
Clean and prepare parameters for Ollama API.
"""
# First let parent class handle parameter reroutes and cleanup
super().clean_prompt_parameters(parameters)
# Remove our internal parameters
if "extra_stopping_strings" in parameters:
del parameters["extra_stopping_strings"]
if "stopping_strings" in parameters:
del parameters["stopping_strings"]
if "stream" in parameters:
del parameters["stream"]
# Remove max_tokens as we've already converted it to num_predict
if "max_tokens" in parameters:
del parameters["max_tokens"]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generate text using Ollama's generate endpoint.
"""
if not self.model_name:
# Try to get a model name
await self.get_model_name()
if not self.model_name:
raise Exception("No model specified or available in Ollama")
# Prepare options for Ollama
options = parameters
options["num_ctx"] = self.max_token_length
try:
# Use generate endpoint for completion
stream = await self.client.generate(
model=self.model_name,
prompt=prompt.strip(),
options=options,
raw=self.can_be_coerced,
think=self.can_think,
stream=True,
)
response = ""
async for part in stream:
content = part.response
response += content
self.update_request_tokens(self.count_tokens(content))
# Extract the response text
return response
except Exception as e:
log.error("Ollama generation error", error=str(e), model=self.model_name)
raise ErrorAction(
message=f"Ollama generation failed: {str(e)}",
title="Generation Error"
)
async def abort_generation(self):
"""
Ollama doesn't have a direct abort endpoint, but we can try to stop the model.
"""
# This is a no-op for now as Ollama doesn't expose an abort endpoint
# in the Python client
pass
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
Adjusts temperature and repetition_penalty by random values.
"""
import random
temp = prompt_config["temperature"]
rep_pen = prompt_config.get("repetition_penalty", 1.0)
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
prompt_config["repetition_penalty"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)
def reconfigure(self, **kwargs):
"""
Reconfigure the client with new settings.
"""
# Handle model update
if kwargs.get("model"):
self.model_name = kwargs["model"]
super().reconfigure(**kwargs)
# Re-initialize client if API URL changed or model changed
if "api_url" in kwargs or "model" in kwargs:
self.set_client(**kwargs)
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]

View File

@@ -5,9 +5,14 @@ import structlog
import tiktoken
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults, ExtraField
from talemate.client.registry import register
from talemate.config import load_config
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
@@ -79,9 +84,6 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
elif "gpt-3.5-turbo" in model:
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model or "o1" in model or "o3" in model:
print(
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
return num_tokens_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
@@ -102,13 +104,15 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
return num_tokens
class Defaults(CommonDefaults, pydantic.BaseModel):
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gpt-4o"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
class OpenAIClient(ClientBase):
class OpenAIClient(EndpointOverrideMixin, ClientBase):
"""
OpenAI client for generating text.
"""
@@ -118,7 +122,8 @@ class OpenAIClient(ClientBase):
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "OpenAI"
title: str = "OpenAI"
@@ -126,10 +131,11 @@ class OpenAIClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="gpt-4o", **kwargs):
self.model_name = model
self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config()
super().__init__(**kwargs)
@@ -192,7 +198,7 @@ class OpenAIClient(ClientBase):
)
def set_client(self, max_token_length: int = None):
if not self.openai_api_key:
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncOpenAI(api_key="sk-1111")
log.error("No OpenAI API key set")
if self.api_key_status:
@@ -209,7 +215,7 @@ class OpenAIClient(ClientBase):
model = self.model_name
self.client = AsyncOpenAI(api_key=self.openai_api_key)
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096)
elif model == "gpt-4":
@@ -247,6 +253,7 @@ class OpenAIClient(ClientBase):
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event):
config = event.data
@@ -278,7 +285,7 @@ class OpenAIClient(ClientBase):
Generates text from the given prompt and parameters.
"""
if not self.openai_api_key:
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No OpenAI API key set")
# only gpt-4-* supports enforcing json object
@@ -333,13 +340,28 @@ class OpenAIClient(ClientBase):
)
try:
response = await self.client.chat.completions.create(
stream = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
stream=True,
**parameters,
)
response = ""
response = response.choices[0].message.content
# Iterate over streamed chunks
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta and getattr(delta, "content", None):
content_piece = delta.content
response += content_piece
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
#self._returned_prompt_tokens = self.prompt_tokens(prompt)
#self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json

View File

@@ -0,0 +1,329 @@
import pydantic
import structlog
import httpx
import asyncio
import json
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"OpenRouterClient",
]
log = structlog.get_logger("talemate.client.openrouter")
# Available models will be populated when first client with API key is initialized
AVAILABLE_MODELS = []
DEFAULT_MODEL = ""
MODELS_FETCHED = False
async def fetch_available_models(api_key: str = None):
"""Fetch available models from OpenRouter API"""
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
if not api_key:
return []
if MODELS_FETCHED:
return AVAILABLE_MODELS
# Only fetch if we haven't already or if explicitly requested
if AVAILABLE_MODELS and not api_key:
return AVAILABLE_MODELS
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://openrouter.ai/api/v1/models",
timeout=10.0
)
if response.status_code == 200:
data = response.json()
models = []
for model in data.get("data", []):
model_id = model.get("id")
if model_id:
models.append(model_id)
AVAILABLE_MODELS = sorted(models)
log.debug(f"Fetched {len(AVAILABLE_MODELS)} models from OpenRouter")
else:
log.warning(f"Failed to fetch models from OpenRouter: {response.status_code}")
except Exception as e:
log.error(f"Error fetching models from OpenRouter: {e}")
MODELS_FETCHED = True
return AVAILABLE_MODELS
def fetch_models_sync(event):
api_key = event.data.get("openrouter", {}).get("api_key")
loop = asyncio.get_event_loop()
loop.run_until_complete(fetch_available_models(api_key))
handlers["config_saved"].connect(fetch_models_sync)
handlers["talemate_started"].connect(fetch_models_sync)
class Defaults(CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = DEFAULT_MODEL
@register()
class OpenRouterClient(ClientBase):
"""
OpenRouter client for generating text using various models.
"""
client_type = "openrouter"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
class Meta(ClientBase.Meta):
name_prefix: str = "OpenRouter"
title: str = "OpenRouter"
manual_model: bool = True
manual_model_choices: list[str] = pydantic.Field(default_factory=lambda: AVAILABLE_MODELS)
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model=None, **kwargs):
self.model_name = model or DEFAULT_MODEL
self.api_key_status = None
self.config = load_config()
self._models_fetched = False
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
@property
def openrouter_api_key(self):
return self.config.get("openrouter", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"repetition_penalty",
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.openrouter_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"openrouter_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
# Unlike other clients, we don't need to set up a client instance
# We'll use httpx directly in the generate method
if not self.openrouter_api_key:
log.error("No OpenRouter API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = DEFAULT_MODEL
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
# Set max token length (default to 16k if not specified)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"openrouter set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=self.model_name,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
async def status(self):
# Fetch models if we have an API key and haven't fetched yet
if self.openrouter_api_key and not self._models_fetched:
self._models_fetched = True
# Update the Meta class with new model choices
self.Meta.manual_model_choices = AVAILABLE_MODELS
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
"""
Open-router handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters using OpenRouter API.
"""
if not self.openrouter_api_key:
raise Exception("No OpenRouter API key set")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
# Prepare messages for chat completion
messages = [
{"role": "system", "content": self.get_system_message(kind)},
{"role": "user", "content": prompt.strip()}
]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
# Prepare request payload
payload = {
"model": self.model_name,
"messages": messages,
"stream": True,
**parameters
}
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
model=self.model_name,
)
response_text = ""
buffer = ""
completion_tokens = 0
prompt_tokens = 0
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
"https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {self.openrouter_api_key}",
"Content-Type": "application/json",
},
json=payload,
timeout=120.0 # 2 minute timeout for generation
) as response:
async for chunk in response.aiter_text():
buffer += chunk
while True:
# Find the next complete SSE line
line_end = buffer.find('\n')
if line_end == -1:
break
line = buffer[:line_end].strip()
buffer = buffer[line_end + 1:]
if line.startswith('data: '):
data = line[6:]
if data == '[DONE]':
break
try:
data_obj = json.loads(data)
content = data_obj["choices"][0]["delta"].get("content")
usage = data_obj.get("usage", {})
completion_tokens += usage.get("completion_tokens", 0)
prompt_tokens += usage.get("prompt_tokens", 0)
if content:
response_text += content
# Update tokens as content streams in
self.update_request_tokens(self.count_tokens(content))
except json.JSONDecodeError:
pass
# Extract the response content
response_content = response_text
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
return response_content
except httpx.ConnectTimeout:
self.log.error("OpenRouter API timeout")
emit("status", message="OpenRouter API: Request timed out", status="error")
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message=f"OpenRouter API Error: {str(e)}", status="error")
raise

View File

@@ -1,5 +1,109 @@
__all__ = ["RemoteServiceMixin"]
import pydantic
import structlog
from .base import FieldGroup, ExtraField
import talemate.ux.schema as ux_schema
log = structlog.get_logger(__name__)
__all__ = [
"RemoteServiceMixin",
"EndpointOverrideMixin",
"endpoint_override_extra_fields",
"EndpointOverrideGroup",
"EndpointOverrideField",
"EndpointOverrideBaseURLField",
"EndpointOverrideAPIKeyField",
]
def endpoint_override_extra_fields():
return {
"override_base_url": EndpointOverrideBaseURLField(),
"override_api_key": EndpointOverrideAPIKeyField(),
}
class EndpointOverride(pydantic.BaseModel):
override_base_url: str | None = None
override_api_key: str | None = None
class EndpointOverrideGroup(FieldGroup):
name: str = "endpoint_override"
label: str = "Endpoint Override"
description: str = ("Override the default base URL used by this client to access the {client_type} service API.\n\n"
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
"If the override endpoint requires an API key, enter it below.")
icon: str = "mdi-api"
class EndpointOverrideField(ExtraField):
group: EndpointOverrideGroup = pydantic.Field(default_factory=EndpointOverrideGroup)
class EndpointOverrideBaseURLField(EndpointOverrideField):
name: str = "override_base_url"
type: str = "text"
label: str = "Base URL"
required: bool = False
description: str = "Override the base URL for the remote service"
class EndpointOverrideAPIKeyField(EndpointOverrideField):
name: str = "override_api_key"
type: str = "password"
label: str = "API Key"
required: bool = False
description: str = "Override the API key for the remote service"
note: ux_schema.Note = pydantic.Field(default_factory=lambda: ux_schema.Note(
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
color="warning",
))
class EndpointOverrideMixin:
override_base_url: str | None = None
override_api_key: str | None = None
def set_client_api_key(self, api_key: str | None):
if getattr(self, "client", None):
try:
self.client.api_key = api_key
except Exception as e:
log.error("Error setting client API key", error=e, client=self.client_type)
@property
def api_key(self) -> str | None:
if self.endpoint_override_base_url_configured:
return self.override_api_key
return getattr(self, f"{self.client_type}_api_key", None)
@property
def base_url(self) -> str | None:
if self.override_base_url and self.override_base_url.strip():
return self.override_base_url
return None
@property
def endpoint_override_base_url_configured(self) -> bool:
return self.override_base_url and self.override_base_url.strip()
@property
def endpoint_override_api_key_configured(self) -> bool:
return self.override_api_key and self.override_api_key.strip()
@property
def endpoint_override_fully_configured(self) -> bool:
return self.endpoint_override_base_url_configured and self.endpoint_override_api_key_configured
def _reconfigure_endpoint_override(self, **kwargs):
if "override_base_url" in kwargs:
orig = getattr(self, "override_base_url", None)
self.override_base_url = kwargs["override_base_url"]
if getattr(self, "client", None) and orig != self.override_base_url:
log.info("Reconfiguring client base URL", new=self.override_base_url)
self.set_client(kwargs.get("max_token_length"))
if "override_api_key" in kwargs:
self.override_api_key = kwargs["override_api_key"]
self.set_client_api_key(self.override_api_key)
class RemoteServiceMixin:

View File

@@ -1,10 +1,10 @@
import random
import urllib
from typing import Literal
import aiohttp
import json
import httpx
import pydantic
import structlog
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
from openai import PermissionDeniedError
from talemate.client.base import ClientBase, ExtraField, CommonDefaults
from talemate.client.registry import register
@@ -17,61 +17,6 @@ log = structlog.get_logger("talemate.client.tabbyapi")
EXPERIMENTAL_DESCRIPTION = """Use this client to use all of TabbyAPI's features"""
class CustomAPIClient:
def __init__(self, base_url, api_key):
self.base_url = base_url
self.api_key = api_key
async def get_model_name(self):
url = urljoin(self.base_url, "model")
headers = {
"x-api-key": self.api_key,
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
response_data = await response.json()
model_name = response_data.get("id")
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
return model_name
async def create_chat_completion(self, model, messages, **parameters):
url = urljoin(self.base_url, "chat/completions")
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": messages,
**parameters,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
return await response.json()
async def create_completion(self, model, **parameters):
url = urljoin(self.base_url, "completions")
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
data = {
"model": model,
**parameters,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
return await response.json()
class Defaults(CommonDefaults, pydantic.BaseModel):
api_url: str = "http://localhost:5000/v1"
api_key: str = ""
@@ -153,7 +98,6 @@ class TabbyAPIClient(ClientBase):
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
self.client = CustomAPIClient(base_url=self.api_url, api_key=self.api_key)
self.model_name = (
kwargs.get("model") or kwargs.get("model_name") or self.model_name
)
@@ -178,49 +122,150 @@ class TabbyAPIClient(ClientBase):
return prompt
async def get_model_name(self):
return await self.client.get_model_name()
url = urljoin(self.api_url, "model")
headers = {
"x-api-key": self.api_key,
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=10.0)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}")
response_data = response.json()
model_name = response_data.get("id")
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
return model_name
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
Generates text from the given prompt and parameters using streaming responses.
"""
# Determine whether we are using chat or completions endpoint
is_chat = self.api_handles_prompt_template
try:
if self.api_handles_prompt_template:
# Custom API handles prompt template
# Use the chat completions endpoint
if is_chat:
# Chat completions endpoint
self.log.debug(
"generate (chat/completions)",
prompt=prompt[:128] + " ...",
parameters=parameters,
)
human_message = {"role": "user", "content": prompt.strip()}
response = await self.client.create_chat_completion(
self.model_name, [human_message], **parameters
)
response = response["choices"][0]["message"]["content"]
return self.process_response_for_indirect_coercion(prompt, response)
payload = {
"model": self.model_name,
"messages": [human_message],
"stream": True,
"stream_options": {
"include_usage": True,
},
**parameters,
}
endpoint = "chat/completions"
else:
# Talemate handles prompt template
# Use the completions endpoint
# Completions endpoint
self.log.debug(
"generate (completions)",
prompt=prompt[:128] + " ...",
parameters=parameters,
)
parameters["prompt"] = prompt
response = await self.client.create_completion(
self.model_name, **parameters
)
return response["choices"][0]["text"]
payload = {
"model": self.model_name,
"prompt": prompt,
"stream": True,
**parameters,
}
endpoint = "completions"
url = urljoin(self.api_url, endpoint)
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
response_text = ""
buffer = ""
completion_tokens = 0
prompt_tokens = 0
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
url,
headers=headers,
json=payload,
timeout=120.0
) as response:
async for chunk in response.aiter_text():
buffer += chunk
while True:
line_end = buffer.find('\n')
if line_end == -1:
break
line = buffer[:line_end].strip()
buffer = buffer[line_end + 1:]
if not line:
continue
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
try:
data_obj = json.loads(data)
choice = data_obj.get("choices", [{}])[0]
# Chat completions use delta -> content.
delta = choice.get("delta", {})
content = (
delta.get("content")
or delta.get("text")
or choice.get("text")
)
usage = data_obj.get("usage", {})
completion_tokens = usage.get("completion_tokens", 0)
prompt_tokens = usage.get("prompt_tokens", 0)
if content:
response_text += content
self.update_request_tokens(self.count_tokens(content))
except json.JSONDecodeError:
# ignore malformed json chunks
pass
# Save token stats for logging
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
if is_chat:
# Process indirect coercion
response_text = self.process_response_for_indirect_coercion(prompt, response_text)
return response_text
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
emit("status", message="Client API: Permission Denied", status="error")
return ""
except httpx.ConnectTimeout:
self.log.error("API timeout")
emit("status", message="TabbyAPI: Request timed out", status="error")
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit(
"status", message="Error during generation (check logs)", status="error"
)
emit("status", message="Error during generation (check logs)", status="error")
return ""
def reconfigure(self, **kwargs):

View File

@@ -195,7 +195,8 @@ class TextGeneratorWebuiClient(ClientBase):
payload = json.loads(event.data)
chunk = payload['choices'][0]['text']
response += chunk
self.update_request_tokens(self.count_tokens(chunk))
return response

View File

@@ -22,7 +22,10 @@ class CmdSetEnvironmentToScene(TalemateCommand):
player_character = self.scene.get_player_character()
if not player_character:
self.system_message("No player character found")
self.system_message("No characters found - cannot switch to gameplay mode.", meta={
"icon": "mdi-alert",
"color": "warning",
})
return True
self.scene.set_environment("scene")

View File

@@ -93,6 +93,7 @@ class General(BaseModel):
auto_save: bool = True
auto_progress: bool = True
max_backscroll: int = 512
add_default_character: bool = True
class StateReinforcementTemplate(BaseModel):
@@ -161,6 +162,9 @@ class DeepSeekConfig(BaseModel):
api_key: Union[str, None] = None
class OpenRouterConfig(BaseModel):
api_key: Union[str, None] = None
class RunPodConfig(BaseModel):
api_key: Union[str, None] = None
@@ -177,6 +181,7 @@ class CoquiConfig(BaseModel):
class GoogleConfig(BaseModel):
gcloud_credentials_path: Union[str, None] = None
gcloud_location: Union[str, None] = None
api_key: Union[str, None] = None
class TTSVoiceSamples(BaseModel):
@@ -209,6 +214,7 @@ class EmbeddingFunctionPreset(BaseModel):
gpu_recommendation: bool = False
local: bool = True
custom: bool = False
client: str | None = None
@@ -506,6 +512,8 @@ class Config(BaseModel):
anthropic: AnthropicConfig = AnthropicConfig()
openrouter: OpenRouterConfig = OpenRouterConfig()
cohere: CohereConfig = CohereConfig()
groq: GroqConfig = GroqConfig()

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
__all__ = [
"ArchiveEntry",
]
@dataclass
class ArchiveEntry:
text: str
start: int = None
end: int = None
ts: str = None

View File

@@ -25,7 +25,7 @@ class AsyncSignal:
async def send(self, emission):
for receiver in self.receivers:
await receiver(emission)
def _register(name: str):
"""

View File

@@ -180,11 +180,11 @@ class Emitter:
def setup_emitter(self, scene: Scene = None):
self.emit_for_scene = scene
def emit(self, typ: str, message: str, character: Character = None):
emit(typ, message, character=character, scene=self.emit_for_scene)
def emit(self, typ: str, message: str, character: Character = None, **kwargs):
emit(typ, message, character=character, scene=self.emit_for_scene, **kwargs)
def system_message(self, message: str):
self.emit("system", message)
def system_message(self, message: str, **kwargs):
self.emit("system", message, **kwargs)
def narrator_message(self, message: str):
self.emit("narrator", message)

View File

@@ -49,6 +49,8 @@ SpiceApplied = signal("spice_applied")
WorldSateManager = signal("world_state_manager")
TalemateStarted = signal("talemate_started")
handlers = {
"system": SystemMessage,
"narrator": NarratorMessage,
@@ -86,4 +88,5 @@ handlers = {
"memory_request": MemoryRequest,
"player_choice": PlayerChoiceMessage,
"world_state_manager": WorldSateManager,
"talemate_started": TalemateStarted,
}

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import pydantic
import talemate.emit.async_signals as async_signals
@@ -29,10 +28,9 @@ class HistoryEvent(Event):
@dataclass
class ArchiveEvent(Event):
text: str
memory_id: str = None
memory_id: str
ts: str = None
@dataclass
class CharacterStateEvent(Event):
state: str

Some files were not shown because too many files have changed in this diff Show More