* some prompt cleanup

* prompt tweaks

* prompt tweaks

* prompt tweaks

* set 0.31.0

* relock

* rag queries add brief analysis

* brief analysis before building rag questions

* rag improvements

* prompt tweaks

* address circular import issues

* set 0.30.1

* docs

* numpy to 2

* docs

* prompt tweaks

* prompt tweak

* some template cleanup

* prompt viewer increase height

* fix codemirror highlighting not working

* adjust details height

* update default

* change to log debug

* allow response to scale to max height

* template cleanup

* prompt tweaks

* first progress for installing modules to scene

* package install logic

* package install polish

* package install polish

* package install polish and fixes

* refactor initial world state update and expose setting for it

* add `num` property to ModuleProperty to control order of widgets

* dynamic storyline package info

* fix issue where deactivating player character would cause inconsistencies in the creative tools menui

* cruft

* add openrouter support

* ollama support

* refactor how model choices are loaded, so that can be done per client instance as opposed to just per client type

* set num_ctx

* remove debug messages

* ollama tweaks

* toggle for whether or not default character gets added to blank talemate scenes

* narrator prompt tweaks and template cleanup

* cleanup

* prompt tweaks and template cleanup

* prompt tweaks

* fix instructor embeddings

* add additional error handling to prevent broken world state templates from breaking the world editor side menu

* fix openrouter breaking startup if not configured

* remove debug message

* promp tweaks

* fix example dialogue generation no longer working

* prompt tweaks and better showing of dialogue examples in conversation instructions

* prompt tweak

* add initial startup message

* prompt tweaks

* fix asset error

* move complex acting instructions into the task block

* fix content input socket on DynamicInstructions node

* log.error with traceback over log.exception since that has a tendency to explode and hang everything

* fix usse with number nodes where they would try to execute even if the incoming wire wasnt active

* fix issue with editor revision events missing template_vars

* DynamicInstruction node should only run if both header and content can resolve

* removed remaining references to 90s adventure game writing style

* prompt tweaks

* support embeddings via client apis (koboldcpp)

* fix label on client-api embeddings

* fix issue where adding / removing an embedding preset would not be reflected immediately in the memory agent config

* remove debug output

* prompt tweaks

* prompt tweaks

* autocomplete passes message object

* validate group names to be filename valid

* embedded winsows env installs and up to poetry2

* version config

* get-pip

* relock

* pin runpod

* no longer needed

* remove rapidfuzz dependency

* nodejs directly into embedded_node without a versioned middleman dir - also remove defunct local-tts install script

* fix update script

* update script error handling

* update.bat error handling

* adjust wording

* support loading jinja2 templates node modules in templates/modules

* update google model list

* client t/s and business indicator - also switch all clients to async streaming

* formatting

* support coercion for anthropic / google
switch to the new google genai sdk
upgrade websockets

* more coercion fixes

* gracefully handle keyboard interrupt

* EmitSystemMessage node

* allow visual prompt generation without image generation

* relock

* chromadb to v1

* fix error handling

* fix issue where adding client model list would be empty

* supress pip install warnings

* allow overriding of base urls

* remove key from log

* add fade effect

* tweak request info ux

* further clarification of endpoint override api key

* world state manager: fix issue that caused max changes setting to disappear from character progress config

* fix issue with google safety settings off causing generation failures

* update to base url should always reset the client

* getattr

* support v3 chara card version and attempt to future proof

* client based embeddings improvements

* more fixes for client based embeddings

* use client icon

* history management tools progress

* history memory ids fixed and added validation

* regenerate summary fixes

* more history regeneration fixes

* fix layered history gen and prompt twweaks

* allow regeneration of individual layered history entries

* return list of LayeredArchiveEntry

* reorg for less code dupelication

* new scene message renderer based on marked

* add inspect functionality to history viewer

* message if no history entries yet

* allow adding of history entries manually

* allow deletion of history

* summarization unslop improvements

* fix make charcter real action from worldstate listing

* allow overriding length in all context generation isntructioon dialogs

* fix issue where extract_list could fail with an unhandled error if the llm response didnt contain a list

* update whats'new

* fix issues with the new history management tools

* fix check

* Switch dependency handling to UV (#202)

* Migrate from Poetry to uv package manager            (#200)

* migrate from poetry to uv package manager

* Update all installation and startup scripts for uv migration

* Fix pyproject.toml for uv - allow direct references for hatchling

* Fix PR feedback: Restore removed functionality

- Restored embedded Python/Node.js functionality in install.bat and update.bat
- Restored environment variable exposure in docker-compose.yml (CUDA_AVAILABLE, port configs)
- Fixed GitHub Actions branches (main, prep-* instead of main, dev)
- Restored fail-fast: false and cache configuration in test.yml

These changes preserve all the functionality that should not be removed
during the migration from Poetry to uv.

---------

Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>

* remove uv.lock from .gitignore

* add lock file

* fix install issues

* warn if unable to remove legacy poetry virt env dir

* uv needs to be explicitly installed into the .venv so its available

* third time's the charm?

* fix windows install scripts

* add .venv guard to update.bat

* call :die

* fix docker venv install

* node 21

* fix cuda install

* start.bat calls install if needed

* sync start-local to other startup scripts

* no need to activate venv

---------

Co-authored-by: Ztripez <reg@otherland.nu>
Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>

* ignore hfhub symlink warnings

* add openrouter and ollama mentions

* update windows install documentation

* docs

* docs

* fix issue with memory agent fingerprint

* removing a client that supports embeddings will also remove any embedding functions it created

* on invalid embeddings reset to default

* docs

* typo

* formatting

* docs

* docs

* install package

* adjust topic

* add more obvious way to exit creative mode

* when importing character cards immediately persist a usable save after the restoration save

---------

Co-authored-by: Ztripez <reg@otherland.nu>
Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>
This commit is contained in:
veguAI
2025-06-29 18:06:11 +03:00
committed by GitHub
parent e4d465ba42
commit 9eb4c48d79
222 changed files with 43178 additions and 9603 deletions

View File

@@ -2,9 +2,9 @@ name: Python Tests
on: on:
push: push:
branches: [ master, main, 'prep-*' ] branches: [ main, 'prep-*' ]
pull_request: pull_request:
branches: [ master, main, 'prep-*' ] branches: [ main, 'prep-*' ]
jobs: jobs:
test: test:
@@ -23,25 +23,24 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'pip' cache: 'pip'
- name: Install poetry - name: Install uv
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install poetry pip install uv
- name: Cache poetry dependencies - name: Cache uv dependencies
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: ~/.cache/pypoetry path: ~/.cache/uv
key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} key: ${{ runner.os }}-uv-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
restore-keys: | restore-keys: |
${{ runner.os }}-poetry-${{ matrix.python-version }}- ${{ runner.os }}-uv-${{ matrix.python-version }}-
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m venv talemate_env uv venv
source talemate_env/bin/activate source .venv/bin/activate
poetry config virtualenvs.create false uv pip install -e ".[dev]"
poetry install
- name: Setup configuration file - name: Setup configuration file
run: | run: |
@@ -49,10 +48,10 @@ jobs:
- name: Download NLTK data - name: Download NLTK data
run: | run: |
source talemate_env/bin/activate source .venv/bin/activate
python -c "import nltk; nltk.download('punkt_tab')" python -c "import nltk; nltk.download('punkt_tab')"
- name: Run tests - name: Run tests
run: | run: |
source talemate_env/bin/activate source .venv/bin/activate
pytest tests/ -p no:warnings pytest tests/ -p no:warnings

3
.gitignore vendored
View File

@@ -8,6 +8,9 @@
talemate_env talemate_env
chroma chroma
config.yaml config.yaml
# uv
.venv/
templates/llm-prompt/user/*.jinja2 templates/llm-prompt/user/*.jinja2
templates/world-state/*.yaml templates/world-state/*.yaml
scenes/ scenes/

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

View File

@@ -1,15 +1,19 @@
# Stage 1: Frontend build # Stage 1: Frontend build
FROM node:21 AS frontend-build FROM node:21-slim AS frontend-build
ENV NODE_ENV=development
WORKDIR /app WORKDIR /app
# Copy the frontend directory contents into the container at /app # Copy frontend package files
COPY ./talemate_frontend /app COPY talemate_frontend/package*.json ./
# Install all dependencies and build # Install dependencies
RUN npm install && npm run build RUN npm ci
# Copy frontend source
COPY talemate_frontend/ ./
# Build frontend
RUN npm run build
# Stage 2: Backend build # Stage 2: Backend build
FROM python:3.11-slim AS backend-build FROM python:3.11-slim AS backend-build
@@ -22,30 +26,25 @@ RUN apt-get update && apt-get install -y \
gcc \ gcc \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install poetry # Install uv
RUN pip install poetry RUN pip install uv
# Copy poetry files # Copy installation files
COPY pyproject.toml poetry.lock* /app/ COPY pyproject.toml uv.lock /app/
# Create a virtual environment # Copy the Python source code (needed for editable install)
RUN python -m venv /app/talemate_env
# Activate virtual environment and install dependencies
RUN . /app/talemate_env/bin/activate && \
poetry config virtualenvs.create false && \
poetry install --only main --no-root
# Copy the Python source code
COPY ./src /app/src COPY ./src /app/src
# Create virtual environment and install dependencies
RUN uv sync
# Conditional PyTorch+CUDA install # Conditional PyTorch+CUDA install
ARG CUDA_AVAILABLE=false ARG CUDA_AVAILABLE=false
RUN . /app/talemate_env/bin/activate && \ RUN . /app/.venv/bin/activate && \
if [ "$CUDA_AVAILABLE" = "true" ]; then \ if [ "$CUDA_AVAILABLE" = "true" ]; then \
echo "Installing PyTorch with CUDA support..." && \ echo "Installing PyTorch with CUDA support..." && \
pip uninstall torch torchaudio -y && \ uv pip uninstall torch torchaudio && \
pip install torch~=2.4.1 torchaudio~=2.4.1 --index-url https://download.pytorch.org/whl/cu121; \ uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
fi fi
# Stage 3: Final image # Stage 3: Final image
@@ -57,8 +56,11 @@ RUN apt-get update && apt-get install -y \
bash \ bash \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install uv in the final stage
RUN pip install uv
# Copy virtual environment from backend-build stage # Copy virtual environment from backend-build stage
COPY --from=backend-build /app/talemate_env /app/talemate_env COPY --from=backend-build /app/.venv /app/.venv
# Copy Python source code # Copy Python source code
COPY --from=backend-build /app/src /app/src COPY --from=backend-build /app/src /app/src
@@ -83,4 +85,4 @@ EXPOSE 5050
EXPOSE 8080 EXPOSE 8080
# Use bash as the shell, activate the virtual environment, and run backend server # Use bash as the shell, activate the virtual environment, and run backend server
CMD ["/bin/bash", "-c", "source /app/talemate_env/bin/activate && python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050 --frontend-host 0.0.0.0 --frontend-port 8080"] CMD ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]

View File

@@ -39,12 +39,14 @@ Need help? Join the new [Discord community](https://discord.gg/8bGNRmFxMj)
- [Cohere](https://www.cohere.com/) - [Cohere](https://www.cohere.com/)
- [Groq](https://www.groq.com/) - [Groq](https://www.groq.com/)
- [Google Gemini](https://console.cloud.google.com/) - [Google Gemini](https://console.cloud.google.com/)
- [OpenRouter](https://openrouter.ai/)
Supported self-hosted APIs: Supported self-hosted APIs:
- [KoboldCpp](https://koboldai.org/cpp) ([Local](https://koboldai.org/cpp), [Runpod](https://koboldai.org/runpodcpp), [VastAI](https://koboldai.org/vastcpp), also includes image gen support) - [KoboldCpp](https://koboldai.org/cpp) ([Local](https://koboldai.org/cpp), [Runpod](https://koboldai.org/runpodcpp), [VastAI](https://koboldai.org/vastcpp), also includes image gen support)
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support) - [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
- [LMStudio](https://lmstudio.ai/) - [LMStudio](https://lmstudio.ai/)
- [TabbyAPI](https://github.com/theroyallab/tabbyAPI/) - [TabbyAPI](https://github.com/theroyallab/tabbyAPI/)
- [Ollama](https://ollama.com/)
Generic OpenAI api implementations (tested and confirmed working): Generic OpenAI api implementations (tested and confirmed working):
- [DeepInfra](https://deepinfra.com/) - [DeepInfra](https://deepinfra.com/)

View File

@@ -18,4 +18,4 @@ services:
environment: environment:
- PYTHONUNBUFFERED=1 - PYTHONUNBUFFERED=1
- PYTHONPATH=/app/src:$PYTHONPATH - PYTHONPATH=/app/src:$PYTHONPATH
command: ["/bin/bash", "-c", "source /app/talemate_env/bin/activate && python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050 --frontend-host 0.0.0.0 --frontend-port 8080"] command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]

View File

@@ -0,0 +1,166 @@
# Adding a new world-state template
I am writing this up as I add phrase detection functionality to the `Writing Style` template, so that in the future, hopefully when new template types need to be added this document can just given to the LLM of the month, to do it.
## Introduction
World state templates are reusable components that plug in various parts of talemate.
At this point there are following types:
- Character Attribute
- Character Detail
- Writing Style
- Spice (for randomization of content during generation)
- Scene Type
- State Reinforcement
Basically whenever we want to add something reusable and customizable by the user, a world state template is likely a good solution.
## Steps to creating a new template type
### 1. Add a pydantic schema (python)
In `src/talemate/world_state/templates` create a new `.py` file with reasonable name.
In this example I am extending the `Writing Style` template to include phrase detection functionality, which will be used by the `Editor` agent to detect certain phrases and then act upon them.
There already is a `content.py` file - so it makes sense to just add this new functionality to this file.
```python
class PhraseDetection(pydantic.BaseModel):
phrase: str
instructions: str
# can be "unwanted" for now, more added later
classification: Literal["unwanted"] = "unwanted"
@register("writing_style")
class WritingStyle(Template):
description: str | None = None
phrases: list[PhraseDetection] = pydantic.Field(default_factory=list)
def render(self, scene: "Scene", character_name: str):
return self.formatted("instructions", scene, character_name)
```
If I were to create a new file I'd still want to read one of the existing files first to understand imports and style.
### 2. Add a vue component to allow management (vue, js)
Next we need to add a new vue component that exposes a UX for us to manage this new template type.
For this I am creating `talemate_frontend/src/components/WorldStateManagerTemplateWritingStyle.vue`.
## Bare Minimum Understanding for New Template Components
When adding a new component for managing a template type, you need to understand:
### Component Structure
1. **Props**: The component always receives an `immutableTemplate` prop with the template data.
2. **Data Management**: Create a local copy of the template data for editing before saving back.
3. **Emits**: Use the `update` event to send modified template data back to the parent.
### Core Implementation Requirements
1. **Template Properties**: Always include fields for `name`, `description`, and `favorite` status.
2. **Data Binding**: Implement two-way binding with `v-model` for all editable fields.
3. **Dirty State Tracking**: Track when changes are made but not yet saved.
4. **Save Method**: Implement a `save()` method that emits the updated template.
### Component Lifecycle
1. **Initialization**: Use the `created` hook to initialize the local template copy.
2. **Watching for Changes**: Set up a watcher for the `immutableTemplate` to handle external updates.
### UI Patterns
1. **Forms**: Use Vuetify form components with consistent validation.
2. **Actions**: Provide clear user actions for editing and managing template items.
3. **Feedback**: Give visual feedback when changes are being made or saved.
The WorldStateManagerTemplate components follow a consistent pattern where they:
- Display and edit general template metadata (name, description, favorite status)
- Provide specialized UI for the template's unique properties
- Handle the create, read, update, delete (CRUD) operations for template items
- Maintain data integrity by properly handling template updates
You absolutely should read an existing component like `WorldStateManagerTemplateWritingStyle.vue` first to get a good understanding of the implementation.
## Integrating with WorldStateManagerTemplates
After creating your template component, you need to integrate it with the WorldStateManagerTemplates component:
### 1. Import the Component
Edit `talemate_frontend/src/components/WorldStateManagerTemplates.vue` and add an import for your new component:
```javascript
import WorldStateManagerTemplateWritingStyle from './WorldStateManagerTemplateWritingStyle.vue'
```
### 2. Register the Component
Add your component to the components section of the WorldStateManagerTemplates:
```javascript
components: {
// ... existing components
WorldStateManagerTemplateWritingStyle
}
```
### 3. Add Conditional Rendering
In the template section, add a new conditional block to render your component when the template type matches:
```html
<WorldStateManagerTemplateWritingStyle v-else-if="template.template_type === 'writing_style'"
:immutableTemplate="template"
@update="(template) => applyAndSaveTemplate(template)"
/>
```
### 4. Add Icon and Color
Add cases for your template type in the `iconForTemplate` and `colorForTemplate` methods:
```javascript
iconForTemplate(template) {
// ... existing conditions
else if (template.template_type == 'writing_style') {
return 'mdi-script-text';
}
return 'mdi-cube-scan';
},
colorForTemplate(template) {
// ... existing conditions
else if (template.template_type == 'writing_style') {
return 'highlight5';
}
return 'grey';
}
```
### 5. Add Help Message
Add a help message for your template type in the `helpMessages` object in the data section:
```javascript
helpMessages: {
// ... existing messages
writing_style: "Writing style templates are used to define a writing style that can be applied to the generated content. They can be used to add a specific flavor or tone. A template must explicitly support writing styles to be able to use a writing style template.",
}
```
### 6. Update Template Type Selection
Add your template type to the `templateTypes` array in the data section:
```javascript
templateTypes: [
// ... existing types
{ "title": "Writing style", "value": 'writing_style'},
]
```

View File

@@ -10,20 +10,19 @@ To run the server on a different host and port, you need to change the values pa
#### :material-linux: Linux #### :material-linux: Linux
Copy `start.sh` to `start_custom.sh` and edit the `--host` and `--port` parameters in the `uvicorn` command. Copy `start.sh` to `start_custom.sh` and edit the `--host` and `--port` parameters.
```bash ```bash
#!/bin/sh #!/bin/sh
. talemate_env/bin/activate uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 1234
python src/talemate/server/run.py runserver --host 0.0.0.0 --port 1234
``` ```
#### :material-microsoft-windows: Windows #### :material-microsoft-windows: Windows
Copy `start.bat` to `start_custom.bat` and edit the `--host` and `--port` parameters in the `uvicorn` command. Copy `start.bat` to `start_custom.bat` and edit the `--host` and `--port` parameters.
```batch ```batch
start cmd /k "cd talemate_env\Scripts && activate && cd ../../ && python src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234" uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
``` ```
### Letting the frontend know about the new host and port ### Letting the frontend know about the new host and port
@@ -71,8 +70,7 @@ Copy `start.sh` to `start_custom.sh` and edit the `--frontend-host` and `--front
```bash ```bash
#!/bin/sh #!/bin/sh
. talemate_env/bin/activate uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
--frontend-host localhost --frontend-port 8082 --frontend-host localhost --frontend-port 8082
``` ```
@@ -81,7 +79,7 @@ python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
Copy `start.bat` to `start_custom.bat` and edit the `--frontend-host` and `--frontend-port` parameters. Copy `start.bat` to `start_custom.bat` and edit the `--frontend-host` and `--frontend-port` parameters.
```batch ```batch
start cmd /k "cd talemate_env\Scripts && activate && cd ../../ && python src\talemate\server\run.py runserver --host 0.0.0.0 --port 5055 --frontend-host localhost --frontend-port 8082" uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 5055 --frontend-host localhost --frontend-port 8082
``` ```
### Start the backend and frontend ### Start the backend and frontend
@@ -99,4 +97,3 @@ Start the backend and frontend as usual.
```batch ```batch
start_custom.bat start_custom.bat
``` ```

View File

@@ -1,4 +1,3 @@
## Quick install instructions ## Quick install instructions
### Dependencies ### Dependencies
@@ -7,6 +6,7 @@
1. node.js and npm - see instructions [here](https://nodejs.org/en/download/package-manager/) 1. node.js and npm - see instructions [here](https://nodejs.org/en/download/package-manager/)
1. python- see instructions [here](https://www.python.org/downloads/) 1. python- see instructions [here](https://www.python.org/downloads/)
1. uv - see instructions [here](https://github.com/astral-sh/uv#installation)
### Installation ### Installation
@@ -25,19 +25,15 @@ If everything went well, you can proceed to [connect a client](../../connect-a-c
1. Open a terminal. 1. Open a terminal.
2. Navigate to the project directory. 2. Navigate to the project directory.
3. Create a virtual environment by running `python3 -m venv talemate_env`. 3. uv will automatically create a virtual environment when you run `uv venv`.
4. Activate the virtual environment by running `source talemate_env/bin/activate`.
### Installing Dependencies ### Installing Dependencies
1. With the virtual environment activated, install poetry by running `pip install poetry`. 1. Use uv to install dependencies by running `uv pip install -e ".[dev]"`.
2. Use poetry to install dependencies by running `poetry install`.
### Running the Backend ### Running the Backend
1. With the virtual environment activated and dependencies installed, you can start the backend server. 1. You can start the backend server using `uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
2. Navigate to the `src/talemate/server` directory.
3. Run the server with `python run.py runserver --host 0.0.0.0 --port 5050`.
### Running the Frontend ### Running the Frontend
@@ -45,4 +41,4 @@ If everything went well, you can proceed to [connect a client](../../connect-a-c
2. If you haven't already, install npm dependencies by running `npm install`. 2. If you haven't already, install npm dependencies by running `npm install`.
3. Start the server with `npm run serve`. 3. Start the server with `npm run serve`.
Please note that you may need to set environment variables or modify the host and port as per your setup. You can refer to the `runserver.sh` and `frontend.sh` files for more details. Please note that you may need to set environment variables or modify the host and port as per your setup. You can refer to the various start scripts for more details.

View File

@@ -2,16 +2,9 @@
## Windows ## Windows
### Installation fails with "Microsoft Visual C++" or "ValueError: The onnxruntime python package is not installed." errors
If your installation errors with a notification to upgrade "Microsoft Visual C++" go to [https://visualstudio.microsoft.com/visual-cpp-build-tools/](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and click "Download Build Tools" and run it.
- During installation make sure you select the C++ development package (upper left corner)
- Run `reinstall.bat` inside talemate directory
### Frontend fails with errors ### Frontend fails with errors
- ensure none of the directories have special characters in them, this can cause issues with the frontend. so no `(1)` in the directory name. - ensure none of the directories leading to your talemate directory have special characters in them, this can cause issues with the frontend. so no `(1)` in the directory name.
## Docker ## Docker

View File

@@ -1,53 +1,32 @@
## Quick install instructions ## Quick install instructions
1. Download and install Python 3.10 - 3.13 from the [official Python website](https://www.python.org/downloads/windows/). 1. Download the latest Talemate release ZIP from the [Releases page](https://github.com/vegu-ai/talemate/releases) and extract it anywhere on your system (for example, `C:\Talemate`).
- [Click here for direct link to python 3.11.9 download](https://www.python.org/downloads/release/python-3119/) 2. Double-click **`start.bat`**.
- June 2025: people have reported issues with python 3.13 still, due to some dependencies not being available yet, if you run into issues during installation try downgrading. - On the very first run Talemate will automatically:
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/prebuilt-installer). This will also install npm. 1. Download a portable build of Python 3 and Node.js (no global installs required).
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases). 2. Create and configure a Python virtual environment.
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine. 3. Install all back-end and front-end dependencies with the included *uv* and *npm*.
1. **Optional:** If you are using an nvidia graphics card with CUDA support you may want to also run `install-cuda.bat` **afterwards**, to install the cuda enabled version of torch - although this is only needed if you want to run some bigger embedding models where CUDA can be helpful. 4. Build the web client.
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`. 3. When the console window prints **"Talemate is now running"** and the logo appears, open your browser at **http://localhost:8080**.
1. Once the talemate logo shows up, navigate your browser to http://localhost:8080
!!! note "First start up may take a while" !!! note "First start can take a while"
We have seen cases where the first start of talemate will sit at a black screen for a minute or two. Just wait it out, eventually the Talemate logo should show up. The initial download and dependency installation may take several minutes, especially on slow internet connections. The console will keep you updated just wait until the Talemate logo shows up.
If everything went well, you can proceed to [connect a client](../../connect-a-client). ### Optional: CUDA support
## Additional Information If you have an NVIDIA GPU and want CUDA acceleration for larger embedding models:
### How to Install Python 1. Close Talemate (if it is running).
2. Double-click **`install-cuda.bat`**. This script swaps the CPU-only Torch build for the CUDA 12.8 build.
3. Start Talemate again via **`start.bat`**.
--8<-- "docs/snippets/common.md:python-versions" ## Maintenance & advanced usage
1. Visit the official Python website's download page for Windows at [https://www.python.org/downloads/windows/](https://www.python.org/downloads/windows/). | Script | Purpose |
2. Find the latest updated of Python 3.13 and click on one of the download links. (You will likely want the Windows installer (64-bit)) |--------|---------|
4. Run the installer file and follow the setup instructions. Make sure to check the box that says Add Python 3.13 to PATH before you click Install Now. | **`start.bat`** | Primary entry point performs the initial install if needed and then starts Talemate. |
| **`install.bat`** | Runs the installer without launching the server. Useful for automated setups or debugging. |
| **`install-cuda.bat`** | Installs the CUDA-enabled Torch build (run after the regular install). |
| **`update.bat`** | Pulls the latest changes from GitHub, updates dependencies, rebuilds the web client. |
### How to Install npm No system-wide Python or Node.js is required Talemate uses the embedded runtimes it downloads automatically.
1. Download Node.js from the official site [https://nodejs.org/en/download/prebuilt-installer](https://nodejs.org/en/download/prebuilt-installer).
2. Run the installer (the .msi installer is recommended).
3. Follow the prompts in the installer (Accept the license agreement, click the NEXT button a bunch of times and accept the default installation settings).
### Usage of the Supplied bat Files
#### install.bat
This batch file is used to set up the project on your local machine. It creates a virtual environment, activates it, installs poetry, and uses poetry to install dependencies. It then navigates to the frontend directory and installs the necessary npm packages.
To run this file, simply double click on it or open a command prompt in the same directory and type `install.bat`.
#### update.bat
If you are inside a git checkout of talemate you can use this to pull and reinstall talemate if there have been updates.
!!! note "CUDA needs to be reinstalled manually"
Running `update.bat` will downgrade your torch install to the non-CUDA version, so if you want CUDA support you will need to run the `install-cuda.bat` script after the update is finished.
#### start.bat
This batch file is used to start the backend and frontend servers. It opens two command prompts, one for the frontend and one for the backend.
To run this file, simply double click on it or open a command prompt in the same directory and type `start.bat`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

BIN
docs/img/0.31.0/history.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

View File

@@ -0,0 +1,25 @@
# KoboldCpp Embeddings
Talemate can leverage an embeddings model that is already loaded in your KoboldCpp instance.
## 1. Start KoboldCpp with an embeddings model
Launch KoboldCpp with the `--embeddingsmodel` flag so that it loads an embeddings-capable GGUF model alongside the main LLM:
```bash
koboldcpp_cu12.exe --model google_gemma-3-27b-it-Q6_K.gguf --embeddingsmodel bge-large-en-v1.5.Q8_0.gguf
```
## 2. Talemate will pick it up automatically
When Talemate starts, the **Memory** agent probes every connected client that advertises embedding support. If it detects that your KoboldCpp instance has an embeddings model loaded:
1. The Memory backend switches the current embedding to **Client API**.
2. The `client` field in the agent details shows the name of the KoboldCpp client.
3. A banner informs you that Talemate has switched to the new embedding. <!-- stub: screenshot -->
![Memory agent automatically switched to KoboldCpp embeddings](/talemate/img/0.31.0/koboldcpp-embeddings.png)
## 3. Reverting to a local embedding
Open the memory agent settings and pick a different embedding. See [Memory agent settings](/talemate/user-guide/agents/memory/settings).

View File

@@ -5,5 +5,6 @@ nav:
- Google Cloud: google.md - Google Cloud: google.md
- Groq: groq.md - Groq: groq.md
- Mistral.ai: mistral.md - Mistral.ai: mistral.md
- OpenRouter: openrouter.md
- OpenAI: openai.md - OpenAI: openai.md
- ... - ...

View File

@@ -0,0 +1,11 @@
# OpenRouter API Setup
Talemate can use any model accessible through OpenRouter.
You need an OpenRouter API key and must set it in the application config. You can create and manage keys in your OpenRouter dashboard at [https://openrouter.ai/keys](https://openrouter.ai/keys).
Once you have generated a key open the Talemate settings, switch to the `APPLICATION` tab and then select the `OPENROUTER API` category. Paste your key in the **API Key** field.
![OpenRouter API settings](/talemate/img/0.31.0/openrouter-settings.png)
Finally click **Save** to store the credentials.

View File

@@ -4,4 +4,5 @@ nav:
- Recommended Local Models: recommended-models.md - Recommended Local Models: recommended-models.md
- Inference Presets: presets.md - Inference Presets: presets.md
- Client Types: types - Client Types: types
- Endpoint Override: endpoint-override.md
- ... - ...

View File

@@ -0,0 +1,24 @@
# Endpoint Override
Starting in version 0.31.0 it is now possible for some of the remote clients to override the endpoint used for the API.
THis is helpful wehn you want to point the client at a proxy gateway to serve the api instead (LiteLLM for example).
!!! warning "Only use trusted endpoints"
Only use endpoints that you trust and NEVER used your actual API key with them, unless you are hosting your endpoint proxy yourself.
If you need to provide an api key there is a separate field for that specifically in the endpoint override settings.
## How to use
Clients that support it will have a tab in their settings that allows you to override the endpoint.
![Endpoint Override](/talemate/img/0.31.0/client-endpoint-override.png)
##### Base URL
The base URL of the endpoint. For example, `http://localhost:4000` if you're running a local LiteLLM gateway,
##### API Key
The API key to use for the endpoint. This is only required if the endpoint requires an API key. This is **NOT** the API key you would use for the official API. For LiteLLM for example this could be the `general_settings.master_key` value.

View File

@@ -8,6 +8,8 @@ nav:
- Mistral.ai: mistral.md - Mistral.ai: mistral.md
- OpenAI: openai.md - OpenAI: openai.md
- OpenAI Compatible: openai-compatible.md - OpenAI Compatible: openai-compatible.md
- Ollama: ollama.md
- OpenRouter: openrouter.md
- TabbyAPI: tabbyapi.md - TabbyAPI: tabbyapi.md
- Text-Generation-WebUI: text-generation-webui.md - Text-Generation-WebUI: text-generation-webui.md
- ... - ...

View File

@@ -0,0 +1,59 @@
# Ollama Client
If you want to add an Ollama client, change the `Client Type` to `Ollama`.
![Client Ollama](/talemate/img/0.31.0/client-ollama.png)
Click `Save` to add the client.
### Ollama Server
The client should appear in the clients list. Talemate will ping the Ollama server to verify that it is running. If the server is not reachable you will see a warning.
![Client ollama offline](/talemate/img/0.31.0/client-ollama-offline.png)
Make sure that the Ollama server is running (by default at `http://localhost:11434`) and that the model you want to use has been pulled.
It may also show a yellow dot next to it, saying that there is no model loaded.
![Client ollama no model](/talemate/img/0.31.0/client-ollama-no-model.png)
Open the client settings by clicking the :material-cogs: icon, to select a model.
![Ollama settings](/talemate/img/0.31.0/client-ollama-select-model.png)
Click save and the client should have a green dot next to it, indicating that it is ready to go.
![Client ollama ready](/talemate/img/0.31.0/client-ollama-ready.png)
### Settings
##### Client Name
A unique name for the client that makes sense to you.
##### API URL
The base URL where the Ollama HTTP endpoint is running. Defaults to `http://localhost:11434`.
##### Model
Name of the Ollama model to use. Talemate will automatically fetch the list of models that are currently available in your local Ollama instance.
##### API handles prompt template
If enabled, Talemate will send the raw prompt and let Ollama apply its own built-in prompt template. If you are unsure leave this disabled Talemate's own prompt template generally produces better results.
##### Allow thinking
If enabled Talemate will allow models that support "thinking" (`assistant:thinking` messages) to deliberate before forming the final answer. At the moment Talemate has limited support for this feature when talemate is handling the prompt template. Its probably ok to turn it on if you let Ollama handle the prompt template.
!!! tip
You can quickly refresh the list of models by making sure the Ollama server is running and then hitting **Save** again in the client settings.
### Common issues
#### Generations are weird / bad
If letting talemate handle the prompt template, make sure the [correct prompt template is assigned](/talemate/user-guide/clients/prompt-templates/).

View File

@@ -0,0 +1,48 @@
# OpenRouter Client
If you want to add an OpenRouter client, change the `Client Type` to `OpenRouter`.
![Client OpenRouter](/talemate/img/0.31.0/client-openrouter.png)
Click `Save` to add the client.
### OpenRouter API Key
The client should appear in the clients list. If you haven't set up OpenRouter before, you will see a warning that the API key is missing.
![Client openrouter no api key](/talemate/img/0.31.0/client-openrouter-no-api-key.png)
Click the `SET API KEY` button. This will open the API settings window where you can add your OpenRouter API key.
For additional instructions on obtaining and setting your OpenRouter API key, see [OpenRouter API instructions](/talemate/user-guide/apis/openrouter/).
![OpenRouter settings](/talemate/img/0.31.0/openrouter-settings.png)
Click `Save` and after a moment the client should have a red dot next to it, saying that there is no model loaded.
Click the :material-cogs: icon to open the client settings and select a model.
![OpenRouter select model](/talemate/img/0.31.0/client-openrouter-select-model.png).
Click save and the client should have a green dot next to it, indicating that it is ready to go.
### Ready to use
![Client OpenRouter Ready](/talemate/img/0.31.0/client-openrouter-ready.png)
### Settings
##### Client Name
A unique name for the client that makes sense to you.
##### Model
Choose any model available via your OpenRouter account. Talemate dynamically fetches the list of models associated with your API key so new models will show up automatically.
##### Max token length
Maximum context length (in tokens) that OpenRouter should consider. If you are not sure leave the default value.
!!! note "Available models are fetched automatically"
Talemate fetches the list of available OpenRouter models when you save the configuration (if a valid API key is present). If you add or remove models to your account later, simply click **Save** in the application settings again to refresh the list.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.7 KiB

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.2 KiB

After

Width:  |  Height:  |  Size: 7.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -2,17 +2,6 @@
This tutorial will show you how to use the `Dynamic Storyline` module (added in `0.30`) to randomize the scene introduction for ANY scene. This tutorial will show you how to use the `Dynamic Storyline` module (added in `0.30`) to randomize the scene introduction for ANY scene.
!!! note "A more streamlined approach is coming soon"
I am aware that some people may not want to touch the node editor at all, so a more streamlined approach is planned.
For now this will lay out the simplest way to set this up while still using the node editor.
!!! learn-more "For those interested..."
There is tutorial on how the `Dynamic Storyline` module was made (or at least the beginnings of it).
If you are interested in the process, you can find it [here](/talemate/user-guide/howto/infinity-quest-dynamic).
## Save a foundation scene copy ## Save a foundation scene copy
This should be a save of your scene that has had NO progress made to it yet. We are generating a new scene introduction after all. This should be a save of your scene that has had NO progress made to it yet. We are generating a new scene introduction after all.
@@ -21,59 +10,52 @@ The introduction is only generated once. So you should maintain a save-file of t
To ensure this foundation scene save isn't overwritten you can go to the scene settings in the world editor and turn on the Locked save file flag: To ensure this foundation scene save isn't overwritten you can go to the scene settings in the world editor and turn on the Locked save file flag:
![Immutable save](./img/0008.png) ![Immutable save](./img/0001.png)
Save the scene. Save the scene.
## Switch to the node editor ## Install the module
In your scene tools find the :material-puzzle-edit: creative menu and click on the **Node Editor** option. Click the `Mods` tab in the world editor.
![Node Editor](./img/0001.png) ![Mods Tab](./img/0002.png)
Find the `COPY AS EDITABLE MODULE FOR ..` button beneath the node editor.
![Copy as editable module](./img/0002.png) Find the `Dynamic Storyline` module and click **Install**.
Click it. It will say installed (not configured)
In the next window, don't even read any of the stuff, just click **Continue**. ![Installed (not configured)](./img/0003.png)
## Find a blank area Click **Configure** and set topic to something like `Sci-fi adventure with lovecraftian horror`.
Use the mousewheel to zoom out a bit, then click the canvas and drag it to the side so you're looking at some blank space. Literally anywhere that's grey background is fine. ![Configure Module](./img/0004.png)
Double click the empty area to bring up the module searcand type in "Dynamic Story" into th !!! note "Optional settings"
![Dynamic Story](./img/0003.png) ##### Max intro text length
How many tokens to generate for the intro text.
Select the `Dynamic Storyline` node to add it to the scene. ##### Additional instructions for topic analysis task
If topic analysis is enabled, this will be used to augment the topic analysis task with further instructions
![Dynamic Story](./img/0004.png) ##### Enable topic analysis
This will enable the topic analysis task
Click the `topic` input and type in a general genre or thematic guide for the story. **Save** the module configuration.
Some examples Finally click "Reload Scene" in the left sidebar.
- `sci-fi with cosmic horror elements` ![Reload Scene](./img/0007.png)
- `dungeons and dragons campaign ideas`
- `slice of life story ideas`
Whatever you enter will be used to generate a list of story ideas, of which one will be chosen at random to bootstrap a new story, taking the scene context that exists already into account. If everything is configured correctly, the storyline generation will begin immediately.
This will NOT create new characters or world context. ![Dynamic Storyline Module Configured](./img/0005.png)
It simply bootstraps a story premise based on the random topic and what's already there. !!! note "Switch out of edit mode"
Once the topic is set, save the changes by clicking the node editor's **Save** button in the upper right corner. If nothing is happening after configuration and reloading the scene, make sure you are not in edit mode.
![Save](./img/0005.png) You can leave edit mode by clicking the "Exit Node Editor" button in the creative menu.
Exit the node editor through the same menu as before.
![Exit node editor](./img/0006.png)
Once back in the scene, if everythign was done correctly you should see it working on setting the scene introduction.
![Scene introduction](./img/0007.png)
![Exit Node Editor](./img/0006.png)

View File

@@ -0,0 +1,105 @@
# Installable Packages
It is possible to "package" your node modules so they can be installed into a scene.
This allows for easier controlled set up and makes your node module more sharable as users no longer need to use the node editor to install it.
Installable packages show up in the Mods list once a scene is loaded.
![Mods List](../img/package-0003.png)
## 1. Create a package module
To create a package - click the **:material-plus: Create Module** button in the node editor and select **Package**.
![Create Package Module](../img/package-0001.png)
The package itself is a special kind of node module that will let Talemate know that your node module is installable and how to install it.
## 2. Open the module properties
With the package module open find the module properties in the upper left corner of the node editor.
![Module Properties](../img/package-0002.png)
Fill in the fields:
##### The name of the node module
This is what the module package will be called in the Mods list.
Example:
```
Dynamic Storyline
```
##### The author of the node module
Your name or handle. This is arbitrary and just lets people know who made the package.
##### The description of the node module
A short description of the package. This is displayed in the Mods list.
Example:
```
Generate a random story premise at the beginning of the scene.
```
##### Whether the node module is installable to the scene
A checkbox to indicate if the package is installable to the scene.
Right now this should always be checked, there are no other package types currently.
##### Whether the scene loop should be restarted when the package is installed
If checked, installing this package will restart the scene loop. This is mostly important for modules that require to hook into the scene loop init event.
## 3. Install instructions
Currently there are only two nodes relevant for the node module install process.
1. `Install Node Module` - this node is used to make sure the target node module is added to the scene loop when installing the package. You can have more than one of these nodes in your package.
1. `Promote Config` - this node is used to promote your node module's properties to configurable fields in the mods list. E.g., this dictates what the user can configure when installing the package.
### Install Node Module
![Install Node Module](../img/package-0004.png)
!!! payload "Install Node Module"
| Property | Value |
|----------|-------|
| node_registry | the registry path of the node module to install |
### Promote Config
![Promote Config](../img/package-0005.png)
!!! payload "Promote Config"
| Property | Value |
|----------|-------|
| node_registry | the registry path of the node module |
| property_name | the name of the property to promote (as it is set in the node module) |
| exposed_property_name | expose as this name in the mods list, this can be the same as the property name or a different name - this is important if youre installing multiple node modules with the same property name, so you can differentiate between them |
| required | whether the property is required to be set when installing the package |
| label | a user friendly label for the property |
### Make talemate aware of the package
For talemate to be aware of the package, you need to copy it to the public node module directory, which exists as `templates/modules/`.
Create a new sub directory:
```
./templates/modules/<your-package-name>/
```
Copy the package module and your node module files into the directory.
Restart talemate and the package should now show up in the Mods list.

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@@ -1,14 +1,63 @@
# History # History
Will hold the archived history for the scene. Will hold historical events for the scene.
This historical archive is extended everytime the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) summarizes the scene. There are three types of historical entries:
Summarization happens when a certain progress treshold is reached (tokens) or when a time passage event occurs. - **Archived history (static)** - These are entries are manually defined and dated before the starting point of the scene. Things that happened in the past that are **IMPORTANT** for the understanding of the world. For anything that is not **VITAL**, use world entries instead.
- **Archived history (from summary)** - These are historical entries generated from the progress in the scene. Whenever a certain token (length) threshold is reached, the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) will generate a summary of the progression and add it to the history.
- **Layered history (from summary)** - As summary archives are generated, they themselves will be summarized again and added to the history, leading to a natural compression of the history sent with the context while also keeping track of the most important events. (hopefully)
All archived history is a potential candidate to be included in the context sent to the AI based on relevancy. This is handled by the [Memory Agent](/talemate/user-guide/agents/memory/). ![History](/talemate/img/0.31.0/history.png)
You can use the **:material-refresh: Regenerate History** button to force a new summarization of the scene. ## Layers
!!! warning There is always the **BASE** layer, which is where the archived history (both static and from summary) is stored. For all intents and purposes, this is layer 0.
If there has been lots of progress this will potentially take a long time to complete.
At the beginning of a scene, there won't be any additional layers, as any layer past layer 0 will come from summarization down the line.
Note that layered history is managed by the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) and can be disabled in its settings.
### Managing entries
- **All entries** can be edited by double-clicking the text.
- **Static entries** can be deleted by clicking the **:material-close-box-outline: Delete** button.
- **Summarized entries** can be regenerated by clicking the **:material-refresh: Regenerate** button. This will cause the LLM to re-summarize the entry and update the text.
- **Summarized entries** can be inspected by clicking the **:material-magnify-expand: Inspect** button. This will expand the entry and show the source entries that were used to generate the summary.
### Adding static entries
Static entries can be added by clicking the **:material-plus: Add Entry** button.
!!! note "Static entries must be older than any summary entries"
Static entries must be older than any summary entries. This is to ensure that the history is always chronological.
Trying to add a static entry that is more recent than any summary entry will result in an error.
##### Entry Text
The text of the entry. Should be at most 1 - 2 paragraphs. Less is more. Anything that needs great detail should be a world entry instead.
##### Unit
Defines the duration unit of the entry. So minutes, hours, days, weeks, months or years.
##### Amount
Defines the duration unit amount of the entry.
So if you want to define something that happened 10 months ago (from the current moment in the scene), you would set the unit to months and the amount to 10.
![Add Static Entry](/talemate/img/0.31.0/history-add-entry.png)
## Regenerate everything
It is possible to regenerate the entire history by clicking the **:material-refresh: Regenerate All History** button in the left sidebar. Static entries will remain unchanged.
![Regenerate All History](/talemate/img/0.31.0/history-regenerate-all.png)
!!! warning "This can take a long time"
This will go through the entire scene progress and regenerate all summarized entries.
If you have a lot of progress, be ready to wait for a while.

View File

@@ -1,8 +1,23 @@
REM activate the virtual environment @echo off
call talemate_env\Scripts\activate
REM uninstall torch and torchaudio REM Check if .venv exists
python -m pip uninstall torch torchaudio -y IF NOT EXIST ".venv" (
echo [ERROR] .venv directory not found. Please run install.bat first.
goto :eof
)
REM install torch and torchaudio REM Check if embedded Python exists
python -m pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128 IF NOT EXIST "embedded_python\python.exe" (
echo [ERROR] embedded_python not found. Please run install.bat first.
goto :eof
)
REM uninstall torch and torchaudio using embedded Python's uv
embedded_python\python.exe -m uv pip uninstall torch torchaudio --python .venv\Scripts\python.exe
REM install torch and torchaudio with CUDA support using embedded Python's uv
embedded_python\python.exe -m uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128 --python .venv\Scripts\python.exe
echo.
echo CUDA versions of torch and torchaudio installed!
echo You may need to restart your application for changes to take effect.

View File

@@ -1,10 +1,7 @@
#!/bin/bash #!/bin/bash
# activate the virtual environment
source talemate_env/bin/activate
# uninstall torch and torchaudio # uninstall torch and torchaudio
python -m pip uninstall torch torchaudio -y uv pip uninstall torch torchaudio
# install torch and torchaudio # install torch and torchaudio with CUDA support
python -m pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128 uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128

View File

@@ -1,4 +0,0 @@
REM activate the virtual environment
call talemate_env\Scripts\activate
call pip install "TTS>=0.21.1"

28579
install-utils/get-pip.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,65 +1,227 @@
@echo off @echo off
REM Check for Python version and use a supported version if available REM ===============================
SET PYTHON=python REM Talemate project installer
python -c "import sys; sys.exit(0 if sys.version_info[:2] in [(3, 10), (3, 11), (3, 12), (3, 13)] else 1)" 2>nul REM ===============================
IF NOT ERRORLEVEL 1 ( REM 1. Detect CPU architecture and pick the best-fitting embedded Python build.
echo Selected Python version: %PYTHON% REM 2. Download & extract that build into .\embedded_python\
GOTO EndVersionCheck REM 3. Bootstrap pip via install-utils\get-pip.py
) REM 4. Install virtualenv and create .\talemate_env\ using the embedded Python.
REM 5. Activate the venv and proceed with Poetry + frontend installation.
REM ---------------------------------------------------------------
SET PYTHON=python SETLOCAL ENABLEDELAYEDEXPANSION
FOR /F "tokens=*" %%i IN ('py --list') DO (
echo %%i | findstr /C:"-V:3.11 " >nul && SET PYTHON=py -3.11 && GOTO EndPythonCheck
echo %%i | findstr /C:"-V:3.10 " >nul && SET PYTHON=py -3.10 && GOTO EndPythonCheck
)
:EndPythonCheck
%PYTHON% -c "import sys; sys.exit(0 if sys.version_info[:2] in [(3, 10), (3, 11), (3, 12), (3, 13)] else 1)" 2>nul
IF ERRORLEVEL 1 (
echo Unsupported Python version. Please install Python 3.10 or 3.11.
exit /b 1
)
IF "%PYTHON%"=="python" (
echo Default Python version is being used: %PYTHON%
) ELSE (
echo Selected Python version: %PYTHON%
)
:EndVersionCheck REM Define fatal-error handler
REM Usage: CALL :die "Message explaining what failed"
goto :after_die
IF ERRORLEVEL 1 ( :die
echo Unsupported Python version. Please install Python 3.10 or 3.11. echo.
exit /b 1 echo ============================================================
) echo !!! INSTALL FAILED !!!
echo %*
REM create a virtual environment echo ============================================================
%PYTHON% -m venv talemate_env
REM activate the virtual environment
call talemate_env\Scripts\activate
REM upgrade pip and setuptools
python -m pip install --upgrade pip setuptools
REM install poetry
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies
python -m poetry install
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
IF NOT EXIST config.yaml copy config.example.yaml config.yaml
REM navigate to the frontend directory
echo Installing frontend dependencies...
cd talemate_frontend
call npm install
echo Building frontend...
call npm run build
REM return to the root directory
cd ..
echo Installation completed successfully.
pause pause
exit 1
:after_die
REM ---------[ Check Prerequisites ]---------
ECHO Checking prerequisites...
where tar >nul 2>&1 || CALL :die "tar command not found. Please ensure Windows 10 version 1803+ or install tar manually."
where curl >nul 2>&1
IF %ERRORLEVEL% NEQ 0 (
where bitsadmin >nul 2>&1 || CALL :die "Neither curl nor bitsadmin found. Cannot download files."
)
REM ---------[ Remove legacy Poetry venv if present ]---------
IF EXIST "talemate_env" (
ECHO Detected legacy Poetry virtual environment 'talemate_env'. Removing...
RD /S /Q "talemate_env"
IF ERRORLEVEL 1 (
ECHO [WARNING] Failed to fully remove legacy 'talemate_env' directory. Continuing installation.
)
)
REM ---------[ Clean reinstall check ]---------
SET "NEED_CLEAN=0"
IF EXIST ".venv" SET "NEED_CLEAN=1"
IF EXIST "embedded_python" SET "NEED_CLEAN=1"
IF EXIST "embedded_node" SET "NEED_CLEAN=1"
IF "%NEED_CLEAN%"=="1" (
ECHO.
ECHO Detected existing Talemate environments.
REM Prompt user (empty input defaults to Y)
SET "ANSWER=Y"
SET /P "ANSWER=Perform a clean reinstall of the python and node.js environments? [Y/n] "
IF /I "!ANSWER!"=="N" (
ECHO Installation aborted by user.
GOTO :EOF
)
ECHO Removing previous installation...
IF EXIST ".venv" RD /S /Q ".venv"
IF EXIST "embedded_python" RD /S /Q "embedded_python"
IF EXIST "embedded_node" RD /S /Q "embedded_node"
ECHO Cleanup complete.
)
REM ---------[ Version configuration ]---------
SET "PYTHON_VERSION=3.11.9"
SET "NODE_VERSION=22.16.0"
REM ---------[ Detect architecture & choose download URL ]---------
REM Prefer PROCESSOR_ARCHITEW6432 when the script is run from a 32-bit shell on 64-bit Windows
IF DEFINED PROCESSOR_ARCHITEW6432 (
SET "ARCH=%PROCESSOR_ARCHITEW6432%"
) ELSE (
SET "ARCH=%PROCESSOR_ARCHITECTURE%"
)
REM Map architecture to download URL
IF /I "%ARCH%"=="AMD64" (
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x64.zip"
) ELSE IF /I "%ARCH%"=="IA64" (
REM Itanium systems are rare, but AMD64 build works with WoW64 layer
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x64.zip"
) ELSE IF /I "%ARCH%"=="ARM64" (
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-arm64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-arm64.zip"
) ELSE (
REM Fallback to 64-bit build for x86 / unknown architectures
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x86.zip"
)
ECHO Detected architecture: %ARCH%
ECHO Downloading embedded Python from: %PY_URL%
REM ---------[ Download ]---------
SET "PY_ZIP=python_embed.zip"
where curl >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Using curl to download Python...
curl -L -# -o "%PY_ZIP%" "%PY_URL%" || CALL :die "Failed to download Python embed package with curl."
) ELSE (
ECHO curl not found, falling back to bitsadmin...
bitsadmin /transfer "DownloadPython" /download /priority normal "%PY_URL%" "%CD%\%PY_ZIP%" || CALL :die "Failed to download Python embed package (curl & bitsadmin unavailable)."
)
REM ---------[ Extract ]---------
SET "PY_DIR=embedded_python"
IF EXIST "%PY_DIR%" RD /S /Q "%PY_DIR%"
mkdir "%PY_DIR%" || CALL :die "Could not create directory %PY_DIR%."
where tar >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Extracting with tar...
tar -xf "%PY_ZIP%" -C "%PY_DIR%" || CALL :die "Failed to extract Python embed package with tar."
) ELSE (
CALL :die "tar utility not found (required to unpack zip without PowerShell)."
)
DEL /F /Q "%PY_ZIP%"
SET "PYTHON=%PY_DIR%\python.exe"
ECHO Using embedded Python at %PYTHON%
REM ---------[ Enable site-packages in embedded Python ]---------
FOR %%f IN ("%PY_DIR%\python*._pth") DO (
ECHO Adding 'import site' to %%~nxf ...
echo import site>>"%%~ff"
)
REM ---------[ Ensure pip ]---------
ECHO Installing pip...
"%PYTHON%" install-utils\get-pip.py || (
CALL :die "pip installation failed."
)
REM Upgrade pip to latest
"%PYTHON%" -m pip install --no-warn-script-location --upgrade pip || CALL :die "Failed to upgrade pip in embedded Python."
REM ---------[ Install uv ]---------
ECHO Installing uv...
"%PYTHON%" -m pip install uv || (
CALL :die "uv installation failed."
)
REM ---------[ Create virtual environment with uv ]---------
ECHO Creating virtual environment with uv...
"%PYTHON%" -m uv venv || (
CALL :die "Virtual environment creation failed."
)
REM ---------[ Install dependencies using embedded Python's uv ]---------
ECHO Installing backend dependencies with uv...
"%PYTHON%" -m uv sync || CALL :die "Failed to install backend dependencies with uv."
REM Activate the venv for the remainder of the script
CALL .venv\Scripts\activate
REM echo python version
python --version
REM ---------[ Config file ]---------
IF NOT EXIST config.yaml COPY config.example.yaml config.yaml
REM ---------[ Node.js portable runtime ]---------
ECHO.
ECHO Downloading portable Node.js runtime...
REM Node download variables already set earlier based on %ARCH%.
ECHO Downloading Node.js from: %NODE_URL%
SET "NODE_ZIP=node_embed.zip"
where curl >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Using curl to download Node.js...
curl -L -# -o "%NODE_ZIP%" "%NODE_URL%" || CALL :die "Failed to download Node.js package with curl."
) ELSE (
ECHO curl not found, falling back to bitsadmin...
bitsadmin /transfer "DownloadNode" /download /priority normal "%NODE_URL%" "%CD%\%NODE_ZIP%" || CALL :die "Failed to download Node.js package (curl & bitsadmin unavailable)."
)
REM ---------[ Extract Node.js ]---------
SET "NODE_DIR=embedded_node"
IF EXIST "%NODE_DIR%" RD /S /Q "%NODE_DIR%"
mkdir "%NODE_DIR%" || CALL :die "Could not create directory %NODE_DIR%."
where tar >nul 2>&1
IF %ERRORLEVEL% EQU 0 (
ECHO Extracting Node.js...
tar -xf "%NODE_ZIP%" -C "%NODE_DIR%" --strip-components 1 || CALL :die "Failed to extract Node.js package with tar."
) ELSE (
CALL :die "tar utility not found (required to unpack zip without PowerShell)."
)
DEL /F /Q "%NODE_ZIP%"
REM Prepend Node.js folder to PATH so npm & node are available
SET "PATH=%CD%\%NODE_DIR%;%PATH%"
ECHO Using portable Node.js at %CD%\%NODE_DIR%\node.exe
ECHO Node.js version:
node -v
REM ---------[ Frontend ]---------
ECHO Installing frontend dependencies...
CD talemate_frontend
CALL npm install || CALL :die "npm install failed."
ECHO Building frontend...
CALL npm run build || CALL :die "Frontend build failed."
REM Return to repo root
CD ..
ECHO.
ECHO ==============================
ECHO Installation completed!
ECHO ==============================
PAUSE
ENDLOCAL

View File

@@ -1,20 +1,16 @@
#!/bin/bash #!/bin/bash
# create a virtual environment # create a virtual environment with uv
echo "Creating a virtual environment..." echo "Creating a virtual environment with uv..."
python3 -m venv talemate_env uv venv
# activate the virtual environment # activate the virtual environment
echo "Activating the virtual environment..." echo "Activating the virtual environment..."
source talemate_env/bin/activate source .venv/bin/activate
# install poetry # install dependencies with uv
echo "Installing poetry..."
pip install poetry
# use poetry to install dependencies
echo "Installing dependencies..." echo "Installing dependencies..."
poetry install uv pip install -e ".[dev]"
# copy config.example.yaml to config.yaml only if config.yaml doesn't exist # copy config.example.yaml to config.yaml only if config.yaml doesn't exist
if [ ! -f config.yaml ]; then if [ ! -f config.yaml ]; then

6554
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,77 +1,82 @@
[build-system] [project]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
[tool.poetry]
name = "talemate" name = "talemate"
version = "0.30.0" version = "0.31.0"
description = "AI-backed roleplay and narrative tools" description = "AI-backed roleplay and narrative tools"
authors = ["VeguAITools"] authors = [{name = "VeguAITools"}]
license = "GNU Affero General Public License v3.0" license = {text = "GNU Affero General Public License v3.0"}
requires-python = ">=3.10,<3.14"
dependencies = [
"astroid>=2.8",
"jedi>=0.18",
"black",
"rope>=0.22",
"isort>=5.10",
"jinja2>=3.0",
"openai>=1",
"mistralai>=0.1.8",
"cohere>=5.2.2",
"anthropic>=0.19.1",
"groq>=0.5.0",
"requests>=2.26",
"colorama>=0.4.6",
"Pillow>=9.5",
"httpx<1",
"piexif>=1.1",
"typing-inspect==0.8.0",
"typing_extensions>=4.5.0",
"uvicorn>=0.23",
"blinker>=1.6.2",
"pydantic<3",
"beautifulsoup4>=4.12.2",
"python-dotenv>=1.0.0",
"structlog>=23.1.0",
# 1.7.11 breaks subprocess stuff ???
"runpod==1.7.10",
"google-genai>=1.20.0",
"nest_asyncio>=1.5.7",
"isodate>=0.6.1",
"thefuzz>=0.20.0",
"tiktoken>=0.5.1",
"nltk>=3.8.1",
"huggingface-hub>=0.20.2",
"RestrictedPython>7.1",
"numpy>=2",
"aiofiles>=24.1.0",
"pyyaml>=6.0",
"limits>=5.0",
"diff-match-patch>=20241021",
"sseclient-py>=1.8.0",
"ollama>=0.5.1",
# ChromaDB
"chromadb>=1.0.12",
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
"torch>=2.7.0",
"torchaudio>=2.7.0",
# locked for instructor embeddings
#sentence-transformers==2.2.2
"sentence_transformers>=2.7.0",
]
[tool.poetry.dependencies] [project.optional-dependencies]
python = ">=3.10,<3.14" dev = [
astroid = "^2.8" "pytest>=6.2",
jedi = "^0.18" "pytest-asyncio>=0.25.3",
black = "*" "mypy>=0.910",
rope = "^0.22" "mkdocs-material>=9.5.27",
isort = "^5.10" "mkdocs-awesome-pages-plugin>=2.9.2",
jinja2 = ">=3.0" "mkdocs-glightbox>=0.4.0",
openai = ">=1" ]
mistralai = ">=0.1.8"
cohere = ">=5.2.2"
anthropic = ">=0.19.1"
groq = ">=0.5.0"
requests = "^2.26"
colorama = ">=0.4.6"
Pillow = ">=9.5"
httpx = "<1"
piexif = "^1.1"
typing-inspect = "0.8.0"
typing_extensions = "^4.5.0"
uvicorn = "^0.23"
blinker = "^1.6.2"
pydantic = "<3"
beautifulsoup4 = "^4.12.2"
python-dotenv = "^1.0.0"
websockets = "^11.0.3"
structlog = "^23.1.0"
runpod = "^1.2.0"
google-cloud-aiplatform = ">=1.50.0"
nest_asyncio = "^1.5.7"
isodate = ">=0.6.1"
thefuzz = ">=0.20.0"
tiktoken = ">=0.5.1"
nltk = ">=3.8.1"
huggingface-hub = ">=0.20.2"
RestrictedPython = ">7.1"
numpy = "^2"
aiofiles = ">=24.1.0"
pyyaml = ">=6.0"
limits = ">=5.0"
diff-match-patch = ">=20241021"
sseclient-py = "^1.8.0"
# ChromaDB [project.scripts]
chromadb = ">=0.4.17,<1"
InstructorEmbedding = "^1.0.1"
torch = "^2.7.0"
torchaudio = "^2.7.0"
# locked for instructor embeddings
#sentence-transformers="==2.2.2"
sentence_transformers=">=2.7.0"
[tool.poetry.dev-dependencies]
pytest = ">=6.2"
pytest-asyncio = ">=0.25.3"
mypy = "^0.910"
mkdocs-material = ">=9.5.27"
mkdocs-awesome-pages-plugin = ">=2.9.2"
mkdocs-glightbox = ">=0.4.0"
[tool.poetry.scripts]
talemate = "talemate:cli.main" talemate = "talemate:cli.main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.black] [tool.black]
line-length = 88 line-length = 88
target-version = ['py38'] target-version = ['py38']
@@ -87,6 +92,7 @@ exclude = '''
| buck-out | buck-out
| build | build
| dist | dist
| talemate_env
)/ )/
''' '''

View File

@@ -1,18 +0,0 @@
@echo off
IF EXIST talemate_env rmdir /s /q "talemate_env"
REM create a virtual environment
python -m venv talemate_env
REM activate the virtual environment
call talemate_env\Scripts\activate
REM install poetry
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies
python -m poetry install
echo Virtual environment re-created.
pause

View File

@@ -0,0 +1,65 @@
{
"packages": [
{
"name": "Dynamic Storyline",
"author": "Talemate",
"description": "Generate a random story premise at the beginning of the scene.",
"installable": true,
"registry": "package/talemate/DynamicStoryline",
"status": "installed",
"errors": [],
"package_properties": {
"topic": {
"module": "scene/dynamicStoryline",
"name": "topic",
"label": "Topic",
"description": "The overarching topic - will be used to generate a theme that falls within this category. Example - 'Sci-fi adventure with cosmic horror'.",
"type": "str",
"default": "",
"value": "Sci-fi episodic adventures onboard of a spaceship, with focus on AI, alien contact, ancient creators and cosmic horror.",
"required": true,
"choices": []
},
"intro_length": {
"module": "scene/dynamicStoryline",
"name": "intro_length",
"label": "Max. intro text length (tokens)",
"description": "Length of the introduction",
"type": "int",
"default": 512,
"value": 512,
"required": true,
"choices": []
},
"analysis_instructions": {
"module": "scene/dynamicStoryline",
"name": "analysis_instructions",
"label": "Additional instructions for topic analysis task",
"description": "Additional instructions for topic analysis task - if topic analysis is enabled, this will be used to augment the topic analysis task with further instructions.",
"type": "text",
"default": "",
"value": "",
"required": false,
"choices": []
},
"analysis_enabled": {
"module": "scene/dynamicStoryline",
"name": "analysis_enabled",
"label": "Enable topic analysis",
"description": "Theme analysis",
"type": "bool",
"default": true,
"value": true,
"required": false,
"choices": []
}
},
"install_nodes": [
"scene/dynamicStoryline"
],
"installed_nodes": [],
"restart_scene_loop": true,
"configured": true
}
]
}

View File

@@ -1,6 +1,6 @@
{ {
"title": "Scene Loop", "title": "Scene Loop",
"id": "71652a76-5db3-4836-8f00-1085977cd8e8", "id": "af468414-b30d-4f67-b08e-5b7cfd139adc",
"properties": { "properties": {
"trigger_game_loop": true "trigger_game_loop": true
}, },
@@ -11,50 +11,10 @@
"collapsed": false, "collapsed": false,
"inherited": false, "inherited": false,
"registry": "scene/SceneLoop", "registry": "scene/SceneLoop",
"nodes": { "nodes": {},
"ede29db4-700d-4edc-b93b-bf7c79f6a6a5": {
"title": "Dynamic Storyline",
"id": "ede29db4-700d-4edc-b93b-bf7c79f6a6a5",
"properties": {
"event_name": "scene_loop_init",
"analysis_instructions": "",
"reset": false,
"topic": "sci-fi with cosmic horror elements",
"analysis_enabled": true,
"intro_length": 512
},
"x": 32,
"y": -249,
"width": 295,
"height": 158,
"collapsed": false,
"inherited": false,
"registry": "scene/dynamicStoryline",
"base_type": "core/Event"
}
},
"edges": {}, "edges": {},
"groups": [ "groups": [],
{ "comments": [],
"title": "Randomize Story",
"x": 8,
"y": -321,
"width": 619,
"height": 257,
"color": "#a1309b",
"font_size": 24,
"inherited": false
}
],
"comments": [
{
"text": "Will generate a randomized story line based on the topic given",
"x": 352,
"y": -269,
"width": 215,
"inherited": false
}
],
"extends": "src/talemate/game/engine/nodes/modules/scene/scene-loop.json", "extends": "src/talemate/game/engine/nodes/modules/scene/scene-loop.json",
"sleep": 0.001, "sleep": 0.001,
"base_type": "scene/SceneLoop", "base_type": "scene/SceneLoop",

View File

@@ -19,6 +19,7 @@ from talemate.agents.context import ActiveAgent, active_agent
from talemate.emit import emit from talemate.emit import emit
from talemate.events import GameLoopStartEvent from talemate.events import GameLoopStartEvent
from talemate.context import active_scene from talemate.context import active_scene
import talemate.config as config
from talemate.client.context import ( from talemate.client.context import (
ClientContext, ClientContext,
set_client_context_attribute, set_client_context_attribute,
@@ -438,6 +439,29 @@ class Agent(ABC):
except AttributeError: except AttributeError:
pass pass
async def save_config(self, app_config: config.Config | None = None):
"""
Saves the agent config to the config file.
If no config object is provided, the config is loaded from the config file.
"""
if not app_config:
app_config:config.Config = config.load_config(as_model=True)
app_config.agents[self.agent_type] = config.Agent(
name=self.agent_type,
client=self.client.name if self.client else None,
enabled=self.enabled,
actions={action_key: config.AgentAction(
enabled=action.enabled,
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
) for action_key, action in self.actions.items()}
)
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
config.save_config(app_config)
async def on_game_loop_start(self, event: GameLoopStartEvent): async def on_game_loop_start(self, event: GameLoopStartEvent):
""" """
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values Finds all ActionConfigs that have a scope of "scene" and resets them to their default values

View File

@@ -233,10 +233,20 @@ class ConversationAgent(
def generation_settings_actor_instructions_offset(self): def generation_settings_actor_instructions_offset(self):
return self.actions["generation_override"].config["actor_instructions_offset"].value return self.actions["generation_override"].config["actor_instructions_offset"].value
@property
def generation_settings_response_length(self):
return self.actions["generation_override"].config["length"].value
@property
def generation_settings_override_enabled(self):
return self.actions["generation_override"].enabled
@property @property
def content_use_writing_style(self) -> bool: def content_use_writing_style(self) -> bool:
return self.actions["content"].config["use_writing_style"].value return self.actions["content"].config["use_writing_style"].value
def connect(self, scene): def connect(self, scene):
super().connect(scene) super().connect(scene)
@@ -322,6 +332,7 @@ class ConversationAgent(
"actor_instructions_offset": self.generation_settings_actor_instructions_offset, "actor_instructions_offset": self.generation_settings_actor_instructions_offset,
"direct_instruction": instruction, "direct_instruction": instruction,
"decensor": self.client.decensor_enabled, "decensor": self.client.decensor_enabled,
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
}, },
) )

View File

@@ -3,6 +3,7 @@ import random
import uuid import uuid
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
import dataclasses import dataclasses
import traceback
import pydantic import pydantic
import structlog import structlog
@@ -328,7 +329,14 @@ class AssistantMixin:
if not content.startswith(generation_context.character + ":"): if not content.startswith(generation_context.character + ":"):
content = generation_context.character + ": " + content content = generation_context.character + ": " + content
content = util.strip_partial_sentences(content) content = util.strip_partial_sentences(content)
emission.response = await editor.cleanup_character_message(content, generation_context.character.name)
character = self.scene.get_character(generation_context.character)
if not character:
log.warning("Character not found", character=generation_context.character)
return content
emission.response = await editor.cleanup_character_message(content, character)
await async_signals.get("agent.creator.contextual_generate.after").send(emission) await async_signals.get("agent.creator.contextual_generate.after").send(emission)
return emission.response return emission.response
@@ -447,6 +455,7 @@ class AssistantMixin:
) )
continuing_message = False continuing_message = False
message = None
try: try:
message = self.scene.history[-1] message = self.scene.history[-1]
@@ -470,6 +479,7 @@ class AssistantMixin:
"can_coerce": self.client.can_be_coerced, "can_coerce": self.client.can_be_coerced,
"response_length": response_length, "response_length": response_length,
"continuing_message": continuing_message, "continuing_message": continuing_message,
"message": message,
"anchor": anchor, "anchor": anchor,
"non_anchor": non_anchor, "non_anchor": non_anchor,
"prefix": prefix, "prefix": prefix,
@@ -675,7 +685,7 @@ class AssistantMixin:
emit("status", f"Scene forked", status="success") emit("status", f"Scene forked", status="success")
except Exception as e: except Exception as e:
log.exception("Scene fork failed", exc=e) log.error("Scene fork failed", exc=traceback.format_exc())
emit("status", "Scene fork failed", status="error") emit("status", "Scene fork failed", status="error")

View File

@@ -4,6 +4,7 @@ import random
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import structlog import structlog
import traceback
import talemate.emit.async_signals import talemate.emit.async_signals
import talemate.instance as instance import talemate.instance as instance
@@ -259,7 +260,7 @@ class DirectorAgent(
except Exception as e: except Exception as e:
loading_status.done(message="Character creation failed", status="error") loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor) await scene.remove_actor(actor)
log.exception("Error persisting character", error=e) log.error("Error persisting character", error=traceback.format_exc())
async def log_action(self, action: str, action_description: str): async def log_action(self, action: str, action_description: str):
message = DirectorMessage(message=action_description, action=action) message = DirectorMessage(message=action_description, action=action)

View File

@@ -1,6 +1,7 @@
import pydantic import pydantic
import asyncio import asyncio
import structlog import structlog
import traceback
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from talemate.instance import get_agent from talemate.instance import get_agent
@@ -98,7 +99,7 @@ class DirectorWebsocketHandler(Plugin):
task = asyncio.create_task(self.director.persist_character(**payload.model_dump())) task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
async def handle_task_done(task): async def handle_task_done(task):
if task.exception(): if task.exception():
log.exception("Error persisting character", error=task.exception()) log.error("Error persisting character", error=task.exception())
await self.signal_operation_failed("Error persisting character") await self.signal_operation_failed("Error persisting character")
else: else:
self.websocket_handler.queue_put( self.websocket_handler.queue_put(

View File

@@ -54,7 +54,7 @@ class EditorAgent(
type="text", type="text",
label="Formatting", label="Formatting",
description="The formatting to use for exposition.", description="The formatting to use for exposition.",
value="chat", value="novel",
choices=[ choices=[
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"}, {"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
{"label": "Novel: \"Speech\" narration", "value": "novel"}, {"label": "Novel: \"Speech\" narration", "value": "novel"},

View File

@@ -27,6 +27,7 @@ from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission from talemate.agents.narrator import NarratorAgentEmission
from talemate.agents.creator.assistant import ContextualGenerateEmission from talemate.agents.creator.assistant import ContextualGenerateEmission
from talemate.agents.summarize import SummarizeEmission from talemate.agents.summarize import SummarizeEmission
from talemate.agents.summarize.layered_history import LayeredHistoryFinalizeEmission
from talemate.scene_message import CharacterMessage from talemate.scene_message import CharacterMessage
from talemate.util.dedupe import ( from talemate.util.dedupe import (
dedupe_sentences, dedupe_sentences,
@@ -387,13 +388,16 @@ class RevisionMixin:
async_signals.get("agent.summarization.summarize.after").connect( async_signals.get("agent.summarization.summarize.after").connect(
self.revision_on_generation self.revision_on_generation
) )
async_signals.get("agent.summarization.layered_history.finalize").connect(
self.revision_on_generation
)
# connect to the super class AFTER so these run first. # connect to the super class AFTER so these run first.
super().connect(scene) super().connect(scene)
async def revision_on_generation( async def revision_on_generation(
self, self,
emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission, emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission | LayeredHistoryFinalizeEmission,
): ):
""" """
Called when a conversation or narrator message is generated Called when a conversation or narrator message is generated
@@ -411,7 +415,15 @@ class RevisionMixin:
if isinstance(emission, NarratorAgentEmission) and "narrator" not in self.revision_automatic_targets: if isinstance(emission, NarratorAgentEmission) and "narrator" not in self.revision_automatic_targets:
return return
if isinstance(emission, SummarizeEmission) and "summarization" not in self.revision_automatic_targets: if isinstance(emission, SummarizeEmission):
if emission.summarization_type == "dialogue" and "summarization" not in self.revision_automatic_targets:
return
if emission.summarization_type == "events":
# event summarization is very pragmatic and doesn't really benefit
# from revision, so we skip it
return
if isinstance(emission, LayeredHistoryFinalizeEmission) and "summarization" not in self.revision_automatic_targets:
return return
try: try:
@@ -428,7 +440,7 @@ class RevisionMixin:
context_name = getattr(emission, "context_name", None), context_name = getattr(emission, "context_name", None),
) )
if isinstance(emission, SummarizeEmission): if isinstance(emission, (SummarizeEmission, LayeredHistoryFinalizeEmission)):
info.summarization_history = emission.summarization_history or [] info.summarization_history = emission.summarization_history or []
if isinstance(emission, ContextualGenerateEmission) and info.context_type not in CONTEXTUAL_GENERATION_TYPES: if isinstance(emission, ContextualGenerateEmission) and info.context_type not in CONTEXTUAL_GENERATION_TYPES:
@@ -489,7 +501,8 @@ class RevisionMixin:
log.warning("revision_revise: generation cancelled", text=info.text) log.warning("revision_revise: generation cancelled", text=info.text)
return info.text return info.text
except Exception as e: except Exception as e:
log.exception("revision_revise: error", error=e) import traceback
log.error("revision_revise: error", error=traceback.format_exc())
return info.text return info.text
finally: finally:
info.loading_status.done() info.loading_status.done()
@@ -872,7 +885,13 @@ class RevisionMixin:
if loading_status: if loading_status:
loading_status("Editor - Issues identified, analyzing text...") loading_status("Editor - Issues identified, analyzing text...")
template_vars = { emission = RevisionEmission(
agent=self,
info=info,
issues=issues,
)
emission.template_vars = {
"text": text, "text": text,
"character": character, "character": character,
"scene": self.scene, "scene": self.scene,
@@ -880,14 +899,11 @@ class RevisionMixin:
"max_tokens": self.client.max_token_length, "max_tokens": self.client.max_token_length,
"repetition": issues.repetition, "repetition": issues.repetition,
"bad_prose": issues.bad_prose, "bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
} }
emission = RevisionEmission(
agent=self,
template_vars=template_vars,
info=info,
issues=issues,
)
await async_signals.get("agent.editor.revision-revise.before").send( await async_signals.get("agent.editor.revision-revise.before").send(
emission emission
@@ -898,18 +914,7 @@ class RevisionMixin:
"editor.revision-analysis", "editor.revision-analysis",
self.client, self.client,
f"edit_768", f"edit_768",
vars={ vars=emission.template_vars,
"text": text,
"character": character,
"scene": self.scene,
"response_length": token_count,
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
},
dedupe_enabled=False, dedupe_enabled=False,
) )
@@ -1016,39 +1021,43 @@ class RevisionMixin:
log.debug("revision_unslop: issues", issues=issues, template=template) log.debug("revision_unslop: issues", issues=issues, template=template)
emission = RevisionEmission( emission = RevisionEmission(
agent=self, agent=self,
info=info, info=info,
issues=issues, issues=issues,
) )
emission.template_vars = {
"text": text,
"scene_analysis": scene_analysis,
"character": character,
"scene": self.scene,
"response_length": response_length,
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
"summarization_history": info.summarization_history,
}
await async_signals.get("agent.editor.revision-revise.before").send(emission) await async_signals.get("agent.editor.revision-revise.before").send(emission)
response = await Prompt.request( response = await Prompt.request(
template, template,
self.client, self.client,
"edit_768", "edit_768",
vars={ vars=emission.template_vars,
"text": text,
"scene_analysis": scene_analysis,
"character": character,
"scene": self.scene,
"response_length": response_length,
"max_tokens": self.client.max_token_length,
"repetition": issues.repetition,
"bad_prose": issues.bad_prose,
"dynamic_instructions": emission.dynamic_instructions,
"context_type": info.context_type,
"context_name": info.context_name,
"summarization_history": info.summarization_history,
},
dedupe_enabled=False, dedupe_enabled=False,
) )
# extract <FIX>...</FIX> # extract <FIX>...</FIX>
if "<FIX>" not in response: if "<FIX>" not in response:
log.error("revision_unslop: no <FIX> found in response", response=response) log.debug("revision_unslop: no <FIX> found in response", response=response)
return original_text return original_text
fix = response.split("<FIX>", 1)[1] fix = response.split("<FIX>", 1)[1]

View File

@@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
import asyncio import asyncio
import functools import functools
import hashlib import hashlib
import uuid import traceback
import numpy as np import numpy as np
from typing import Callable from typing import Callable
@@ -12,6 +14,8 @@ from chromadb.config import Settings
import talemate.events as events import talemate.events as events
import talemate.util as util import talemate.util as util
from talemate.client import ClientBase
import talemate.instance as instance
from talemate.agents.base import ( from talemate.agents.base import (
Agent, Agent,
AgentAction, AgentAction,
@@ -23,6 +27,7 @@ from talemate.config import load_config
from talemate.context import scene_is_loading, active_scene from talemate.context import scene_is_loading, active_scene
from talemate.emit import emit from talemate.emit import emit
from talemate.emit.signals import handlers from talemate.emit.signals import handlers
import talemate.emit.async_signals as async_signals
from talemate.agents.memory.context import memory_request, MemoryRequest from talemate.agents.memory.context import memory_request, MemoryRequest
from talemate.agents.memory.exceptions import ( from talemate.agents.memory.exceptions import (
EmbeddingsModelLoadError, EmbeddingsModelLoadError,
@@ -31,19 +36,23 @@ from talemate.agents.memory.exceptions import (
try: try:
import chromadb import chromadb
import chromadb.errors
from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
except ImportError: except ImportError:
chromadb = None chromadb = None
pass pass
from talemate.agents.registry import register
if TYPE_CHECKING:
from talemate.client.base import ClientEmbeddingsStatus
log = structlog.get_logger("talemate.agents.memory") log = structlog.get_logger("talemate.agents.memory")
if not chromadb: if not chromadb:
log.info("ChromaDB not found, disabling Chroma agent") log.info("ChromaDB not found, disabling Chroma agent")
from talemate.agents.registry import register
class MemoryDocument(str): class MemoryDocument(str):
def __new__(cls, text, meta, id, raw): def __new__(cls, text, meta, id, raw):
inst = super().__new__(cls, text) inst = super().__new__(cls, text)
@@ -107,6 +116,7 @@ class MemoryAgent(Agent):
self._ready_to_add = False self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved) handlers["config_saved"].connect(self.on_config_saved)
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
self.actions = MemoryAgent.init_actions(presets=self.get_presets) self.actions = MemoryAgent.init_actions(presets=self.get_presets)
@@ -125,8 +135,16 @@ class MemoryAgent(Agent):
@property @property
def get_presets(self): def get_presets(self):
def _label(embedding:dict):
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
if embedding['model']:
return f"{prefix}: {embedding['model']}"
else:
return f"{prefix}"
return [ return [
{"value": k, "label": f"{v['embeddings']}: {v['model']}"} for k,v in self.config.get("presets", {}).get("embeddings", {}).items() {"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
] ]
@property @property
@@ -150,6 +168,10 @@ class MemoryAgent(Agent):
def using_sentence_transformer_embeddings(self): def using_sentence_transformer_embeddings(self):
return self.embeddings == "default" or self.embeddings == "sentence-transformer" return self.embeddings == "default" or self.embeddings == "sentence-transformer"
@property
def using_client_api_embeddings(self):
return self.embeddings == "client-api"
@property @property
def using_local_embeddings(self): def using_local_embeddings(self):
return self.embeddings in [ return self.embeddings in [
@@ -158,6 +180,11 @@ class MemoryAgent(Agent):
"default" "default"
] ]
@property
def embeddings_client(self):
return self.embeddings_config.get("client")
@property @property
def max_distance(self) -> float: def max_distance(self) -> float:
distance = float(self.embeddings_config.get("distance", 1.0)) distance = float(self.embeddings_config.get("distance", 1.0))
@@ -186,7 +213,10 @@ class MemoryAgent(Agent):
""" """
Returns a unique fingerprint for the current configuration Returns a unique fingerprint for the current configuration
""" """
return f"{self.embeddings}-{self.model.replace('/','-')}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
model_name = self.model.replace('/','-') if self.model else "none"
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
async def apply_config(self, *args, **kwargs): async def apply_config(self, *args, **kwargs):
@@ -206,6 +236,10 @@ class MemoryAgent(Agent):
async def handle_embeddings_change(self): async def handle_embeddings_change(self):
scene = active_scene.get() scene = active_scene.get()
# if sentence-transformer and no model-name, set embeddings to default
if self.using_sentence_transformer_embeddings and not self.model:
self.actions["_config"].config["embeddings"].value = "default"
if not scene or not scene.get_helper("memory"): if not scene or not scene.get_helper("memory"):
return return
@@ -216,21 +250,49 @@ class MemoryAgent(Agent):
await scene.save(auto=True) await scene.save(auto=True)
emit("status", "Context database re-imported", status="success") emit("status", "Context database re-imported", status="success")
def sync_presets(self) -> list[dict]:
self.actions["_config"].config["embeddings"].choices = self.get_presets
return self.actions["_config"].config["embeddings"].choices
def on_config_saved(self, event): def on_config_saved(self, event):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
openai_key = self.openai_api_key openai_key = self.openai_api_key
fingerprint = self.fingerprint fingerprint = self.fingerprint
self.config = load_config() old_presets = self.actions["_config"].config["embeddings"].choices.copy()
self.config = load_config()
new_presets = self.sync_presets()
if fingerprint != self.fingerprint: if fingerprint != self.fingerprint:
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint) log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
loop.run_until_complete(self.handle_embeddings_change()) loop.run_until_complete(self.handle_embeddings_change())
emit_status = False
if openai_key != self.openai_api_key: if openai_key != self.openai_api_key:
emit_status = True
if old_presets != new_presets:
emit_status = True
if emit_status:
loop.run_until_complete(self.emit_status()) loop.run_until_complete(self.emit_status())
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
current_embeddings = self.actions["_config"].config["embeddings"].value
if current_embeddings == event.client.embeddings_identifier:
return
if not self.using_client_api_embeddings or not self.ready:
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
@set_processing @set_processing
async def set_db(self): async def set_db(self):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@@ -239,7 +301,7 @@ class MemoryAgent(Agent):
except EmbeddingsModelLoadError: except EmbeddingsModelLoadError:
raise raise
except Exception as e: except Exception as e:
log.error("memory agent", error="failed to set db", details=e) log.error("memory agent", error="failed to set db", details=traceback.format_exc())
if "torchvision::nms does not exist" in str(e): if "torchvision::nms does not exist" in str(e):
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed") raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
@@ -379,14 +441,12 @@ class MemoryAgent(Agent):
def _get_document(self, id): def _get_document(self, id):
raise NotImplementedError() raise NotImplementedError()
def on_archive_add(self, event: events.ArchiveEvent): async def on_archive_add(self, event: events.ArchiveEvent):
asyncio.ensure_future( await self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
)
def connect(self, scene): def connect(self, scene):
super().connect(scene) super().connect(scene)
scene.signals["archive_add"].connect(self.on_archive_add) async_signals.get("archive_add").connect(self.on_archive_add)
async def memory_context( async def memory_context(
self, self,
@@ -453,29 +513,72 @@ class MemoryAgent(Agent):
Get the character memory context for a given character Get the character memory context for a given character
""" """
memory_context = [] # First, collect results for each individual query (respecting the
# per-query `iterate` limit) so that we have them available before we
# start filling the final context. This prevents early queries from
# monopolising the token budget.
per_query_results: list[list[str]] = []
for query in queries: for query in queries:
# Skip empty queries so that we keep indexing consistent for the
# round-robin step that follows.
if not query: if not query:
per_query_results.append([])
continue continue
i = 0 # Fetch potential memories for this query.
for memory in await self.get(formatter(query), limit=limit, **where): raw_results = await self.get(
if memory in memory_context: formatter(query), limit=limit, **where
continue )
# Apply filter and respect the `iterate` limit for this query.
accepted: list[str] = []
for memory in raw_results:
if filter and not filter(memory): if filter and not filter(memory):
continue continue
accepted.append(memory)
if len(accepted) >= iterate:
break
per_query_results.append(accepted)
# Now interleave the results in a round-robin fashion so that each
# query gets a fair chance to contribute, until we hit the token
# budget.
memory_context: list[str] = []
idx = 0
while True:
added_any = False
for result_list in per_query_results:
if idx >= len(result_list):
# No more items remaining for this query at this depth.
continue
memory = result_list[idx]
# Avoid duplicates in the final context.
if memory in memory_context:
continue
memory_context.append(memory) memory_context.append(memory)
added_any = True
i += 1 # Check token budget after each addition.
if i >= iterate:
break
if util.count_tokens(memory_context) >= max_tokens: if util.count_tokens(memory_context) >= max_tokens:
break return memory_context
if util.count_tokens(memory_context) >= max_tokens:
if not added_any:
# We iterated over all query result lists without adding
# anything. That means we have exhausted all available
# memories.
break break
idx += 1
return memory_context return memory_context
@property @property
@@ -588,8 +691,31 @@ class ChromaDBMemoryAgent(MemoryAgent):
return True return True
return False return False
@property
def client_api_ready(self) -> bool:
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
if not embeddings_client:
return False
if not embeddings_client.supports_embeddings:
return False
if not embeddings_client.embeddings_status:
return False
if embeddings_client.current_status not in ["idle", "busy"]:
return False
return True
return False
@property @property
def status(self): def status(self):
if self.using_client_api_embeddings and not self.client_api_ready:
return "error"
if self.ready: if self.ready:
return "active" if not getattr(self, "processing", False) else "busy" return "active" if not getattr(self, "processing", False) else "busy"
@@ -612,12 +738,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
value=self.embeddings, value=self.embeddings,
description="The embeddings type.", description="The embeddings type.",
).model_dump(), ).model_dump(),
"model": AgentDetail(
}
if self.model:
details["model"] = AgentDetail(
icon="mdi-brain", icon="mdi-brain",
value=self.model, value=self.model,
description="The embeddings model.", description="The embeddings model.",
).model_dump(), ).model_dump()
}
if self.embeddings_client:
details["client"] = AgentDetail(
icon="mdi-network-outline",
value=self.embeddings_client,
description="The client to use for embeddings.",
).model_dump()
if self.using_local_embeddings: if self.using_local_embeddings:
details["device"] = AgentDetail( details["device"] = AgentDetail(
@@ -635,6 +771,37 @@ class ChromaDBMemoryAgent(MemoryAgent):
"color": "error", "color": "error",
} }
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
if not embeddings_client:
details["error"] = {
"icon": "mdi-alert",
"value": f"Client {self.embeddings_client} not found",
"description": f"Client {self.embeddings_client} not found",
"color": "error",
}
return details
client_name = embeddings_client.name
if not embeddings_client.supports_embeddings:
error_message = f"{client_name} does not support embeddings"
elif embeddings_client.current_status not in ["idle", "busy"]:
error_message = f"{client_name} is not ready"
elif not embeddings_client.embeddings_status:
error_message = f"{client_name} has no embeddings model loaded"
else:
error_message = None
if error_message:
details["error"] = {
"icon": "mdi-alert",
"value": error_message,
"description": error_message,
"color": "error",
}
return details return details
@property @property
@@ -686,7 +853,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
self.collection_name = collection_name = self.make_collection_name(self.scene) self.collection_name = collection_name = self.make_collection_name(self.scene)
log.info( log.info(
"chromadb agent", status="setting up db", collection_name=collection_name "chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
) )
distance_function = self.distance_function distance_function = self.distance_function
@@ -713,6 +880,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
self.db = self.db_client.get_or_create_collection( self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=openai_ef, metadata=collection_metadata collection_name, embedding_function=openai_ef, metadata=collection_metadata
) )
elif self.using_client_api_embeddings:
log.info(
"chromadb",
embeddings="Client API",
client=self.embeddings_client,
)
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
if not embeddings_client:
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
if not embeddings_client.supports_embeddings:
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
ef = embeddings_client.embeddings_function
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef, metadata=collection_metadata
)
elif self.using_instructor_embeddings: elif self.using_instructor_embeddings:
log.info( log.info(
"chromadb", "chromadb",
@@ -722,7 +909,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
) )
ef = embedding_functions.InstructorEmbeddingFunction( ef = embedding_functions.InstructorEmbeddingFunction(
model_name=model_name, device=device model_name=model_name, device=device, instruction="Represent the document for retrieval:"
) )
log.info("chromadb", status="embedding function ready") log.info("chromadb", status="embedding function ready")
@@ -801,6 +988,10 @@ class ChromaDBMemoryAgent(MemoryAgent):
) )
try: try:
self.db_client.delete_collection(collection_name) self.db_client.delete_collection(collection_name)
except chromadb.errors.NotFoundError as exc:
log.error(
"chromadb agent", error="collection not found", details=exc
)
except ValueError as exc: except ValueError as exc:
log.error( log.error(
"chromadb agent", error="failed to delete collection", details=exc "chromadb agent", error="failed to delete collection", details=exc

View File

@@ -510,53 +510,6 @@ class NarratorAgent(
return response return response
@set_processing
async def augment_context(self):
"""
Takes a context history generated via scene.context_history() and augments it with additional information
by asking and answering questions with help from the long term memory.
"""
memory = self.scene.get_helper("memory").agent
questions = await Prompt.request(
"narrator.context-questions",
self.client,
"narrate",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
},
)
log.debug("context_questions", questions=questions)
questions = [q for q in questions.split("\n") if q.strip()]
memory_context = await memory.multi_query(
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
)
answers = await Prompt.request(
"narrator.context-answers",
self.client,
"narrate",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"memory": memory_context,
"questions": questions,
"extra_instructions": self.extra_instructions,
},
)
log.debug("context_answers", answers=answers)
answers = [a for a in answers.split("\n") if a.strip()]
# return questions and answers
return list(zip(questions, answers))
@set_processing @set_processing
@store_context_state('narrative_direction', time_narration=True) @store_context_state('narrative_direction', time_narration=True)
async def narrate_time_passage( async def narrate_time_passage(

View File

@@ -4,8 +4,7 @@ import re
import dataclasses import dataclasses
import structlog import structlog
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Literal
import talemate.data_objects as data_objects
import talemate.emit.async_signals import talemate.emit.async_signals
import talemate.util as util import talemate.util as util
from talemate.emit import emit from talemate.emit import emit
@@ -35,6 +34,8 @@ from talemate.agents.base import (
from talemate.agents.registry import register from talemate.agents.registry import register
from talemate.agents.memory.rag import MemoryRAGMixin from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.history import ArchiveEntry
from .analyze_scene import SceneAnalyzationMixin from .analyze_scene import SceneAnalyzationMixin
from .context_investigation import ContextInvestigationMixin from .context_investigation import ContextInvestigationMixin
from .layered_history import LayeredHistoryMixin from .layered_history import LayeredHistoryMixin
@@ -63,6 +64,7 @@ class SummarizeEmission(AgentTemplateEmission):
extra_instructions: str | None = None extra_instructions: str | None = None
generation_options: GenerationOptions | None = None generation_options: GenerationOptions | None = None
summarization_history: list[str] | None = None summarization_history: list[str] | None = None
summarization_type: Literal["dialogue", "events"] = "dialogue"
@register() @register()
class SummarizeAgent( class SummarizeAgent(
@@ -189,6 +191,34 @@ class SummarizeAgent(
return emission.sub_instruction return emission.sub_instruction
# SUMMARIZATION HELPERS
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
num_previous = self.archive_include_previous
# find entry by .id
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
if entry_index is None:
raise ValueError("Entry not found")
end = entry_index - 1
previous_summaries = []
if entry and num_previous > 0:
if self.layered_history_available:
previous_summaries = self.compile_layered_history(
include_base_layer=True,
base_layer_end_id=entry.id
)[-num_previous:]
else:
previous_summaries = [
entry.text for entry in self.scene.archived_history[end-num_previous:end]
]
return previous_summaries
# SUMMARIZE # SUMMARIZE
@set_processing @set_processing
@@ -352,7 +382,7 @@ class SummarizeAgent(
# determine the appropariate timestamp for the summarization # determine the appropariate timestamp for the summarization
scene.push_archive(data_objects.ArchiveEntry(summarized, start, end, ts=ts)) await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
scene.ts=ts scene.ts=ts
scene.emit_status() scene.emit_status()
@@ -478,7 +508,8 @@ class SummarizeAgent(
extra_instructions=extra_instructions, extra_instructions=extra_instructions,
generation_options=generation_options, generation_options=generation_options,
template_vars=template_vars, template_vars=template_vars,
summarization_history=extra_context or [] summarization_history=extra_context or [],
summarization_type="dialogue",
) )
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission) await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
@@ -562,7 +593,8 @@ class SummarizeAgent(
extra_instructions=extra_instructions, extra_instructions=extra_instructions,
generation_options=generation_options, generation_options=generation_options,
template_vars=template_vars, template_vars=template_vars,
summarization_history=[extra_context] if extra_context else [] summarization_history=[extra_context] if extra_context else [],
summarization_type="events",
) )
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission) await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)

View File

@@ -1,17 +1,18 @@
import structlog import structlog
import re from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
from talemate.agents.base import ( from talemate.agents.base import (
set_processing, set_processing,
AgentAction, AgentAction,
AgentActionConfig AgentActionConfig,
AgentEmission,
) )
from talemate.prompts import Prompt import dataclasses
import talemate.emit.async_signals import talemate.emit.async_signals
from talemate.exceptions import GenerationCancelled from talemate.exceptions import GenerationCancelled
from talemate.world_state.templates import GenerationOptions from talemate.world_state.templates import GenerationOptions
from talemate.emit import emit from talemate.emit import emit
from talemate.context import handle_generation_cancelled from talemate.context import handle_generation_cancelled
from talemate.history import LayeredArchiveEntry, HistoryEntry, entry_contained
import talemate.util as util import talemate.util as util
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -19,6 +20,24 @@ if TYPE_CHECKING:
log = structlog.get_logger() log = structlog.get_logger()
talemate.emit.async_signals.register(
"agent.summarization.layered_history.finalize",
)
@dataclasses.dataclass
class LayeredHistoryFinalizeEmission(AgentEmission):
entry: LayeredArchiveEntry | None = None
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
@property
def response(self) -> str | None:
return self.entry.text if self.entry else None
@response.setter
def response(self, value: str):
if self.entry:
self.entry.text = value
class SummaryLongerThanOriginalError(ValueError): class SummaryLongerThanOriginalError(ValueError):
def __init__(self, original_length:int, summarized_length:int): def __init__(self, original_length:int, summarized_length:int):
self.original_length = original_length self.original_length = original_length
@@ -156,6 +175,101 @@ class LayeredHistoryMixin:
generation_options=emission.generation_options generation_options=emission.generation_options
) )
# helpers
async def _lh_split_and_summarize_chunks(
self,
chunks: list[dict],
extra_context: str,
generation_options: GenerationOptions | None = None,
) -> list[str]:
"""
Split chunks based on max_process_tokens and summarize each part.
Returns a list of summary texts.
"""
summaries = []
current_chunk = chunks.copy()
while current_chunk:
partial_chunk = []
max_process_tokens = self.layered_history_max_process_tokens
# Build partial chunk up to max_process_tokens
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
partial_chunk.append(current_chunk.pop(0))
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
log.debug("_split_and_summarize_chunks",
tokens_in_chunk=util.count_tokens(text_to_summarize),
max_process_tokens=max_process_tokens)
summary_text = await self.summarize_events(
text_to_summarize,
extra_context=extra_context + "\n\n".join(summaries),
generation_options=generation_options,
response_length=self.layered_history_response_length,
analyze_chunks=self.layered_history_analyze_chunks,
chunk_size=self.layered_history_chunk_size,
)
summaries.append(summary_text)
return summaries
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
"""
Validates that the summarized text is not longer than the original.
Raises SummaryLongerThanOriginalError if validation fails.
"""
summarized_length = util.count_tokens(summaries)
if summarized_length > original_length:
raise SummaryLongerThanOriginalError(original_length, summarized_length)
log.debug("_validate_summary_length",
original_length=original_length,
summarized_length=summarized_length)
def _lh_build_extra_context(self, layer_index: int) -> str:
"""
Builds extra context from compiled layered history for the given layer.
"""
return "\n\n".join(self.compile_layered_history(layer_index))
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
"""
Extracts timestamps from a chunk of entries.
Returns (ts, ts_start, ts_end)
"""
if not chunk:
return "PT1S", "PT1S", "PT1S"
ts = chunk[0].get('ts', 'PT1S')
ts_start = chunk[0].get('ts_start', ts)
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
return ts, ts_start, ts_end
async def _lh_finalize_archive_entry(
self,
entry: LayeredArchiveEntry,
summarization_history: list[str] | None = None,
) -> LayeredArchiveEntry:
"""
Finalizes an archive entry by summarizing it and adding it to the layered history.
"""
emission = LayeredHistoryFinalizeEmission(
agent=self,
entry=entry,
summarization_history=summarization_history,
)
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
return emission.entry
# methods # methods
def compile_layered_history( def compile_layered_history(
@@ -164,6 +278,7 @@ class LayeredHistoryMixin:
as_objects:bool=False, as_objects:bool=False,
include_base_layer:bool=False, include_base_layer:bool=False,
max:int = None, max:int = None,
base_layer_end_id: str | None = None,
) -> list[str]: ) -> list[str]:
""" """
Starts at the last layer and compiles the layered history into a single Starts at the last layer and compiles the layered history into a single
@@ -194,6 +309,17 @@ class LayeredHistoryMixin:
entry_num = 1 entry_num = 1
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]: for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
if base_layer_end_id:
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
index=0,
layer=i+1,
**layered_history_entry)
)
if contained:
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
break
text = f"{layered_history_entry['text']}" text = f"{layered_history_entry['text']}"
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]: if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
@@ -213,7 +339,7 @@ class LayeredHistoryMixin:
else: else:
compiled.append(text) compiled.append(text)
next_layer_start = layered_history_entry["end"] + 1 next_layer_start = layered_history_entry["end"] + 1
if i == 0 and include_base_layer: if i == 0 and include_base_layer:
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested # we are are at layered history layer zero and inclusion of base layer (archived history) is requested
@@ -222,7 +348,10 @@ class LayeredHistoryMixin:
entry_num = 1 entry_num = 1
for ah in self.scene.archived_history[next_layer_start:]: for ah in self.scene.archived_history[next_layer_start or 0:]:
if base_layer_end_id and ah["id"] == base_layer_end_id:
break
text = f"{ah['text']}" text = f"{ah['text']}"
if as_objects: if as_objects:
@@ -291,8 +420,6 @@ class LayeredHistoryMixin:
return # No base layer summaries to work with return # No base layer summaries to work with
token_threshold = self.layered_history_threshold token_threshold = self.layered_history_threshold
method = self.actions["archive"].config["method"].value
max_process_tokens = self.layered_history_max_process_tokens
max_layers = self.layered_history_max_layers max_layers = self.layered_history_max_layers
if not hasattr(self.scene, 'layered_history'): if not hasattr(self.scene, 'layered_history'):
@@ -329,15 +456,9 @@ class LayeredHistoryMixin:
log.debug("summarize_to_layered_history", created_layer=next_layer_index) log.debug("summarize_to_layered_history", created_layer=next_layer_index)
next_layer = layered_history[next_layer_index] next_layer = layered_history[next_layer_index]
ts = current_chunk[0]['ts'] ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
ts_start = current_chunk[0]['ts_start'] if 'ts_start' in current_chunk[0] else ts
ts_end = current_chunk[-1]['ts_end'] if 'ts_end' in current_chunk[-1] else ts
summaries = [] extra_context = self._lh_build_extra_context(next_layer_index)
extra_context = "\n\n".join(
self.compile_layered_history(next_layer_index)
)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk)) text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
@@ -345,44 +466,24 @@ class LayeredHistoryMixin:
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True}) emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
while current_chunk: summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
noop = False
log.debug("summarize_to_layered_history", tokens_in_chunk=util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk)), max_process_tokens=max_process_tokens) # validate summary length
self._lh_validate_summary_length(summaries, text_length)
partial_chunk = [] next_layer.append(LayeredArchiveEntry(**{
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
partial_chunk.append(current_chunk.pop(0))
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
summary_text = await self.summarize_events(
text_to_summarize,
extra_context=extra_context + "\n\n".join(summaries),
generation_options=generation_options,
response_length=self.layered_history_response_length,
analyze_chunks=self.layered_history_analyze_chunks,
chunk_size=self.layered_history_chunk_size,
)
noop = False
summaries.append(summary_text)
# if summarized text is longer than the original, we will
# raise an error
if util.count_tokens(summaries) > text_length:
raise SummaryLongerThanOriginalError(text_length, util.count_tokens(summaries))
log.debug("summarize_to_layered_history", original_length=text_length, summarized_length=util.count_tokens(summaries))
next_layer.append({
"start": start_index, "start": start_index,
"end": i - 1, "end": i,
"ts": ts, "ts": ts,
"ts_start": ts_start, "ts_start": ts_start,
"ts_end": ts_end, "ts_end": ts_end,
"text": "\n\n".join(summaries) "text": "\n\n".join(summaries),
}) }).model_dump(exclude_none=True))
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}") emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
@@ -412,7 +513,7 @@ class LayeredHistoryMixin:
last_entry = layered_history[0][-1] last_entry = layered_history[0][-1]
end = last_entry["end"] end = last_entry["end"]
log.debug("summarize_to_layered_history", layer="base", start=end) log.debug("summarize_to_layered_history", layer="base", start=end)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end + 1) has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
else: else:
log.debug("summarize_to_layered_history", layer="base", empty=True) log.debug("summarize_to_layered_history", layer="base", empty=True)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0) has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
@@ -445,7 +546,7 @@ class LayeredHistoryMixin:
end = next_layer[-1]["end"] if next_layer else 0 end = next_layer[-1]["end"] if next_layer else 0
log.debug("summarize_to_layered_history", layer=index, start=end) log.debug("summarize_to_layered_history", layer=index, start=end)
summarized = await summarize_layer(layered_history[index], index + 1, end + 1 if end else 0) summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
if summarized: if summarized:
noop = False noop = False
@@ -467,3 +568,106 @@ class LayeredHistoryMixin:
emit("status", message="Rebuilding of layered history cancelled", status="info") emit("status", message="Rebuilding of layered history cancelled", status="info")
handle_generation_cancelled(e) handle_generation_cancelled(e)
return return
async def summarize_entries_to_layered_history(
self,
entries: list[dict],
next_layer_index: int,
start_index: int,
end_index: int,
generation_options: GenerationOptions | None = None,
) -> list[LayeredArchiveEntry]:
"""
Summarizes a list of entries into layered history entries.
This method is used for regenerating specific history entries by processing
their source entries. It chunks the entries based on the token threshold and
summarizes each chunk into a LayeredArchiveEntry.
Args:
entries: List of dictionaries containing the text entries to summarize.
Each entry should have at least a 'text' field and optionally
'ts', 'ts_start', and 'ts_end' fields.
next_layer_index: The index of the layer where the summarized entries
will be placed.
start_index: The starting index in the source layer that these entries
correspond to.
end_index: The ending index in the source layer that these entries
correspond to.
generation_options: Optional generation options to pass to the summarization
process.
Returns:
List of LayeredArchiveEntry objects containing the summarized text along
with timestamp and index information. Currently returns a list with a
single entry, but the structure supports multiple entries if needed.
Notes:
- The method respects the layered_history_threshold for chunking
- Uses helper methods for timestamp extraction, context building, and
chunk summarization
- Validates that summaries are not longer than the original text
- The last entry is always included in the final chunk if it doesn't
exceed the token threshold
"""
token_threshold = self.layered_history_threshold
archive_entries = []
summaries = []
current_chunk = []
current_tokens = 0
ts = "PT1S"
ts_start = "PT1S"
ts_end = "PT1S"
for entry_index, entry in enumerate(entries):
is_last_entry = entry_index == len(entries) - 1
entry_tokens = util.count_tokens(entry['text'])
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
if current_tokens + entry_tokens > token_threshold or is_last_entry:
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
# if we are here because this is the last entry and adding it to
# the current chunk would not exceed the token threshold, we will
# add it to the current chunk
current_chunk.append(entry)
if current_chunk:
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
archive_entry = LayeredArchiveEntry(**{
"start": start_index,
"end": end_index,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
})
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
archive_entries.append(archive_entry)
current_chunk.append(entry)
current_tokens += entry_tokens
return archive_entries

View File

@@ -23,7 +23,7 @@ from talemate.emit.signals import handlers as signal_handlers
from talemate.prompts.base import Prompt from talemate.prompts.base import Prompt
from .commands import * # noqa from .commands import * # noqa
from .context import VIS_TYPES, VisualContext, visual_context from .context import VIS_TYPES, VisualContext, VisualContextState, visual_context
from .handlers import HANDLERS from .handlers import HANDLERS
from .schema import RESOLUTION_MAP, RenderSettings from .schema import RESOLUTION_MAP, RenderSettings
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
@@ -40,6 +40,14 @@ BACKENDS = [
for mixin_backend, mixin in HANDLERS.items() for mixin_backend, mixin in HANDLERS.items()
] ]
PROMPT_OUTPUT_FORMAT = """
### Positive
{positive_prompt}
### Negative
{negative_prompt}
"""
log = structlog.get_logger("talemate.agents.visual") log = structlog.get_logger("talemate.agents.visual")
@@ -284,7 +292,7 @@ class VisualBase(Agent):
try: try:
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"] backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
except KeyError: except (KeyError, TypeError):
backend = self.backend backend = self.backend
backend_changed = backend != self.backend backend_changed = backend != self.backend
@@ -425,10 +433,9 @@ class VisualBase(Agent):
self, format: str = "portrait", prompt: str = None, automatic: bool = False self, format: str = "portrait", prompt: str = None, automatic: bool = False
): ):
context = visual_context.get() context:VisualContextState = visual_context.get()
if not self.enabled: log.debug("visual generate", context=context)
return
if automatic and not self.allow_automatic_generation: if automatic and not self.allow_automatic_generation:
return return
@@ -459,7 +466,7 @@ class VisualBase(Agent):
thematic_style = self.default_style thematic_style = self.default_style
vis_type_styles = self.vis_type_styles(context.vis_type) vis_type_styles = self.vis_type_styles(context.vis_type)
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style]) prompt:Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
if context.vis_type == VIS_TYPES.CHARACTER: if context.vis_type == VIS_TYPES.CHARACTER:
prompt.keywords.append("character portrait") prompt.keywords.append("character portrait")
@@ -482,6 +489,33 @@ class VisualBase(Agent):
context.format = format context.format = format
can_generate_image = self.enabled and self.backend_ready
if not context.prompt_only and not can_generate_image:
emit("status", "Visual agent is not ready for image generation, will output prompt instead.", status="warning")
# if prompt_only, we don't need to generate an image
# instead we emit a system message with the prompt
if context.prompt_only or not can_generate_image:
emit(
"system",
message=PROMPT_OUTPUT_FORMAT.format(
positive_prompt=prompt.positive_prompt,
negative_prompt=prompt.negative_prompt,
),
meta={
"icon": "mdi-image-text",
"color": "highlight7",
"title": f"Visual Prompt - {context.title}",
"display": "tonal",
"as_markdown": True,
}
)
return
if not can_generate_image:
return
# Call the backend specific generate function # Call the backend specific generate function
backend = self.backend backend = self.backend
@@ -541,8 +575,16 @@ class VisualBase(Agent):
return response.strip() return response.strip()
async def generate_environment_background(self, instructions: str = None): async def generate_environment_background(
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions): self,
instructions: str = None,
prompt_only: bool = False,
):
with VisualContext(
vis_type=VIS_TYPES.ENVIRONMENT,
instructions=instructions,
prompt_only=prompt_only,
):
await self.generate(format="landscape") await self.generate(format="landscape")
async def generate_character_portrait( async def generate_character_portrait(
@@ -550,12 +592,14 @@ class VisualBase(Agent):
character_name: str, character_name: str,
instructions: str = None, instructions: str = None,
replace: bool = False, replace: bool = False,
prompt_only: bool = False,
): ):
with VisualContext( with VisualContext(
vis_type=VIS_TYPES.CHARACTER, vis_type=VIS_TYPES.CHARACTER,
character_name=character_name, character_name=character_name,
instructions=instructions, instructions=instructions,
replace=replace, replace=replace,
prompt_only=prompt_only,
): ):
await self.generate(format="portrait") await self.generate(format="portrait")

View File

@@ -29,6 +29,15 @@ class VisualContextState(pydantic.BaseModel):
prepared_prompt: Union[str, None] = None prepared_prompt: Union[str, None] = None
format: Union[str, None] = None format: Union[str, None] = None
replace: bool = False replace: bool = False
prompt_only: bool = False
@property
def title(self) -> str:
if self.vis_type == VIS_TYPES.ENVIRONMENT:
return "Environment"
elif self.vis_type == VIS_TYPES.CHARACTER:
return f"Character: {self.character_name}"
return "Visual Context"
class VisualContext: class VisualContext:

View File

@@ -90,12 +90,16 @@ class VisualWebsocketHandler(Plugin):
payload = GeneratePayload(**data) payload = GeneratePayload(**data)
visual = get_agent("visual") visual = get_agent("visual")
await visual.generate_character_portrait( await visual.generate_character_portrait(
payload.context.character_name, payload.context.instructions, replace=True payload.context.character_name,
payload.context.instructions,
replace=True,
prompt_only=payload.context.prompt_only,
) )
async def handle_visualize_environment(self, data: dict): async def handle_visualize_environment(self, data: dict):
payload = GeneratePayload(**data) payload = GeneratePayload(**data)
visual = get_agent("visual") visual = get_agent("visual")
await visual.generate_environment_background( await visual.generate_environment_background(
instructions=payload.context.instructions instructions=payload.context.instructions,
prompt_only=payload.context.prompt_only,
) )

View File

@@ -18,6 +18,7 @@ from talemate.scene_message import (
ReinforcementMessage, ReinforcementMessage,
TimePassageMessage, TimePassageMessage,
) )
from talemate.util.response import extract_list
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
@@ -76,6 +77,12 @@ class WorldStateAgent(
label="Update world state", label="Update world state",
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.", description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
config={ config={
"initial": AgentActionConfig(
type="bool",
label="When a new scene is started",
description="Whether to update the world state on scene start.",
value=True,
),
"turns": AgentActionConfig( "turns": AgentActionConfig(
type="number", type="number",
label="Turns", label="Turns",
@@ -134,9 +141,14 @@ class WorldStateAgent(
def experimental(self): def experimental(self):
return True return True
@property
def initial_update(self):
return self.actions["update_world_state"].config["initial"].value
def connect(self, scene): def connect(self, scene):
super().connect(scene) super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop) talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
talemate.emit.async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init_after)
async def advance_time(self, duration: str, narrative: str = None): async def advance_time(self, duration: str, narrative: str = None):
""" """
@@ -162,6 +174,22 @@ class WorldStateAgent(
) )
) )
async def on_scene_loop_init_after(self, emission):
"""
Called when a scene is initialized
"""
if not self.enabled:
return
if not self.initial_update:
return
if self.get_scene_state("inital_update_done"):
return
await self.scene.world_state.request_update()
self.set_scene_states(inital_update_done=True)
async def on_game_loop(self, emission: GameLoopEvent): async def on_game_loop(self, emission: GameLoopEvent):
""" """
Called when a conversation is generated Called when a conversation is generated
@@ -305,7 +333,7 @@ class WorldStateAgent(
}, },
) )
queries = response.split("\n") queries = extract_list(response)
memory_agent = get_agent("memory") memory_agent = get_agent("memory")

View File

@@ -10,7 +10,9 @@ from talemate.client.groq import GroqClient
from talemate.client.koboldcpp import KoboldCppClient from talemate.client.koboldcpp import KoboldCppClient
from talemate.client.lmstudio import LMStudioClient from talemate.client.lmstudio import LMStudioClient
from talemate.client.mistral import MistralAIClient from talemate.client.mistral import MistralAIClient
from talemate.client.ollama import OllamaClient
from talemate.client.openai import OpenAIClient from talemate.client.openai import OpenAIClient
from talemate.client.openrouter import OpenRouterClient
from talemate.client.openai_compat import OpenAICompatibleClient from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.tabbyapi import TabbyAPIClient from talemate.client.tabbyapi import TabbyAPIClient

View File

@@ -2,8 +2,14 @@ import pydantic
import structlog import structlog
from anthropic import AsyncAnthropic, PermissionDeniedError from anthropic import AsyncAnthropic, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults from talemate.client.base import ClientBase, ErrorAction, CommonDefaults, ExtraField
from talemate.client.registry import register from talemate.client.registry import register
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig
from talemate.config import load_config from talemate.config import load_config
from talemate.emit import emit from talemate.emit import emit
from talemate.emit.signals import handlers from talemate.emit.signals import handlers
@@ -28,13 +34,17 @@ SUPPORTED_MODELS = [
] ]
class Defaults(CommonDefaults, pydantic.BaseModel): class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384 max_token_length: int = 16384
model: str = "claude-3-5-sonnet-latest" model: str = "claude-3-5-sonnet-latest"
double_coercion: str = None
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register() @register()
class AnthropicClient(ClientBase): class AnthropicClient(EndpointOverrideMixin, ClientBase):
""" """
Anthropic client for generating text. Anthropic client for generating text.
""" """
@@ -44,6 +54,7 @@ class AnthropicClient(ClientBase):
auto_break_repetition_enabled = False auto_break_repetition_enabled = False
# TODO: make this configurable? # TODO: make this configurable?
decensor_enabled = False decensor_enabled = False
config_cls = ClientConfig
class Meta(ClientBase.Meta): class Meta(ClientBase.Meta):
name_prefix: str = "Anthropic" name_prefix: str = "Anthropic"
@@ -52,15 +63,21 @@ class AnthropicClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False requires_prompt_template: bool = False
defaults: Defaults = Defaults() defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="claude-3-5-sonnet-latest", **kwargs): def __init__(self, model="claude-3-5-sonnet-latest", **kwargs):
self.model_name = model self.model_name = model
self.api_key_status = None self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config() self.config = load_config()
super().__init__(**kwargs) super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved) handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
@property @property
def anthropic_api_key(self): def anthropic_api_key(self):
return self.config.get("anthropic", {}).get("api_key") return self.config.get("anthropic", {}).get("api_key")
@@ -103,6 +120,7 @@ class AnthropicClient(ClientBase):
data={ data={
"error_action": error_action.model_dump() if error_action else None, "error_action": error_action.model_dump() if error_action else None,
"double_coercion": self.double_coercion,
"meta": self.Meta().model_dump(), "meta": self.Meta().model_dump(),
"enabled": self.enabled, "enabled": self.enabled,
} }
@@ -117,7 +135,7 @@ class AnthropicClient(ClientBase):
) )
def set_client(self, max_token_length: int = None): def set_client(self, max_token_length: int = None):
if not self.anthropic_api_key: if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncAnthropic(api_key="sk-1111") self.client = AsyncAnthropic(api_key="sk-1111")
log.error("No anthropic API key set") log.error("No anthropic API key set")
if self.api_key_status: if self.api_key_status:
@@ -134,7 +152,7 @@ class AnthropicClient(ClientBase):
model = self.model_name model = self.model_name
self.client = AsyncAnthropic(api_key=self.anthropic_api_key) self.client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384 self.max_token_length = max_token_length or 16384
if not self.api_key_status: if not self.api_key_status:
@@ -158,7 +176,11 @@ class AnthropicClient(ClientBase):
if "enabled" in kwargs: if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs) self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event): def on_config_saved(self, event):
config = event.data config = event.data
@@ -175,13 +197,10 @@ class AnthropicClient(ClientBase):
self.emit_status() self.emit_status()
def prompt_template(self, system_message: str, prompt: str): def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt: """
_, right = prompt.split("<|BOT|>", 1) Anthropic handles the prompt template internally, so we just
if right: give the prompt as is.
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ") """
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt return prompt
async def generate(self, prompt: str, parameters: dict, kind: str): async def generate(self, prompt: str, parameters: dict, kind: str):
@@ -189,20 +208,20 @@ class AnthropicClient(ClientBase):
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if not self.anthropic_api_key: if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No anthropic API key set") raise Exception("No anthropic API key set")
right = None prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
human_message = {"role": "user", "content": prompt.strip()}
system_message = self.get_system_message(kind) system_message = self.get_system_message(kind)
messages = [
{"role": "user", "content": prompt.strip()}
]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
self.log.debug( self.log.debug(
"generate", "generate",
prompt=prompt[:128] + " ...", prompt=prompt[:128] + " ...",
@@ -210,27 +229,38 @@ class AnthropicClient(ClientBase):
system_message=system_message, system_message=system_message,
) )
completion_tokens = 0
prompt_tokens = 0
try: try:
response = await self.client.messages.create( stream = await self.client.messages.create(
model=self.model_name, model=self.model_name,
system=system_message, system=system_message,
messages=[human_message], messages=messages,
stream=True,
**parameters, **parameters,
) )
self._returned_prompt_tokens = self.prompt_tokens(response) response = ""
self._returned_response_tokens = self.response_tokens(response)
log.debug("generated response", response=response.content) async for event in stream:
response = response.content[0].text if event.type == "content_block_delta":
content = event.delta.text
response += content
self.update_request_tokens(self.count_tokens(content))
if expected_response and expected_response.startswith("{"): elif event.type == "message_start":
if response.startswith("```json") and response.endswith("```"): prompt_tokens = event.message.usage.input_tokens
response = response[7:-3].strip()
if right and response.startswith(right): elif event.type == "message_delta":
response = response[len(right) :].strip() completion_tokens += event.usage.output_tokens
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
log.debug("generated response", response=response)
return response return response
except PermissionDeniedError as e: except PermissionDeniedError as e:

View File

@@ -6,10 +6,12 @@ import ipaddress
import logging import logging
import random import random
import time import time
import traceback
import asyncio import asyncio
from typing import Callable, Union, Literal from typing import Callable, Union, Literal
import pydantic import pydantic
import dataclasses
import structlog import structlog
import urllib3 import urllib3
from openai import AsyncOpenAI, PermissionDeniedError from openai import AsyncOpenAI, PermissionDeniedError
@@ -23,7 +25,10 @@ from talemate.client.model_prompts import model_prompt
from talemate.client.ratelimit import CounterRateLimiter from talemate.client.ratelimit import CounterRateLimiter
from talemate.context import active_scene from talemate.context import active_scene
from talemate.emit import emit from talemate.emit import emit
from talemate.config import load_config, save_config, EmbeddingFunctionPreset
import talemate.emit.async_signals as async_signals
from talemate.exceptions import SceneInactiveError, GenerationCancelled from talemate.exceptions import SceneInactiveError, GenerationCancelled
import talemate.ux.schema as ux_schema
from talemate.client.system_prompts import SystemPrompts from talemate.client.system_prompts import SystemPrompts
@@ -77,13 +82,20 @@ class Defaults(CommonDefaults, pydantic.BaseModel):
double_coercion: str = None double_coercion: str = None
class FieldGroup(pydantic.BaseModel):
name: str
label: str
description: str
icon: str = "mdi-cog"
class ExtraField(pydantic.BaseModel): class ExtraField(pydantic.BaseModel):
name: str name: str
type: str type: str
label: str label: str
required: bool required: bool
description: str description: str
group: FieldGroup | None = None
note: ux_schema.Note | None = None
class ParameterReroute(pydantic.BaseModel): class ParameterReroute(pydantic.BaseModel):
talemate_parameter: str talemate_parameter: str
@@ -101,6 +113,56 @@ class ParameterReroute(pydantic.BaseModel):
return str(self) == str(other) return str(self) == str(other)
class RequestInformation(pydantic.BaseModel):
start_time: float = pydantic.Field(default_factory=time.time)
end_time: float | None = None
tokens: int = 0
@pydantic.computed_field(description="Duration")
@property
def duration(self) -> float:
end_time = self.end_time or time.time()
return end_time - self.start_time
@pydantic.computed_field(description="Tokens per second")
@property
def rate(self) -> float:
try:
end_time = self.end_time or time.time()
return self.tokens / (end_time - self.start_time)
except:
pass
return 0
@pydantic.computed_field(description="Status")
@property
def status(self) -> str:
if self.end_time:
return "completed"
elif self.start_time:
if self.duration > 1 and self.rate == 0:
return "stopped"
return "in progress"
else:
return "pending"
@pydantic.computed_field(description="Age")
@property
def age(self) -> float:
if not self.end_time:
return -1
return time.time() - self.end_time
@dataclasses.dataclass
class ClientEmbeddingsStatus:
client: "ClientBase | None" = None
embedding_name: str | None = None
async_signals.register(
"client.embeddings_available",
)
class ClientBase: class ClientBase:
api_url: str api_url: str
model_name: str model_name: str
@@ -120,6 +182,7 @@ class ClientBase:
data_format: Literal["yaml", "json"] | None = None data_format: Literal["yaml", "json"] | None = None
rate_limit: int | None = None rate_limit: int | None = None
client_type = "base" client_type = "base"
request_information: RequestInformation | None = None
status_request_timeout:int = 2 status_request_timeout:int = 2
@@ -171,6 +234,13 @@ class ClientBase:
""" """
return self.Meta().requires_prompt_template return self.Meta().requires_prompt_template
@property
def can_think(self) -> bool:
"""
Allow reasoning models to think before responding.
"""
return False
@property @property
def max_tokens_param_name(self): def max_tokens_param_name(self):
return "max_tokens" return "max_tokens"
@@ -183,9 +253,87 @@ class ClientBase:
"max_tokens", "max_tokens",
] ]
@property
def supports_embeddings(self) -> bool:
return False
@property
def embeddings_function(self):
return None
@property
def embeddings_status(self) -> bool:
return getattr(self, "_embeddings_status", False)
@property
def embeddings_model_name(self) -> str | None:
return getattr(self, "_embeddings_model_name", None)
@property
def embeddings_url(self) -> str:
return None
@property
def embeddings_identifier(self) -> str:
return f"client-api/{self.name}/{self.embeddings_model_name}"
async def destroy(self, config:dict):
"""
This is called before the client is removed from talemate.instance.clients
Use this to perform any cleanup that is necessary.
If a subclass overrides this method, it should call super().destroy(config) in the
end of the method.
"""
if self.supports_embeddings:
self.remove_embeddings(config)
def reset_embeddings(self):
self._embeddings_model_name = None
self._embeddings_status = False
def set_client(self, **kwargs): def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111") self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def set_embeddings(self):
log.debug("setting embeddings", client=self.name, supports_embeddings=self.supports_embeddings, embeddings_status=self.embeddings_status)
if not self.supports_embeddings or not self.embeddings_status:
return
config = load_config(as_model=True)
key = self.embeddings_identifier
if key in config.presets.embeddings:
log.debug("embeddings already set", client=self.name, key=key)
return config.presets.embeddings[key]
log.debug("setting embeddings", client=self.name, key=key)
config.presets.embeddings[key] = EmbeddingFunctionPreset(
embeddings="client-api",
client=self.name,
model=self.embeddings_model_name,
distance=1,
distance_function="cosine",
local=False,
custom=True,
)
save_config(config)
def remove_embeddings(self, config:dict | None = None):
# remove all embeddings for this client
for key, value in list(config["presets"]["embeddings"].items()):
if value["client"] == self.name and value["embeddings"] == "client-api":
log.warning("!!! removing embeddings", client=self.name, key=key)
config["presets"]["embeddings"].pop(key)
def set_system_prompts(self, system_prompts: dict | SystemPrompts): def set_system_prompts(self, system_prompts: dict | SystemPrompts):
if isinstance(system_prompts, dict): if isinstance(system_prompts, dict):
self.system_prompts = SystemPrompts(**system_prompts) self.system_prompts = SystemPrompts(**system_prompts)
@@ -222,6 +370,19 @@ class ClientBase:
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}" self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
) )
def split_prompt_for_coercion(self, prompt: str) -> tuple[str, str]:
"""
Splits the prompt and the prefill/coercion prompt.
"""
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if self.double_coercion:
right = f"{self.double_coercion}\n\n{right}"
return prompt, right
return prompt, None
def reconfigure(self, **kwargs): def reconfigure(self, **kwargs):
""" """
Reconfigures the client. Reconfigures the client.
@@ -241,6 +402,8 @@ class ClientBase:
if "enabled" in kwargs: if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
if not self.enabled and self.supports_embeddings and self.embeddings_status:
self.reset_embeddings()
if "double_coercion" in kwargs: if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"] self.double_coercion = kwargs["double_coercion"]
@@ -388,6 +551,8 @@ class ClientBase:
for field_name in getattr(self.Meta(), "extra_fields", {}).keys(): for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
data[field_name] = getattr(self, field_name, None) data[field_name] = getattr(self, field_name, None)
data = self.finalize_status(data)
emit( emit(
"client_status", "client_status",
message=self.client_type, message=self.client_type,
@@ -400,13 +565,31 @@ class ClientBase:
if status_change: if status_change:
instance.emit_agent_status_by_client(self) instance.emit_agent_status_by_client(self)
def finalize_status(self, data: dict):
"""
Finalizes the status data for the client.
"""
return data
def _common_status_data(self): def _common_status_data(self):
return { common_data = {
"can_be_coerced": self.can_be_coerced,
"preset_group": self.preset_group or "", "preset_group": self.preset_group or "",
"rate_limit": self.rate_limit, "rate_limit": self.rate_limit,
"data_format": self.data_format, "data_format": self.data_format,
"manual_model_choices": getattr(self.Meta(), "manual_model_choices", []),
"supports_embeddings": self.supports_embeddings,
"embeddings_status": self.embeddings_status,
"embeddings_model_name": self.embeddings_model_name,
"request_information": self.request_information.model_dump() if self.request_information else None,
} }
extra_fields = getattr(self.Meta(), "extra_fields", {})
for field_name in extra_fields.keys():
common_data[field_name] = getattr(self, field_name, None)
return common_data
def populate_extra_fields(self, data: dict): def populate_extra_fields(self, data: dict):
""" """
Updates data with the extra fields from the client's Meta Updates data with the extra fields from the client's Meta
@@ -438,6 +621,7 @@ class ClientBase:
:return: None :return: None
""" """
if self.processing: if self.processing:
self.emit_status()
return return
if not self.enabled: if not self.enabled:
@@ -619,6 +803,27 @@ class ClientBase:
""" """
pass pass
def new_request(self):
"""
Creates a new request information object.
"""
self.request_information = RequestInformation()
def end_request(self):
"""
Ends the request information object.
"""
self.request_information.end_time = time.time()
def update_request_tokens(self, tokens: int, replace: bool = False):
"""
Updates the request information object with the number of tokens received.
"""
if self.request_information:
if replace:
self.request_information.tokens = tokens
else:
self.request_information.tokens += tokens
async def send_prompt( async def send_prompt(
self, self,
@@ -690,7 +895,7 @@ class ClientBase:
except GenerationCancelled: except GenerationCancelled:
raise raise
except Exception as e: except Exception as e:
log.exception("Error during rate limit check", e=e) log.error("Error during rate limit check", e=traceback.format_exc())
if not active_scene.get(): if not active_scene.get():
@@ -736,8 +941,12 @@ class ClientBase:
) )
prompt_sent = self.repetition_adjustment(finalized_prompt) prompt_sent = self.repetition_adjustment(finalized_prompt)
self.new_request()
response = await self._cancelable_generate(prompt_sent, prompt_param, kind) response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
self.end_request()
if isinstance(response, GenerationCancelled): if isinstance(response, GenerationCancelled):
# generation was cancelled # generation was cancelled
raise response raise response
@@ -786,7 +995,7 @@ class ClientBase:
except GenerationCancelled as e: except GenerationCancelled as e:
raise raise
except Exception as e: except Exception as e:
self.log.exception("send_prompt error", e=e) self.log.error("send_prompt error", e=traceback.format_exc())
emit( emit(
"status", message="Error during generation (check logs)", status="error" "status", message="Error during generation (check logs)", status="error"
) )

View File

@@ -1,10 +1,15 @@
import pydantic import pydantic
import structlog import structlog
from cohere import AsyncClient from cohere import AsyncClientV2
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
from talemate.client.registry import register from talemate.client.registry import register
from talemate.config import load_config from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.emit import emit from talemate.emit import emit
from talemate.emit.signals import handlers from talemate.emit.signals import handlers
from talemate.util import count_tokens from talemate.util import count_tokens
@@ -26,13 +31,17 @@ SUPPORTED_MODELS = [
] ]
class Defaults(CommonDefaults, pydantic.BaseModel): class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384 max_token_length: int = 16384
model: str = "command-r-plus" model: str = "command-r-plus"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register() @register()
class CohereClient(ClientBase): class CohereClient(EndpointOverrideMixin, ClientBase):
""" """
Cohere client for generating text. Cohere client for generating text.
""" """
@@ -41,6 +50,7 @@ class CohereClient(ClientBase):
conversation_retries = 0 conversation_retries = 0
auto_break_repetition_enabled = False auto_break_repetition_enabled = False
decensor_enabled = True decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta): class Meta(ClientBase.Meta):
name_prefix: str = "Cohere" name_prefix: str = "Cohere"
@@ -48,11 +58,13 @@ class CohereClient(ClientBase):
manual_model: bool = True manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False requires_prompt_template: bool = False
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
defaults: Defaults = Defaults() defaults: Defaults = Defaults()
def __init__(self, model="command-r-plus", **kwargs): def __init__(self, model="command-r-plus", **kwargs):
self.model_name = model self.model_name = model
self.api_key_status = None self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config() self.config = load_config()
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -119,8 +131,8 @@ class CohereClient(ClientBase):
) )
def set_client(self, max_token_length: int = None): def set_client(self, max_token_length: int = None):
if not self.cohere_api_key: if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncClient("sk-1111") self.client = AsyncClientV2("sk-1111")
log.error("No cohere API key set") log.error("No cohere API key set")
if self.api_key_status: if self.api_key_status:
self.api_key_status = False self.api_key_status = False
@@ -136,7 +148,7 @@ class CohereClient(ClientBase):
model = self.model_name model = self.model_name
self.client = AsyncClient(self.cohere_api_key) self.client = AsyncClientV2(self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384 self.max_token_length = max_token_length or 16384
if not self.api_key_status: if not self.api_key_status:
@@ -161,6 +173,7 @@ class CohereClient(ClientBase):
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs) self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event): def on_config_saved(self, event):
config = event.data config = event.data
@@ -168,7 +181,7 @@ class CohereClient(ClientBase):
self.set_client(max_token_length=self.max_token_length) self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str): def response_tokens(self, response: str):
return count_tokens(response.text) return count_tokens(response)
def prompt_tokens(self, prompt: str): def prompt_tokens(self, prompt: str):
return count_tokens(prompt) return count_tokens(prompt)
@@ -207,7 +220,7 @@ class CohereClient(ClientBase):
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if not self.cohere_api_key: if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No cohere API key set") raise Exception("No cohere API key set")
right = None right = None
@@ -228,20 +241,42 @@ class CohereClient(ClientBase):
system_message=system_message, system_message=system_message,
) )
messages = [
{
"role": "system",
"content": system_message,
},
{
"role": "user",
"content": human_message,
}
]
try: try:
response = await self.client.chat( # Cohere's `chat_stream` returns an **asynchronous generator** that can be
# consumed directly with `async for`. It is not an asynchronous context
# manager, so attempting to use `async with` raises a `TypeError` as seen
# in issue logs. We therefore iterate over the generator directly.
stream = self.client.chat_stream(
model=self.model_name, model=self.model_name,
preamble=system_message, messages=messages,
message=human_message,
**parameters, **parameters,
) )
response = ""
async for event in stream:
if event and event.type == "content-delta":
chunk = event.delta.message.content.text
response += chunk
# Track token usage incrementally
self.update_request_tokens(self.count_tokens(chunk))
self._returned_prompt_tokens = self.prompt_tokens(prompt) self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response) self._returned_response_tokens = self.response_tokens(response)
log.debug("generated response", response=response.text) log.debug("generated response", response=response)
response = response.text
if expected_response and expected_response.startswith("{"): if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"): if response.startswith("```json") and response.endswith("```"):

View File

@@ -187,6 +187,14 @@ class DeepSeekClient(ClientBase):
return prompt return prompt
def response_tokens(self, response: str):
# Count tokens in a response string using the util.count_tokens helper
return self.count_tokens(response)
def prompt_tokens(self, prompt: str):
# Count tokens in a prompt string using the util.count_tokens helper
return self.count_tokens(prompt)
async def generate(self, prompt: str, parameters: dict, kind: str): async def generate(self, prompt: str, parameters: dict, kind: str):
""" """
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
@@ -221,13 +229,30 @@ class DeepSeekClient(ClientBase):
) )
try: try:
response = await self.client.chat.completions.create( # Use streaming so we can update_Request_tokens incrementally
stream = await self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=[system_message, human_message], messages=[system_message, human_message],
stream=True,
**parameters, **parameters,
) )
response = response.choices[0].message.content response = ""
# Iterate over streamed chunks
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta and getattr(delta, "content", None):
content_piece = delta.content
response += content_piece
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
# Save token accounting for whole request
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion # older models don't support json_object response coersion
# and often like to return the response wrapped in ```json # and often like to return the response wrapped in ```json

View File

@@ -3,19 +3,18 @@ import os
import pydantic import pydantic
import structlog import structlog
import vertexai from google import genai
from google.api_core.exceptions import ResourceExhausted import google.genai.types as genai_types
from vertexai.generative_models import ( from google.genai.errors import APIError
ChatSession,
GenerationConfig,
GenerativeModel,
ResponseValidationError,
SafetySetting,
)
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults
from talemate.client.registry import register from talemate.client.registry import register
from talemate.client.remote import RemoteServiceMixin from talemate.client.remote import (
RemoteServiceMixin,
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig from talemate.config import Client as BaseClientConfig
from talemate.config import load_config from talemate.config import load_config
from talemate.emit import emit from talemate.emit import emit
@@ -31,23 +30,29 @@ log = structlog.get_logger("talemate")
SUPPORTED_MODELS = [ SUPPORTED_MODELS = [
"gemini-1.0-pro", "gemini-1.0-pro",
"gemini-1.5-pro-preview-0409", "gemini-1.5-pro-preview-0409",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash-preview-05-20",
"gemini-2.5-pro-preview-03-25", "gemini-2.5-pro-preview-03-25",
"gemini-2.5-pro-preview-06-05",
] ]
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
class Defaults(CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384 max_token_length: int = 16384
model: str = "gemini-1.0-pro" model: str = "gemini-2.0-flash"
disable_safety_settings: bool = False disable_safety_settings: bool = False
double_coercion: str = None
class ClientConfig(EndpointOverride, BaseClientConfig):
class ClientConfig(BaseClientConfig):
disable_safety_settings: bool = False disable_safety_settings: bool = False
@register() @register()
class GoogleClient(RemoteServiceMixin, ClientBase): class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
""" """
Google client for generating text. Google client for generating text.
""" """
@@ -74,19 +79,26 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
description="Disable Google's safety settings for responses generated by the model.", description="Disable Google's safety settings for responses generated by the model.",
), ),
} }
extra_fields.update(endpoint_override_extra_fields())
def __init__(self, model="gemini-1.0-pro", **kwargs):
def __init__(self, model="gemini-2.0-flash", **kwargs):
self.model_name = model self.model_name = model
self.setup_status = None self.setup_status = None
self.model_instance = None self.model_instance = None
self.disable_safety_settings = kwargs.get("disable_safety_settings", False) self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
self.google_credentials_read = False self.google_credentials_read = False
self.google_project_id = None self.google_project_id = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config() self.config = load_config()
super().__init__(**kwargs) super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved) handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
@property @property
def google_credentials(self): def google_credentials(self):
path = self.google_credentials_path path = self.google_credentials_path
@@ -103,15 +115,35 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
def google_location(self): def google_location(self):
return self.config.get("google").get("gcloud_location") return self.config.get("google").get("gcloud_location")
@property
def google_api_key(self):
return self.config.get("google").get("api_key")
@property
def vertexai_ready(self) -> bool:
return all([
self.google_credentials_path,
self.google_location,
])
@property
def developer_api_ready(self) -> bool:
return all([
self.google_api_key,
])
@property
def using(self) -> str:
if self.developer_api_ready:
return "API"
if self.vertexai_ready:
return "VertexAI"
return "Unknown"
@property @property
def ready(self): def ready(self):
# all google settings must be set # all google settings must be set
return all( return self.vertexai_ready or self.developer_api_ready or self.endpoint_override_base_url_configured
[
self.google_credentials_path,
self.google_location,
]
)
@property @property
def safety_settings(self): def safety_settings(self):
@@ -119,30 +151,39 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
return None return None
safety_settings = [ safety_settings = [
SafetySetting( genai_types.SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE, threshold="BLOCK_NONE",
), ),
SafetySetting( genai_types.SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, category="HARM_CATEGORY_DANGEROUS_CONTENT",
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE, threshold="BLOCK_NONE",
), ),
SafetySetting( genai_types.SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, category="HARM_CATEGORY_HARASSMENT",
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE, threshold="BLOCK_NONE",
), ),
SafetySetting( genai_types.SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, category="HARM_CATEGORY_HATE_SPEECH",
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE, threshold="BLOCK_NONE",
), ),
SafetySetting( genai_types.SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_UNSPECIFIED, category="HARM_CATEGORY_CIVIC_INTEGRITY",
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE, threshold="BLOCK_NONE",
), ),
] ]
return safety_settings return safety_settings
@property
def http_options(self) -> genai_types.HttpOptions | None:
if not self.endpoint_override_base_url_configured:
return None
return genai_types.HttpOptions(
base_url=self.base_url
)
@property @property
def supported_parameters(self): def supported_parameters(self):
return [ return [
@@ -184,6 +225,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
self.current_status = status self.current_status = status
data = { data = {
"double_coercion": self.double_coercion,
"error_action": error_action.model_dump() if error_action else None, "error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(), "meta": self.Meta().model_dump(),
"enabled": self.enabled, "enabled": self.enabled,
@@ -191,15 +233,27 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
data.update(self._common_status_data()) data.update(self._common_status_data())
self.populate_extra_fields(data) self.populate_extra_fields(data)
if self.using == "VertexAI":
details = f"{model_name} (VertexAI)"
else:
details = model_name
emit( emit(
"client_status", "client_status",
message=self.client_type, message=self.client_type,
id=self.name, id=self.name,
details=model_name, details=details,
status=status if self.enabled else "disabled", status=status if self.enabled else "disabled",
data=data, data=data,
) )
def set_client_base_url(self, base_url: str | None):
if getattr(self, "client", None):
try:
self.client.http_options.base_url = base_url
except Exception as e:
log.error("Error setting client base URL", error=e, client=self.client_type)
def set_client(self, max_token_length: int = None, **kwargs): def set_client(self, max_token_length: int = None, **kwargs):
if not self.ready: if not self.ready:
log.error("Google cloud setup incomplete") log.error("Google cloud setup incomplete")
@@ -210,7 +264,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
return return
if not self.model_name: if not self.model_name:
self.model_name = "gemini-1.0-pro" self.model_name = "gemini-2.0-flash"
if max_token_length and not isinstance(max_token_length, int): if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length) max_token_length = int(max_token_length)
@@ -222,17 +276,14 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
self.max_token_length = max_token_length or 16384 self.max_token_length = max_token_length or 16384
if not self.setup_status: if self.vertexai_ready and not self.developer_api_ready:
if self.setup_status is False: self.client = genai.Client(
project_id = self.google_credentials.get("project_id") vertexai=True,
self.google_project_id = project_id project=self.google_project_id,
if self.google_credentials_path: location=self.google_location,
vertexai.init(project=project_id, location=self.google_location) )
emit("request_client_status") else:
emit("request_agent_status") self.client = genai.Client(api_key=self.api_key or None, http_options=self.http_options)
self.setup_status = True
self.model_instance = GenerativeModel(model_name=model)
log.info( log.info(
"google set client", "google set client",
@@ -241,8 +292,9 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
model=model, model=model,
) )
def response_tokens(self, response: str): def response_tokens(self, response:str):
return count_tokens(response.text) """Return token count for a response which may be a string or SDK object."""
return count_tokens(response)
def prompt_tokens(self, prompt: str): def prompt_tokens(self, prompt: str):
return count_tokens(prompt) return count_tokens(prompt)
@@ -258,6 +310,9 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
if "enabled" in kwargs: if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs) self._reconfigure_common_parameters(**kwargs)
def clean_prompt_parameters(self, parameters: dict): def clean_prompt_parameters(self, parameters: dict):
@@ -267,27 +322,53 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
if "top_k" in parameters and parameters["top_k"] == 0: if "top_k" in parameters and parameters["top_k"] == 0:
del parameters["top_k"] del parameters["top_k"]
def prompt_template(self, system_message: str, prompt: str):
"""
Google handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str): async def generate(self, prompt: str, parameters: dict, kind: str):
""" """
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if not self.ready: if not self.ready:
raise Exception("Google cloud setup incomplete") raise Exception("Google setup incomplete")
right = None prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
human_message = prompt.strip() human_message = prompt.strip()
system_message = self.get_system_message(kind) system_message = self.get_system_message(kind)
contents = [
genai_types.Content(
role="user",
parts=[
genai_types.Part.from_text(
text=human_message,
)
]
)
]
if coercion_prompt:
contents.append(
genai_types.Content(
role="model",
parts=[
genai_types.Part.from_text(
text=coercion_prompt,
)
]
)
)
self.log.debug( self.log.debug(
"generate", "generate",
base_url=self.base_url,
prompt=prompt[:128] + " ...", prompt=prompt[:128] + " ...",
parameters=parameters, parameters=parameters,
system_message=system_message, system_message=system_message,
@@ -296,48 +377,53 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
) )
try: try:
# Use streaming so we can update_Request_tokens incrementally
#stream = await chat.send_message_async(
# human_message,
# safety_settings=self.safety_settings,
# generation_config=parameters,
# stream=True
#)
chat = self.model_instance.start_chat()
response = await chat.send_message_async( stream = await self.client.aio.models.generate_content_stream(
human_message, model=self.model_name,
safety_settings=self.safety_settings, contents=contents,
generation_config=parameters, config=genai_types.GenerateContentConfig(
safety_settings=self.safety_settings,
http_options=self.http_options,
**parameters
),
) )
response = ""
async for chunk in stream:
# For each streamed chunk, append content and update token counts
content_piece = getattr(chunk, "text", None)
if not content_piece:
# Some SDK versions wrap text under candidates[0].text
try:
content_piece = chunk.candidates[0].text # type: ignore
except Exception:
content_piece = None
if content_piece:
response += content_piece
# Incrementally update token usage
self.update_request_tokens(count_tokens(content_piece))
# Store total token accounting for prompt/response
self._returned_prompt_tokens = self.prompt_tokens(prompt) self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response) self._returned_response_tokens = self.response_tokens(response)
response = response.text
log.debug("generated response", response=response) log.debug("generated response", response=response)
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response return response
# except PermissionDeniedError as e: except APIError as e:
# self.log.error("generate error", e=e)
# emit("status", message="google API: Permission Denied", status="error")
# return ""
except ResourceExhausted as e:
self.log.error("generate error", e=e) self.log.error("generate error", e=e)
emit("status", message="google API: Quota Limit reached", status="error") emit("status", message="google API: API Error", status="error")
return "" return ""
except ResponseValidationError as e:
self.log.error("generate error", e=e)
emit(
"status",
message="google API: Response Validation Error",
status="error",
)
if not self.disable_safety_settings:
return "Failed to generate response. Probably due to safety settings, you can turn them off in the client settings."
return "Failed to generate response. Please check logs."
except Exception as e: except Exception as e:
raise raise

View File

@@ -2,11 +2,16 @@ import pydantic
import structlog import structlog
from groq import AsyncGroq, PermissionDeniedError from groq import AsyncGroq, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, ExtraField
from talemate.client.registry import register from talemate.client.registry import register
from talemate.config import load_config from talemate.config import load_config
from talemate.emit import emit from talemate.emit import emit
from talemate.emit.signals import handlers from talemate.emit.signals import handlers
from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
__all__ = [ __all__ = [
"GroqClient", "GroqClient",
@@ -23,13 +28,13 @@ SUPPORTED_MODELS = [
JSON_OBJECT_RESPONSE_MODELS = [] JSON_OBJECT_RESPONSE_MODELS = []
class Defaults(pydantic.BaseModel): class Defaults(EndpointOverride, pydantic.BaseModel):
max_token_length: int = 8192 max_token_length: int = 8192
model: str = "llama3-70b-8192" model: str = "llama3-70b-8192"
@register() @register()
class GroqClient(ClientBase): class GroqClient(EndpointOverrideMixin, ClientBase):
""" """
OpenAI client for generating text. OpenAI client for generating text.
""" """
@@ -47,10 +52,13 @@ class GroqClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False requires_prompt_template: bool = False
defaults: Defaults = Defaults() defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="llama3-70b-8192", **kwargs): def __init__(self, model="llama3-70b-8192", **kwargs):
self.model_name = model self.model_name = model
self.api_key_status = None self.api_key_status = None
# Apply any endpoint override parameters provided via kwargs before creating client
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config() self.config = load_config()
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -100,21 +108,27 @@ class GroqClient(ClientBase):
self.current_status = status self.current_status = status
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
# Include shared/common status data (rate limit, etc.)
data.update(self._common_status_data())
emit( emit(
"client_status", "client_status",
message=self.client_type, message=self.client_type,
id=self.name, id=self.name,
details=model_name, details=model_name,
status=status if self.enabled else "disabled", status=status if self.enabled else "disabled",
data={ data=data,
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
},
) )
def set_client(self, max_token_length: int = None): def set_client(self, max_token_length: int = None):
if not self.groq_api_key: # Determine if we should use the globally configured API key or the override key
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
# No API key and no endpoint override cannot initialize client correctly
self.client = AsyncGroq(api_key="sk-1111") self.client = AsyncGroq(api_key="sk-1111")
log.error("No groq.ai API key set") log.error("No groq.ai API key set")
if self.api_key_status: if self.api_key_status:
@@ -131,7 +145,8 @@ class GroqClient(ClientBase):
model = self.model_name model = self.model_name
self.client = AsyncGroq(api_key=self.groq_api_key) # Use the override values (if any) when constructing the Groq client
self.client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
self.max_token_length = max_token_length or 16384 self.max_token_length = max_token_length or 16384
if not self.api_key_status: if not self.api_key_status:
@@ -155,6 +170,11 @@ class GroqClient(ClientBase):
if "enabled" in kwargs: if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
# Allow dynamic reconfiguration of endpoint override parameters
self._reconfigure_endpoint_override(**kwargs)
# Reconfigure any common parameters (rate limit, data format, etc.)
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event): def on_config_saved(self, event):
config = event.data config = event.data
self.config = config self.config = config
@@ -184,7 +204,7 @@ class GroqClient(ClientBase):
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if not self.groq_api_key: if not self.groq_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No groq.ai API key set") raise Exception("No groq.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS

View File

@@ -1,6 +1,10 @@
import random import random
import re import json
import sseclient
import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import requests
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
# import urljoin # import urljoin
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
@@ -10,12 +14,14 @@ import structlog
import talemate.util as util import talemate.util as util
from talemate.client.base import ( from talemate.client.base import (
STOPPING_STRINGS,
ClientBase, ClientBase,
Defaults, Defaults,
ParameterReroute, ParameterReroute,
ClientEmbeddingsStatus
) )
from talemate.client.registry import register from talemate.client.registry import register
import talemate.emit.async_signals as async_signals
if TYPE_CHECKING: if TYPE_CHECKING:
from talemate.agents.visual import VisualBase from talemate.agents.visual import VisualBase
@@ -28,6 +34,37 @@ class KoboldCppClientDefaults(Defaults):
api_key: str = "" api_key: str = ""
class KoboldEmbeddingFunction(EmbeddingFunction):
def __init__(self, api_url: str, model_name: str = None):
"""
Initialize the embedding function with the KoboldCPP API endpoint.
"""
self.api_url = api_url
self.model_name = model_name
def __call__(self, texts: Documents) -> Embeddings:
"""
Embed a list of input texts using the KoboldCPP embeddings endpoint.
"""
log.debug("KoboldCppEmbeddingFunction", api_url=self.api_url, model_name=self.model_name)
# Prepare the request payload for KoboldCPP. Include model name if required.
payload = {"input": texts}
if self.model_name is not None:
payload["model"] = self.model_name # e.g. the model's name/ID if needed
# Send POST request to the local KoboldCPP embeddings endpoint
response = requests.post(self.api_url, json=payload)
response.raise_for_status() # Throw an error if the request failed (e.g., connection issue)
# Parse the JSON response to extract embedding vectors
data = response.json()
# The 'data' field contains a list of embeddings (one per input)
embedding_results = data.get("data", [])
embeddings = [item["embedding"] for item in embedding_results]
return embeddings
@register() @register()
class KoboldCppClient(ClientBase): class KoboldCppClient(ClientBase):
auto_determine_prompt_template: bool = True auto_determine_prompt_template: bool = True
@@ -58,7 +95,7 @@ class KoboldCppClient(ClientBase):
kcpp has two apis kcpp has two apis
open-ai implementation at /v1 open-ai implementation at /v1
their own implenation at /api/v1 their own implementation at /api/v1
""" """
return "/api/v1" not in self.api_url return "/api/v1" not in self.api_url
@@ -77,8 +114,8 @@ class KoboldCppClient(ClientBase):
# join /v1/completions # join /v1/completions
return urljoin(self.api_url, "completions") return urljoin(self.api_url, "completions")
else: else:
# join /api/v1/generate # join /api/extra/generate/stream
return urljoin(self.api_url, "generate") return urljoin(self.api_url.replace("v1", "extra"), "generate/stream")
@property @property
def max_tokens_param_name(self): def max_tokens_param_name(self):
@@ -132,6 +169,21 @@ class KoboldCppClient(ClientBase):
"temperature", "temperature",
] ]
@property
def supports_embeddings(self) -> bool:
return True
@property
def embeddings_url(self) -> str:
if self.is_openai:
return urljoin(self.api_url, "embeddings")
else:
return urljoin(self.api_url, "api/extra/embeddings")
@property
def embeddings_function(self):
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
def api_endpoint_specified(self, url: str) -> bool: def api_endpoint_specified(self, url: str) -> bool:
return "/v1" in self.api_url return "/v1" in self.api_url
@@ -152,15 +204,63 @@ class KoboldCppClient(ClientBase):
self.api_key = kwargs.get("api_key", self.api_key) self.api_key = kwargs.get("api_key", self.api_key)
self.ensure_api_endpoint_specified() self.ensure_api_endpoint_specified()
async def get_model_name(self): async def get_embeddings_model_name(self):
self.ensure_api_endpoint_specified() # if self._embeddings_model_name is set, return it
if self.embeddings_model_name:
return self.embeddings_model_name
# otherwise, get the model name by doing a request to
# the embeddings endpoint with a single character
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.post(
self.api_url_for_model, self.embeddings_url,
json={"input": ["test"]},
timeout=2, timeout=2,
headers=self.request_headers, headers=self.request_headers,
) )
response_data = response.json()
self._embeddings_model_name = response_data.get("model")
return self._embeddings_model_name
async def get_embeddings_status(self):
url_version = urljoin(self.api_url, "api/extra/version")
async with httpx.AsyncClient() as client:
response = await client.get(url_version, timeout=2)
response_data = response.json()
self._embeddings_status = response_data.get("embeddings", False)
if not self.embeddings_status or self.embeddings_model_name:
return
await self.get_embeddings_model_name()
log.debug("KoboldCpp embeddings are enabled, suggesting embeddings", model_name=self.embeddings_model_name)
self.set_embeddings()
await async_signals.get("client.embeddings_available").send(
ClientEmbeddingsStatus(
client=self,
embedding_name=self.embeddings_model_name,
)
)
async def get_model_name(self):
self.ensure_api_endpoint_specified()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
self.api_url_for_model,
timeout=2,
headers=self.request_headers,
)
except Exception:
self._embeddings_model_name = None
raise
if response.status_code == 404: if response.status_code == 404:
raise KeyError(f"Could not find model info at: {self.api_url_for_model}") raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
@@ -176,6 +276,8 @@ class KoboldCppClient(ClientBase):
if model_name: if model_name:
model_name = model_name.split("/")[-1] model_name = model_name.split("/")[-1]
await self.get_embeddings_status()
return model_name return model_name
async def tokencount(self, content: str) -> int: async def tokencount(self, content: str) -> int:
@@ -228,6 +330,43 @@ class KoboldCppClient(ClientBase):
""" """
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if self.is_openai:
return await self._generate_openai(prompt, parameters, kind)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._generate_kcpp_stream, prompt, parameters, kind)
def _generate_kcpp_stream(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ")
response = ""
parameters["stream"] = True
stream_response = requests.post(
self.api_url_for_generation,
json=parameters,
timeout=None,
headers=self.request_headers,
stream=True,
)
stream_response.raise_for_status()
sse = sseclient.SSEClient(stream_response)
for event in sse.events():
payload = json.loads(event.data)
chunk = payload['token']
response += chunk
self.update_request_tokens(self.count_tokens(chunk))
return response
async def _generate_openai(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ") parameters["prompt"] = prompt.strip(" ")

View File

@@ -54,18 +54,55 @@ class LMStudioClient(ClientBase):
async def generate(self, prompt: str, parameters: dict, kind: str): async def generate(self, prompt: str, parameters: dict, kind: str):
""" """
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters using a streaming
request so that token usage can be tracked incrementally via
`update_request_tokens`.
""" """
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters) self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
)
try: try:
response = await self.client.chat.completions.create( # Send the request in streaming mode so we can update token counts
model=self.model_name, messages=[human_message], **parameters stream = await self.client.completions.create(
model=self.model_name,
prompt=prompt,
stream=True,
**parameters,
) )
return response.choices[0].message.content response = ""
# Iterate over streamed chunks and accumulate the response while
# incrementally updating the token counter
async for chunk in stream:
if not chunk.choices:
continue
content_piece = chunk.choices[0].text
response += content_piece
# Track token usage incrementally
self.update_request_tokens(self.count_tokens(content_piece))
# Store overall token accounting once the stream is finished
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
return response
except Exception as e: except Exception as e:
self.log.error("generate error", e=e) self.log.error("generate error", e=e)
return "" return ""
# ------------------------------------------------------------------
# Token helpers
# ------------------------------------------------------------------
def response_tokens(self, response: str):
"""Count tokens in a model response string."""
return self.count_tokens(response)
def prompt_tokens(self, prompt: str):
"""Count tokens in a prompt string."""
return self.count_tokens(prompt)

View File

@@ -4,9 +4,14 @@ from typing import Literal
from mistralai import Mistral from mistralai import Mistral
from mistralai.models.sdkerror import SDKError from mistralai.models.sdkerror import SDKError
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
from talemate.client.registry import register from talemate.client.registry import register
from talemate.config import load_config from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.emit import emit from talemate.emit import emit
from talemate.emit.signals import handlers from talemate.emit.signals import handlers
@@ -35,14 +40,15 @@ JSON_OBJECT_RESPONSE_MODELS = [
] ]
class Defaults(CommonDefaults, pydantic.BaseModel): class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384 max_token_length: int = 16384
model: str = "open-mixtral-8x22b" model: str = "open-mixtral-8x22b"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register() @register()
class MistralAIClient(ClientBase): class MistralAIClient(EndpointOverrideMixin, ClientBase):
""" """
OpenAI client for generating text. OpenAI client for generating text.
""" """
@@ -52,6 +58,7 @@ class MistralAIClient(ClientBase):
auto_break_repetition_enabled = False auto_break_repetition_enabled = False
# TODO: make this configurable? # TODO: make this configurable?
decensor_enabled = True decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta): class Meta(ClientBase.Meta):
name_prefix: str = "MistralAI" name_prefix: str = "MistralAI"
@@ -60,16 +67,18 @@ class MistralAIClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False requires_prompt_template: bool = False
defaults: Defaults = Defaults() defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="open-mixtral-8x22b", **kwargs): def __init__(self, model="open-mixtral-8x22b", **kwargs):
self.model_name = model self.model_name = model
self.api_key_status = None self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config() self.config = load_config()
super().__init__(**kwargs) super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved) handlers["config_saved"].connect(self.on_config_saved)
@property @property
def mistralai_api_key(self): def mistral_api_key(self):
return self.config.get("mistralai", {}).get("api_key") return self.config.get("mistralai", {}).get("api_key")
@property @property
@@ -85,7 +94,7 @@ class MistralAIClient(ClientBase):
if processing is not None: if processing is not None:
self.processing = processing self.processing = processing
if self.mistralai_api_key: if self.mistral_api_key:
status = "busy" if self.processing else "idle" status = "busy" if self.processing else "idle"
model_name = self.model_name model_name = self.model_name
else: else:
@@ -122,7 +131,7 @@ class MistralAIClient(ClientBase):
) )
def set_client(self, max_token_length: int = None): def set_client(self, max_token_length: int = None):
if not self.mistralai_api_key: if not self.mistral_api_key and not self.endpoint_override_base_url_configured:
self.client = Mistral(api_key="sk-1111") self.client = Mistral(api_key="sk-1111")
log.error("No mistral.ai API key set") log.error("No mistral.ai API key set")
if self.api_key_status: if self.api_key_status:
@@ -139,7 +148,7 @@ class MistralAIClient(ClientBase):
model = self.model_name model = self.model_name
self.client = Mistral(api_key=self.mistralai_api_key) self.client = Mistral(api_key=self.api_key, server_url=self.base_url)
self.max_token_length = max_token_length or 16384 self.max_token_length = max_token_length or 16384
if not self.api_key_status: if not self.api_key_status:
@@ -160,6 +169,7 @@ class MistralAIClient(ClientBase):
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs) self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
if kwargs.get("model"): if kwargs.get("model"):
self.model_name = kwargs["model"] self.model_name = kwargs["model"]
@@ -201,7 +211,7 @@ class MistralAIClient(ClientBase):
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if not self.mistralai_api_key: if not self.mistral_api_key:
raise Exception("No mistral.ai API key set") raise Exception("No mistral.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
@@ -224,22 +234,36 @@ class MistralAIClient(ClientBase):
self.log.debug( self.log.debug(
"generate", "generate",
base_url=self.base_url,
prompt=prompt[:128] + " ...", prompt=prompt[:128] + " ...",
parameters=parameters, parameters=parameters,
system_message=system_message, system_message=system_message,
) )
try: try:
response = await self.client.chat.complete_async( event_stream = await self.client.chat.stream_async(
model=self.model_name, model=self.model_name,
messages=messages, messages=messages,
**parameters, **parameters,
) )
self._returned_prompt_tokens = self.prompt_tokens(response) response = ""
self._returned_response_tokens = self.response_tokens(response)
response = response.choices[0].message.content completion_tokens = 0
prompt_tokens = 0
async for event in event_stream:
if event.data.choices:
response += event.data.choices[0].delta.content
self.update_request_tokens(self.count_tokens(event.data.choices[0].delta.content))
if event.data.usage:
completion_tokens += event.data.usage.completion_tokens
prompt_tokens += event.data.usage.prompt_tokens
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
#response = response.choices[0].message.content
# older models don't support json_object response coersion # older models don't support json_object response coersion
# and often like to return the response wrapped in ```json # and often like to return the response wrapped in ```json

View File

@@ -129,6 +129,12 @@ class ModelPrompt:
return prompt return prompt
def clean_model_name(self, model_name: str):
"""
Clean the model name to be used in the template file name.
"""
return model_name.replace("/", "__").replace(":", "_")
def get_template(self, model_name: str): def get_template(self, model_name: str):
""" """
Will attempt to load an LLM prompt template - this supports Will attempt to load an LLM prompt template - this supports
@@ -137,7 +143,7 @@ class ModelPrompt:
matches = [] matches = []
cleaned_model_name = model_name.replace("/", "__") cleaned_model_name = self.clean_model_name(model_name)
# Iterate over all templates in the loader's directory # Iterate over all templates in the loader's directory
for template_name in self.env.list_templates(): for template_name in self.env.list_templates():
@@ -166,7 +172,7 @@ class ModelPrompt:
template_name = template_name.split(".jinja2")[0] template_name = template_name.split(".jinja2")[0]
cleaned_model_name = model_name.replace("/", "__") cleaned_model_name = self.clean_model_name(model_name)
shutil.copyfile( shutil.copyfile(
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"), os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),

View File

@@ -0,0 +1,313 @@
import asyncio
import structlog
import httpx
import ollama
import time
from typing import Union
from talemate.client.base import STOPPING_STRINGS, ClientBase, CommonDefaults, ErrorAction, ParameterReroute, ExtraField
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.ollama")
FETCH_MODELS_INTERVAL = 15
class OllamaClientDefaults(CommonDefaults):
api_url: str = "http://localhost:11434" # Default Ollama URL
model: str = "" # Allow empty default, will fetch from Ollama
api_handles_prompt_template: bool = False
allow_thinking: bool = False
class ClientConfig(BaseClientConfig):
api_handles_prompt_template: bool = False
allow_thinking: bool = False
@register()
class OllamaClient(ClientBase):
"""
Ollama client for generating text using locally hosted models.
"""
auto_determine_prompt_template: bool = True
client_type = "ollama"
conversation_retries = 0
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Ollama"
title: str = "Ollama"
enable_api_auth: bool = False
manual_model: bool = True
manual_model_choices: list[str] = [] # Will be overridden by finalize_status
defaults: OllamaClientDefaults = OllamaClientDefaults()
extra_fields: dict[str, ExtraField] = {
"api_handles_prompt_template": ExtraField(
name="api_handles_prompt_template",
type="bool",
label="API handles prompt template",
required=False,
description="Let Ollama handle the prompt template. Only do this if you don't know which prompt template to use. Letting talemate handle the prompt template will generally lead to improved responses.",
),
"allow_thinking": ExtraField(
name="allow_thinking",
type="bool",
label="Allow thinking",
required=False,
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
)
}
@property
def supported_parameters(self):
# Parameters supported by Ollama's generate endpoint
# Based on the API documentation
return [
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
ParameterReroute(
talemate_parameter="repetition_penalty",
client_parameter="repeat_penalty"
),
ParameterReroute(
talemate_parameter="max_tokens",
client_parameter="num_predict"
),
"stopping_strings",
# internal parameters that will be removed before sending
"extra_stopping_strings",
]
@property
def can_be_coerced(self):
"""
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
return not self.api_handles_prompt_template
@property
def can_think(self) -> bool:
"""
Allow reasoning models to think before responding.
"""
return self.allow_thinking
def __init__(self, model=None, api_handles_prompt_template=False, allow_thinking=False, **kwargs):
self.model_name = model
self.api_handles_prompt_template = api_handles_prompt_template
self.allow_thinking = allow_thinking
self._available_models = []
self._models_last_fetched = 0
self.client = None
super().__init__(**kwargs)
def set_client(self, **kwargs):
"""
Initialize the Ollama client with the API URL.
"""
# Update model if provided
if kwargs.get("model"):
self.model_name = kwargs["model"]
# Create async client with the configured API URL
# Ollama's AsyncClient expects just the base URL without any path
self.client = ollama.AsyncClient(host=self.api_url)
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
self.allow_thinking = kwargs.get(
"allow_thinking", self.allow_thinking
)
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
Raises an error if no model name is returned.
:return: None
"""
if self.processing:
self.emit_status()
return
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
# instead of using the client (which apparently cannot set a timeout per endpoint)
# we use httpx to check {api_url}/api/version to see if the server is running
# use a timeout of 2 seconds
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url}/api/version", timeout=2)
response.raise_for_status()
# if the server is running, fetch the available models
await self.fetch_available_models()
except Exception as e:
log.error("Failed to fetch models from Ollama", error=str(e))
self.connected = False
self.emit_status()
return
await super().status()
async def fetch_available_models(self):
"""
Fetch list of available models from Ollama.
"""
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
return self._available_models
response = await self.client.list()
models = response.get("models", [])
model_names = [model.model for model in models]
self._available_models = sorted(model_names)
self._models_last_fetched = time.time()
return self._available_models
def finalize_status(self, data: dict):
"""
Finalizes the status data for the client.
"""
data["manual_model_choices"] = self._available_models
return data
async def get_model_name(self):
return self.model_name
def prompt_template(self, system_message: str, prompt: str):
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
"""
Tune parameters for Ollama's generate endpoint.
"""
super().tune_prompt_parameters(parameters, kind)
# Build stopping strings list
parameters["stop"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
# Ollama uses num_predict instead of max_tokens
if "max_tokens" in parameters:
parameters["num_predict"] = parameters["max_tokens"]
def clean_prompt_parameters(self, parameters: dict):
"""
Clean and prepare parameters for Ollama API.
"""
# First let parent class handle parameter reroutes and cleanup
super().clean_prompt_parameters(parameters)
# Remove our internal parameters
if "extra_stopping_strings" in parameters:
del parameters["extra_stopping_strings"]
if "stopping_strings" in parameters:
del parameters["stopping_strings"]
if "stream" in parameters:
del parameters["stream"]
# Remove max_tokens as we've already converted it to num_predict
if "max_tokens" in parameters:
del parameters["max_tokens"]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generate text using Ollama's generate endpoint.
"""
if not self.model_name:
# Try to get a model name
await self.get_model_name()
if not self.model_name:
raise Exception("No model specified or available in Ollama")
# Prepare options for Ollama
options = parameters
options["num_ctx"] = self.max_token_length
try:
# Use generate endpoint for completion
stream = await self.client.generate(
model=self.model_name,
prompt=prompt.strip(),
options=options,
raw=self.can_be_coerced,
think=self.can_think,
stream=True,
)
response = ""
async for part in stream:
content = part.response
response += content
self.update_request_tokens(self.count_tokens(content))
# Extract the response text
return response
except Exception as e:
log.error("Ollama generation error", error=str(e), model=self.model_name)
raise ErrorAction(
message=f"Ollama generation failed: {str(e)}",
title="Generation Error"
)
async def abort_generation(self):
"""
Ollama doesn't have a direct abort endpoint, but we can try to stop the model.
"""
# This is a no-op for now as Ollama doesn't expose an abort endpoint
# in the Python client
pass
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
Adjusts temperature and repetition_penalty by random values.
"""
import random
temp = prompt_config["temperature"]
rep_pen = prompt_config.get("repetition_penalty", 1.0)
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
prompt_config["repetition_penalty"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)
def reconfigure(self, **kwargs):
"""
Reconfigure the client with new settings.
"""
# Handle model update
if kwargs.get("model"):
self.model_name = kwargs["model"]
super().reconfigure(**kwargs)
# Re-initialize client if API URL changed or model changed
if "api_url" in kwargs or "model" in kwargs:
self.set_client(**kwargs)
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]

View File

@@ -5,9 +5,14 @@ import structlog
import tiktoken import tiktoken
from openai import AsyncOpenAI, PermissionDeniedError from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults from talemate.client.base import ClientBase, ErrorAction, CommonDefaults, ExtraField
from talemate.client.registry import register from talemate.client.registry import register
from talemate.config import load_config from talemate.client.remote import (
EndpointOverride,
EndpointOverrideMixin,
endpoint_override_extra_fields,
)
from talemate.config import Client as BaseClientConfig, load_config
from talemate.emit import emit from talemate.emit import emit
from talemate.emit.signals import handlers from talemate.emit.signals import handlers
@@ -79,9 +84,6 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
elif "gpt-3.5-turbo" in model: elif "gpt-3.5-turbo" in model:
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model or "o1" in model or "o3" in model: elif "gpt-4" in model or "o1" in model or "o3" in model:
print(
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
return num_tokens_from_messages(messages, model="gpt-4-0613") return num_tokens_from_messages(messages, model="gpt-4-0613")
else: else:
raise NotImplementedError( raise NotImplementedError(
@@ -102,13 +104,15 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
return num_tokens return num_tokens
class Defaults(CommonDefaults, pydantic.BaseModel): class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384 max_token_length: int = 16384
model: str = "gpt-4o" model: str = "gpt-4o"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register() @register()
class OpenAIClient(ClientBase): class OpenAIClient(EndpointOverrideMixin, ClientBase):
""" """
OpenAI client for generating text. OpenAI client for generating text.
""" """
@@ -118,6 +122,7 @@ class OpenAIClient(ClientBase):
auto_break_repetition_enabled = False auto_break_repetition_enabled = False
# TODO: make this configurable? # TODO: make this configurable?
decensor_enabled = False decensor_enabled = False
config_cls = ClientConfig
class Meta(ClientBase.Meta): class Meta(ClientBase.Meta):
name_prefix: str = "OpenAI" name_prefix: str = "OpenAI"
@@ -126,10 +131,11 @@ class OpenAIClient(ClientBase):
manual_model_choices: list[str] = SUPPORTED_MODELS manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False requires_prompt_template: bool = False
defaults: Defaults = Defaults() defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="gpt-4o", **kwargs): def __init__(self, model="gpt-4o", **kwargs):
self.model_name = model self.model_name = model
self.api_key_status = None self.api_key_status = None
self._reconfigure_endpoint_override(**kwargs)
self.config = load_config() self.config = load_config()
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -192,7 +198,7 @@ class OpenAIClient(ClientBase):
) )
def set_client(self, max_token_length: int = None): def set_client(self, max_token_length: int = None):
if not self.openai_api_key: if not self.openai_api_key and not self.endpoint_override_base_url_configured:
self.client = AsyncOpenAI(api_key="sk-1111") self.client = AsyncOpenAI(api_key="sk-1111")
log.error("No OpenAI API key set") log.error("No OpenAI API key set")
if self.api_key_status: if self.api_key_status:
@@ -209,7 +215,7 @@ class OpenAIClient(ClientBase):
model = self.model_name model = self.model_name
self.client = AsyncOpenAI(api_key=self.openai_api_key) self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
if model == "gpt-3.5-turbo": if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096) self.max_token_length = min(max_token_length or 4096, 4096)
elif model == "gpt-4": elif model == "gpt-4":
@@ -247,6 +253,7 @@ class OpenAIClient(ClientBase):
self.enabled = bool(kwargs["enabled"]) self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs) self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
def on_config_saved(self, event): def on_config_saved(self, event):
config = event.data config = event.data
@@ -278,7 +285,7 @@ class OpenAIClient(ClientBase):
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters.
""" """
if not self.openai_api_key: if not self.openai_api_key and not self.endpoint_override_base_url_configured:
raise Exception("No OpenAI API key set") raise Exception("No OpenAI API key set")
# only gpt-4-* supports enforcing json object # only gpt-4-* supports enforcing json object
@@ -333,13 +340,28 @@ class OpenAIClient(ClientBase):
) )
try: try:
response = await self.client.chat.completions.create( stream = await self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=messages, messages=messages,
stream=True,
**parameters, **parameters,
) )
response = response.choices[0].message.content response = ""
# Iterate over streamed chunks
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta and getattr(delta, "content", None):
content_piece = delta.content
response += content_piece
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
#self._returned_prompt_tokens = self.prompt_tokens(prompt)
#self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion # older models don't support json_object response coersion
# and often like to return the response wrapped in ```json # and often like to return the response wrapped in ```json

View File

@@ -0,0 +1,329 @@
import pydantic
import structlog
import httpx
import asyncio
import json
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"OpenRouterClient",
]
log = structlog.get_logger("talemate.client.openrouter")
# Available models will be populated when first client with API key is initialized
AVAILABLE_MODELS = []
DEFAULT_MODEL = ""
MODELS_FETCHED = False
async def fetch_available_models(api_key: str = None):
"""Fetch available models from OpenRouter API"""
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
if not api_key:
return []
if MODELS_FETCHED:
return AVAILABLE_MODELS
# Only fetch if we haven't already or if explicitly requested
if AVAILABLE_MODELS and not api_key:
return AVAILABLE_MODELS
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://openrouter.ai/api/v1/models",
timeout=10.0
)
if response.status_code == 200:
data = response.json()
models = []
for model in data.get("data", []):
model_id = model.get("id")
if model_id:
models.append(model_id)
AVAILABLE_MODELS = sorted(models)
log.debug(f"Fetched {len(AVAILABLE_MODELS)} models from OpenRouter")
else:
log.warning(f"Failed to fetch models from OpenRouter: {response.status_code}")
except Exception as e:
log.error(f"Error fetching models from OpenRouter: {e}")
MODELS_FETCHED = True
return AVAILABLE_MODELS
def fetch_models_sync(event):
api_key = event.data.get("openrouter", {}).get("api_key")
loop = asyncio.get_event_loop()
loop.run_until_complete(fetch_available_models(api_key))
handlers["config_saved"].connect(fetch_models_sync)
handlers["talemate_started"].connect(fetch_models_sync)
class Defaults(CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = DEFAULT_MODEL
@register()
class OpenRouterClient(ClientBase):
"""
OpenRouter client for generating text using various models.
"""
client_type = "openrouter"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
class Meta(ClientBase.Meta):
name_prefix: str = "OpenRouter"
title: str = "OpenRouter"
manual_model: bool = True
manual_model_choices: list[str] = pydantic.Field(default_factory=lambda: AVAILABLE_MODELS)
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model=None, **kwargs):
self.model_name = model or DEFAULT_MODEL
self.api_key_status = None
self.config = load_config()
self._models_fetched = False
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def can_be_coerced(self) -> bool:
return True
@property
def openrouter_api_key(self):
return self.config.get("openrouter", {}).get("api_key")
@property
def supported_parameters(self):
return [
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"repetition_penalty",
"max_tokens",
]
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.openrouter_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"openrouter_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status if self.enabled else "disabled",
data=data,
)
def set_client(self, max_token_length: int = None):
# Unlike other clients, we don't need to set up a client instance
# We'll use httpx directly in the generate method
if not self.openrouter_api_key:
log.error("No OpenRouter API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = DEFAULT_MODEL
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
# Set max token length (default to 16k if not specified)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"openrouter set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=self.model_name,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
async def status(self):
# Fetch models if we have an API key and haven't fetched yet
if self.openrouter_api_key and not self._models_fetched:
self._models_fetched = True
# Update the Meta class with new model choices
self.Meta.manual_model_choices = AVAILABLE_MODELS
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
"""
Open-router handles the prompt template internally, so we just
give the prompt as is.
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters using OpenRouter API.
"""
if not self.openrouter_api_key:
raise Exception("No OpenRouter API key set")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
# Prepare messages for chat completion
messages = [
{"role": "system", "content": self.get_system_message(kind)},
{"role": "user", "content": prompt.strip()}
]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
# Prepare request payload
payload = {
"model": self.model_name,
"messages": messages,
"stream": True,
**parameters
}
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
model=self.model_name,
)
response_text = ""
buffer = ""
completion_tokens = 0
prompt_tokens = 0
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
"https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {self.openrouter_api_key}",
"Content-Type": "application/json",
},
json=payload,
timeout=120.0 # 2 minute timeout for generation
) as response:
async for chunk in response.aiter_text():
buffer += chunk
while True:
# Find the next complete SSE line
line_end = buffer.find('\n')
if line_end == -1:
break
line = buffer[:line_end].strip()
buffer = buffer[line_end + 1:]
if line.startswith('data: '):
data = line[6:]
if data == '[DONE]':
break
try:
data_obj = json.loads(data)
content = data_obj["choices"][0]["delta"].get("content")
usage = data_obj.get("usage", {})
completion_tokens += usage.get("completion_tokens", 0)
prompt_tokens += usage.get("prompt_tokens", 0)
if content:
response_text += content
# Update tokens as content streams in
self.update_request_tokens(self.count_tokens(content))
except json.JSONDecodeError:
pass
# Extract the response content
response_content = response_text
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
return response_content
except httpx.ConnectTimeout:
self.log.error("OpenRouter API timeout")
emit("status", message="OpenRouter API: Request timed out", status="error")
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message=f"OpenRouter API Error: {str(e)}", status="error")
raise

View File

@@ -1,5 +1,109 @@
__all__ = ["RemoteServiceMixin"] import pydantic
import structlog
from .base import FieldGroup, ExtraField
import talemate.ux.schema as ux_schema
log = structlog.get_logger(__name__)
__all__ = [
"RemoteServiceMixin",
"EndpointOverrideMixin",
"endpoint_override_extra_fields",
"EndpointOverrideGroup",
"EndpointOverrideField",
"EndpointOverrideBaseURLField",
"EndpointOverrideAPIKeyField",
]
def endpoint_override_extra_fields():
return {
"override_base_url": EndpointOverrideBaseURLField(),
"override_api_key": EndpointOverrideAPIKeyField(),
}
class EndpointOverride(pydantic.BaseModel):
override_base_url: str | None = None
override_api_key: str | None = None
class EndpointOverrideGroup(FieldGroup):
name: str = "endpoint_override"
label: str = "Endpoint Override"
description: str = ("Override the default base URL used by this client to access the {client_type} service API.\n\n"
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
"If the override endpoint requires an API key, enter it below.")
icon: str = "mdi-api"
class EndpointOverrideField(ExtraField):
group: EndpointOverrideGroup = pydantic.Field(default_factory=EndpointOverrideGroup)
class EndpointOverrideBaseURLField(EndpointOverrideField):
name: str = "override_base_url"
type: str = "text"
label: str = "Base URL"
required: bool = False
description: str = "Override the base URL for the remote service"
class EndpointOverrideAPIKeyField(EndpointOverrideField):
name: str = "override_api_key"
type: str = "password"
label: str = "API Key"
required: bool = False
description: str = "Override the API key for the remote service"
note: ux_schema.Note = pydantic.Field(default_factory=lambda: ux_schema.Note(
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
color="warning",
))
class EndpointOverrideMixin:
override_base_url: str | None = None
override_api_key: str | None = None
def set_client_api_key(self, api_key: str | None):
if getattr(self, "client", None):
try:
self.client.api_key = api_key
except Exception as e:
log.error("Error setting client API key", error=e, client=self.client_type)
@property
def api_key(self) -> str | None:
if self.endpoint_override_base_url_configured:
return self.override_api_key
return getattr(self, f"{self.client_type}_api_key", None)
@property
def base_url(self) -> str | None:
if self.override_base_url and self.override_base_url.strip():
return self.override_base_url
return None
@property
def endpoint_override_base_url_configured(self) -> bool:
return self.override_base_url and self.override_base_url.strip()
@property
def endpoint_override_api_key_configured(self) -> bool:
return self.override_api_key and self.override_api_key.strip()
@property
def endpoint_override_fully_configured(self) -> bool:
return self.endpoint_override_base_url_configured and self.endpoint_override_api_key_configured
def _reconfigure_endpoint_override(self, **kwargs):
if "override_base_url" in kwargs:
orig = getattr(self, "override_base_url", None)
self.override_base_url = kwargs["override_base_url"]
if getattr(self, "client", None) and orig != self.override_base_url:
log.info("Reconfiguring client base URL", new=self.override_base_url)
self.set_client(kwargs.get("max_token_length"))
if "override_api_key" in kwargs:
self.override_api_key = kwargs["override_api_key"]
self.set_client_api_key(self.override_api_key)
class RemoteServiceMixin: class RemoteServiceMixin:

View File

@@ -1,10 +1,10 @@
import random import random
import urllib
from typing import Literal from typing import Literal
import aiohttp import json
import httpx
import pydantic import pydantic
import structlog import structlog
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError from openai import PermissionDeniedError
from talemate.client.base import ClientBase, ExtraField, CommonDefaults from talemate.client.base import ClientBase, ExtraField, CommonDefaults
from talemate.client.registry import register from talemate.client.registry import register
@@ -17,61 +17,6 @@ log = structlog.get_logger("talemate.client.tabbyapi")
EXPERIMENTAL_DESCRIPTION = """Use this client to use all of TabbyAPI's features""" EXPERIMENTAL_DESCRIPTION = """Use this client to use all of TabbyAPI's features"""
class CustomAPIClient:
def __init__(self, base_url, api_key):
self.base_url = base_url
self.api_key = api_key
async def get_model_name(self):
url = urljoin(self.base_url, "model")
headers = {
"x-api-key": self.api_key,
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
response_data = await response.json()
model_name = response_data.get("id")
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
return model_name
async def create_chat_completion(self, model, messages, **parameters):
url = urljoin(self.base_url, "chat/completions")
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": messages,
**parameters,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
return await response.json()
async def create_completion(self, model, **parameters):
url = urljoin(self.base_url, "completions")
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
data = {
"model": model,
**parameters,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Request failed: {response.status}")
return await response.json()
class Defaults(CommonDefaults, pydantic.BaseModel): class Defaults(CommonDefaults, pydantic.BaseModel):
api_url: str = "http://localhost:5000/v1" api_url: str = "http://localhost:5000/v1"
api_key: str = "" api_key: str = ""
@@ -153,7 +98,6 @@ class TabbyAPIClient(ClientBase):
self.api_handles_prompt_template = kwargs.get( self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template "api_handles_prompt_template", self.api_handles_prompt_template
) )
self.client = CustomAPIClient(base_url=self.api_url, api_key=self.api_key)
self.model_name = ( self.model_name = (
kwargs.get("model") or kwargs.get("model_name") or self.model_name kwargs.get("model") or kwargs.get("model_name") or self.model_name
) )
@@ -178,49 +122,150 @@ class TabbyAPIClient(ClientBase):
return prompt return prompt
async def get_model_name(self): async def get_model_name(self):
return await self.client.get_model_name() url = urljoin(self.api_url, "model")
headers = {
"x-api-key": self.api_key,
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=10.0)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}")
response_data = response.json()
model_name = response_data.get("id")
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
return model_name
async def generate(self, prompt: str, parameters: dict, kind: str): async def generate(self, prompt: str, parameters: dict, kind: str):
""" """
Generates text from the given prompt and parameters. Generates text from the given prompt and parameters using streaming responses.
""" """
# Determine whether we are using chat or completions endpoint
is_chat = self.api_handles_prompt_template
try: try:
if self.api_handles_prompt_template: if is_chat:
# Custom API handles prompt template # Chat completions endpoint
# Use the chat completions endpoint
self.log.debug( self.log.debug(
"generate (chat/completions)", "generate (chat/completions)",
prompt=prompt[:128] + " ...", prompt=prompt[:128] + " ...",
parameters=parameters, parameters=parameters,
) )
human_message = {"role": "user", "content": prompt.strip()} human_message = {"role": "user", "content": prompt.strip()}
response = await self.client.create_chat_completion(
self.model_name, [human_message], **parameters payload = {
) "model": self.model_name,
response = response["choices"][0]["message"]["content"] "messages": [human_message],
return self.process_response_for_indirect_coercion(prompt, response) "stream": True,
"stream_options": {
"include_usage": True,
},
**parameters,
}
endpoint = "chat/completions"
else: else:
# Talemate handles prompt template # Completions endpoint
# Use the completions endpoint
self.log.debug( self.log.debug(
"generate (completions)", "generate (completions)",
prompt=prompt[:128] + " ...", prompt=prompt[:128] + " ...",
parameters=parameters, parameters=parameters,
) )
parameters["prompt"] = prompt
response = await self.client.create_completion( payload = {
self.model_name, **parameters "model": self.model_name,
) "prompt": prompt,
return response["choices"][0]["text"] "stream": True,
**parameters,
}
endpoint = "completions"
url = urljoin(self.api_url, endpoint)
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
}
response_text = ""
buffer = ""
completion_tokens = 0
prompt_tokens = 0
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
url,
headers=headers,
json=payload,
timeout=120.0
) as response:
async for chunk in response.aiter_text():
buffer += chunk
while True:
line_end = buffer.find('\n')
if line_end == -1:
break
line = buffer[:line_end].strip()
buffer = buffer[line_end + 1:]
if not line:
continue
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
try:
data_obj = json.loads(data)
choice = data_obj.get("choices", [{}])[0]
# Chat completions use delta -> content.
delta = choice.get("delta", {})
content = (
delta.get("content")
or delta.get("text")
or choice.get("text")
)
usage = data_obj.get("usage", {})
completion_tokens = usage.get("completion_tokens", 0)
prompt_tokens = usage.get("prompt_tokens", 0)
if content:
response_text += content
self.update_request_tokens(self.count_tokens(content))
except json.JSONDecodeError:
# ignore malformed json chunks
pass
# Save token stats for logging
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
if is_chat:
# Process indirect coercion
response_text = self.process_response_for_indirect_coercion(prompt, response_text)
return response_text
except PermissionDeniedError as e: except PermissionDeniedError as e:
self.log.error("generate error", e=e) self.log.error("generate error", e=e)
emit("status", message="Client API: Permission Denied", status="error") emit("status", message="Client API: Permission Denied", status="error")
return "" return ""
except httpx.ConnectTimeout:
self.log.error("API timeout")
emit("status", message="TabbyAPI: Request timed out", status="error")
return ""
except Exception as e: except Exception as e:
self.log.error("generate error", e=e) self.log.error("generate error", e=e)
emit( emit("status", message="Error during generation (check logs)", status="error")
"status", message="Error during generation (check logs)", status="error"
)
return "" return ""
def reconfigure(self, **kwargs): def reconfigure(self, **kwargs):

View File

@@ -195,6 +195,7 @@ class TextGeneratorWebuiClient(ClientBase):
payload = json.loads(event.data) payload = json.loads(event.data)
chunk = payload['choices'][0]['text'] chunk = payload['choices'][0]['text']
response += chunk response += chunk
self.update_request_tokens(self.count_tokens(chunk))
return response return response

View File

@@ -22,7 +22,10 @@ class CmdSetEnvironmentToScene(TalemateCommand):
player_character = self.scene.get_player_character() player_character = self.scene.get_player_character()
if not player_character: if not player_character:
self.system_message("No player character found") self.system_message("No characters found - cannot switch to gameplay mode.", meta={
"icon": "mdi-alert",
"color": "warning",
})
return True return True
self.scene.set_environment("scene") self.scene.set_environment("scene")

View File

@@ -93,6 +93,7 @@ class General(BaseModel):
auto_save: bool = True auto_save: bool = True
auto_progress: bool = True auto_progress: bool = True
max_backscroll: int = 512 max_backscroll: int = 512
add_default_character: bool = True
class StateReinforcementTemplate(BaseModel): class StateReinforcementTemplate(BaseModel):
@@ -161,6 +162,9 @@ class DeepSeekConfig(BaseModel):
api_key: Union[str, None] = None api_key: Union[str, None] = None
class OpenRouterConfig(BaseModel):
api_key: Union[str, None] = None
class RunPodConfig(BaseModel): class RunPodConfig(BaseModel):
api_key: Union[str, None] = None api_key: Union[str, None] = None
@@ -177,6 +181,7 @@ class CoquiConfig(BaseModel):
class GoogleConfig(BaseModel): class GoogleConfig(BaseModel):
gcloud_credentials_path: Union[str, None] = None gcloud_credentials_path: Union[str, None] = None
gcloud_location: Union[str, None] = None gcloud_location: Union[str, None] = None
api_key: Union[str, None] = None
class TTSVoiceSamples(BaseModel): class TTSVoiceSamples(BaseModel):
@@ -209,6 +214,7 @@ class EmbeddingFunctionPreset(BaseModel):
gpu_recommendation: bool = False gpu_recommendation: bool = False
local: bool = True local: bool = True
custom: bool = False custom: bool = False
client: str | None = None
@@ -506,6 +512,8 @@ class Config(BaseModel):
anthropic: AnthropicConfig = AnthropicConfig() anthropic: AnthropicConfig = AnthropicConfig()
openrouter: OpenRouterConfig = OpenRouterConfig()
cohere: CohereConfig = CohereConfig() cohere: CohereConfig = CohereConfig()
groq: GroqConfig = GroqConfig() groq: GroqConfig = GroqConfig()

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
__all__ = [
"ArchiveEntry",
]
@dataclass
class ArchiveEntry:
text: str
start: int = None
end: int = None
ts: str = None

View File

@@ -180,11 +180,11 @@ class Emitter:
def setup_emitter(self, scene: Scene = None): def setup_emitter(self, scene: Scene = None):
self.emit_for_scene = scene self.emit_for_scene = scene
def emit(self, typ: str, message: str, character: Character = None): def emit(self, typ: str, message: str, character: Character = None, **kwargs):
emit(typ, message, character=character, scene=self.emit_for_scene) emit(typ, message, character=character, scene=self.emit_for_scene, **kwargs)
def system_message(self, message: str): def system_message(self, message: str, **kwargs):
self.emit("system", message) self.emit("system", message, **kwargs)
def narrator_message(self, message: str): def narrator_message(self, message: str):
self.emit("narrator", message) self.emit("narrator", message)

View File

@@ -49,6 +49,8 @@ SpiceApplied = signal("spice_applied")
WorldSateManager = signal("world_state_manager") WorldSateManager = signal("world_state_manager")
TalemateStarted = signal("talemate_started")
handlers = { handlers = {
"system": SystemMessage, "system": SystemMessage,
"narrator": NarratorMessage, "narrator": NarratorMessage,
@@ -86,4 +88,5 @@ handlers = {
"memory_request": MemoryRequest, "memory_request": MemoryRequest,
"player_choice": PlayerChoiceMessage, "player_choice": PlayerChoiceMessage,
"world_state_manager": WorldSateManager, "world_state_manager": WorldSateManager,
"talemate_started": TalemateStarted,
} }

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pydantic
import talemate.emit.async_signals as async_signals import talemate.emit.async_signals as async_signals
@@ -29,10 +28,9 @@ class HistoryEvent(Event):
@dataclass @dataclass
class ArchiveEvent(Event): class ArchiveEvent(Event):
text: str text: str
memory_id: str = None memory_id: str
ts: str = None ts: str = None
@dataclass @dataclass
class CharacterStateEvent(Event): class CharacterStateEvent(Event):
state: str state: str

View File

@@ -465,7 +465,7 @@ class DynamicInstruction(Node):
def setup(self): def setup(self):
self.add_input("header", socket_type="str", optional=True) self.add_input("header", socket_type="str", optional=True)
self.add_input("content", socket_type="text", optional=True) self.add_input("content", socket_type="str", optional=True)
self.set_property("header", UNRESOLVED) self.set_property("header", UNRESOLVED)
self.set_property("content", UNRESOLVED) self.set_property("content", UNRESOLVED)
@@ -473,8 +473,11 @@ class DynamicInstruction(Node):
self.add_output("dynamic_instruction", socket_type="dynamic_instruction") self.add_output("dynamic_instruction", socket_type="dynamic_instruction")
async def run(self, state: GraphState): async def run(self, state: GraphState):
header = self.require_input("header") header = self.normalized_input_value("header")
content = self.require_input("content") content = self.normalized_input_value("content")
if not header or not content:
return
self.set_output_values({ self.set_output_values({
"dynamic_instruction": DynamicInstructionType(title=header, content=content) "dynamic_instruction": DynamicInstructionType(title=header, content=content)

Some files were not shown because too many files have changed in this diff Show More