0.31.0 (#193)
* some prompt cleanup * prompt tweaks * prompt tweaks * prompt tweaks * set 0.31.0 * relock * rag queries add brief analysis * brief analysis before building rag questions * rag improvements * prompt tweaks * address circular import issues * set 0.30.1 * docs * numpy to 2 * docs * prompt tweaks * prompt tweak * some template cleanup * prompt viewer increase height * fix codemirror highlighting not working * adjust details height * update default * change to log debug * allow response to scale to max height * template cleanup * prompt tweaks * first progress for installing modules to scene * package install logic * package install polish * package install polish * package install polish and fixes * refactor initial world state update and expose setting for it * add `num` property to ModuleProperty to control order of widgets * dynamic storyline package info * fix issue where deactivating player character would cause inconsistencies in the creative tools menui * cruft * add openrouter support * ollama support * refactor how model choices are loaded, so that can be done per client instance as opposed to just per client type * set num_ctx * remove debug messages * ollama tweaks * toggle for whether or not default character gets added to blank talemate scenes * narrator prompt tweaks and template cleanup * cleanup * prompt tweaks and template cleanup * prompt tweaks * fix instructor embeddings * add additional error handling to prevent broken world state templates from breaking the world editor side menu * fix openrouter breaking startup if not configured * remove debug message * promp tweaks * fix example dialogue generation no longer working * prompt tweaks and better showing of dialogue examples in conversation instructions * prompt tweak * add initial startup message * prompt tweaks * fix asset error * move complex acting instructions into the task block * fix content input socket on DynamicInstructions node * log.error with traceback over log.exception since that has a tendency to explode and hang everything * fix usse with number nodes where they would try to execute even if the incoming wire wasnt active * fix issue with editor revision events missing template_vars * DynamicInstruction node should only run if both header and content can resolve * removed remaining references to 90s adventure game writing style * prompt tweaks * support embeddings via client apis (koboldcpp) * fix label on client-api embeddings * fix issue where adding / removing an embedding preset would not be reflected immediately in the memory agent config * remove debug output * prompt tweaks * prompt tweaks * autocomplete passes message object * validate group names to be filename valid * embedded winsows env installs and up to poetry2 * version config * get-pip * relock * pin runpod * no longer needed * remove rapidfuzz dependency * nodejs directly into embedded_node without a versioned middleman dir - also remove defunct local-tts install script * fix update script * update script error handling * update.bat error handling * adjust wording * support loading jinja2 templates node modules in templates/modules * update google model list * client t/s and business indicator - also switch all clients to async streaming * formatting * support coercion for anthropic / google switch to the new google genai sdk upgrade websockets * more coercion fixes * gracefully handle keyboard interrupt * EmitSystemMessage node * allow visual prompt generation without image generation * relock * chromadb to v1 * fix error handling * fix issue where adding client model list would be empty * supress pip install warnings * allow overriding of base urls * remove key from log * add fade effect * tweak request info ux * further clarification of endpoint override api key * world state manager: fix issue that caused max changes setting to disappear from character progress config * fix issue with google safety settings off causing generation failures * update to base url should always reset the client * getattr * support v3 chara card version and attempt to future proof * client based embeddings improvements * more fixes for client based embeddings * use client icon * history management tools progress * history memory ids fixed and added validation * regenerate summary fixes * more history regeneration fixes * fix layered history gen and prompt twweaks * allow regeneration of individual layered history entries * return list of LayeredArchiveEntry * reorg for less code dupelication * new scene message renderer based on marked * add inspect functionality to history viewer * message if no history entries yet * allow adding of history entries manually * allow deletion of history * summarization unslop improvements * fix make charcter real action from worldstate listing * allow overriding length in all context generation isntructioon dialogs * fix issue where extract_list could fail with an unhandled error if the llm response didnt contain a list * update whats'new * fix issues with the new history management tools * fix check * Switch dependency handling to UV (#202) * Migrate from Poetry to uv package manager (#200) * migrate from poetry to uv package manager * Update all installation and startup scripts for uv migration * Fix pyproject.toml for uv - allow direct references for hatchling * Fix PR feedback: Restore removed functionality - Restored embedded Python/Node.js functionality in install.bat and update.bat - Restored environment variable exposure in docker-compose.yml (CUDA_AVAILABLE, port configs) - Fixed GitHub Actions branches (main, prep-* instead of main, dev) - Restored fail-fast: false and cache configuration in test.yml These changes preserve all the functionality that should not be removed during the migration from Poetry to uv. --------- Co-authored-by: Ztripez von Matérn <ztripez@bobby.se> * remove uv.lock from .gitignore * add lock file * fix install issues * warn if unable to remove legacy poetry virt env dir * uv needs to be explicitly installed into the .venv so its available * third time's the charm? * fix windows install scripts * add .venv guard to update.bat * call :die * fix docker venv install * node 21 * fix cuda install * start.bat calls install if needed * sync start-local to other startup scripts * no need to activate venv --------- Co-authored-by: Ztripez <reg@otherland.nu> Co-authored-by: Ztripez von Matérn <ztripez@bobby.se> * ignore hfhub symlink warnings * add openrouter and ollama mentions * update windows install documentation * docs * docs * fix issue with memory agent fingerprint * removing a client that supports embeddings will also remove any embedding functions it created * on invalid embeddings reset to default * docs * typo * formatting * docs * docs * install package * adjust topic * add more obvious way to exit creative mode * when importing character cards immediately persist a usable save after the restoration save --------- Co-authored-by: Ztripez <reg@otherland.nu> Co-authored-by: Ztripez von Matérn <ztripez@bobby.se>
27
.github/workflows/test.yml
vendored
@@ -2,9 +2,9 @@ name: Python Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, main, 'prep-*' ]
|
||||
branches: [ main, 'prep-*' ]
|
||||
pull_request:
|
||||
branches: [ master, main, 'prep-*' ]
|
||||
branches: [ main, 'prep-*' ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -23,25 +23,24 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install poetry
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install poetry
|
||||
pip install uv
|
||||
|
||||
- name: Cache poetry dependencies
|
||||
- name: Cache uv dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
|
||||
path: ~/.cache/uv
|
||||
key: ${{ runner.os }}-uv-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-poetry-${{ matrix.python-version }}-
|
||||
${{ runner.os }}-uv-${{ matrix.python-version }}-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv talemate_env
|
||||
source talemate_env/bin/activate
|
||||
poetry config virtualenvs.create false
|
||||
poetry install
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Setup configuration file
|
||||
run: |
|
||||
@@ -49,10 +48,10 @@ jobs:
|
||||
|
||||
- name: Download NLTK data
|
||||
run: |
|
||||
source talemate_env/bin/activate
|
||||
source .venv/bin/activate
|
||||
python -c "import nltk; nltk.download('punkt_tab')"
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
source talemate_env/bin/activate
|
||||
source .venv/bin/activate
|
||||
pytest tests/ -p no:warnings
|
||||
3
.gitignore
vendored
@@ -8,6 +8,9 @@
|
||||
talemate_env
|
||||
chroma
|
||||
config.yaml
|
||||
|
||||
# uv
|
||||
.venv/
|
||||
templates/llm-prompt/user/*.jinja2
|
||||
templates/world-state/*.yaml
|
||||
scenes/
|
||||
|
||||
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
52
Dockerfile
@@ -1,15 +1,19 @@
|
||||
# Stage 1: Frontend build
|
||||
FROM node:21 AS frontend-build
|
||||
|
||||
ENV NODE_ENV=development
|
||||
FROM node:21-slim AS frontend-build
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the frontend directory contents into the container at /app
|
||||
COPY ./talemate_frontend /app
|
||||
# Copy frontend package files
|
||||
COPY talemate_frontend/package*.json ./
|
||||
|
||||
# Install all dependencies and build
|
||||
RUN npm install && npm run build
|
||||
# Install dependencies
|
||||
RUN npm ci
|
||||
|
||||
# Copy frontend source
|
||||
COPY talemate_frontend/ ./
|
||||
|
||||
# Build frontend
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Backend build
|
||||
FROM python:3.11-slim AS backend-build
|
||||
@@ -22,30 +26,25 @@ RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install poetry
|
||||
RUN pip install poetry
|
||||
# Install uv
|
||||
RUN pip install uv
|
||||
|
||||
# Copy poetry files
|
||||
COPY pyproject.toml poetry.lock* /app/
|
||||
# Copy installation files
|
||||
COPY pyproject.toml uv.lock /app/
|
||||
|
||||
# Create a virtual environment
|
||||
RUN python -m venv /app/talemate_env
|
||||
|
||||
# Activate virtual environment and install dependencies
|
||||
RUN . /app/talemate_env/bin/activate && \
|
||||
poetry config virtualenvs.create false && \
|
||||
poetry install --only main --no-root
|
||||
|
||||
# Copy the Python source code
|
||||
# Copy the Python source code (needed for editable install)
|
||||
COPY ./src /app/src
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
RUN uv sync
|
||||
|
||||
# Conditional PyTorch+CUDA install
|
||||
ARG CUDA_AVAILABLE=false
|
||||
RUN . /app/talemate_env/bin/activate && \
|
||||
RUN . /app/.venv/bin/activate && \
|
||||
if [ "$CUDA_AVAILABLE" = "true" ]; then \
|
||||
echo "Installing PyTorch with CUDA support..." && \
|
||||
pip uninstall torch torchaudio -y && \
|
||||
pip install torch~=2.4.1 torchaudio~=2.4.1 --index-url https://download.pytorch.org/whl/cu121; \
|
||||
uv pip uninstall torch torchaudio && \
|
||||
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128; \
|
||||
fi
|
||||
|
||||
# Stage 3: Final image
|
||||
@@ -57,8 +56,11 @@ RUN apt-get update && apt-get install -y \
|
||||
bash \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv in the final stage
|
||||
RUN pip install uv
|
||||
|
||||
# Copy virtual environment from backend-build stage
|
||||
COPY --from=backend-build /app/talemate_env /app/talemate_env
|
||||
COPY --from=backend-build /app/.venv /app/.venv
|
||||
|
||||
# Copy Python source code
|
||||
COPY --from=backend-build /app/src /app/src
|
||||
@@ -83,4 +85,4 @@ EXPOSE 5050
|
||||
EXPOSE 8080
|
||||
|
||||
# Use bash as the shell, activate the virtual environment, and run backend server
|
||||
CMD ["/bin/bash", "-c", "source /app/talemate_env/bin/activate && python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050 --frontend-host 0.0.0.0 --frontend-port 8080"]
|
||||
CMD ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]
|
||||
@@ -39,12 +39,14 @@ Need help? Join the new [Discord community](https://discord.gg/8bGNRmFxMj)
|
||||
- [Cohere](https://www.cohere.com/)
|
||||
- [Groq](https://www.groq.com/)
|
||||
- [Google Gemini](https://console.cloud.google.com/)
|
||||
- [OpenRouter](https://openrouter.ai/)
|
||||
|
||||
Supported self-hosted APIs:
|
||||
- [KoboldCpp](https://koboldai.org/cpp) ([Local](https://koboldai.org/cpp), [Runpod](https://koboldai.org/runpodcpp), [VastAI](https://koboldai.org/vastcpp), also includes image gen support)
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
- [TabbyAPI](https://github.com/theroyallab/tabbyAPI/)
|
||||
- [Ollama](https://ollama.com/)
|
||||
|
||||
Generic OpenAI api implementations (tested and confirmed working):
|
||||
- [DeepInfra](https://deepinfra.com/)
|
||||
|
||||
@@ -18,4 +18,4 @@ services:
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONPATH=/app/src:$PYTHONPATH
|
||||
command: ["/bin/bash", "-c", "source /app/talemate_env/bin/activate && python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050 --frontend-host 0.0.0.0 --frontend-port 8080"]
|
||||
command: ["uv", "run", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050", "--frontend-host", "0.0.0.0", "--frontend-port", "8080"]
|
||||
166
docs/dev/howto/add-a-worldstate-template-type.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# Adding a new world-state template
|
||||
|
||||
I am writing this up as I add phrase detection functionality to the `Writing Style` template, so that in the future, hopefully when new template types need to be added this document can just given to the LLM of the month, to do it.
|
||||
|
||||
## Introduction
|
||||
|
||||
World state templates are reusable components that plug in various parts of talemate.
|
||||
|
||||
At this point there are following types:
|
||||
|
||||
- Character Attribute
|
||||
- Character Detail
|
||||
- Writing Style
|
||||
- Spice (for randomization of content during generation)
|
||||
- Scene Type
|
||||
- State Reinforcement
|
||||
|
||||
Basically whenever we want to add something reusable and customizable by the user, a world state template is likely a good solution.
|
||||
|
||||
## Steps to creating a new template type
|
||||
|
||||
### 1. Add a pydantic schema (python)
|
||||
|
||||
In `src/talemate/world_state/templates` create a new `.py` file with reasonable name.
|
||||
|
||||
In this example I am extending the `Writing Style` template to include phrase detection functionality, which will be used by the `Editor` agent to detect certain phrases and then act upon them.
|
||||
|
||||
There already is a `content.py` file - so it makes sense to just add this new functionality to this file.
|
||||
|
||||
```python
|
||||
class PhraseDetection(pydantic.BaseModel):
|
||||
phrase: str
|
||||
instructions: str
|
||||
# can be "unwanted" for now, more added later
|
||||
classification: Literal["unwanted"] = "unwanted"
|
||||
|
||||
@register("writing_style")
|
||||
class WritingStyle(Template):
|
||||
description: str | None = None
|
||||
phrases: list[PhraseDetection] = pydantic.Field(default_factory=list)
|
||||
|
||||
def render(self, scene: "Scene", character_name: str):
|
||||
return self.formatted("instructions", scene, character_name)
|
||||
```
|
||||
|
||||
If I were to create a new file I'd still want to read one of the existing files first to understand imports and style.
|
||||
|
||||
### 2. Add a vue component to allow management (vue, js)
|
||||
|
||||
Next we need to add a new vue component that exposes a UX for us to manage this new template type.
|
||||
|
||||
For this I am creating `talemate_frontend/src/components/WorldStateManagerTemplateWritingStyle.vue`.
|
||||
|
||||
## Bare Minimum Understanding for New Template Components
|
||||
|
||||
When adding a new component for managing a template type, you need to understand:
|
||||
|
||||
### Component Structure
|
||||
|
||||
1. **Props**: The component always receives an `immutableTemplate` prop with the template data.
|
||||
2. **Data Management**: Create a local copy of the template data for editing before saving back.
|
||||
3. **Emits**: Use the `update` event to send modified template data back to the parent.
|
||||
|
||||
### Core Implementation Requirements
|
||||
|
||||
1. **Template Properties**: Always include fields for `name`, `description`, and `favorite` status.
|
||||
2. **Data Binding**: Implement two-way binding with `v-model` for all editable fields.
|
||||
3. **Dirty State Tracking**: Track when changes are made but not yet saved.
|
||||
4. **Save Method**: Implement a `save()` method that emits the updated template.
|
||||
|
||||
### Component Lifecycle
|
||||
|
||||
1. **Initialization**: Use the `created` hook to initialize the local template copy.
|
||||
2. **Watching for Changes**: Set up a watcher for the `immutableTemplate` to handle external updates.
|
||||
|
||||
### UI Patterns
|
||||
|
||||
1. **Forms**: Use Vuetify form components with consistent validation.
|
||||
2. **Actions**: Provide clear user actions for editing and managing template items.
|
||||
3. **Feedback**: Give visual feedback when changes are being made or saved.
|
||||
|
||||
The WorldStateManagerTemplate components follow a consistent pattern where they:
|
||||
- Display and edit general template metadata (name, description, favorite status)
|
||||
- Provide specialized UI for the template's unique properties
|
||||
- Handle the create, read, update, delete (CRUD) operations for template items
|
||||
- Maintain data integrity by properly handling template updates
|
||||
|
||||
You absolutely should read an existing component like `WorldStateManagerTemplateWritingStyle.vue` first to get a good understanding of the implementation.
|
||||
|
||||
## Integrating with WorldStateManagerTemplates
|
||||
|
||||
After creating your template component, you need to integrate it with the WorldStateManagerTemplates component:
|
||||
|
||||
### 1. Import the Component
|
||||
|
||||
Edit `talemate_frontend/src/components/WorldStateManagerTemplates.vue` and add an import for your new component:
|
||||
|
||||
```javascript
|
||||
import WorldStateManagerTemplateWritingStyle from './WorldStateManagerTemplateWritingStyle.vue'
|
||||
```
|
||||
|
||||
### 2. Register the Component
|
||||
|
||||
Add your component to the components section of the WorldStateManagerTemplates:
|
||||
|
||||
```javascript
|
||||
components: {
|
||||
// ... existing components
|
||||
WorldStateManagerTemplateWritingStyle
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Add Conditional Rendering
|
||||
|
||||
In the template section, add a new conditional block to render your component when the template type matches:
|
||||
|
||||
```html
|
||||
<WorldStateManagerTemplateWritingStyle v-else-if="template.template_type === 'writing_style'"
|
||||
:immutableTemplate="template"
|
||||
@update="(template) => applyAndSaveTemplate(template)"
|
||||
/>
|
||||
```
|
||||
|
||||
### 4. Add Icon and Color
|
||||
|
||||
Add cases for your template type in the `iconForTemplate` and `colorForTemplate` methods:
|
||||
|
||||
```javascript
|
||||
iconForTemplate(template) {
|
||||
// ... existing conditions
|
||||
else if (template.template_type == 'writing_style') {
|
||||
return 'mdi-script-text';
|
||||
}
|
||||
return 'mdi-cube-scan';
|
||||
},
|
||||
|
||||
colorForTemplate(template) {
|
||||
// ... existing conditions
|
||||
else if (template.template_type == 'writing_style') {
|
||||
return 'highlight5';
|
||||
}
|
||||
return 'grey';
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Add Help Message
|
||||
|
||||
Add a help message for your template type in the `helpMessages` object in the data section:
|
||||
|
||||
```javascript
|
||||
helpMessages: {
|
||||
// ... existing messages
|
||||
writing_style: "Writing style templates are used to define a writing style that can be applied to the generated content. They can be used to add a specific flavor or tone. A template must explicitly support writing styles to be able to use a writing style template.",
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Update Template Type Selection
|
||||
|
||||
Add your template type to the `templateTypes` array in the data section:
|
||||
|
||||
```javascript
|
||||
templateTypes: [
|
||||
// ... existing types
|
||||
{ "title": "Writing style", "value": 'writing_style'},
|
||||
]
|
||||
```
|
||||
@@ -10,20 +10,19 @@ To run the server on a different host and port, you need to change the values pa
|
||||
|
||||
#### :material-linux: Linux
|
||||
|
||||
Copy `start.sh` to `start_custom.sh` and edit the `--host` and `--port` parameters in the `uvicorn` command.
|
||||
Copy `start.sh` to `start_custom.sh` and edit the `--host` and `--port` parameters.
|
||||
|
||||
```bash
|
||||
#!/bin/sh
|
||||
. talemate_env/bin/activate
|
||||
python src/talemate/server/run.py runserver --host 0.0.0.0 --port 1234
|
||||
uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 1234
|
||||
```
|
||||
|
||||
#### :material-microsoft-windows: Windows
|
||||
|
||||
Copy `start.bat` to `start_custom.bat` and edit the `--host` and `--port` parameters in the `uvicorn` command.
|
||||
Copy `start.bat` to `start_custom.bat` and edit the `--host` and `--port` parameters.
|
||||
|
||||
```batch
|
||||
start cmd /k "cd talemate_env\Scripts && activate && cd ../../ && python src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234"
|
||||
uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 1234
|
||||
```
|
||||
|
||||
### Letting the frontend know about the new host and port
|
||||
@@ -71,8 +70,7 @@ Copy `start.sh` to `start_custom.sh` and edit the `--frontend-host` and `--front
|
||||
|
||||
```bash
|
||||
#!/bin/sh
|
||||
. talemate_env/bin/activate
|
||||
python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
|
||||
uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
|
||||
--frontend-host localhost --frontend-port 8082
|
||||
```
|
||||
|
||||
@@ -81,7 +79,7 @@ python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5055 \
|
||||
Copy `start.bat` to `start_custom.bat` and edit the `--frontend-host` and `--frontend-port` parameters.
|
||||
|
||||
```batch
|
||||
start cmd /k "cd talemate_env\Scripts && activate && cd ../../ && python src\talemate\server\run.py runserver --host 0.0.0.0 --port 5055 --frontend-host localhost --frontend-port 8082"
|
||||
uv run src\talemate\server\run.py runserver --host 0.0.0.0 --port 5055 --frontend-host localhost --frontend-port 8082
|
||||
```
|
||||
|
||||
### Start the backend and frontend
|
||||
@@ -98,5 +96,4 @@ Start the backend and frontend as usual.
|
||||
|
||||
```batch
|
||||
start_custom.bat
|
||||
```
|
||||
|
||||
```
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
## Quick install instructions
|
||||
|
||||
### Dependencies
|
||||
@@ -7,6 +6,7 @@
|
||||
|
||||
1. node.js and npm - see instructions [here](https://nodejs.org/en/download/package-manager/)
|
||||
1. python- see instructions [here](https://www.python.org/downloads/)
|
||||
1. uv - see instructions [here](https://github.com/astral-sh/uv#installation)
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -25,19 +25,15 @@ If everything went well, you can proceed to [connect a client](../../connect-a-c
|
||||
|
||||
1. Open a terminal.
|
||||
2. Navigate to the project directory.
|
||||
3. Create a virtual environment by running `python3 -m venv talemate_env`.
|
||||
4. Activate the virtual environment by running `source talemate_env/bin/activate`.
|
||||
3. uv will automatically create a virtual environment when you run `uv venv`.
|
||||
|
||||
### Installing Dependencies
|
||||
|
||||
1. With the virtual environment activated, install poetry by running `pip install poetry`.
|
||||
2. Use poetry to install dependencies by running `poetry install`.
|
||||
1. Use uv to install dependencies by running `uv pip install -e ".[dev]"`.
|
||||
|
||||
### Running the Backend
|
||||
|
||||
1. With the virtual environment activated and dependencies installed, you can start the backend server.
|
||||
2. Navigate to the `src/talemate/server` directory.
|
||||
3. Run the server with `python run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
1. You can start the backend server using `uv run src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
|
||||
### Running the Frontend
|
||||
|
||||
@@ -45,4 +41,4 @@ If everything went well, you can proceed to [connect a client](../../connect-a-c
|
||||
2. If you haven't already, install npm dependencies by running `npm install`.
|
||||
3. Start the server with `npm run serve`.
|
||||
|
||||
Please note that you may need to set environment variables or modify the host and port as per your setup. You can refer to the `runserver.sh` and `frontend.sh` files for more details.
|
||||
Please note that you may need to set environment variables or modify the host and port as per your setup. You can refer to the various start scripts for more details.
|
||||
@@ -2,16 +2,9 @@
|
||||
|
||||
## Windows
|
||||
|
||||
### Installation fails with "Microsoft Visual C++" or "ValueError: The onnxruntime python package is not installed." errors
|
||||
|
||||
If your installation errors with a notification to upgrade "Microsoft Visual C++" go to [https://visualstudio.microsoft.com/visual-cpp-build-tools/](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and click "Download Build Tools" and run it.
|
||||
|
||||
- During installation make sure you select the C++ development package (upper left corner)
|
||||
- Run `reinstall.bat` inside talemate directory
|
||||
|
||||
### Frontend fails with errors
|
||||
|
||||
- ensure none of the directories have special characters in them, this can cause issues with the frontend. so no `(1)` in the directory name.
|
||||
- ensure none of the directories leading to your talemate directory have special characters in them, this can cause issues with the frontend. so no `(1)` in the directory name.
|
||||
|
||||
## Docker
|
||||
|
||||
|
||||
@@ -1,53 +1,32 @@
|
||||
## Quick install instructions
|
||||
|
||||
1. Download and install Python 3.10 - 3.13 from the [official Python website](https://www.python.org/downloads/windows/).
|
||||
- [Click here for direct link to python 3.11.9 download](https://www.python.org/downloads/release/python-3119/)
|
||||
- June 2025: people have reported issues with python 3.13 still, due to some dependencies not being available yet, if you run into issues during installation try downgrading.
|
||||
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/prebuilt-installer). This will also install npm.
|
||||
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases).
|
||||
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine.
|
||||
1. **Optional:** If you are using an nvidia graphics card with CUDA support you may want to also run `install-cuda.bat` **afterwards**, to install the cuda enabled version of torch - although this is only needed if you want to run some bigger embedding models where CUDA can be helpful.
|
||||
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`.
|
||||
1. Once the talemate logo shows up, navigate your browser to http://localhost:8080
|
||||
1. Download the latest Talemate release ZIP from the [Releases page](https://github.com/vegu-ai/talemate/releases) and extract it anywhere on your system (for example, `C:\Talemate`).
|
||||
2. Double-click **`start.bat`**.
|
||||
- On the very first run Talemate will automatically:
|
||||
1. Download a portable build of Python 3 and Node.js (no global installs required).
|
||||
2. Create and configure a Python virtual environment.
|
||||
3. Install all back-end and front-end dependencies with the included *uv* and *npm*.
|
||||
4. Build the web client.
|
||||
3. When the console window prints **"Talemate is now running"** and the logo appears, open your browser at **http://localhost:8080**.
|
||||
|
||||
!!! note "First start up may take a while"
|
||||
We have seen cases where the first start of talemate will sit at a black screen for a minute or two. Just wait it out, eventually the Talemate logo should show up.
|
||||
!!! note "First start can take a while"
|
||||
The initial download and dependency installation may take several minutes, especially on slow internet connections. The console will keep you updated – just wait until the Talemate logo shows up.
|
||||
|
||||
If everything went well, you can proceed to [connect a client](../../connect-a-client).
|
||||
### Optional: CUDA support
|
||||
|
||||
## Additional Information
|
||||
If you have an NVIDIA GPU and want CUDA acceleration for larger embedding models:
|
||||
|
||||
### How to Install Python
|
||||
1. Close Talemate (if it is running).
|
||||
2. Double-click **`install-cuda.bat`**. This script swaps the CPU-only Torch build for the CUDA 12.8 build.
|
||||
3. Start Talemate again via **`start.bat`**.
|
||||
|
||||
--8<-- "docs/snippets/common.md:python-versions"
|
||||
## Maintenance & advanced usage
|
||||
|
||||
1. Visit the official Python website's download page for Windows at [https://www.python.org/downloads/windows/](https://www.python.org/downloads/windows/).
|
||||
2. Find the latest updated of Python 3.13 and click on one of the download links. (You will likely want the Windows installer (64-bit))
|
||||
4. Run the installer file and follow the setup instructions. Make sure to check the box that says Add Python 3.13 to PATH before you click Install Now.
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| **`start.bat`** | Primary entry point – performs the initial install if needed and then starts Talemate. |
|
||||
| **`install.bat`** | Runs the installer without launching the server. Useful for automated setups or debugging. |
|
||||
| **`install-cuda.bat`** | Installs the CUDA-enabled Torch build (run after the regular install). |
|
||||
| **`update.bat`** | Pulls the latest changes from GitHub, updates dependencies, rebuilds the web client. |
|
||||
|
||||
### How to Install npm
|
||||
|
||||
1. Download Node.js from the official site [https://nodejs.org/en/download/prebuilt-installer](https://nodejs.org/en/download/prebuilt-installer).
|
||||
2. Run the installer (the .msi installer is recommended).
|
||||
3. Follow the prompts in the installer (Accept the license agreement, click the NEXT button a bunch of times and accept the default installation settings).
|
||||
|
||||
### Usage of the Supplied bat Files
|
||||
|
||||
#### install.bat
|
||||
|
||||
This batch file is used to set up the project on your local machine. It creates a virtual environment, activates it, installs poetry, and uses poetry to install dependencies. It then navigates to the frontend directory and installs the necessary npm packages.
|
||||
|
||||
To run this file, simply double click on it or open a command prompt in the same directory and type `install.bat`.
|
||||
|
||||
#### update.bat
|
||||
|
||||
If you are inside a git checkout of talemate you can use this to pull and reinstall talemate if there have been updates.
|
||||
|
||||
!!! note "CUDA needs to be reinstalled manually"
|
||||
Running `update.bat` will downgrade your torch install to the non-CUDA version, so if you want CUDA support you will need to run the `install-cuda.bat` script after the update is finished.
|
||||
|
||||
#### start.bat
|
||||
|
||||
This batch file is used to start the backend and frontend servers. It opens two command prompts, one for the frontend and one for the backend.
|
||||
|
||||
To run this file, simply double click on it or open a command prompt in the same directory and type `start.bat`.
|
||||
No system-wide Python or Node.js is required – Talemate uses the embedded runtimes it downloads automatically.
|
||||
BIN
docs/img/0.31.0/client-endpoint-override.png
Normal file
|
After Width: | Height: | Size: 51 KiB |
BIN
docs/img/0.31.0/client-ollama-no-model.png
Normal file
|
After Width: | Height: | Size: 6.4 KiB |
BIN
docs/img/0.31.0/client-ollama-offline.png
Normal file
|
After Width: | Height: | Size: 6.3 KiB |
BIN
docs/img/0.31.0/client-ollama-ready.png
Normal file
|
After Width: | Height: | Size: 6.5 KiB |
BIN
docs/img/0.31.0/client-ollama-select-model.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/img/0.31.0/client-ollama.png
Normal file
|
After Width: | Height: | Size: 53 KiB |
BIN
docs/img/0.31.0/client-openrouter-no-api-key.png
Normal file
|
After Width: | Height: | Size: 8.2 KiB |
BIN
docs/img/0.31.0/client-openrouter-ready.png
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
BIN
docs/img/0.31.0/client-openrouter-select-model.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.31.0/client-openrouter.png
Normal file
|
After Width: | Height: | Size: 38 KiB |
BIN
docs/img/0.31.0/history-add-entry.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
docs/img/0.31.0/history-regenerate-all.png
Normal file
|
After Width: | Height: | Size: 9.2 KiB |
BIN
docs/img/0.31.0/history.png
Normal file
|
After Width: | Height: | Size: 96 KiB |
BIN
docs/img/0.31.0/koboldcpp-embeddings.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
BIN
docs/img/0.31.0/openrouter-settings.png
Normal file
|
After Width: | Height: | Size: 42 KiB |
25
docs/user-guide/agents/memory/koboldcpp.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# KoboldCpp Embeddings
|
||||
|
||||
Talemate can leverage an embeddings model that is already loaded in your KoboldCpp instance.
|
||||
|
||||
## 1. Start KoboldCpp with an embeddings model
|
||||
|
||||
Launch KoboldCpp with the `--embeddingsmodel` flag so that it loads an embeddings-capable GGUF model alongside the main LLM:
|
||||
|
||||
```bash
|
||||
koboldcpp_cu12.exe --model google_gemma-3-27b-it-Q6_K.gguf --embeddingsmodel bge-large-en-v1.5.Q8_0.gguf
|
||||
```
|
||||
|
||||
## 2. Talemate will pick it up automatically
|
||||
|
||||
When Talemate starts, the **Memory** agent probes every connected client that advertises embedding support. If it detects that your KoboldCpp instance has an embeddings model loaded:
|
||||
|
||||
1. The Memory backend switches the current embedding to **Client API**.
|
||||
2. The `client` field in the agent details shows the name of the KoboldCpp client.
|
||||
3. A banner informs you that Talemate has switched to the new embedding. <!-- stub: screenshot -->
|
||||
|
||||

|
||||
|
||||
## 3. Reverting to a local embedding
|
||||
|
||||
Open the memory agent settings and pick a different embedding. See [Memory agent settings](/talemate/user-guide/agents/memory/settings).
|
||||
@@ -5,5 +5,6 @@ nav:
|
||||
- Google Cloud: google.md
|
||||
- Groq: groq.md
|
||||
- Mistral.ai: mistral.md
|
||||
- OpenRouter: openrouter.md
|
||||
- OpenAI: openai.md
|
||||
- ...
|
||||
11
docs/user-guide/apis/openrouter.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# OpenRouter API Setup
|
||||
|
||||
Talemate can use any model accessible through OpenRouter.
|
||||
|
||||
You need an OpenRouter API key and must set it in the application config. You can create and manage keys in your OpenRouter dashboard at [https://openrouter.ai/keys](https://openrouter.ai/keys).
|
||||
|
||||
Once you have generated a key open the Talemate settings, switch to the `APPLICATION` tab and then select the `OPENROUTER API` category. Paste your key in the **API Key** field.
|
||||
|
||||

|
||||
|
||||
Finally click **Save** to store the credentials.
|
||||
@@ -4,4 +4,5 @@ nav:
|
||||
- Recommended Local Models: recommended-models.md
|
||||
- Inference Presets: presets.md
|
||||
- Client Types: types
|
||||
- Endpoint Override: endpoint-override.md
|
||||
- ...
|
||||
24
docs/user-guide/clients/endpoint-override.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Endpoint Override
|
||||
|
||||
Starting in version 0.31.0 it is now possible for some of the remote clients to override the endpoint used for the API.
|
||||
|
||||
THis is helpful wehn you want to point the client at a proxy gateway to serve the api instead (LiteLLM for example).
|
||||
|
||||
!!! warning "Only use trusted endpoints"
|
||||
Only use endpoints that you trust and NEVER used your actual API key with them, unless you are hosting your endpoint proxy yourself.
|
||||
|
||||
If you need to provide an api key there is a separate field for that specifically in the endpoint override settings.
|
||||
|
||||
## How to use
|
||||
|
||||
Clients that support it will have a tab in their settings that allows you to override the endpoint.
|
||||
|
||||

|
||||
|
||||
##### Base URL
|
||||
|
||||
The base URL of the endpoint. For example, `http://localhost:4000` if you're running a local LiteLLM gateway,
|
||||
|
||||
##### API Key
|
||||
|
||||
The API key to use for the endpoint. This is only required if the endpoint requires an API key. This is **NOT** the API key you would use for the official API. For LiteLLM for example this could be the `general_settings.master_key` value.
|
||||
@@ -8,6 +8,8 @@ nav:
|
||||
- Mistral.ai: mistral.md
|
||||
- OpenAI: openai.md
|
||||
- OpenAI Compatible: openai-compatible.md
|
||||
- Ollama: ollama.md
|
||||
- OpenRouter: openrouter.md
|
||||
- TabbyAPI: tabbyapi.md
|
||||
- Text-Generation-WebUI: text-generation-webui.md
|
||||
- ...
|
||||
59
docs/user-guide/clients/types/ollama.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Ollama Client
|
||||
|
||||
If you want to add an Ollama client, change the `Client Type` to `Ollama`.
|
||||
|
||||

|
||||
|
||||
Click `Save` to add the client.
|
||||
|
||||
### Ollama Server
|
||||
|
||||
The client should appear in the clients list. Talemate will ping the Ollama server to verify that it is running. If the server is not reachable you will see a warning.
|
||||
|
||||

|
||||
|
||||
Make sure that the Ollama server is running (by default at `http://localhost:11434`) and that the model you want to use has been pulled.
|
||||
|
||||
It may also show a yellow dot next to it, saying that there is no model loaded.
|
||||
|
||||

|
||||
|
||||
Open the client settings by clicking the :material-cogs: icon, to select a model.
|
||||
|
||||

|
||||
|
||||
Click save and the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
### Settings
|
||||
|
||||
##### Client Name
|
||||
|
||||
A unique name for the client that makes sense to you.
|
||||
|
||||
##### API URL
|
||||
|
||||
The base URL where the Ollama HTTP endpoint is running. Defaults to `http://localhost:11434`.
|
||||
|
||||
##### Model
|
||||
|
||||
Name of the Ollama model to use. Talemate will automatically fetch the list of models that are currently available in your local Ollama instance.
|
||||
|
||||
##### API handles prompt template
|
||||
|
||||
If enabled, Talemate will send the raw prompt and let Ollama apply its own built-in prompt template. If you are unsure leave this disabled – Talemate's own prompt template generally produces better results.
|
||||
|
||||
##### Allow thinking
|
||||
|
||||
If enabled Talemate will allow models that support "thinking" (`assistant:thinking` messages) to deliberate before forming the final answer. At the moment Talemate has limited support for this feature when talemate is handling the prompt template. Its probably ok to turn it on if you let Ollama handle the prompt template.
|
||||
|
||||
!!! tip
|
||||
You can quickly refresh the list of models by making sure the Ollama server is running and then hitting **Save** again in the client settings.
|
||||
|
||||
### Common issues
|
||||
|
||||
#### Generations are weird / bad
|
||||
|
||||
If letting talemate handle the prompt template, make sure the [correct prompt template is assigned](/talemate/user-guide/clients/prompt-templates/).
|
||||
|
||||
48
docs/user-guide/clients/types/openrouter.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# OpenRouter Client
|
||||
|
||||
If you want to add an OpenRouter client, change the `Client Type` to `OpenRouter`.
|
||||
|
||||

|
||||
|
||||
Click `Save` to add the client.
|
||||
|
||||
### OpenRouter API Key
|
||||
|
||||
The client should appear in the clients list. If you haven't set up OpenRouter before, you will see a warning that the API key is missing.
|
||||
|
||||

|
||||
|
||||
Click the `SET API KEY` button. This will open the API settings window where you can add your OpenRouter API key.
|
||||
|
||||
For additional instructions on obtaining and setting your OpenRouter API key, see [OpenRouter API instructions](/talemate/user-guide/apis/openrouter/).
|
||||
|
||||

|
||||
|
||||
Click `Save` and after a moment the client should have a red dot next to it, saying that there is no model loaded.
|
||||
|
||||
Click the :material-cogs: icon to open the client settings and select a model.
|
||||
|
||||
.
|
||||
|
||||
Click save and the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||
### Ready to use
|
||||
|
||||

|
||||
|
||||
### Settings
|
||||
|
||||
##### Client Name
|
||||
|
||||
A unique name for the client that makes sense to you.
|
||||
|
||||
##### Model
|
||||
|
||||
Choose any model available via your OpenRouter account. Talemate dynamically fetches the list of models associated with your API key so new models will show up automatically.
|
||||
|
||||
##### Max token length
|
||||
|
||||
Maximum context length (in tokens) that OpenRouter should consider. If you are not sure leave the default value.
|
||||
|
||||
!!! note "Available models are fetched automatically"
|
||||
Talemate fetches the list of available OpenRouter models when you save the configuration (if a valid API key is present). If you add or remove models to your account later, simply click **Save** in the application settings again to refresh the list.
|
||||
|
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 34 KiB |
|
Before Width: | Height: | Size: 4.7 KiB After Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 9.2 KiB After Width: | Height: | Size: 7.4 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 22 KiB |
|
Before Width: | Height: | Size: 2.3 KiB After Width: | Height: | Size: 2.7 KiB |
|
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 8.7 KiB |
|
Before Width: | Height: | Size: 2.8 KiB After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 34 KiB |
@@ -2,17 +2,6 @@
|
||||
|
||||
This tutorial will show you how to use the `Dynamic Storyline` module (added in `0.30`) to randomize the scene introduction for ANY scene.
|
||||
|
||||
!!! note "A more streamlined approach is coming soon"
|
||||
I am aware that some people may not want to touch the node editor at all, so a more streamlined approach is planned.
|
||||
|
||||
For now this will lay out the simplest way to set this up while still using the node editor.
|
||||
|
||||
!!! learn-more "For those interested..."
|
||||
|
||||
There is tutorial on how the `Dynamic Storyline` module was made (or at least the beginnings of it).
|
||||
|
||||
If you are interested in the process, you can find it [here](/talemate/user-guide/howto/infinity-quest-dynamic).
|
||||
|
||||
## Save a foundation scene copy
|
||||
|
||||
This should be a save of your scene that has had NO progress made to it yet. We are generating a new scene introduction after all.
|
||||
@@ -21,59 +10,52 @@ The introduction is only generated once. So you should maintain a save-file of t
|
||||
|
||||
To ensure this foundation scene save isn't overwritten you can go to the scene settings in the world editor and turn on the Locked save file flag:
|
||||
|
||||

|
||||

|
||||
|
||||
Save the scene.
|
||||
|
||||
## Switch to the node editor
|
||||
## Install the module
|
||||
|
||||
In your scene tools find the :material-puzzle-edit: creative menu and click on the **Node Editor** option.
|
||||
Click the `Mods` tab in the world editor.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
Find the `COPY AS EDITABLE MODULE FOR ..` button beneath the node editor.
|
||||
Find the `Dynamic Storyline` module and click **Install**.
|
||||
|
||||

|
||||
It will say installed (not configured)
|
||||
|
||||
Click it.
|
||||

|
||||
|
||||
In the next window, don't even read any of the stuff, just click **Continue**.
|
||||
Click **Configure** and set topic to something like `Sci-fi adventure with lovecraftian horror`.
|
||||
|
||||
## Find a blank area
|
||||

|
||||
|
||||
Use the mousewheel to zoom out a bit, then click the canvas and drag it to the side so you're looking at some blank space. Literally anywhere that's grey background is fine.
|
||||
!!! note "Optional settings"
|
||||
|
||||
Double click the empty area to bring up the module searcand type in "Dynamic Story" into th
|
||||
##### Max intro text length
|
||||
How many tokens to generate for the intro text.
|
||||
|
||||

|
||||
##### Additional instructions for topic analysis task
|
||||
If topic analysis is enabled, this will be used to augment the topic analysis task with further instructions
|
||||
|
||||
Select the `Dynamic Storyline` node to add it to the scene.
|
||||
##### Enable topic analysis
|
||||
This will enable the topic analysis task
|
||||
|
||||

|
||||
**Save** the module configuration.
|
||||
|
||||
Click the `topic` input and type in a general genre or thematic guide for the story.
|
||||
Finally click "Reload Scene" in the left sidebar.
|
||||
|
||||
Some examples
|
||||

|
||||
|
||||
- `sci-fi with cosmic horror elements`
|
||||
- `dungeons and dragons campaign ideas`
|
||||
- `slice of life story ideas`
|
||||
If everything is configured correctly, the storyline generation will begin immediately.
|
||||
|
||||
Whatever you enter will be used to generate a list of story ideas, of which one will be chosen at random to bootstrap a new story, taking the scene context that exists already into account.
|
||||

|
||||
|
||||
This will NOT create new characters or world context.
|
||||
!!! note "Switch out of edit mode"
|
||||
|
||||
It simply bootstraps a story premise based on the random topic and what's already there.
|
||||
|
||||
Once the topic is set, save the changes by clicking the node editor's **Save** button in the upper right corner.
|
||||
|
||||

|
||||
|
||||
Exit the node editor through the same menu as before.
|
||||
|
||||

|
||||
|
||||
Once back in the scene, if everythign was done correctly you should see it working on setting the scene introduction.
|
||||
|
||||

|
||||
If nothing is happening after configuration and reloading the scene, make sure you are not in edit mode.
|
||||
|
||||
You can leave edit mode by clicking the "Exit Node Editor" button in the creative menu.
|
||||
|
||||

|
||||
105
docs/user-guide/node-editor/core-concepts/package.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Installable Packages
|
||||
|
||||
It is possible to "package" your node modules so they can be installed into a scene.
|
||||
|
||||
This allows for easier controlled set up and makes your node module more sharable as users no longer need to use the node editor to install it.
|
||||
|
||||
Installable packages show up in the Mods list once a scene is loaded.
|
||||
|
||||

|
||||
|
||||
## 1. Create a package module
|
||||
|
||||
To create a package - click the **:material-plus: Create Module** button in the node editor and select **Package**.
|
||||
|
||||

|
||||
|
||||
The package itself is a special kind of node module that will let Talemate know that your node module is installable and how to install it.
|
||||
|
||||
## 2. Open the module properties
|
||||
|
||||
With the package module open find the module properties in the upper left corner of the node editor.
|
||||
|
||||

|
||||
|
||||
Fill in the fields:
|
||||
|
||||
##### The name of the node module
|
||||
|
||||
This is what the module package will be called in the Mods list.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Dynamic Storyline
|
||||
```
|
||||
|
||||
##### The author of the node module
|
||||
|
||||
Your name or handle. This is arbitrary and just lets people know who made the package.
|
||||
|
||||
##### The description of the node module
|
||||
|
||||
A short description of the package. This is displayed in the Mods list.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
Generate a random story premise at the beginning of the scene.
|
||||
```
|
||||
|
||||
##### Whether the node module is installable to the scene
|
||||
|
||||
A checkbox to indicate if the package is installable to the scene.
|
||||
|
||||
Right now this should always be checked, there are no other package types currently.
|
||||
|
||||
##### Whether the scene loop should be restarted when the package is installed
|
||||
|
||||
If checked, installing this package will restart the scene loop. This is mostly important for modules that require to hook into the scene loop init event.
|
||||
|
||||
## 3. Install instructions
|
||||
|
||||
Currently there are only two nodes relevant for the node module install process.
|
||||
|
||||
|
||||
1. `Install Node Module` - this node is used to make sure the target node module is added to the scene loop when installing the package. You can have more than one of these nodes in your package.
|
||||
1. `Promote Config` - this node is used to promote your node module's properties to configurable fields in the mods list. E.g., this dictates what the user can configure when installing the package.
|
||||
|
||||
### Install Node Module
|
||||
|
||||

|
||||
|
||||
!!! payload "Install Node Module"
|
||||
|
||||
| Property | Value |
|
||||
|----------|-------|
|
||||
| node_registry | the registry path of the node module to install |
|
||||
|
||||
### Promote Config
|
||||
|
||||

|
||||
|
||||
!!! payload "Promote Config"
|
||||
|
||||
| Property | Value |
|
||||
|----------|-------|
|
||||
| node_registry | the registry path of the node module |
|
||||
| property_name | the name of the property to promote (as it is set in the node module) |
|
||||
| exposed_property_name | expose as this name in the mods list, this can be the same as the property name or a different name - this is important if youre installing multiple node modules with the same property name, so you can differentiate between them |
|
||||
| required | whether the property is required to be set when installing the package |
|
||||
| label | a user friendly label for the property |
|
||||
|
||||
### Make talemate aware of the package
|
||||
|
||||
For talemate to be aware of the package, you need to copy it to the public node module directory, which exists as `templates/modules/`.
|
||||
|
||||
Create a new sub directory:
|
||||
|
||||
```
|
||||
./templates/modules/<your-package-name>/
|
||||
```
|
||||
|
||||
Copy the package module and your node module files into the directory.
|
||||
|
||||
Restart talemate and the package should now show up in the Mods list.
|
||||
BIN
docs/user-guide/node-editor/img/package-0001.png
Normal file
|
After Width: | Height: | Size: 19 KiB |
BIN
docs/user-guide/node-editor/img/package-0002.png
Normal file
|
After Width: | Height: | Size: 24 KiB |
BIN
docs/user-guide/node-editor/img/package-0003.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/user-guide/node-editor/img/package-0004.png
Normal file
|
After Width: | Height: | Size: 6.1 KiB |
BIN
docs/user-guide/node-editor/img/package-0005.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
@@ -1,14 +1,63 @@
|
||||
# History
|
||||
|
||||
Will hold the archived history for the scene.
|
||||
Will hold historical events for the scene.
|
||||
|
||||
This historical archive is extended everytime the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) summarizes the scene.
|
||||
There are three types of historical entries:
|
||||
|
||||
Summarization happens when a certain progress treshold is reached (tokens) or when a time passage event occurs.
|
||||
- **Archived history (static)** - These are entries are manually defined and dated before the starting point of the scene. Things that happened in the past that are **IMPORTANT** for the understanding of the world. For anything that is not **VITAL**, use world entries instead.
|
||||
- **Archived history (from summary)** - These are historical entries generated from the progress in the scene. Whenever a certain token (length) threshold is reached, the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) will generate a summary of the progression and add it to the history.
|
||||
- **Layered history (from summary)** - As summary archives are generated, they themselves will be summarized again and added to the history, leading to a natural compression of the history sent with the context while also keeping track of the most important events. (hopefully)
|
||||
|
||||
All archived history is a potential candidate to be included in the context sent to the AI based on relevancy. This is handled by the [Memory Agent](/talemate/user-guide/agents/memory/).
|
||||

|
||||
|
||||
You can use the **:material-refresh: Regenerate History** button to force a new summarization of the scene.
|
||||
## Layers
|
||||
|
||||
!!! warning
|
||||
If there has been lots of progress this will potentially take a long time to complete.
|
||||
There is always the **BASE** layer, which is where the archived history (both static and from summary) is stored. For all intents and purposes, this is layer 0.
|
||||
|
||||
At the beginning of a scene, there won't be any additional layers, as any layer past layer 0 will come from summarization down the line.
|
||||
|
||||
Note that layered history is managed by the [Summarizer Agent](/talemate/user-guide/agents/summarizer/) and can be disabled in its settings.
|
||||
|
||||
### Managing entries
|
||||
|
||||
- **All entries** can be edited by double-clicking the text.
|
||||
- **Static entries** can be deleted by clicking the **:material-close-box-outline: Delete** button.
|
||||
- **Summarized entries** can be regenerated by clicking the **:material-refresh: Regenerate** button. This will cause the LLM to re-summarize the entry and update the text.
|
||||
- **Summarized entries** can be inspected by clicking the **:material-magnify-expand: Inspect** button. This will expand the entry and show the source entries that were used to generate the summary.
|
||||
|
||||
### Adding static entries
|
||||
|
||||
Static entries can be added by clicking the **:material-plus: Add Entry** button.
|
||||
|
||||
!!! note "Static entries must be older than any summary entries"
|
||||
Static entries must be older than any summary entries. This is to ensure that the history is always chronological.
|
||||
|
||||
Trying to add a static entry that is more recent than any summary entry will result in an error.
|
||||
|
||||
##### Entry Text
|
||||
|
||||
The text of the entry. Should be at most 1 - 2 paragraphs. Less is more. Anything that needs great detail should be a world entry instead.
|
||||
|
||||
##### Unit
|
||||
|
||||
Defines the duration unit of the entry. So minutes, hours, days, weeks, months or years.
|
||||
|
||||
##### Amount
|
||||
|
||||
Defines the duration unit amount of the entry.
|
||||
|
||||
So if you want to define something that happened 10 months ago (from the current moment in the scene), you would set the unit to months and the amount to 10.
|
||||
|
||||

|
||||
|
||||
## Regenerate everything
|
||||
|
||||
It is possible to regenerate the entire history by clicking the **:material-refresh: Regenerate All History** button in the left sidebar. Static entries will remain unchanged.
|
||||
|
||||

|
||||
|
||||
!!! warning "This can take a long time"
|
||||
|
||||
This will go through the entire scene progress and regenerate all summarized entries.
|
||||
|
||||
If you have a lot of progress, be ready to wait for a while.
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
@echo off
|
||||
|
||||
REM uninstall torch and torchaudio
|
||||
python -m pip uninstall torch torchaudio -y
|
||||
REM Check if .venv exists
|
||||
IF NOT EXIST ".venv" (
|
||||
echo [ERROR] .venv directory not found. Please run install.bat first.
|
||||
goto :eof
|
||||
)
|
||||
|
||||
REM install torch and torchaudio
|
||||
python -m pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
REM Check if embedded Python exists
|
||||
IF NOT EXIST "embedded_python\python.exe" (
|
||||
echo [ERROR] embedded_python not found. Please run install.bat first.
|
||||
goto :eof
|
||||
)
|
||||
|
||||
REM uninstall torch and torchaudio using embedded Python's uv
|
||||
embedded_python\python.exe -m uv pip uninstall torch torchaudio --python .venv\Scripts\python.exe
|
||||
|
||||
REM install torch and torchaudio with CUDA support using embedded Python's uv
|
||||
embedded_python\python.exe -m uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128 --python .venv\Scripts\python.exe
|
||||
|
||||
echo.
|
||||
echo CUDA versions of torch and torchaudio installed!
|
||||
echo You may need to restart your application for changes to take effect.
|
||||
@@ -1,10 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# activate the virtual environment
|
||||
source talemate_env/bin/activate
|
||||
|
||||
# uninstall torch and torchaudio
|
||||
python -m pip uninstall torch torchaudio -y
|
||||
uv pip uninstall torch torchaudio
|
||||
|
||||
# install torch and torchaudio
|
||||
python -m pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
# install torch and torchaudio with CUDA support
|
||||
uv pip install torch~=2.7.0 torchaudio~=2.7.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
@@ -1,4 +0,0 @@
|
||||
REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
|
||||
call pip install "TTS>=0.21.1"
|
||||
28579
install-utils/get-pip.py
Normal file
280
install.bat
@@ -1,65 +1,227 @@
|
||||
@echo off
|
||||
|
||||
REM Check for Python version and use a supported version if available
|
||||
SET PYTHON=python
|
||||
python -c "import sys; sys.exit(0 if sys.version_info[:2] in [(3, 10), (3, 11), (3, 12), (3, 13)] else 1)" 2>nul
|
||||
IF NOT ERRORLEVEL 1 (
|
||||
echo Selected Python version: %PYTHON%
|
||||
GOTO EndVersionCheck
|
||||
)
|
||||
REM ===============================
|
||||
REM Talemate project installer
|
||||
REM ===============================
|
||||
REM 1. Detect CPU architecture and pick the best-fitting embedded Python build.
|
||||
REM 2. Download & extract that build into .\embedded_python\
|
||||
REM 3. Bootstrap pip via install-utils\get-pip.py
|
||||
REM 4. Install virtualenv and create .\talemate_env\ using the embedded Python.
|
||||
REM 5. Activate the venv and proceed with Poetry + frontend installation.
|
||||
REM ---------------------------------------------------------------
|
||||
|
||||
SET PYTHON=python
|
||||
FOR /F "tokens=*" %%i IN ('py --list') DO (
|
||||
echo %%i | findstr /C:"-V:3.11 " >nul && SET PYTHON=py -3.11 && GOTO EndPythonCheck
|
||||
echo %%i | findstr /C:"-V:3.10 " >nul && SET PYTHON=py -3.10 && GOTO EndPythonCheck
|
||||
)
|
||||
:EndPythonCheck
|
||||
%PYTHON% -c "import sys; sys.exit(0 if sys.version_info[:2] in [(3, 10), (3, 11), (3, 12), (3, 13)] else 1)" 2>nul
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Unsupported Python version. Please install Python 3.10 or 3.11.
|
||||
exit /b 1
|
||||
)
|
||||
IF "%PYTHON%"=="python" (
|
||||
echo Default Python version is being used: %PYTHON%
|
||||
) ELSE (
|
||||
echo Selected Python version: %PYTHON%
|
||||
)
|
||||
SETLOCAL ENABLEDELAYEDEXPANSION
|
||||
|
||||
:EndVersionCheck
|
||||
REM Define fatal-error handler
|
||||
REM Usage: CALL :die "Message explaining what failed"
|
||||
goto :after_die
|
||||
|
||||
IF ERRORLEVEL 1 (
|
||||
echo Unsupported Python version. Please install Python 3.10 or 3.11.
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
REM create a virtual environment
|
||||
%PYTHON% -m venv talemate_env
|
||||
|
||||
REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
|
||||
REM upgrade pip and setuptools
|
||||
python -m pip install --upgrade pip setuptools
|
||||
|
||||
REM install poetry
|
||||
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
|
||||
|
||||
REM use poetry to install dependencies
|
||||
python -m poetry install
|
||||
|
||||
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
|
||||
IF NOT EXIST config.yaml copy config.example.yaml config.yaml
|
||||
|
||||
REM navigate to the frontend directory
|
||||
echo Installing frontend dependencies...
|
||||
cd talemate_frontend
|
||||
call npm install
|
||||
|
||||
echo Building frontend...
|
||||
call npm run build
|
||||
|
||||
REM return to the root directory
|
||||
cd ..
|
||||
|
||||
echo Installation completed successfully.
|
||||
:die
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo !!! INSTALL FAILED !!!
|
||||
echo %*
|
||||
echo ============================================================
|
||||
pause
|
||||
exit 1
|
||||
|
||||
:after_die
|
||||
|
||||
REM ---------[ Check Prerequisites ]---------
|
||||
ECHO Checking prerequisites...
|
||||
where tar >nul 2>&1 || CALL :die "tar command not found. Please ensure Windows 10 version 1803+ or install tar manually."
|
||||
where curl >nul 2>&1
|
||||
IF %ERRORLEVEL% NEQ 0 (
|
||||
where bitsadmin >nul 2>&1 || CALL :die "Neither curl nor bitsadmin found. Cannot download files."
|
||||
)
|
||||
|
||||
REM ---------[ Remove legacy Poetry venv if present ]---------
|
||||
IF EXIST "talemate_env" (
|
||||
ECHO Detected legacy Poetry virtual environment 'talemate_env'. Removing...
|
||||
RD /S /Q "talemate_env"
|
||||
IF ERRORLEVEL 1 (
|
||||
ECHO [WARNING] Failed to fully remove legacy 'talemate_env' directory. Continuing installation.
|
||||
)
|
||||
)
|
||||
|
||||
REM ---------[ Clean reinstall check ]---------
|
||||
SET "NEED_CLEAN=0"
|
||||
IF EXIST ".venv" SET "NEED_CLEAN=1"
|
||||
IF EXIST "embedded_python" SET "NEED_CLEAN=1"
|
||||
IF EXIST "embedded_node" SET "NEED_CLEAN=1"
|
||||
|
||||
IF "%NEED_CLEAN%"=="1" (
|
||||
ECHO.
|
||||
ECHO Detected existing Talemate environments.
|
||||
REM Prompt user (empty input defaults to Y)
|
||||
SET "ANSWER=Y"
|
||||
SET /P "ANSWER=Perform a clean reinstall of the python and node.js environments? [Y/n] "
|
||||
IF /I "!ANSWER!"=="N" (
|
||||
ECHO Installation aborted by user.
|
||||
GOTO :EOF
|
||||
)
|
||||
ECHO Removing previous installation...
|
||||
IF EXIST ".venv" RD /S /Q ".venv"
|
||||
IF EXIST "embedded_python" RD /S /Q "embedded_python"
|
||||
IF EXIST "embedded_node" RD /S /Q "embedded_node"
|
||||
ECHO Cleanup complete.
|
||||
)
|
||||
|
||||
REM ---------[ Version configuration ]---------
|
||||
SET "PYTHON_VERSION=3.11.9"
|
||||
SET "NODE_VERSION=22.16.0"
|
||||
|
||||
REM ---------[ Detect architecture & choose download URL ]---------
|
||||
REM Prefer PROCESSOR_ARCHITEW6432 when the script is run from a 32-bit shell on 64-bit Windows
|
||||
IF DEFINED PROCESSOR_ARCHITEW6432 (
|
||||
SET "ARCH=%PROCESSOR_ARCHITEW6432%"
|
||||
) ELSE (
|
||||
SET "ARCH=%PROCESSOR_ARCHITECTURE%"
|
||||
)
|
||||
|
||||
REM Map architecture to download URL
|
||||
IF /I "%ARCH%"=="AMD64" (
|
||||
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
|
||||
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x64.zip"
|
||||
) ELSE IF /I "%ARCH%"=="IA64" (
|
||||
REM Itanium systems are rare, but AMD64 build works with WoW64 layer
|
||||
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
|
||||
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x64.zip"
|
||||
) ELSE IF /I "%ARCH%"=="ARM64" (
|
||||
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-arm64.zip"
|
||||
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-arm64.zip"
|
||||
) ELSE (
|
||||
REM Fallback to 64-bit build for x86 / unknown architectures
|
||||
SET "PY_URL=https://www.python.org/ftp/python/%PYTHON_VERSION%/python-%PYTHON_VERSION%-embed-amd64.zip"
|
||||
SET "NODE_URL=https://nodejs.org/dist/v%NODE_VERSION%/node-v%NODE_VERSION%-win-x86.zip"
|
||||
)
|
||||
ECHO Detected architecture: %ARCH%
|
||||
ECHO Downloading embedded Python from: %PY_URL%
|
||||
|
||||
REM ---------[ Download ]---------
|
||||
SET "PY_ZIP=python_embed.zip"
|
||||
|
||||
where curl >nul 2>&1
|
||||
IF %ERRORLEVEL% EQU 0 (
|
||||
ECHO Using curl to download Python...
|
||||
curl -L -# -o "%PY_ZIP%" "%PY_URL%" || CALL :die "Failed to download Python embed package with curl."
|
||||
) ELSE (
|
||||
ECHO curl not found, falling back to bitsadmin...
|
||||
bitsadmin /transfer "DownloadPython" /download /priority normal "%PY_URL%" "%CD%\%PY_ZIP%" || CALL :die "Failed to download Python embed package (curl & bitsadmin unavailable)."
|
||||
)
|
||||
|
||||
REM ---------[ Extract ]---------
|
||||
SET "PY_DIR=embedded_python"
|
||||
IF EXIST "%PY_DIR%" RD /S /Q "%PY_DIR%"
|
||||
mkdir "%PY_DIR%" || CALL :die "Could not create directory %PY_DIR%."
|
||||
|
||||
where tar >nul 2>&1
|
||||
IF %ERRORLEVEL% EQU 0 (
|
||||
ECHO Extracting with tar...
|
||||
tar -xf "%PY_ZIP%" -C "%PY_DIR%" || CALL :die "Failed to extract Python embed package with tar."
|
||||
) ELSE (
|
||||
CALL :die "tar utility not found (required to unpack zip without PowerShell)."
|
||||
)
|
||||
|
||||
DEL /F /Q "%PY_ZIP%"
|
||||
|
||||
SET "PYTHON=%PY_DIR%\python.exe"
|
||||
ECHO Using embedded Python at %PYTHON%
|
||||
|
||||
REM ---------[ Enable site-packages in embedded Python ]---------
|
||||
FOR %%f IN ("%PY_DIR%\python*._pth") DO (
|
||||
ECHO Adding 'import site' to %%~nxf ...
|
||||
echo import site>>"%%~ff"
|
||||
)
|
||||
|
||||
REM ---------[ Ensure pip ]---------
|
||||
ECHO Installing pip...
|
||||
"%PYTHON%" install-utils\get-pip.py || (
|
||||
CALL :die "pip installation failed."
|
||||
)
|
||||
|
||||
REM Upgrade pip to latest
|
||||
"%PYTHON%" -m pip install --no-warn-script-location --upgrade pip || CALL :die "Failed to upgrade pip in embedded Python."
|
||||
|
||||
REM ---------[ Install uv ]---------
|
||||
ECHO Installing uv...
|
||||
"%PYTHON%" -m pip install uv || (
|
||||
CALL :die "uv installation failed."
|
||||
)
|
||||
|
||||
REM ---------[ Create virtual environment with uv ]---------
|
||||
ECHO Creating virtual environment with uv...
|
||||
"%PYTHON%" -m uv venv || (
|
||||
CALL :die "Virtual environment creation failed."
|
||||
)
|
||||
|
||||
REM ---------[ Install dependencies using embedded Python's uv ]---------
|
||||
ECHO Installing backend dependencies with uv...
|
||||
"%PYTHON%" -m uv sync || CALL :die "Failed to install backend dependencies with uv."
|
||||
|
||||
REM Activate the venv for the remainder of the script
|
||||
CALL .venv\Scripts\activate
|
||||
|
||||
REM echo python version
|
||||
python --version
|
||||
|
||||
REM ---------[ Config file ]---------
|
||||
IF NOT EXIST config.yaml COPY config.example.yaml config.yaml
|
||||
|
||||
REM ---------[ Node.js portable runtime ]---------
|
||||
ECHO.
|
||||
ECHO Downloading portable Node.js runtime...
|
||||
|
||||
REM Node download variables already set earlier based on %ARCH%.
|
||||
ECHO Downloading Node.js from: %NODE_URL%
|
||||
|
||||
SET "NODE_ZIP=node_embed.zip"
|
||||
|
||||
where curl >nul 2>&1
|
||||
IF %ERRORLEVEL% EQU 0 (
|
||||
ECHO Using curl to download Node.js...
|
||||
curl -L -# -o "%NODE_ZIP%" "%NODE_URL%" || CALL :die "Failed to download Node.js package with curl."
|
||||
) ELSE (
|
||||
ECHO curl not found, falling back to bitsadmin...
|
||||
bitsadmin /transfer "DownloadNode" /download /priority normal "%NODE_URL%" "%CD%\%NODE_ZIP%" || CALL :die "Failed to download Node.js package (curl & bitsadmin unavailable)."
|
||||
)
|
||||
|
||||
REM ---------[ Extract Node.js ]---------
|
||||
SET "NODE_DIR=embedded_node"
|
||||
IF EXIST "%NODE_DIR%" RD /S /Q "%NODE_DIR%"
|
||||
mkdir "%NODE_DIR%" || CALL :die "Could not create directory %NODE_DIR%."
|
||||
|
||||
where tar >nul 2>&1
|
||||
IF %ERRORLEVEL% EQU 0 (
|
||||
ECHO Extracting Node.js...
|
||||
tar -xf "%NODE_ZIP%" -C "%NODE_DIR%" --strip-components 1 || CALL :die "Failed to extract Node.js package with tar."
|
||||
) ELSE (
|
||||
CALL :die "tar utility not found (required to unpack zip without PowerShell)."
|
||||
)
|
||||
|
||||
DEL /F /Q "%NODE_ZIP%"
|
||||
|
||||
REM Prepend Node.js folder to PATH so npm & node are available
|
||||
SET "PATH=%CD%\%NODE_DIR%;%PATH%"
|
||||
ECHO Using portable Node.js at %CD%\%NODE_DIR%\node.exe
|
||||
ECHO Node.js version:
|
||||
node -v
|
||||
|
||||
REM ---------[ Frontend ]---------
|
||||
ECHO Installing frontend dependencies...
|
||||
CD talemate_frontend
|
||||
CALL npm install || CALL :die "npm install failed."
|
||||
|
||||
ECHO Building frontend...
|
||||
CALL npm run build || CALL :die "Frontend build failed."
|
||||
|
||||
REM Return to repo root
|
||||
CD ..
|
||||
|
||||
ECHO.
|
||||
ECHO ==============================
|
||||
ECHO Installation completed!
|
||||
ECHO ==============================
|
||||
PAUSE
|
||||
|
||||
ENDLOCAL
|
||||
|
||||
16
install.sh
@@ -1,20 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
# create a virtual environment
|
||||
echo "Creating a virtual environment..."
|
||||
python3 -m venv talemate_env
|
||||
# create a virtual environment with uv
|
||||
echo "Creating a virtual environment with uv..."
|
||||
uv venv
|
||||
|
||||
# activate the virtual environment
|
||||
echo "Activating the virtual environment..."
|
||||
source talemate_env/bin/activate
|
||||
source .venv/bin/activate
|
||||
|
||||
# install poetry
|
||||
echo "Installing poetry..."
|
||||
pip install poetry
|
||||
|
||||
# use poetry to install dependencies
|
||||
# install dependencies with uv
|
||||
echo "Installing dependencies..."
|
||||
poetry install
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
# copy config.example.yaml to config.yaml only if config.yaml doesn't exist
|
||||
if [ ! -f config.yaml ]; then
|
||||
|
||||
6554
poetry.lock
generated
144
pyproject.toml
@@ -1,77 +1,82 @@
|
||||
[build-system]
|
||||
requires = ["poetry>=0.12"]
|
||||
build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
[project]
|
||||
name = "talemate"
|
||||
version = "0.30.0"
|
||||
version = "0.31.0"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["VeguAITools"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
authors = [{name = "VeguAITools"}]
|
||||
license = {text = "GNU Affero General Public License v3.0"}
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"astroid>=2.8",
|
||||
"jedi>=0.18",
|
||||
"black",
|
||||
"rope>=0.22",
|
||||
"isort>=5.10",
|
||||
"jinja2>=3.0",
|
||||
"openai>=1",
|
||||
"mistralai>=0.1.8",
|
||||
"cohere>=5.2.2",
|
||||
"anthropic>=0.19.1",
|
||||
"groq>=0.5.0",
|
||||
"requests>=2.26",
|
||||
"colorama>=0.4.6",
|
||||
"Pillow>=9.5",
|
||||
"httpx<1",
|
||||
"piexif>=1.1",
|
||||
"typing-inspect==0.8.0",
|
||||
"typing_extensions>=4.5.0",
|
||||
"uvicorn>=0.23",
|
||||
"blinker>=1.6.2",
|
||||
"pydantic<3",
|
||||
"beautifulsoup4>=4.12.2",
|
||||
"python-dotenv>=1.0.0",
|
||||
"structlog>=23.1.0",
|
||||
# 1.7.11 breaks subprocess stuff ???
|
||||
"runpod==1.7.10",
|
||||
"google-genai>=1.20.0",
|
||||
"nest_asyncio>=1.5.7",
|
||||
"isodate>=0.6.1",
|
||||
"thefuzz>=0.20.0",
|
||||
"tiktoken>=0.5.1",
|
||||
"nltk>=3.8.1",
|
||||
"huggingface-hub>=0.20.2",
|
||||
"RestrictedPython>7.1",
|
||||
"numpy>=2",
|
||||
"aiofiles>=24.1.0",
|
||||
"pyyaml>=6.0",
|
||||
"limits>=5.0",
|
||||
"diff-match-patch>=20241021",
|
||||
"sseclient-py>=1.8.0",
|
||||
"ollama>=0.5.1",
|
||||
# ChromaDB
|
||||
"chromadb>=1.0.12",
|
||||
"InstructorEmbedding @ https://github.com/vegu-ai/instructor-embedding/archive/refs/heads/202506-fixes.zip",
|
||||
"torch>=2.7.0",
|
||||
"torchaudio>=2.7.0",
|
||||
# locked for instructor embeddings
|
||||
#sentence-transformers==2.2.2
|
||||
"sentence_transformers>=2.7.0",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.14"
|
||||
astroid = "^2.8"
|
||||
jedi = "^0.18"
|
||||
black = "*"
|
||||
rope = "^0.22"
|
||||
isort = "^5.10"
|
||||
jinja2 = ">=3.0"
|
||||
openai = ">=1"
|
||||
mistralai = ">=0.1.8"
|
||||
cohere = ">=5.2.2"
|
||||
anthropic = ">=0.19.1"
|
||||
groq = ">=0.5.0"
|
||||
requests = "^2.26"
|
||||
colorama = ">=0.4.6"
|
||||
Pillow = ">=9.5"
|
||||
httpx = "<1"
|
||||
piexif = "^1.1"
|
||||
typing-inspect = "0.8.0"
|
||||
typing_extensions = "^4.5.0"
|
||||
uvicorn = "^0.23"
|
||||
blinker = "^1.6.2"
|
||||
pydantic = "<3"
|
||||
beautifulsoup4 = "^4.12.2"
|
||||
python-dotenv = "^1.0.0"
|
||||
websockets = "^11.0.3"
|
||||
structlog = "^23.1.0"
|
||||
runpod = "^1.2.0"
|
||||
google-cloud-aiplatform = ">=1.50.0"
|
||||
nest_asyncio = "^1.5.7"
|
||||
isodate = ">=0.6.1"
|
||||
thefuzz = ">=0.20.0"
|
||||
tiktoken = ">=0.5.1"
|
||||
nltk = ">=3.8.1"
|
||||
huggingface-hub = ">=0.20.2"
|
||||
RestrictedPython = ">7.1"
|
||||
numpy = "^2"
|
||||
aiofiles = ">=24.1.0"
|
||||
pyyaml = ">=6.0"
|
||||
limits = ">=5.0"
|
||||
diff-match-patch = ">=20241021"
|
||||
sseclient-py = "^1.8.0"
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=6.2",
|
||||
"pytest-asyncio>=0.25.3",
|
||||
"mypy>=0.910",
|
||||
"mkdocs-material>=9.5.27",
|
||||
"mkdocs-awesome-pages-plugin>=2.9.2",
|
||||
"mkdocs-glightbox>=0.4.0",
|
||||
]
|
||||
|
||||
# ChromaDB
|
||||
chromadb = ">=0.4.17,<1"
|
||||
InstructorEmbedding = "^1.0.1"
|
||||
torch = "^2.7.0"
|
||||
torchaudio = "^2.7.0"
|
||||
# locked for instructor embeddings
|
||||
#sentence-transformers="==2.2.2"
|
||||
sentence_transformers=">=2.7.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = ">=6.2"
|
||||
pytest-asyncio = ">=0.25.3"
|
||||
mypy = "^0.910"
|
||||
mkdocs-material = ">=9.5.27"
|
||||
mkdocs-awesome-pages-plugin = ">=2.9.2"
|
||||
mkdocs-glightbox = ">=0.4.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
[project.scripts]
|
||||
talemate = "talemate:cli.main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py38']
|
||||
@@ -87,6 +92,7 @@ exclude = '''
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
| talemate_env
|
||||
)/
|
||||
'''
|
||||
|
||||
@@ -97,4 +103,4 @@ include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
line_length = 88
|
||||
@@ -1,18 +0,0 @@
|
||||
@echo off
|
||||
|
||||
IF EXIST talemate_env rmdir /s /q "talemate_env"
|
||||
|
||||
REM create a virtual environment
|
||||
python -m venv talemate_env
|
||||
|
||||
REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
|
||||
REM install poetry
|
||||
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
|
||||
|
||||
REM use poetry to install dependencies
|
||||
python -m poetry install
|
||||
|
||||
echo Virtual environment re-created.
|
||||
pause
|
||||
65
scenes/infinity-quest-dynamic-story-v2/info/modules.json
Normal file
@@ -0,0 +1,65 @@
|
||||
{
|
||||
"packages": [
|
||||
{
|
||||
"name": "Dynamic Storyline",
|
||||
"author": "Talemate",
|
||||
"description": "Generate a random story premise at the beginning of the scene.",
|
||||
"installable": true,
|
||||
"registry": "package/talemate/DynamicStoryline",
|
||||
"status": "installed",
|
||||
"errors": [],
|
||||
"package_properties": {
|
||||
"topic": {
|
||||
"module": "scene/dynamicStoryline",
|
||||
"name": "topic",
|
||||
"label": "Topic",
|
||||
"description": "The overarching topic - will be used to generate a theme that falls within this category. Example - 'Sci-fi adventure with cosmic horror'.",
|
||||
"type": "str",
|
||||
"default": "",
|
||||
"value": "Sci-fi episodic adventures onboard of a spaceship, with focus on AI, alien contact, ancient creators and cosmic horror.",
|
||||
"required": true,
|
||||
"choices": []
|
||||
},
|
||||
"intro_length": {
|
||||
"module": "scene/dynamicStoryline",
|
||||
"name": "intro_length",
|
||||
"label": "Max. intro text length (tokens)",
|
||||
"description": "Length of the introduction",
|
||||
"type": "int",
|
||||
"default": 512,
|
||||
"value": 512,
|
||||
"required": true,
|
||||
"choices": []
|
||||
},
|
||||
"analysis_instructions": {
|
||||
"module": "scene/dynamicStoryline",
|
||||
"name": "analysis_instructions",
|
||||
"label": "Additional instructions for topic analysis task",
|
||||
"description": "Additional instructions for topic analysis task - if topic analysis is enabled, this will be used to augment the topic analysis task with further instructions.",
|
||||
"type": "text",
|
||||
"default": "",
|
||||
"value": "",
|
||||
"required": false,
|
||||
"choices": []
|
||||
},
|
||||
"analysis_enabled": {
|
||||
"module": "scene/dynamicStoryline",
|
||||
"name": "analysis_enabled",
|
||||
"label": "Enable topic analysis",
|
||||
"description": "Theme analysis",
|
||||
"type": "bool",
|
||||
"default": true,
|
||||
"value": true,
|
||||
"required": false,
|
||||
"choices": []
|
||||
}
|
||||
},
|
||||
"install_nodes": [
|
||||
"scene/dynamicStoryline"
|
||||
],
|
||||
"installed_nodes": [],
|
||||
"restart_scene_loop": true,
|
||||
"configured": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"title": "Scene Loop",
|
||||
"id": "71652a76-5db3-4836-8f00-1085977cd8e8",
|
||||
"id": "af468414-b30d-4f67-b08e-5b7cfd139adc",
|
||||
"properties": {
|
||||
"trigger_game_loop": true
|
||||
},
|
||||
@@ -11,50 +11,10 @@
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "scene/SceneLoop",
|
||||
"nodes": {
|
||||
"ede29db4-700d-4edc-b93b-bf7c79f6a6a5": {
|
||||
"title": "Dynamic Storyline",
|
||||
"id": "ede29db4-700d-4edc-b93b-bf7c79f6a6a5",
|
||||
"properties": {
|
||||
"event_name": "scene_loop_init",
|
||||
"analysis_instructions": "",
|
||||
"reset": false,
|
||||
"topic": "sci-fi with cosmic horror elements",
|
||||
"analysis_enabled": true,
|
||||
"intro_length": 512
|
||||
},
|
||||
"x": 32,
|
||||
"y": -249,
|
||||
"width": 295,
|
||||
"height": 158,
|
||||
"collapsed": false,
|
||||
"inherited": false,
|
||||
"registry": "scene/dynamicStoryline",
|
||||
"base_type": "core/Event"
|
||||
}
|
||||
},
|
||||
"nodes": {},
|
||||
"edges": {},
|
||||
"groups": [
|
||||
{
|
||||
"title": "Randomize Story",
|
||||
"x": 8,
|
||||
"y": -321,
|
||||
"width": 619,
|
||||
"height": 257,
|
||||
"color": "#a1309b",
|
||||
"font_size": 24,
|
||||
"inherited": false
|
||||
}
|
||||
],
|
||||
"comments": [
|
||||
{
|
||||
"text": "Will generate a randomized story line based on the topic given",
|
||||
"x": 352,
|
||||
"y": -269,
|
||||
"width": 215,
|
||||
"inherited": false
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"comments": [],
|
||||
"extends": "src/talemate/game/engine/nodes/modules/scene/scene-loop.json",
|
||||
"sleep": 0.001,
|
||||
"base_type": "scene/SceneLoop",
|
||||
|
||||
@@ -19,6 +19,7 @@ from talemate.agents.context import ActiveAgent, active_agent
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopStartEvent
|
||||
from talemate.context import active_scene
|
||||
import talemate.config as config
|
||||
from talemate.client.context import (
|
||||
ClientContext,
|
||||
set_client_context_attribute,
|
||||
@@ -438,6 +439,29 @@ class Agent(ABC):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
async def save_config(self, app_config: config.Config | None = None):
|
||||
"""
|
||||
Saves the agent config to the config file.
|
||||
|
||||
If no config object is provided, the config is loaded from the config file.
|
||||
"""
|
||||
|
||||
if not app_config:
|
||||
app_config:config.Config = config.load_config(as_model=True)
|
||||
|
||||
app_config.agents[self.agent_type] = config.Agent(
|
||||
name=self.agent_type,
|
||||
client=self.client.name if self.client else None,
|
||||
enabled=self.enabled,
|
||||
actions={action_key: config.AgentAction(
|
||||
enabled=action.enabled,
|
||||
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
|
||||
) for action_key, action in self.actions.items()}
|
||||
)
|
||||
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
|
||||
config.save_config(app_config)
|
||||
|
||||
async def on_game_loop_start(self, event: GameLoopStartEvent):
|
||||
"""
|
||||
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values
|
||||
|
||||
@@ -232,11 +232,21 @@ class ConversationAgent(
|
||||
@property
|
||||
def generation_settings_actor_instructions_offset(self):
|
||||
return self.actions["generation_override"].config["actor_instructions_offset"].value
|
||||
|
||||
|
||||
@property
|
||||
def generation_settings_response_length(self):
|
||||
return self.actions["generation_override"].config["length"].value
|
||||
|
||||
@property
|
||||
def generation_settings_override_enabled(self):
|
||||
return self.actions["generation_override"].enabled
|
||||
|
||||
@property
|
||||
def content_use_writing_style(self) -> bool:
|
||||
return self.actions["content"].config["use_writing_style"].value
|
||||
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
|
||||
@@ -322,6 +332,7 @@ class ConversationAgent(
|
||||
"actor_instructions_offset": self.generation_settings_actor_instructions_offset,
|
||||
"direct_instruction": instruction,
|
||||
"decensor": self.client.decensor_enabled,
|
||||
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import random
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
import dataclasses
|
||||
import traceback
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
@@ -328,7 +329,14 @@ class AssistantMixin:
|
||||
if not content.startswith(generation_context.character + ":"):
|
||||
content = generation_context.character + ": " + content
|
||||
content = util.strip_partial_sentences(content)
|
||||
emission.response = await editor.cleanup_character_message(content, generation_context.character.name)
|
||||
|
||||
character = self.scene.get_character(generation_context.character)
|
||||
|
||||
if not character:
|
||||
log.warning("Character not found", character=generation_context.character)
|
||||
return content
|
||||
|
||||
emission.response = await editor.cleanup_character_message(content, character)
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
|
||||
return emission.response
|
||||
|
||||
@@ -447,6 +455,7 @@ class AssistantMixin:
|
||||
)
|
||||
|
||||
continuing_message = False
|
||||
message = None
|
||||
|
||||
try:
|
||||
message = self.scene.history[-1]
|
||||
@@ -470,6 +479,7 @@ class AssistantMixin:
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
"response_length": response_length,
|
||||
"continuing_message": continuing_message,
|
||||
"message": message,
|
||||
"anchor": anchor,
|
||||
"non_anchor": non_anchor,
|
||||
"prefix": prefix,
|
||||
@@ -675,7 +685,7 @@ class AssistantMixin:
|
||||
|
||||
emit("status", f"Scene forked", status="success")
|
||||
except Exception as e:
|
||||
log.exception("Scene fork failed", exc=e)
|
||||
log.error("Scene fork failed", exc=traceback.format_exc())
|
||||
emit("status", "Scene fork failed", status="error")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import random
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import structlog
|
||||
import traceback
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
@@ -259,7 +260,7 @@ class DirectorAgent(
|
||||
except Exception as e:
|
||||
loading_status.done(message="Character creation failed", status="error")
|
||||
await scene.remove_actor(actor)
|
||||
log.exception("Error persisting character", error=e)
|
||||
log.error("Error persisting character", error=traceback.format_exc())
|
||||
|
||||
async def log_action(self, action: str, action_description: str):
|
||||
message = DirectorMessage(message=action_description, action=action)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pydantic
|
||||
import asyncio
|
||||
import structlog
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from talemate.instance import get_agent
|
||||
@@ -98,7 +99,7 @@ class DirectorWebsocketHandler(Plugin):
|
||||
task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
|
||||
async def handle_task_done(task):
|
||||
if task.exception():
|
||||
log.exception("Error persisting character", error=task.exception())
|
||||
log.error("Error persisting character", error=task.exception())
|
||||
await self.signal_operation_failed("Error persisting character")
|
||||
else:
|
||||
self.websocket_handler.queue_put(
|
||||
|
||||
@@ -54,7 +54,7 @@ class EditorAgent(
|
||||
type="text",
|
||||
label="Formatting",
|
||||
description="The formatting to use for exposition.",
|
||||
value="chat",
|
||||
value="novel",
|
||||
choices=[
|
||||
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
|
||||
{"label": "Novel: \"Speech\" narration", "value": "novel"},
|
||||
|
||||
@@ -27,6 +27,7 @@ from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.agents.narrator import NarratorAgentEmission
|
||||
from talemate.agents.creator.assistant import ContextualGenerateEmission
|
||||
from talemate.agents.summarize import SummarizeEmission
|
||||
from talemate.agents.summarize.layered_history import LayeredHistoryFinalizeEmission
|
||||
from talemate.scene_message import CharacterMessage
|
||||
from talemate.util.dedupe import (
|
||||
dedupe_sentences,
|
||||
@@ -387,13 +388,16 @@ class RevisionMixin:
|
||||
async_signals.get("agent.summarization.summarize.after").connect(
|
||||
self.revision_on_generation
|
||||
)
|
||||
async_signals.get("agent.summarization.layered_history.finalize").connect(
|
||||
self.revision_on_generation
|
||||
)
|
||||
# connect to the super class AFTER so these run first.
|
||||
super().connect(scene)
|
||||
|
||||
|
||||
async def revision_on_generation(
|
||||
self,
|
||||
emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission,
|
||||
emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission | LayeredHistoryFinalizeEmission,
|
||||
):
|
||||
"""
|
||||
Called when a conversation or narrator message is generated
|
||||
@@ -411,7 +415,15 @@ class RevisionMixin:
|
||||
if isinstance(emission, NarratorAgentEmission) and "narrator" not in self.revision_automatic_targets:
|
||||
return
|
||||
|
||||
if isinstance(emission, SummarizeEmission) and "summarization" not in self.revision_automatic_targets:
|
||||
if isinstance(emission, SummarizeEmission):
|
||||
if emission.summarization_type == "dialogue" and "summarization" not in self.revision_automatic_targets:
|
||||
return
|
||||
if emission.summarization_type == "events":
|
||||
# event summarization is very pragmatic and doesn't really benefit
|
||||
# from revision, so we skip it
|
||||
return
|
||||
|
||||
if isinstance(emission, LayeredHistoryFinalizeEmission) and "summarization" not in self.revision_automatic_targets:
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -428,7 +440,7 @@ class RevisionMixin:
|
||||
context_name = getattr(emission, "context_name", None),
|
||||
)
|
||||
|
||||
if isinstance(emission, SummarizeEmission):
|
||||
if isinstance(emission, (SummarizeEmission, LayeredHistoryFinalizeEmission)):
|
||||
info.summarization_history = emission.summarization_history or []
|
||||
|
||||
if isinstance(emission, ContextualGenerateEmission) and info.context_type not in CONTEXTUAL_GENERATION_TYPES:
|
||||
@@ -489,7 +501,8 @@ class RevisionMixin:
|
||||
log.warning("revision_revise: generation cancelled", text=info.text)
|
||||
return info.text
|
||||
except Exception as e:
|
||||
log.exception("revision_revise: error", error=e)
|
||||
import traceback
|
||||
log.error("revision_revise: error", error=traceback.format_exc())
|
||||
return info.text
|
||||
finally:
|
||||
info.loading_status.done()
|
||||
@@ -871,8 +884,14 @@ class RevisionMixin:
|
||||
|
||||
if loading_status:
|
||||
loading_status("Editor - Issues identified, analyzing text...")
|
||||
|
||||
template_vars = {
|
||||
|
||||
emission = RevisionEmission(
|
||||
agent=self,
|
||||
info=info,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
emission.template_vars = {
|
||||
"text": text,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
@@ -880,14 +899,11 @@ class RevisionMixin:
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"repetition": issues.repetition,
|
||||
"bad_prose": issues.bad_prose,
|
||||
"dynamic_instructions": emission.dynamic_instructions,
|
||||
"context_type": info.context_type,
|
||||
"context_name": info.context_name,
|
||||
}
|
||||
|
||||
emission = RevisionEmission(
|
||||
agent=self,
|
||||
template_vars=template_vars,
|
||||
info=info,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
await async_signals.get("agent.editor.revision-revise.before").send(
|
||||
emission
|
||||
@@ -898,18 +914,7 @@ class RevisionMixin:
|
||||
"editor.revision-analysis",
|
||||
self.client,
|
||||
f"edit_768",
|
||||
vars={
|
||||
"text": text,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"response_length": token_count,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"repetition": issues.repetition,
|
||||
"bad_prose": issues.bad_prose,
|
||||
"dynamic_instructions": emission.dynamic_instructions,
|
||||
"context_type": info.context_type,
|
||||
"context_name": info.context_name,
|
||||
},
|
||||
vars=emission.template_vars,
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
@@ -1016,39 +1021,43 @@ class RevisionMixin:
|
||||
|
||||
log.debug("revision_unslop: issues", issues=issues, template=template)
|
||||
|
||||
|
||||
|
||||
emission = RevisionEmission(
|
||||
agent=self,
|
||||
info=info,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
emission.template_vars = {
|
||||
"text": text,
|
||||
"scene_analysis": scene_analysis,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"response_length": response_length,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"repetition": issues.repetition,
|
||||
"bad_prose": issues.bad_prose,
|
||||
"dynamic_instructions": emission.dynamic_instructions,
|
||||
"context_type": info.context_type,
|
||||
"context_name": info.context_name,
|
||||
"summarization_history": info.summarization_history,
|
||||
}
|
||||
|
||||
await async_signals.get("agent.editor.revision-revise.before").send(emission)
|
||||
|
||||
response = await Prompt.request(
|
||||
template,
|
||||
self.client,
|
||||
"edit_768",
|
||||
vars={
|
||||
"text": text,
|
||||
"scene_analysis": scene_analysis,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"response_length": response_length,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"repetition": issues.repetition,
|
||||
"bad_prose": issues.bad_prose,
|
||||
"dynamic_instructions": emission.dynamic_instructions,
|
||||
"context_type": info.context_type,
|
||||
"context_name": info.context_name,
|
||||
"summarization_history": info.summarization_history,
|
||||
},
|
||||
vars=emission.template_vars,
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
# extract <FIX>...</FIX>
|
||||
|
||||
if "<FIX>" not in response:
|
||||
log.error("revision_unslop: no <FIX> found in response", response=response)
|
||||
log.debug("revision_unslop: no <FIX> found in response", response=response)
|
||||
return original_text
|
||||
|
||||
fix = response.split("<FIX>", 1)[1]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import uuid
|
||||
import traceback
|
||||
import numpy as np
|
||||
from typing import Callable
|
||||
|
||||
@@ -12,6 +14,8 @@ from chromadb.config import Settings
|
||||
|
||||
import talemate.events as events
|
||||
import talemate.util as util
|
||||
from talemate.client import ClientBase
|
||||
import talemate.instance as instance
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
@@ -23,6 +27,7 @@ from talemate.config import load_config
|
||||
from talemate.context import scene_is_loading, active_scene
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.agents.memory.context import memory_request, MemoryRequest
|
||||
from talemate.agents.memory.exceptions import (
|
||||
EmbeddingsModelLoadError,
|
||||
@@ -31,19 +36,23 @@ from talemate.agents.memory.exceptions import (
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb.utils import embedding_functions
|
||||
except ImportError:
|
||||
chromadb = None
|
||||
pass
|
||||
|
||||
from talemate.agents.registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.client.base import ClientEmbeddingsStatus
|
||||
|
||||
log = structlog.get_logger("talemate.agents.memory")
|
||||
|
||||
if not chromadb:
|
||||
log.info("ChromaDB not found, disabling Chroma agent")
|
||||
|
||||
|
||||
from talemate.agents.registry import register
|
||||
|
||||
class MemoryDocument(str):
|
||||
def __new__(cls, text, meta, id, raw):
|
||||
inst = super().__new__(cls, text)
|
||||
@@ -105,8 +114,9 @@ class MemoryAgent(Agent):
|
||||
self.memory_tracker = {}
|
||||
self.config = load_config()
|
||||
self._ready_to_add = False
|
||||
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
|
||||
|
||||
self.actions = MemoryAgent.init_actions(presets=self.get_presets)
|
||||
|
||||
@@ -125,8 +135,16 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def get_presets(self):
|
||||
|
||||
def _label(embedding:dict):
|
||||
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
|
||||
if embedding['model']:
|
||||
return f"{prefix}: {embedding['model']}"
|
||||
else:
|
||||
return f"{prefix}"
|
||||
|
||||
return [
|
||||
{"value": k, "label": f"{v['embeddings']}: {v['model']}"} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
|
||||
{"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -150,6 +168,10 @@ class MemoryAgent(Agent):
|
||||
def using_sentence_transformer_embeddings(self):
|
||||
return self.embeddings == "default" or self.embeddings == "sentence-transformer"
|
||||
|
||||
@property
|
||||
def using_client_api_embeddings(self):
|
||||
return self.embeddings == "client-api"
|
||||
|
||||
@property
|
||||
def using_local_embeddings(self):
|
||||
return self.embeddings in [
|
||||
@@ -158,6 +180,11 @@ class MemoryAgent(Agent):
|
||||
"default"
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_client(self):
|
||||
return self.embeddings_config.get("client")
|
||||
|
||||
@property
|
||||
def max_distance(self) -> float:
|
||||
distance = float(self.embeddings_config.get("distance", 1.0))
|
||||
@@ -186,7 +213,10 @@ class MemoryAgent(Agent):
|
||||
"""
|
||||
Returns a unique fingerprint for the current configuration
|
||||
"""
|
||||
return f"{self.embeddings}-{self.model.replace('/','-')}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
|
||||
|
||||
model_name = self.model.replace('/','-') if self.model else "none"
|
||||
|
||||
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
@@ -205,7 +235,11 @@ class MemoryAgent(Agent):
|
||||
@set_processing
|
||||
async def handle_embeddings_change(self):
|
||||
scene = active_scene.get()
|
||||
|
||||
|
||||
# if sentence-transformer and no model-name, set embeddings to default
|
||||
if self.using_sentence_transformer_embeddings and not self.model:
|
||||
self.actions["_config"].config["embeddings"].value = "default"
|
||||
|
||||
if not scene or not scene.get_helper("memory"):
|
||||
return
|
||||
|
||||
@@ -216,21 +250,49 @@ class MemoryAgent(Agent):
|
||||
await scene.save(auto=True)
|
||||
emit("status", "Context database re-imported", status="success")
|
||||
|
||||
def sync_presets(self) -> list[dict]:
|
||||
self.actions["_config"].config["embeddings"].choices = self.get_presets
|
||||
return self.actions["_config"].config["embeddings"].choices
|
||||
|
||||
def on_config_saved(self, event):
|
||||
loop = asyncio.get_running_loop()
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
fingerprint = self.fingerprint
|
||||
|
||||
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
|
||||
|
||||
self.config = load_config()
|
||||
|
||||
new_presets = self.sync_presets()
|
||||
if fingerprint != self.fingerprint:
|
||||
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
|
||||
loop.run_until_complete(self.handle_embeddings_change())
|
||||
|
||||
|
||||
emit_status = False
|
||||
|
||||
if openai_key != self.openai_api_key:
|
||||
emit_status = True
|
||||
|
||||
if old_presets != new_presets:
|
||||
emit_status = True
|
||||
|
||||
if emit_status:
|
||||
loop.run_until_complete(self.emit_status())
|
||||
|
||||
|
||||
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
|
||||
current_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
|
||||
if current_embeddings == event.client.embeddings_identifier:
|
||||
return
|
||||
|
||||
if not self.using_client_api_embeddings or not self.ready:
|
||||
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
|
||||
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
|
||||
await self.emit_status()
|
||||
await self.handle_embeddings_change()
|
||||
await self.save_config()
|
||||
|
||||
@set_processing
|
||||
async def set_db(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -239,7 +301,7 @@ class MemoryAgent(Agent):
|
||||
except EmbeddingsModelLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error("memory agent", error="failed to set db", details=e)
|
||||
log.error("memory agent", error="failed to set db", details=traceback.format_exc())
|
||||
|
||||
if "torchvision::nms does not exist" in str(e):
|
||||
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
|
||||
@@ -379,14 +441,12 @@ class MemoryAgent(Agent):
|
||||
def _get_document(self, id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_archive_add(self, event: events.ArchiveEvent):
|
||||
asyncio.ensure_future(
|
||||
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
|
||||
)
|
||||
async def on_archive_add(self, event: events.ArchiveEvent):
|
||||
await self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
scene.signals["archive_add"].connect(self.on_archive_add)
|
||||
async_signals.get("archive_add").connect(self.on_archive_add)
|
||||
|
||||
async def memory_context(
|
||||
self,
|
||||
@@ -453,29 +513,72 @@ class MemoryAgent(Agent):
|
||||
Get the character memory context for a given character
|
||||
"""
|
||||
|
||||
memory_context = []
|
||||
# First, collect results for each individual query (respecting the
|
||||
# per-query `iterate` limit) so that we have them available before we
|
||||
# start filling the final context. This prevents early queries from
|
||||
# monopolising the token budget.
|
||||
|
||||
per_query_results: list[list[str]] = []
|
||||
|
||||
for query in queries:
|
||||
# Skip empty queries so that we keep indexing consistent for the
|
||||
# round-robin step that follows.
|
||||
if not query:
|
||||
per_query_results.append([])
|
||||
continue
|
||||
|
||||
i = 0
|
||||
for memory in await self.get(formatter(query), limit=limit, **where):
|
||||
if memory in memory_context:
|
||||
continue
|
||||
# Fetch potential memories for this query.
|
||||
raw_results = await self.get(
|
||||
formatter(query), limit=limit, **where
|
||||
)
|
||||
|
||||
# Apply filter and respect the `iterate` limit for this query.
|
||||
accepted: list[str] = []
|
||||
for memory in raw_results:
|
||||
if filter and not filter(memory):
|
||||
continue
|
||||
|
||||
accepted.append(memory)
|
||||
if len(accepted) >= iterate:
|
||||
break
|
||||
|
||||
per_query_results.append(accepted)
|
||||
|
||||
# Now interleave the results in a round-robin fashion so that each
|
||||
# query gets a fair chance to contribute, until we hit the token
|
||||
# budget.
|
||||
|
||||
memory_context: list[str] = []
|
||||
idx = 0
|
||||
while True:
|
||||
added_any = False
|
||||
|
||||
for result_list in per_query_results:
|
||||
if idx >= len(result_list):
|
||||
# No more items remaining for this query at this depth.
|
||||
continue
|
||||
|
||||
memory = result_list[idx]
|
||||
|
||||
# Avoid duplicates in the final context.
|
||||
if memory in memory_context:
|
||||
continue
|
||||
|
||||
memory_context.append(memory)
|
||||
added_any = True
|
||||
|
||||
i += 1
|
||||
if i >= iterate:
|
||||
break
|
||||
|
||||
# Check token budget after each addition.
|
||||
if util.count_tokens(memory_context) >= max_tokens:
|
||||
break
|
||||
if util.count_tokens(memory_context) >= max_tokens:
|
||||
return memory_context
|
||||
|
||||
if not added_any:
|
||||
# We iterated over all query result lists without adding
|
||||
# anything. That means we have exhausted all available
|
||||
# memories.
|
||||
break
|
||||
|
||||
idx += 1
|
||||
|
||||
return memory_context
|
||||
|
||||
@property
|
||||
@@ -587,9 +690,32 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
if getattr(self, "db_client", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def client_api_ready(self) -> bool:
|
||||
if self.using_client_api_embeddings:
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
if not embeddings_client:
|
||||
return False
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
return False
|
||||
|
||||
if not embeddings_client.embeddings_status:
|
||||
return False
|
||||
|
||||
if embeddings_client.current_status not in ["idle", "busy"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.using_client_api_embeddings and not self.client_api_ready:
|
||||
return "error"
|
||||
|
||||
if self.ready:
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
|
||||
@@ -612,12 +738,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
value=self.embeddings,
|
||||
description="The embeddings type.",
|
||||
).model_dump(),
|
||||
"model": AgentDetail(
|
||||
|
||||
}
|
||||
|
||||
if self.model:
|
||||
details["model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.model,
|
||||
description="The embeddings model.",
|
||||
).model_dump(),
|
||||
}
|
||||
).model_dump()
|
||||
|
||||
if self.embeddings_client:
|
||||
details["client"] = AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.embeddings_client,
|
||||
description="The client to use for embeddings.",
|
||||
).model_dump()
|
||||
|
||||
if self.using_local_embeddings:
|
||||
details["device"] = AgentDetail(
|
||||
@@ -634,6 +770,37 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
if self.using_client_api_embeddings:
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
|
||||
if not embeddings_client:
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
"value": f"Client {self.embeddings_client} not found",
|
||||
"description": f"Client {self.embeddings_client} not found",
|
||||
"color": "error",
|
||||
}
|
||||
return details
|
||||
|
||||
client_name = embeddings_client.name
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
error_message = f"{client_name} does not support embeddings"
|
||||
elif embeddings_client.current_status not in ["idle", "busy"]:
|
||||
error_message = f"{client_name} is not ready"
|
||||
elif not embeddings_client.embeddings_status:
|
||||
error_message = f"{client_name} has no embeddings model loaded"
|
||||
else:
|
||||
error_message = None
|
||||
|
||||
if error_message:
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
"value": error_message,
|
||||
"description": error_message,
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
return details
|
||||
|
||||
@@ -686,7 +853,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
self.collection_name = collection_name = self.make_collection_name(self.scene)
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="setting up db", collection_name=collection_name
|
||||
"chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
|
||||
)
|
||||
|
||||
distance_function = self.distance_function
|
||||
@@ -713,6 +880,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef, metadata=collection_metadata
|
||||
)
|
||||
elif self.using_client_api_embeddings:
|
||||
log.info(
|
||||
"chromadb",
|
||||
embeddings="Client API",
|
||||
client=self.embeddings_client,
|
||||
)
|
||||
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
if not embeddings_client:
|
||||
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
|
||||
|
||||
ef = embeddings_client.embeddings_function
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef, metadata=collection_metadata
|
||||
)
|
||||
|
||||
elif self.using_instructor_embeddings:
|
||||
log.info(
|
||||
"chromadb",
|
||||
@@ -722,7 +909,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
)
|
||||
|
||||
ef = embedding_functions.InstructorEmbeddingFunction(
|
||||
model_name=model_name, device=device
|
||||
model_name=model_name, device=device, instruction="Represent the document for retrieval:"
|
||||
)
|
||||
|
||||
log.info("chromadb", status="embedding function ready")
|
||||
@@ -801,6 +988,10 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
)
|
||||
try:
|
||||
self.db_client.delete_collection(collection_name)
|
||||
except chromadb.errors.NotFoundError as exc:
|
||||
log.error(
|
||||
"chromadb agent", error="collection not found", details=exc
|
||||
)
|
||||
except ValueError as exc:
|
||||
log.error(
|
||||
"chromadb agent", error="failed to delete collection", details=exc
|
||||
|
||||
@@ -510,53 +510,6 @@ class NarratorAgent(
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def augment_context(self):
|
||||
"""
|
||||
Takes a context history generated via scene.context_history() and augments it with additional information
|
||||
by asking and answering questions with help from the long term memory.
|
||||
"""
|
||||
memory = self.scene.get_helper("memory").agent
|
||||
|
||||
questions = await Prompt.request(
|
||||
"narrator.context-questions",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("context_questions", questions=questions)
|
||||
|
||||
questions = [q for q in questions.split("\n") if q.strip()]
|
||||
|
||||
memory_context = await memory.multi_query(
|
||||
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
|
||||
)
|
||||
|
||||
answers = await Prompt.request(
|
||||
"narrator.context-answers",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"memory": memory_context,
|
||||
"questions": questions,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("context_answers", answers=answers)
|
||||
|
||||
answers = [a for a in answers.split("\n") if a.strip()]
|
||||
|
||||
# return questions and answers
|
||||
return list(zip(questions, answers))
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', time_narration=True)
|
||||
async def narrate_time_passage(
|
||||
|
||||
@@ -4,8 +4,7 @@ import re
|
||||
import dataclasses
|
||||
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING
|
||||
import talemate.data_objects as data_objects
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
@@ -35,6 +34,8 @@ from talemate.agents.base import (
|
||||
from talemate.agents.registry import register
|
||||
from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
|
||||
from talemate.history import ArchiveEntry
|
||||
|
||||
from .analyze_scene import SceneAnalyzationMixin
|
||||
from .context_investigation import ContextInvestigationMixin
|
||||
from .layered_history import LayeredHistoryMixin
|
||||
@@ -63,6 +64,7 @@ class SummarizeEmission(AgentTemplateEmission):
|
||||
extra_instructions: str | None = None
|
||||
generation_options: GenerationOptions | None = None
|
||||
summarization_history: list[str] | None = None
|
||||
summarization_type: Literal["dialogue", "events"] = "dialogue"
|
||||
|
||||
@register()
|
||||
class SummarizeAgent(
|
||||
@@ -189,6 +191,34 @@ class SummarizeAgent(
|
||||
|
||||
return emission.sub_instruction
|
||||
|
||||
|
||||
# SUMMARIZATION HELPERS
|
||||
|
||||
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
|
||||
|
||||
num_previous = self.archive_include_previous
|
||||
|
||||
# find entry by .id
|
||||
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
|
||||
if entry_index is None:
|
||||
raise ValueError("Entry not found")
|
||||
end = entry_index - 1
|
||||
|
||||
previous_summaries = []
|
||||
|
||||
if entry and num_previous > 0:
|
||||
if self.layered_history_available:
|
||||
previous_summaries = self.compile_layered_history(
|
||||
include_base_layer=True,
|
||||
base_layer_end_id=entry.id
|
||||
)[-num_previous:]
|
||||
else:
|
||||
previous_summaries = [
|
||||
entry.text for entry in self.scene.archived_history[end-num_previous:end]
|
||||
]
|
||||
|
||||
return previous_summaries
|
||||
|
||||
# SUMMARIZE
|
||||
|
||||
@set_processing
|
||||
@@ -352,7 +382,7 @@ class SummarizeAgent(
|
||||
|
||||
# determine the appropariate timestamp for the summarization
|
||||
|
||||
scene.push_archive(data_objects.ArchiveEntry(summarized, start, end, ts=ts))
|
||||
await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
|
||||
|
||||
scene.ts=ts
|
||||
scene.emit_status()
|
||||
@@ -478,7 +508,8 @@ class SummarizeAgent(
|
||||
extra_instructions=extra_instructions,
|
||||
generation_options=generation_options,
|
||||
template_vars=template_vars,
|
||||
summarization_history=extra_context or []
|
||||
summarization_history=extra_context or [],
|
||||
summarization_type="dialogue",
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
|
||||
@@ -562,7 +593,8 @@ class SummarizeAgent(
|
||||
extra_instructions=extra_instructions,
|
||||
generation_options=generation_options,
|
||||
template_vars=template_vars,
|
||||
summarization_history=[extra_context] if extra_context else []
|
||||
summarization_history=[extra_context] if extra_context else [],
|
||||
summarization_type="events",
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import structlog
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
AgentActionConfig
|
||||
AgentActionConfig,
|
||||
AgentEmission,
|
||||
)
|
||||
from talemate.prompts import Prompt
|
||||
import dataclasses
|
||||
import talemate.emit.async_signals
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
from talemate.world_state.templates import GenerationOptions
|
||||
from talemate.emit import emit
|
||||
from talemate.context import handle_generation_cancelled
|
||||
from talemate.history import LayeredArchiveEntry, HistoryEntry, entry_contained
|
||||
import talemate.util as util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,6 +20,24 @@ if TYPE_CHECKING:
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.summarization.layered_history.finalize",
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LayeredHistoryFinalizeEmission(AgentEmission):
|
||||
entry: LayeredArchiveEntry | None = None
|
||||
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
|
||||
|
||||
@property
|
||||
def response(self) -> str | None:
|
||||
return self.entry.text if self.entry else None
|
||||
|
||||
@response.setter
|
||||
def response(self, value: str):
|
||||
if self.entry:
|
||||
self.entry.text = value
|
||||
|
||||
class SummaryLongerThanOriginalError(ValueError):
|
||||
def __init__(self, original_length:int, summarized_length:int):
|
||||
self.original_length = original_length
|
||||
@@ -155,7 +174,102 @@ class LayeredHistoryMixin:
|
||||
await self.summarize_to_layered_history(
|
||||
generation_options=emission.generation_options
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
async def _lh_split_and_summarize_chunks(
|
||||
self,
|
||||
chunks: list[dict],
|
||||
extra_context: str,
|
||||
generation_options: GenerationOptions | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split chunks based on max_process_tokens and summarize each part.
|
||||
Returns a list of summary texts.
|
||||
"""
|
||||
summaries = []
|
||||
current_chunk = chunks.copy()
|
||||
|
||||
while current_chunk:
|
||||
partial_chunk = []
|
||||
max_process_tokens = self.layered_history_max_process_tokens
|
||||
|
||||
# Build partial chunk up to max_process_tokens
|
||||
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
|
||||
partial_chunk.append(current_chunk.pop(0))
|
||||
|
||||
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
|
||||
|
||||
log.debug("_split_and_summarize_chunks",
|
||||
tokens_in_chunk=util.count_tokens(text_to_summarize),
|
||||
max_process_tokens=max_process_tokens)
|
||||
|
||||
summary_text = await self.summarize_events(
|
||||
text_to_summarize,
|
||||
extra_context=extra_context + "\n\n".join(summaries),
|
||||
generation_options=generation_options,
|
||||
response_length=self.layered_history_response_length,
|
||||
analyze_chunks=self.layered_history_analyze_chunks,
|
||||
chunk_size=self.layered_history_chunk_size,
|
||||
)
|
||||
summaries.append(summary_text)
|
||||
|
||||
return summaries
|
||||
|
||||
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
|
||||
"""
|
||||
Validates that the summarized text is not longer than the original.
|
||||
Raises SummaryLongerThanOriginalError if validation fails.
|
||||
"""
|
||||
summarized_length = util.count_tokens(summaries)
|
||||
if summarized_length > original_length:
|
||||
raise SummaryLongerThanOriginalError(original_length, summarized_length)
|
||||
|
||||
log.debug("_validate_summary_length",
|
||||
original_length=original_length,
|
||||
summarized_length=summarized_length)
|
||||
|
||||
def _lh_build_extra_context(self, layer_index: int) -> str:
|
||||
"""
|
||||
Builds extra context from compiled layered history for the given layer.
|
||||
"""
|
||||
return "\n\n".join(self.compile_layered_history(layer_index))
|
||||
|
||||
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
|
||||
"""
|
||||
Extracts timestamps from a chunk of entries.
|
||||
Returns (ts, ts_start, ts_end)
|
||||
"""
|
||||
if not chunk:
|
||||
return "PT1S", "PT1S", "PT1S"
|
||||
|
||||
ts = chunk[0].get('ts', 'PT1S')
|
||||
ts_start = chunk[0].get('ts_start', ts)
|
||||
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
|
||||
|
||||
return ts, ts_start, ts_end
|
||||
|
||||
|
||||
async def _lh_finalize_archive_entry(
|
||||
self,
|
||||
entry: LayeredArchiveEntry,
|
||||
summarization_history: list[str] | None = None,
|
||||
) -> LayeredArchiveEntry:
|
||||
"""
|
||||
Finalizes an archive entry by summarizing it and adding it to the layered history.
|
||||
"""
|
||||
|
||||
emission = LayeredHistoryFinalizeEmission(
|
||||
agent=self,
|
||||
entry=entry,
|
||||
summarization_history=summarization_history,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
|
||||
|
||||
return emission.entry
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
def compile_layered_history(
|
||||
@@ -164,6 +278,7 @@ class LayeredHistoryMixin:
|
||||
as_objects:bool=False,
|
||||
include_base_layer:bool=False,
|
||||
max:int = None,
|
||||
base_layer_end_id: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Starts at the last layer and compiles the layered history into a single
|
||||
@@ -194,6 +309,17 @@ class LayeredHistoryMixin:
|
||||
entry_num = 1
|
||||
|
||||
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
|
||||
|
||||
if base_layer_end_id:
|
||||
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
|
||||
index=0,
|
||||
layer=i+1,
|
||||
**layered_history_entry)
|
||||
)
|
||||
if contained:
|
||||
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
|
||||
break
|
||||
|
||||
text = f"{layered_history_entry['text']}"
|
||||
|
||||
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
|
||||
@@ -212,8 +338,8 @@ class LayeredHistoryMixin:
|
||||
entry_num += 1
|
||||
else:
|
||||
compiled.append(text)
|
||||
|
||||
next_layer_start = layered_history_entry["end"] + 1
|
||||
|
||||
next_layer_start = layered_history_entry["end"] + 1
|
||||
|
||||
if i == 0 and include_base_layer:
|
||||
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested
|
||||
@@ -222,7 +348,10 @@ class LayeredHistoryMixin:
|
||||
|
||||
entry_num = 1
|
||||
|
||||
for ah in self.scene.archived_history[next_layer_start:]:
|
||||
for ah in self.scene.archived_history[next_layer_start or 0:]:
|
||||
|
||||
if base_layer_end_id and ah["id"] == base_layer_end_id:
|
||||
break
|
||||
|
||||
text = f"{ah['text']}"
|
||||
if as_objects:
|
||||
@@ -291,8 +420,6 @@ class LayeredHistoryMixin:
|
||||
return # No base layer summaries to work with
|
||||
|
||||
token_threshold = self.layered_history_threshold
|
||||
method = self.actions["archive"].config["method"].value
|
||||
max_process_tokens = self.layered_history_max_process_tokens
|
||||
max_layers = self.layered_history_max_layers
|
||||
|
||||
if not hasattr(self.scene, 'layered_history'):
|
||||
@@ -329,15 +456,9 @@ class LayeredHistoryMixin:
|
||||
log.debug("summarize_to_layered_history", created_layer=next_layer_index)
|
||||
next_layer = layered_history[next_layer_index]
|
||||
|
||||
ts = current_chunk[0]['ts']
|
||||
ts_start = current_chunk[0]['ts_start'] if 'ts_start' in current_chunk[0] else ts
|
||||
ts_end = current_chunk[-1]['ts_end'] if 'ts_end' in current_chunk[-1] else ts
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
|
||||
|
||||
summaries = []
|
||||
|
||||
extra_context = "\n\n".join(
|
||||
self.compile_layered_history(next_layer_index)
|
||||
)
|
||||
extra_context = self._lh_build_extra_context(next_layer_index)
|
||||
|
||||
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
|
||||
|
||||
@@ -345,44 +466,24 @@ class LayeredHistoryMixin:
|
||||
|
||||
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
|
||||
|
||||
while current_chunk:
|
||||
summaries = await self._lh_split_and_summarize_chunks(
|
||||
current_chunk,
|
||||
extra_context,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
noop = False
|
||||
|
||||
log.debug("summarize_to_layered_history", tokens_in_chunk=util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk)), max_process_tokens=max_process_tokens)
|
||||
|
||||
partial_chunk = []
|
||||
|
||||
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
|
||||
partial_chunk.append(current_chunk.pop(0))
|
||||
|
||||
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
|
||||
# validate summary length
|
||||
self._lh_validate_summary_length(summaries, text_length)
|
||||
|
||||
|
||||
summary_text = await self.summarize_events(
|
||||
text_to_summarize,
|
||||
extra_context=extra_context + "\n\n".join(summaries),
|
||||
generation_options=generation_options,
|
||||
response_length=self.layered_history_response_length,
|
||||
analyze_chunks=self.layered_history_analyze_chunks,
|
||||
chunk_size=self.layered_history_chunk_size,
|
||||
)
|
||||
noop = False
|
||||
summaries.append(summary_text)
|
||||
|
||||
# if summarized text is longer than the original, we will
|
||||
# raise an error
|
||||
if util.count_tokens(summaries) > text_length:
|
||||
raise SummaryLongerThanOriginalError(text_length, util.count_tokens(summaries))
|
||||
|
||||
log.debug("summarize_to_layered_history", original_length=text_length, summarized_length=util.count_tokens(summaries))
|
||||
|
||||
next_layer.append({
|
||||
next_layer.append(LayeredArchiveEntry(**{
|
||||
"start": start_index,
|
||||
"end": i - 1,
|
||||
"end": i,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries)
|
||||
})
|
||||
"text": "\n\n".join(summaries),
|
||||
}).model_dump(exclude_none=True))
|
||||
|
||||
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
|
||||
|
||||
@@ -412,7 +513,7 @@ class LayeredHistoryMixin:
|
||||
last_entry = layered_history[0][-1]
|
||||
end = last_entry["end"]
|
||||
log.debug("summarize_to_layered_history", layer="base", start=end)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end + 1)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
|
||||
else:
|
||||
log.debug("summarize_to_layered_history", layer="base", empty=True)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
|
||||
@@ -445,7 +546,7 @@ class LayeredHistoryMixin:
|
||||
end = next_layer[-1]["end"] if next_layer else 0
|
||||
|
||||
log.debug("summarize_to_layered_history", layer=index, start=end)
|
||||
summarized = await summarize_layer(layered_history[index], index + 1, end + 1 if end else 0)
|
||||
summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
|
||||
|
||||
if summarized:
|
||||
noop = False
|
||||
@@ -466,4 +567,107 @@ class LayeredHistoryMixin:
|
||||
log.info("Generation cancelled, stopping rebuild of historical layered history")
|
||||
emit("status", message="Rebuilding of layered history cancelled", status="info")
|
||||
handle_generation_cancelled(e)
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
async def summarize_entries_to_layered_history(
|
||||
self,
|
||||
entries: list[dict],
|
||||
next_layer_index: int,
|
||||
start_index: int,
|
||||
end_index: int,
|
||||
generation_options: GenerationOptions | None = None,
|
||||
) -> list[LayeredArchiveEntry]:
|
||||
"""
|
||||
Summarizes a list of entries into layered history entries.
|
||||
|
||||
This method is used for regenerating specific history entries by processing
|
||||
their source entries. It chunks the entries based on the token threshold and
|
||||
summarizes each chunk into a LayeredArchiveEntry.
|
||||
|
||||
Args:
|
||||
entries: List of dictionaries containing the text entries to summarize.
|
||||
Each entry should have at least a 'text' field and optionally
|
||||
'ts', 'ts_start', and 'ts_end' fields.
|
||||
next_layer_index: The index of the layer where the summarized entries
|
||||
will be placed.
|
||||
start_index: The starting index in the source layer that these entries
|
||||
correspond to.
|
||||
end_index: The ending index in the source layer that these entries
|
||||
correspond to.
|
||||
generation_options: Optional generation options to pass to the summarization
|
||||
process.
|
||||
|
||||
Returns:
|
||||
List of LayeredArchiveEntry objects containing the summarized text along
|
||||
with timestamp and index information. Currently returns a list with a
|
||||
single entry, but the structure supports multiple entries if needed.
|
||||
|
||||
Notes:
|
||||
- The method respects the layered_history_threshold for chunking
|
||||
- Uses helper methods for timestamp extraction, context building, and
|
||||
chunk summarization
|
||||
- Validates that summaries are not longer than the original text
|
||||
- The last entry is always included in the final chunk if it doesn't
|
||||
exceed the token threshold
|
||||
"""
|
||||
|
||||
token_threshold = self.layered_history_threshold
|
||||
|
||||
archive_entries = []
|
||||
summaries = []
|
||||
current_chunk = []
|
||||
current_tokens = 0
|
||||
|
||||
ts = "PT1S"
|
||||
ts_start = "PT1S"
|
||||
ts_end = "PT1S"
|
||||
|
||||
|
||||
for entry_index, entry in enumerate(entries):
|
||||
is_last_entry = entry_index == len(entries) - 1
|
||||
entry_tokens = util.count_tokens(entry['text'])
|
||||
|
||||
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
|
||||
|
||||
if current_tokens + entry_tokens > token_threshold or is_last_entry:
|
||||
|
||||
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
|
||||
# if we are here because this is the last entry and adding it to
|
||||
# the current chunk would not exceed the token threshold, we will
|
||||
# add it to the current chunk
|
||||
current_chunk.append(entry)
|
||||
|
||||
if current_chunk:
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
|
||||
|
||||
extra_context = self._lh_build_extra_context(next_layer_index)
|
||||
|
||||
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
|
||||
|
||||
summaries = await self._lh_split_and_summarize_chunks(
|
||||
current_chunk,
|
||||
extra_context,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
# validate summary length
|
||||
self._lh_validate_summary_length(summaries, text_length)
|
||||
|
||||
archive_entry = LayeredArchiveEntry(**{
|
||||
"start": start_index,
|
||||
"end": end_index,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
})
|
||||
|
||||
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
|
||||
|
||||
archive_entries.append(archive_entry)
|
||||
|
||||
current_chunk.append(entry)
|
||||
current_tokens += entry_tokens
|
||||
|
||||
return archive_entries
|
||||
|
||||
@@ -23,7 +23,7 @@ from talemate.emit.signals import handlers as signal_handlers
|
||||
from talemate.prompts.base import Prompt
|
||||
|
||||
from .commands import * # noqa
|
||||
from .context import VIS_TYPES, VisualContext, visual_context
|
||||
from .context import VIS_TYPES, VisualContext, VisualContextState, visual_context
|
||||
from .handlers import HANDLERS
|
||||
from .schema import RESOLUTION_MAP, RenderSettings
|
||||
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
|
||||
@@ -40,6 +40,14 @@ BACKENDS = [
|
||||
for mixin_backend, mixin in HANDLERS.items()
|
||||
]
|
||||
|
||||
PROMPT_OUTPUT_FORMAT = """
|
||||
### Positive
|
||||
{positive_prompt}
|
||||
|
||||
### Negative
|
||||
{negative_prompt}
|
||||
"""
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual")
|
||||
|
||||
|
||||
@@ -284,7 +292,7 @@ class VisualBase(Agent):
|
||||
|
||||
try:
|
||||
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
|
||||
except KeyError:
|
||||
except (KeyError, TypeError):
|
||||
backend = self.backend
|
||||
|
||||
backend_changed = backend != self.backend
|
||||
@@ -425,10 +433,9 @@ class VisualBase(Agent):
|
||||
self, format: str = "portrait", prompt: str = None, automatic: bool = False
|
||||
):
|
||||
|
||||
context = visual_context.get()
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
context:VisualContextState = visual_context.get()
|
||||
|
||||
log.debug("visual generate", context=context)
|
||||
|
||||
if automatic and not self.allow_automatic_generation:
|
||||
return
|
||||
@@ -459,7 +466,7 @@ class VisualBase(Agent):
|
||||
|
||||
thematic_style = self.default_style
|
||||
vis_type_styles = self.vis_type_styles(context.vis_type)
|
||||
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
prompt:Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
prompt.keywords.append("character portrait")
|
||||
@@ -481,7 +488,34 @@ class VisualBase(Agent):
|
||||
format = "portrait"
|
||||
|
||||
context.format = format
|
||||
|
||||
|
||||
can_generate_image = self.enabled and self.backend_ready
|
||||
|
||||
if not context.prompt_only and not can_generate_image:
|
||||
emit("status", "Visual agent is not ready for image generation, will output prompt instead.", status="warning")
|
||||
|
||||
# if prompt_only, we don't need to generate an image
|
||||
# instead we emit a system message with the prompt
|
||||
if context.prompt_only or not can_generate_image:
|
||||
emit(
|
||||
"system",
|
||||
message=PROMPT_OUTPUT_FORMAT.format(
|
||||
positive_prompt=prompt.positive_prompt,
|
||||
negative_prompt=prompt.negative_prompt,
|
||||
),
|
||||
meta={
|
||||
"icon": "mdi-image-text",
|
||||
"color": "highlight7",
|
||||
"title": f"Visual Prompt - {context.title}",
|
||||
"display": "tonal",
|
||||
"as_markdown": True,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not can_generate_image:
|
||||
return
|
||||
|
||||
# Call the backend specific generate function
|
||||
|
||||
backend = self.backend
|
||||
@@ -541,8 +575,16 @@ class VisualBase(Agent):
|
||||
|
||||
return response.strip()
|
||||
|
||||
async def generate_environment_background(self, instructions: str = None):
|
||||
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions):
|
||||
async def generate_environment_background(
|
||||
self,
|
||||
instructions: str = None,
|
||||
prompt_only: bool = False,
|
||||
):
|
||||
with VisualContext(
|
||||
vis_type=VIS_TYPES.ENVIRONMENT,
|
||||
instructions=instructions,
|
||||
prompt_only=prompt_only,
|
||||
):
|
||||
await self.generate(format="landscape")
|
||||
|
||||
async def generate_character_portrait(
|
||||
@@ -550,12 +592,14 @@ class VisualBase(Agent):
|
||||
character_name: str,
|
||||
instructions: str = None,
|
||||
replace: bool = False,
|
||||
prompt_only: bool = False,
|
||||
):
|
||||
with VisualContext(
|
||||
vis_type=VIS_TYPES.CHARACTER,
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
replace=replace,
|
||||
prompt_only=prompt_only,
|
||||
):
|
||||
await self.generate(format="portrait")
|
||||
|
||||
|
||||
@@ -29,6 +29,15 @@ class VisualContextState(pydantic.BaseModel):
|
||||
prepared_prompt: Union[str, None] = None
|
||||
format: Union[str, None] = None
|
||||
replace: bool = False
|
||||
prompt_only: bool = False
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
if self.vis_type == VIS_TYPES.ENVIRONMENT:
|
||||
return "Environment"
|
||||
elif self.vis_type == VIS_TYPES.CHARACTER:
|
||||
return f"Character: {self.character_name}"
|
||||
return "Visual Context"
|
||||
|
||||
|
||||
class VisualContext:
|
||||
|
||||
@@ -90,12 +90,16 @@ class VisualWebsocketHandler(Plugin):
|
||||
payload = GeneratePayload(**data)
|
||||
visual = get_agent("visual")
|
||||
await visual.generate_character_portrait(
|
||||
payload.context.character_name, payload.context.instructions, replace=True
|
||||
payload.context.character_name,
|
||||
payload.context.instructions,
|
||||
replace=True,
|
||||
prompt_only=payload.context.prompt_only,
|
||||
)
|
||||
|
||||
async def handle_visualize_environment(self, data: dict):
|
||||
payload = GeneratePayload(**data)
|
||||
visual = get_agent("visual")
|
||||
await visual.generate_environment_background(
|
||||
instructions=payload.context.instructions
|
||||
instructions=payload.context.instructions,
|
||||
prompt_only=payload.context.prompt_only,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from talemate.scene_message import (
|
||||
ReinforcementMessage,
|
||||
TimePassageMessage,
|
||||
)
|
||||
from talemate.util.response import extract_list
|
||||
|
||||
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
|
||||
@@ -76,6 +77,12 @@ class WorldStateAgent(
|
||||
label="Update world state",
|
||||
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
|
||||
config={
|
||||
"initial": AgentActionConfig(
|
||||
type="bool",
|
||||
label="When a new scene is started",
|
||||
description="Whether to update the world state on scene start.",
|
||||
value=True,
|
||||
),
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
@@ -133,10 +140,15 @@ class WorldStateAgent(
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def initial_update(self):
|
||||
return self.actions["update_world_state"].config["initial"].value
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
talemate.emit.async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init_after)
|
||||
|
||||
async def advance_time(self, duration: str, narrative: str = None):
|
||||
"""
|
||||
@@ -162,6 +174,22 @@ class WorldStateAgent(
|
||||
)
|
||||
)
|
||||
|
||||
async def on_scene_loop_init_after(self, emission):
|
||||
"""
|
||||
Called when a scene is initialized
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if not self.initial_update:
|
||||
return
|
||||
|
||||
if self.get_scene_state("inital_update_done"):
|
||||
return
|
||||
|
||||
await self.scene.world_state.request_update()
|
||||
self.set_scene_states(inital_update_done=True)
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
@@ -305,7 +333,7 @@ class WorldStateAgent(
|
||||
},
|
||||
)
|
||||
|
||||
queries = response.split("\n")
|
||||
queries = extract_list(response)
|
||||
|
||||
memory_agent = get_agent("memory")
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ from talemate.client.groq import GroqClient
|
||||
from talemate.client.koboldcpp import KoboldCppClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.ollama import OllamaClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.openrouter import OpenRouterClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.tabbyapi import TabbyAPIClient
|
||||
|
||||
@@ -2,8 +2,14 @@ import pydantic
|
||||
import structlog
|
||||
from anthropic import AsyncAnthropic, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
@@ -28,13 +34,17 @@ SUPPORTED_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "claude-3-5-sonnet-latest"
|
||||
double_coercion: str = None
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register()
|
||||
class AnthropicClient(ClientBase):
|
||||
class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
Anthropic client for generating text.
|
||||
"""
|
||||
@@ -44,6 +54,7 @@ class AnthropicClient(ClientBase):
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Anthropic"
|
||||
@@ -52,15 +63,21 @@ class AnthropicClient(ClientBase):
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="claude-3-5-sonnet-latest", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def anthropic_api_key(self):
|
||||
return self.config.get("anthropic", {}).get("api_key")
|
||||
@@ -103,6 +120,7 @@ class AnthropicClient(ClientBase):
|
||||
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"double_coercion": self.double_coercion,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
@@ -117,7 +135,7 @@ class AnthropicClient(ClientBase):
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.anthropic_api_key:
|
||||
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
@@ -134,7 +152,7 @@ class AnthropicClient(ClientBase):
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
|
||||
self.client = AsyncAnthropic(api_key=self.api_key, base_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
@@ -158,7 +176,11 @@ class AnthropicClient(ClientBase):
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
@@ -175,13 +197,10 @@ class AnthropicClient(ClientBase):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
"""
|
||||
Anthropic handles the prompt template internally, so we just
|
||||
give the prompt as is.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
@@ -189,19 +208,19 @@ class AnthropicClient(ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.anthropic_api_key:
|
||||
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt.strip()}
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
@@ -209,28 +228,39 @@ class AnthropicClient(ClientBase):
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
stream = await self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_message,
|
||||
messages=[human_message],
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
async for event in stream:
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
content = event.delta.text
|
||||
response += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
elif event.type == "message_start":
|
||||
prompt_tokens = event.message.usage.input_tokens
|
||||
|
||||
elif event.type == "message_delta":
|
||||
completion_tokens += event.usage.output_tokens
|
||||
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
log.debug("generated response", response=response.content)
|
||||
|
||||
response = response.content[0].text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
log.debug("generated response", response=response)
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
|
||||
@@ -6,10 +6,12 @@ import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
import asyncio
|
||||
from typing import Callable, Union, Literal
|
||||
|
||||
import pydantic
|
||||
import dataclasses
|
||||
import structlog
|
||||
import urllib3
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
@@ -23,7 +25,10 @@ from talemate.client.model_prompts import model_prompt
|
||||
from talemate.client.ratelimit import CounterRateLimiter
|
||||
from talemate.context import active_scene
|
||||
from talemate.emit import emit
|
||||
from talemate.config import load_config, save_config, EmbeddingFunctionPreset
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.exceptions import SceneInactiveError, GenerationCancelled
|
||||
import talemate.ux.schema as ux_schema
|
||||
|
||||
from talemate.client.system_prompts import SystemPrompts
|
||||
|
||||
@@ -77,13 +82,20 @@ class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class FieldGroup(pydantic.BaseModel):
|
||||
name: str
|
||||
label: str
|
||||
description: str
|
||||
icon: str = "mdi-cog"
|
||||
|
||||
class ExtraField(pydantic.BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
label: str
|
||||
required: bool
|
||||
description: str
|
||||
|
||||
group: FieldGroup | None = None
|
||||
note: ux_schema.Note | None = None
|
||||
|
||||
class ParameterReroute(pydantic.BaseModel):
|
||||
talemate_parameter: str
|
||||
@@ -101,6 +113,56 @@ class ParameterReroute(pydantic.BaseModel):
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class RequestInformation(pydantic.BaseModel):
|
||||
start_time: float = pydantic.Field(default_factory=time.time)
|
||||
end_time: float | None = None
|
||||
tokens: int = 0
|
||||
|
||||
@pydantic.computed_field(description="Duration")
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
end_time = self.end_time or time.time()
|
||||
return end_time - self.start_time
|
||||
|
||||
@pydantic.computed_field(description="Tokens per second")
|
||||
@property
|
||||
def rate(self) -> float:
|
||||
try:
|
||||
end_time = self.end_time or time.time()
|
||||
return self.tokens / (end_time - self.start_time)
|
||||
except:
|
||||
pass
|
||||
return 0
|
||||
|
||||
@pydantic.computed_field(description="Status")
|
||||
@property
|
||||
def status(self) -> str:
|
||||
if self.end_time:
|
||||
return "completed"
|
||||
elif self.start_time:
|
||||
if self.duration > 1 and self.rate == 0:
|
||||
return "stopped"
|
||||
return "in progress"
|
||||
else:
|
||||
return "pending"
|
||||
|
||||
@pydantic.computed_field(description="Age")
|
||||
@property
|
||||
def age(self) -> float:
|
||||
if not self.end_time:
|
||||
return -1
|
||||
return time.time() - self.end_time
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClientEmbeddingsStatus:
|
||||
client: "ClientBase | None" = None
|
||||
embedding_name: str | None = None
|
||||
|
||||
async_signals.register(
|
||||
"client.embeddings_available",
|
||||
)
|
||||
|
||||
class ClientBase:
|
||||
api_url: str
|
||||
model_name: str
|
||||
@@ -120,6 +182,7 @@ class ClientBase:
|
||||
data_format: Literal["yaml", "json"] | None = None
|
||||
rate_limit: int | None = None
|
||||
client_type = "base"
|
||||
request_information: RequestInformation | None = None
|
||||
|
||||
status_request_timeout:int = 2
|
||||
|
||||
@@ -171,6 +234,13 @@ class ClientBase:
|
||||
"""
|
||||
return self.Meta().requires_prompt_template
|
||||
|
||||
@property
|
||||
def can_think(self) -> bool:
|
||||
"""
|
||||
Allow reasoning models to think before responding.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_tokens_param_name(self):
|
||||
return "max_tokens"
|
||||
@@ -182,9 +252,87 @@ class ClientBase:
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def embeddings_function(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def embeddings_status(self) -> bool:
|
||||
return getattr(self, "_embeddings_status", False)
|
||||
|
||||
@property
|
||||
def embeddings_model_name(self) -> str | None:
|
||||
return getattr(self, "_embeddings_model_name", None)
|
||||
|
||||
@property
|
||||
def embeddings_url(self) -> str:
|
||||
return None
|
||||
|
||||
@property
|
||||
def embeddings_identifier(self) -> str:
|
||||
return f"client-api/{self.name}/{self.embeddings_model_name}"
|
||||
|
||||
async def destroy(self, config:dict):
|
||||
"""
|
||||
This is called before the client is removed from talemate.instance.clients
|
||||
|
||||
Use this to perform any cleanup that is necessary.
|
||||
|
||||
If a subclass overrides this method, it should call super().destroy(config) in the
|
||||
end of the method.
|
||||
"""
|
||||
|
||||
if self.supports_embeddings:
|
||||
self.remove_embeddings(config)
|
||||
|
||||
def reset_embeddings(self):
|
||||
self._embeddings_model_name = None
|
||||
self._embeddings_status = False
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||
|
||||
def set_embeddings(self):
|
||||
|
||||
log.debug("setting embeddings", client=self.name, supports_embeddings=self.supports_embeddings, embeddings_status=self.embeddings_status)
|
||||
|
||||
if not self.supports_embeddings or not self.embeddings_status:
|
||||
return
|
||||
|
||||
config = load_config(as_model=True)
|
||||
|
||||
key = self.embeddings_identifier
|
||||
|
||||
if key in config.presets.embeddings:
|
||||
log.debug("embeddings already set", client=self.name, key=key)
|
||||
return config.presets.embeddings[key]
|
||||
|
||||
|
||||
log.debug("setting embeddings", client=self.name, key=key)
|
||||
|
||||
config.presets.embeddings[key] = EmbeddingFunctionPreset(
|
||||
embeddings="client-api",
|
||||
client=self.name,
|
||||
model=self.embeddings_model_name,
|
||||
distance=1,
|
||||
distance_function="cosine",
|
||||
local=False,
|
||||
custom=True,
|
||||
)
|
||||
|
||||
save_config(config)
|
||||
|
||||
def remove_embeddings(self, config:dict | None = None):
|
||||
# remove all embeddings for this client
|
||||
for key, value in list(config["presets"]["embeddings"].items()):
|
||||
if value["client"] == self.name and value["embeddings"] == "client-api":
|
||||
log.warning("!!! removing embeddings", client=self.name, key=key)
|
||||
config["presets"]["embeddings"].pop(key)
|
||||
|
||||
def set_system_prompts(self, system_prompts: dict | SystemPrompts):
|
||||
if isinstance(system_prompts, dict):
|
||||
@@ -222,6 +370,19 @@ class ClientBase:
|
||||
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
|
||||
)
|
||||
|
||||
def split_prompt_for_coercion(self, prompt: str) -> tuple[str, str]:
|
||||
"""
|
||||
Splits the prompt and the prefill/coercion prompt.
|
||||
"""
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
|
||||
if self.double_coercion:
|
||||
right = f"{self.double_coercion}\n\n{right}"
|
||||
|
||||
return prompt, right
|
||||
return prompt, None
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
"""
|
||||
Reconfigures the client.
|
||||
@@ -241,6 +402,8 @@ class ClientBase:
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
if not self.enabled and self.supports_embeddings and self.embeddings_status:
|
||||
self.reset_embeddings()
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
@@ -327,7 +490,7 @@ class ClientBase:
|
||||
"""
|
||||
Sets and emits the client status.
|
||||
"""
|
||||
|
||||
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
@@ -388,6 +551,8 @@ class ClientBase:
|
||||
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
|
||||
data[field_name] = getattr(self, field_name, None)
|
||||
|
||||
data = self.finalize_status(data)
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
@@ -400,13 +565,31 @@ class ClientBase:
|
||||
if status_change:
|
||||
instance.emit_agent_status_by_client(self)
|
||||
|
||||
def finalize_status(self, data: dict):
|
||||
"""
|
||||
Finalizes the status data for the client.
|
||||
"""
|
||||
return data
|
||||
|
||||
def _common_status_data(self):
|
||||
return {
|
||||
common_data = {
|
||||
"can_be_coerced": self.can_be_coerced,
|
||||
"preset_group": self.preset_group or "",
|
||||
"rate_limit": self.rate_limit,
|
||||
"data_format": self.data_format,
|
||||
"manual_model_choices": getattr(self.Meta(), "manual_model_choices", []),
|
||||
"supports_embeddings": self.supports_embeddings,
|
||||
"embeddings_status": self.embeddings_status,
|
||||
"embeddings_model_name": self.embeddings_model_name,
|
||||
"request_information": self.request_information.model_dump() if self.request_information else None,
|
||||
}
|
||||
|
||||
|
||||
extra_fields = getattr(self.Meta(), "extra_fields", {})
|
||||
for field_name in extra_fields.keys():
|
||||
common_data[field_name] = getattr(self, field_name, None)
|
||||
|
||||
return common_data
|
||||
|
||||
def populate_extra_fields(self, data: dict):
|
||||
"""
|
||||
Updates data with the extra fields from the client's Meta
|
||||
@@ -438,6 +621,7 @@ class ClientBase:
|
||||
:return: None
|
||||
"""
|
||||
if self.processing:
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
if not self.enabled:
|
||||
@@ -618,8 +802,29 @@ class ClientBase:
|
||||
at the other side of the client.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def new_request(self):
|
||||
"""
|
||||
Creates a new request information object.
|
||||
"""
|
||||
self.request_information = RequestInformation()
|
||||
|
||||
def end_request(self):
|
||||
"""
|
||||
Ends the request information object.
|
||||
"""
|
||||
self.request_information.end_time = time.time()
|
||||
|
||||
def update_request_tokens(self, tokens: int, replace: bool = False):
|
||||
"""
|
||||
Updates the request information object with the number of tokens received.
|
||||
"""
|
||||
if self.request_information:
|
||||
if replace:
|
||||
self.request_information.tokens = tokens
|
||||
else:
|
||||
self.request_information.tokens += tokens
|
||||
|
||||
async def send_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -690,7 +895,7 @@ class ClientBase:
|
||||
except GenerationCancelled:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception("Error during rate limit check", e=e)
|
||||
log.error("Error during rate limit check", e=traceback.format_exc())
|
||||
|
||||
|
||||
if not active_scene.get():
|
||||
@@ -736,8 +941,12 @@ class ClientBase:
|
||||
)
|
||||
prompt_sent = self.repetition_adjustment(finalized_prompt)
|
||||
|
||||
self.new_request()
|
||||
|
||||
response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
|
||||
|
||||
self.end_request()
|
||||
|
||||
if isinstance(response, GenerationCancelled):
|
||||
# generation was cancelled
|
||||
raise response
|
||||
@@ -786,7 +995,7 @@ class ClientBase:
|
||||
except GenerationCancelled as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.log.exception("send_prompt error", e=e)
|
||||
self.log.error("send_prompt error", e=traceback.format_exc())
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from cohere import AsyncClient
|
||||
from cohere import AsyncClientV2
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig, load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
@@ -26,13 +31,17 @@ SUPPORTED_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "command-r-plus"
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register()
|
||||
class CohereClient(ClientBase):
|
||||
class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
Cohere client for generating text.
|
||||
"""
|
||||
@@ -41,18 +50,21 @@ class CohereClient(ClientBase):
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Cohere"
|
||||
title: str = "Cohere"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="command-r-plus", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -119,8 +131,8 @@ class CohereClient(ClientBase):
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.cohere_api_key:
|
||||
self.client = AsyncClient("sk-1111")
|
||||
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = AsyncClientV2("sk-1111")
|
||||
log.error("No cohere API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
@@ -136,7 +148,7 @@ class CohereClient(ClientBase):
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncClient(self.cohere_api_key)
|
||||
self.client = AsyncClientV2(self.api_key, base_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
@@ -161,6 +173,7 @@ class CohereClient(ClientBase):
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
@@ -168,7 +181,7 @@ class CohereClient(ClientBase):
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response.text)
|
||||
return count_tokens(response)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
@@ -207,7 +220,7 @@ class CohereClient(ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.cohere_api_key:
|
||||
if not self.cohere_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No cohere API key set")
|
||||
|
||||
right = None
|
||||
@@ -227,21 +240,43 @@ class CohereClient(ClientBase):
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_message,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": human_message,
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.client.chat(
|
||||
# Cohere's `chat_stream` returns an **asynchronous generator** that can be
|
||||
# consumed directly with `async for`. It is not an asynchronous context
|
||||
# manager, so attempting to use `async with` raises a `TypeError` as seen
|
||||
# in issue logs. We therefore iterate over the generator directly.
|
||||
|
||||
stream = self.client.chat_stream(
|
||||
model=self.model_name,
|
||||
preamble=system_message,
|
||||
message=human_message,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
async for event in stream:
|
||||
if event and event.type == "content-delta":
|
||||
chunk = event.delta.message.content.text
|
||||
response += chunk
|
||||
# Track token usage incrementally
|
||||
self.update_request_tokens(self.count_tokens(chunk))
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.text)
|
||||
|
||||
response = response.text
|
||||
log.debug("generated response", response=response)
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
|
||||
@@ -187,6 +187,14 @@ class DeepSeekClient(ClientBase):
|
||||
|
||||
return prompt
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
# Count tokens in a response string using the util.count_tokens helper
|
||||
return self.count_tokens(response)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
# Count tokens in a prompt string using the util.count_tokens helper
|
||||
return self.count_tokens(prompt)
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -221,13 +229,30 @@ class DeepSeekClient(ClientBase):
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
# Use streaming so we can update_Request_tokens incrementally
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[system_message, human_message],
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
response = ""
|
||||
|
||||
# Iterate over streamed chunks
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if delta and getattr(delta, "content", None):
|
||||
content_piece = delta.content
|
||||
response += content_piece
|
||||
# Incrementally track token usage
|
||||
self.update_request_tokens(self.count_tokens(content_piece))
|
||||
|
||||
# Save token accounting for whole request
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
|
||||
@@ -3,19 +3,18 @@ import os
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import vertexai
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from vertexai.generative_models import (
|
||||
ChatSession,
|
||||
GenerationConfig,
|
||||
GenerativeModel,
|
||||
ResponseValidationError,
|
||||
SafetySetting,
|
||||
)
|
||||
from google import genai
|
||||
import google.genai.types as genai_types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import RemoteServiceMixin
|
||||
from talemate.client.remote import (
|
||||
RemoteServiceMixin,
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
@@ -31,23 +30,29 @@ log = structlog.get_logger("talemate")
|
||||
SUPPORTED_MODELS = [
|
||||
"gemini-1.0-pro",
|
||||
"gemini-1.5-pro-preview-0409",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-flash-8b",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-pro-preview-06-05",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gemini-1.0-pro"
|
||||
model: str = "gemini-2.0-flash"
|
||||
disable_safety_settings: bool = False
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
"""
|
||||
Google client for generating text.
|
||||
"""
|
||||
@@ -74,19 +79,26 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
description="Disable Google's safety settings for responses generated by the model.",
|
||||
),
|
||||
}
|
||||
extra_fields.update(endpoint_override_extra_fields())
|
||||
|
||||
|
||||
def __init__(self, model="gemini-1.0-pro", **kwargs):
|
||||
def __init__(self, model="gemini-2.0-flash", **kwargs):
|
||||
self.model_name = model
|
||||
self.setup_status = None
|
||||
self.model_instance = None
|
||||
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
|
||||
self.google_credentials_read = False
|
||||
self.google_project_id = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def google_credentials(self):
|
||||
path = self.google_credentials_path
|
||||
@@ -102,16 +114,36 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
@property
|
||||
def google_location(self):
|
||||
return self.config.get("google").get("gcloud_location")
|
||||
|
||||
@property
|
||||
def google_api_key(self):
|
||||
return self.config.get("google").get("api_key")
|
||||
|
||||
@property
|
||||
def vertexai_ready(self) -> bool:
|
||||
return all([
|
||||
self.google_credentials_path,
|
||||
self.google_location,
|
||||
])
|
||||
|
||||
@property
|
||||
def developer_api_ready(self) -> bool:
|
||||
return all([
|
||||
self.google_api_key,
|
||||
])
|
||||
|
||||
@property
|
||||
def using(self) -> str:
|
||||
if self.developer_api_ready:
|
||||
return "API"
|
||||
if self.vertexai_ready:
|
||||
return "VertexAI"
|
||||
return "Unknown"
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
# all google settings must be set
|
||||
return all(
|
||||
[
|
||||
self.google_credentials_path,
|
||||
self.google_location,
|
||||
]
|
||||
)
|
||||
return self.vertexai_ready or self.developer_api_ready or self.endpoint_override_base_url_configured
|
||||
|
||||
@property
|
||||
def safety_settings(self):
|
||||
@@ -119,30 +151,39 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
return None
|
||||
|
||||
safety_settings = [
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
genai_types.SafetySetting(
|
||||
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
threshold="BLOCK_NONE",
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
genai_types.SafetySetting(
|
||||
category="HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
threshold="BLOCK_NONE",
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
genai_types.SafetySetting(
|
||||
category="HARM_CATEGORY_HARASSMENT",
|
||||
threshold="BLOCK_NONE",
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
genai_types.SafetySetting(
|
||||
category="HARM_CATEGORY_HATE_SPEECH",
|
||||
threshold="BLOCK_NONE",
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
genai_types.SafetySetting(
|
||||
category="HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
threshold="BLOCK_NONE",
|
||||
),
|
||||
]
|
||||
|
||||
return safety_settings
|
||||
|
||||
@property
|
||||
def http_options(self) -> genai_types.HttpOptions | None:
|
||||
if not self.endpoint_override_base_url_configured:
|
||||
return None
|
||||
|
||||
return genai_types.HttpOptions(
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
@@ -184,6 +225,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
|
||||
self.current_status = status
|
||||
data = {
|
||||
"double_coercion": self.double_coercion,
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
@@ -191,15 +233,27 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
data.update(self._common_status_data())
|
||||
self.populate_extra_fields(data)
|
||||
|
||||
if self.using == "VertexAI":
|
||||
details = f"{model_name} (VertexAI)"
|
||||
else:
|
||||
details = model_name
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
details=details,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client_base_url(self, base_url: str | None):
|
||||
if getattr(self, "client", None):
|
||||
try:
|
||||
self.client.http_options.base_url = base_url
|
||||
except Exception as e:
|
||||
log.error("Error setting client base URL", error=e, client=self.client_type)
|
||||
|
||||
def set_client(self, max_token_length: int = None, **kwargs):
|
||||
if not self.ready:
|
||||
log.error("Google cloud setup incomplete")
|
||||
@@ -210,7 +264,7 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gemini-1.0-pro"
|
||||
self.model_name = "gemini-2.0-flash"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
@@ -222,17 +276,14 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.setup_status:
|
||||
if self.setup_status is False:
|
||||
project_id = self.google_credentials.get("project_id")
|
||||
self.google_project_id = project_id
|
||||
if self.google_credentials_path:
|
||||
vertexai.init(project=project_id, location=self.google_location)
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.setup_status = True
|
||||
|
||||
self.model_instance = GenerativeModel(model_name=model)
|
||||
if self.vertexai_ready and not self.developer_api_ready:
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self.google_project_id,
|
||||
location=self.google_location,
|
||||
)
|
||||
else:
|
||||
self.client = genai.Client(api_key=self.api_key or None, http_options=self.http_options)
|
||||
|
||||
log.info(
|
||||
"google set client",
|
||||
@@ -241,8 +292,9 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
model=model,
|
||||
)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response.text)
|
||||
def response_tokens(self, response:str):
|
||||
"""Return token count for a response which may be a string or SDK object."""
|
||||
return count_tokens(response)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
@@ -258,6 +310,9 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
@@ -267,27 +322,53 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
if "top_k" in parameters and parameters["top_k"] == 0:
|
||||
del parameters["top_k"]
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
"""
|
||||
Google handles the prompt template internally, so we just
|
||||
give the prompt as is.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.ready:
|
||||
raise Exception("Google cloud setup incomplete")
|
||||
raise Exception("Google setup incomplete")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
contents = [
|
||||
genai_types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
genai_types.Part.from_text(
|
||||
text=human_message,
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
contents.append(
|
||||
genai_types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
genai_types.Part.from_text(
|
||||
text=coercion_prompt,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
base_url=self.base_url,
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
@@ -296,48 +377,53 @@ class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
)
|
||||
|
||||
try:
|
||||
# Use streaming so we can update_Request_tokens incrementally
|
||||
#stream = await chat.send_message_async(
|
||||
# human_message,
|
||||
# safety_settings=self.safety_settings,
|
||||
# generation_config=parameters,
|
||||
# stream=True
|
||||
#)
|
||||
|
||||
chat = self.model_instance.start_chat()
|
||||
|
||||
response = await chat.send_message_async(
|
||||
human_message,
|
||||
safety_settings=self.safety_settings,
|
||||
generation_config=parameters,
|
||||
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=genai_types.GenerateContentConfig(
|
||||
safety_settings=self.safety_settings,
|
||||
http_options=self.http_options,
|
||||
**parameters
|
||||
),
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
async for chunk in stream:
|
||||
# For each streamed chunk, append content and update token counts
|
||||
content_piece = getattr(chunk, "text", None)
|
||||
if not content_piece:
|
||||
# Some SDK versions wrap text under candidates[0].text
|
||||
try:
|
||||
content_piece = chunk.candidates[0].text # type: ignore
|
||||
except Exception:
|
||||
content_piece = None
|
||||
|
||||
if content_piece:
|
||||
response += content_piece
|
||||
# Incrementally update token usage
|
||||
self.update_request_tokens(count_tokens(content_piece))
|
||||
|
||||
# Store total token accounting for prompt/response
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.text
|
||||
|
||||
log.debug("generated response", response=response)
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
|
||||
# except PermissionDeniedError as e:
|
||||
# self.log.error("generate error", e=e)
|
||||
# emit("status", message="google API: Permission Denied", status="error")
|
||||
# return ""
|
||||
except ResourceExhausted as e:
|
||||
except APIError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="google API: Quota Limit reached", status="error")
|
||||
emit("status", message="google API: API Error", status="error")
|
||||
return ""
|
||||
except ResponseValidationError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status",
|
||||
message="google API: Response Validation Error",
|
||||
status="error",
|
||||
)
|
||||
if not self.disable_safety_settings:
|
||||
return "Failed to generate response. Probably due to safety settings, you can turn them off in the client settings."
|
||||
return "Failed to generate response. Please check logs."
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
@@ -2,11 +2,16 @@ import pydantic
|
||||
import structlog
|
||||
from groq import AsyncGroq, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GroqClient",
|
||||
@@ -23,13 +28,13 @@ SUPPORTED_MODELS = [
|
||||
JSON_OBJECT_RESPONSE_MODELS = []
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
class Defaults(EndpointOverride, pydantic.BaseModel):
|
||||
max_token_length: int = 8192
|
||||
model: str = "llama3-70b-8192"
|
||||
|
||||
|
||||
@register()
|
||||
class GroqClient(ClientBase):
|
||||
class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
@@ -47,10 +52,13 @@ class GroqClient(ClientBase):
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="llama3-70b-8192", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
# Apply any endpoint override parameters provided via kwargs before creating client
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -100,21 +108,27 @@ class GroqClient(ClientBase):
|
||||
|
||||
self.current_status = status
|
||||
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
# Include shared/common status data (rate limit, etc.)
|
||||
data.update(self._common_status_data())
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
},
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.groq_api_key:
|
||||
# Determine if we should use the globally configured API key or the override key
|
||||
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
|
||||
# No API key and no endpoint override – cannot initialize client correctly
|
||||
self.client = AsyncGroq(api_key="sk-1111")
|
||||
log.error("No groq.ai API key set")
|
||||
if self.api_key_status:
|
||||
@@ -131,7 +145,8 @@ class GroqClient(ClientBase):
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncGroq(api_key=self.groq_api_key)
|
||||
# Use the override values (if any) when constructing the Groq client
|
||||
self.client = AsyncGroq(api_key=self.api_key, base_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
@@ -155,6 +170,11 @@ class GroqClient(ClientBase):
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
# Allow dynamic reconfiguration of endpoint override parameters
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
# Reconfigure any common parameters (rate limit, data format, etc.)
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
@@ -184,7 +204,7 @@ class GroqClient(ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.groq_api_key:
|
||||
if not self.groq_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No groq.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import random
|
||||
import re
|
||||
import json
|
||||
import sseclient
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
import requests
|
||||
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
|
||||
|
||||
# import urljoin
|
||||
from urllib.parse import urljoin, urlparse
|
||||
@@ -10,12 +14,14 @@ import structlog
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.client.base import (
|
||||
STOPPING_STRINGS,
|
||||
ClientBase,
|
||||
Defaults,
|
||||
ParameterReroute,
|
||||
ClientEmbeddingsStatus
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
import talemate.emit.async_signals as async_signals
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.agents.visual import VisualBase
|
||||
@@ -28,6 +34,37 @@ class KoboldCppClientDefaults(Defaults):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class KoboldEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, api_url: str, model_name: str = None):
|
||||
"""
|
||||
Initialize the embedding function with the KoboldCPP API endpoint.
|
||||
"""
|
||||
self.api_url = api_url
|
||||
self.model_name = model_name
|
||||
|
||||
def __call__(self, texts: Documents) -> Embeddings:
|
||||
"""
|
||||
Embed a list of input texts using the KoboldCPP embeddings endpoint.
|
||||
"""
|
||||
|
||||
log.debug("KoboldCppEmbeddingFunction", api_url=self.api_url, model_name=self.model_name)
|
||||
|
||||
# Prepare the request payload for KoboldCPP. Include model name if required.
|
||||
payload = {"input": texts}
|
||||
if self.model_name is not None:
|
||||
payload["model"] = self.model_name # e.g. the model's name/ID if needed
|
||||
|
||||
# Send POST request to the local KoboldCPP embeddings endpoint
|
||||
response = requests.post(self.api_url, json=payload)
|
||||
response.raise_for_status() # Throw an error if the request failed (e.g., connection issue)
|
||||
|
||||
# Parse the JSON response to extract embedding vectors
|
||||
data = response.json()
|
||||
# The 'data' field contains a list of embeddings (one per input)
|
||||
embedding_results = data.get("data", [])
|
||||
embeddings = [item["embedding"] for item in embedding_results]
|
||||
|
||||
return embeddings
|
||||
@register()
|
||||
class KoboldCppClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
@@ -58,7 +95,7 @@ class KoboldCppClient(ClientBase):
|
||||
kcpp has two apis
|
||||
|
||||
open-ai implementation at /v1
|
||||
their own implenation at /api/v1
|
||||
their own implementation at /api/v1
|
||||
"""
|
||||
return "/api/v1" not in self.api_url
|
||||
|
||||
@@ -77,8 +114,8 @@ class KoboldCppClient(ClientBase):
|
||||
# join /v1/completions
|
||||
return urljoin(self.api_url, "completions")
|
||||
else:
|
||||
# join /api/v1/generate
|
||||
return urljoin(self.api_url, "generate")
|
||||
# join /api/extra/generate/stream
|
||||
return urljoin(self.api_url.replace("v1", "extra"), "generate/stream")
|
||||
|
||||
@property
|
||||
def max_tokens_param_name(self):
|
||||
@@ -132,6 +169,21 @@ class KoboldCppClient(ClientBase):
|
||||
"temperature",
|
||||
]
|
||||
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def embeddings_url(self) -> str:
|
||||
if self.is_openai:
|
||||
return urljoin(self.api_url, "embeddings")
|
||||
else:
|
||||
return urljoin(self.api_url, "api/extra/embeddings")
|
||||
|
||||
@property
|
||||
def embeddings_function(self):
|
||||
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
|
||||
|
||||
def api_endpoint_specified(self, url: str) -> bool:
|
||||
return "/v1" in self.api_url
|
||||
|
||||
@@ -152,14 +204,62 @@ class KoboldCppClient(ClientBase):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
async def get_embeddings_model_name(self):
|
||||
# if self._embeddings_model_name is set, return it
|
||||
if self.embeddings_model_name:
|
||||
return self.embeddings_model_name
|
||||
|
||||
# otherwise, get the model name by doing a request to
|
||||
# the embeddings endpoint with a single character
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.api_url_for_model,
|
||||
response = await client.post(
|
||||
self.embeddings_url,
|
||||
json={"input": ["test"]},
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
self._embeddings_model_name = response_data.get("model")
|
||||
return self._embeddings_model_name
|
||||
|
||||
async def get_embeddings_status(self):
|
||||
url_version = urljoin(self.api_url, "api/extra/version")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url_version, timeout=2)
|
||||
response_data = response.json()
|
||||
self._embeddings_status = response_data.get("embeddings", False)
|
||||
|
||||
if not self.embeddings_status or self.embeddings_model_name:
|
||||
return
|
||||
|
||||
await self.get_embeddings_model_name()
|
||||
|
||||
log.debug("KoboldCpp embeddings are enabled, suggesting embeddings", model_name=self.embeddings_model_name)
|
||||
|
||||
self.set_embeddings()
|
||||
|
||||
await async_signals.get("client.embeddings_available").send(
|
||||
ClientEmbeddingsStatus(
|
||||
client=self,
|
||||
embedding_name=self.embeddings_model_name,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.api_url_for_model,
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
except Exception:
|
||||
self._embeddings_model_name = None
|
||||
raise
|
||||
|
||||
if response.status_code == 404:
|
||||
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
|
||||
@@ -175,6 +275,8 @@ class KoboldCppClient(ClientBase):
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
await self.get_embeddings_status()
|
||||
|
||||
return model_name
|
||||
|
||||
@@ -223,11 +325,48 @@ class KoboldCppClient(ClientBase):
|
||||
url_abort,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
if self.is_openai:
|
||||
return await self._generate_openai(prompt, parameters, kind)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._generate_kcpp_stream, prompt, parameters, kind)
|
||||
|
||||
def _generate_kcpp_stream(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
response = ""
|
||||
parameters["stream"] = True
|
||||
stream_response = requests.post(
|
||||
self.api_url_for_generation,
|
||||
json=parameters,
|
||||
timeout=None,
|
||||
headers=self.request_headers,
|
||||
stream=True,
|
||||
)
|
||||
stream_response.raise_for_status()
|
||||
|
||||
sse = sseclient.SSEClient(stream_response)
|
||||
|
||||
for event in sse.events():
|
||||
payload = json.loads(event.data)
|
||||
chunk = payload['token']
|
||||
response += chunk
|
||||
self.update_request_tokens(self.count_tokens(chunk))
|
||||
|
||||
return response
|
||||
|
||||
async def _generate_openai(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
|
||||
@@ -54,18 +54,55 @@ class LMStudioClient(ClientBase):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
Generates text from the given prompt and parameters using a streaming
|
||||
request so that token usage can be tracked incrementally via
|
||||
`update_request_tokens`.
|
||||
"""
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
# Send the request in streaming mode so we can update token counts
|
||||
stream = await self.client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
response = ""
|
||||
|
||||
# Iterate over streamed chunks and accumulate the response while
|
||||
# incrementally updating the token counter
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
content_piece = chunk.choices[0].text
|
||||
response += content_piece
|
||||
# Track token usage incrementally
|
||||
self.update_request_tokens(self.count_tokens(content_piece))
|
||||
|
||||
# Store overall token accounting once the stream is finished
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
"""Count tokens in a model response string."""
|
||||
return self.count_tokens(response)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
"""Count tokens in a prompt string."""
|
||||
return self.count_tokens(prompt)
|
||||
|
||||
@@ -4,9 +4,14 @@ from typing import Literal
|
||||
from mistralai import Mistral
|
||||
from mistralai.models.sdkerror import SDKError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig, load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
@@ -35,14 +40,15 @@ JSON_OBJECT_RESPONSE_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x22b"
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
@register()
|
||||
class MistralAIClient(ClientBase):
|
||||
class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
@@ -52,6 +58,7 @@ class MistralAIClient(ClientBase):
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "MistralAI"
|
||||
@@ -60,16 +67,18 @@ class MistralAIClient(ClientBase):
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x22b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def mistralai_api_key(self):
|
||||
def mistral_api_key(self):
|
||||
return self.config.get("mistralai", {}).get("api_key")
|
||||
|
||||
@property
|
||||
@@ -85,7 +94,7 @@ class MistralAIClient(ClientBase):
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.mistralai_api_key:
|
||||
if self.mistral_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
@@ -122,7 +131,7 @@ class MistralAIClient(ClientBase):
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.mistralai_api_key:
|
||||
if not self.mistral_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = Mistral(api_key="sk-1111")
|
||||
log.error("No mistral.ai API key set")
|
||||
if self.api_key_status:
|
||||
@@ -139,7 +148,7 @@ class MistralAIClient(ClientBase):
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = Mistral(api_key=self.mistralai_api_key)
|
||||
self.client = Mistral(api_key=self.api_key, server_url=self.base_url)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
@@ -160,7 +169,8 @@ class MistralAIClient(ClientBase):
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
@@ -201,7 +211,7 @@ class MistralAIClient(ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.mistralai_api_key:
|
||||
if not self.mistral_api_key:
|
||||
raise Exception("No mistral.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
@@ -224,22 +234,36 @@ class MistralAIClient(ClientBase):
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
base_url=self.base_url,
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.complete_async(
|
||||
event_stream = await self.client.chat.stream_async(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
response = ""
|
||||
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
response = response.choices[0].message.content
|
||||
async for event in event_stream:
|
||||
if event.data.choices:
|
||||
response += event.data.choices[0].delta.content
|
||||
self.update_request_tokens(self.count_tokens(event.data.choices[0].delta.content))
|
||||
if event.data.usage:
|
||||
completion_tokens += event.data.usage.completion_tokens
|
||||
prompt_tokens += event.data.usage.prompt_tokens
|
||||
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
#response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
|
||||
@@ -129,6 +129,12 @@ class ModelPrompt:
|
||||
|
||||
return prompt
|
||||
|
||||
def clean_model_name(self, model_name: str):
|
||||
"""
|
||||
Clean the model name to be used in the template file name.
|
||||
"""
|
||||
return model_name.replace("/", "__").replace(":", "_")
|
||||
|
||||
def get_template(self, model_name: str):
|
||||
"""
|
||||
Will attempt to load an LLM prompt template - this supports
|
||||
@@ -137,7 +143,7 @@ class ModelPrompt:
|
||||
|
||||
matches = []
|
||||
|
||||
cleaned_model_name = model_name.replace("/", "__")
|
||||
cleaned_model_name = self.clean_model_name(model_name)
|
||||
|
||||
# Iterate over all templates in the loader's directory
|
||||
for template_name in self.env.list_templates():
|
||||
@@ -166,7 +172,7 @@ class ModelPrompt:
|
||||
|
||||
template_name = template_name.split(".jinja2")[0]
|
||||
|
||||
cleaned_model_name = model_name.replace("/", "__")
|
||||
cleaned_model_name = self.clean_model_name(model_name)
|
||||
|
||||
shutil.copyfile(
|
||||
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),
|
||||
|
||||
313
src/talemate/client/ollama.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import asyncio
|
||||
import structlog
|
||||
import httpx
|
||||
import ollama
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, CommonDefaults, ErrorAction, ParameterReroute, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.ollama")
|
||||
|
||||
|
||||
FETCH_MODELS_INTERVAL = 15
|
||||
|
||||
class OllamaClientDefaults(CommonDefaults):
|
||||
api_url: str = "http://localhost:11434" # Default Ollama URL
|
||||
model: str = "" # Allow empty default, will fetch from Ollama
|
||||
api_handles_prompt_template: bool = False
|
||||
allow_thinking: bool = False
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
allow_thinking: bool = False
|
||||
|
||||
@register()
|
||||
class OllamaClient(ClientBase):
|
||||
"""
|
||||
Ollama client for generating text using locally hosted models.
|
||||
"""
|
||||
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "ollama"
|
||||
conversation_retries = 0
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Ollama"
|
||||
title: str = "Ollama"
|
||||
enable_api_auth: bool = False
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = [] # Will be overridden by finalize_status
|
||||
defaults: OllamaClientDefaults = OllamaClientDefaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"api_handles_prompt_template": ExtraField(
|
||||
name="api_handles_prompt_template",
|
||||
type="bool",
|
||||
label="API handles prompt template",
|
||||
required=False,
|
||||
description="Let Ollama handle the prompt template. Only do this if you don't know which prompt template to use. Letting talemate handle the prompt template will generally lead to improved responses.",
|
||||
),
|
||||
"allow_thinking": ExtraField(
|
||||
name="allow_thinking",
|
||||
type="bool",
|
||||
label="Allow thinking",
|
||||
required=False,
|
||||
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
# Parameters supported by Ollama's generate endpoint
|
||||
# Based on the API documentation
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
ParameterReroute(
|
||||
talemate_parameter="repetition_penalty",
|
||||
client_parameter="repeat_penalty"
|
||||
),
|
||||
ParameterReroute(
|
||||
talemate_parameter="max_tokens",
|
||||
client_parameter="num_predict"
|
||||
),
|
||||
"stopping_strings",
|
||||
# internal parameters that will be removed before sending
|
||||
"extra_stopping_strings",
|
||||
]
|
||||
|
||||
@property
|
||||
def can_be_coerced(self):
|
||||
"""
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
return not self.api_handles_prompt_template
|
||||
|
||||
@property
|
||||
def can_think(self) -> bool:
|
||||
"""
|
||||
Allow reasoning models to think before responding.
|
||||
"""
|
||||
return self.allow_thinking
|
||||
|
||||
def __init__(self, model=None, api_handles_prompt_template=False, allow_thinking=False, **kwargs):
|
||||
self.model_name = model
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
self.allow_thinking = allow_thinking
|
||||
self._available_models = []
|
||||
self._models_last_fetched = 0
|
||||
self.client = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
"""
|
||||
Initialize the Ollama client with the API URL.
|
||||
"""
|
||||
# Update model if provided
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
|
||||
# Create async client with the configured API URL
|
||||
# Ollama's AsyncClient expects just the base URL without any path
|
||||
self.client = ollama.AsyncClient(host=self.api_url)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
self.allow_thinking = kwargs.get(
|
||||
"allow_thinking", self.allow_thinking
|
||||
)
|
||||
|
||||
async def status(self):
|
||||
"""
|
||||
Send a request to the API to retrieve the loaded AI model name.
|
||||
Raises an error if no model name is returned.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if self.processing:
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
if not self.enabled:
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
try:
|
||||
# instead of using the client (which apparently cannot set a timeout per endpoint)
|
||||
# we use httpx to check {api_url}/api/version to see if the server is running
|
||||
# use a timeout of 2 seconds
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{self.api_url}/api/version", timeout=2)
|
||||
response.raise_for_status()
|
||||
|
||||
# if the server is running, fetch the available models
|
||||
await self.fetch_available_models()
|
||||
except Exception as e:
|
||||
log.error("Failed to fetch models from Ollama", error=str(e))
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
await super().status()
|
||||
|
||||
async def fetch_available_models(self):
|
||||
"""
|
||||
Fetch list of available models from Ollama.
|
||||
"""
|
||||
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
|
||||
return self._available_models
|
||||
|
||||
response = await self.client.list()
|
||||
models = response.get("models", [])
|
||||
model_names = [model.model for model in models]
|
||||
self._available_models = sorted(model_names)
|
||||
self._models_last_fetched = time.time()
|
||||
return self._available_models
|
||||
|
||||
def finalize_status(self, data: dict):
|
||||
"""
|
||||
Finalizes the status data for the client.
|
||||
"""
|
||||
data["manual_model_choices"] = self._available_models
|
||||
return data
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
"""
|
||||
Tune parameters for Ollama's generate endpoint.
|
||||
"""
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
# Build stopping strings list
|
||||
parameters["stop"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
|
||||
|
||||
# Ollama uses num_predict instead of max_tokens
|
||||
if "max_tokens" in parameters:
|
||||
parameters["num_predict"] = parameters["max_tokens"]
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
"""
|
||||
Clean and prepare parameters for Ollama API.
|
||||
"""
|
||||
# First let parent class handle parameter reroutes and cleanup
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
# Remove our internal parameters
|
||||
if "extra_stopping_strings" in parameters:
|
||||
del parameters["extra_stopping_strings"]
|
||||
if "stopping_strings" in parameters:
|
||||
del parameters["stopping_strings"]
|
||||
if "stream" in parameters:
|
||||
del parameters["stream"]
|
||||
|
||||
# Remove max_tokens as we've already converted it to num_predict
|
||||
if "max_tokens" in parameters:
|
||||
del parameters["max_tokens"]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generate text using Ollama's generate endpoint.
|
||||
"""
|
||||
if not self.model_name:
|
||||
# Try to get a model name
|
||||
await self.get_model_name()
|
||||
if not self.model_name:
|
||||
raise Exception("No model specified or available in Ollama")
|
||||
|
||||
# Prepare options for Ollama
|
||||
options = parameters
|
||||
|
||||
options["num_ctx"] = self.max_token_length
|
||||
|
||||
try:
|
||||
# Use generate endpoint for completion
|
||||
stream = await self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt.strip(),
|
||||
options=options,
|
||||
raw=self.can_be_coerced,
|
||||
think=self.can_think,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
async for part in stream:
|
||||
content = part.response
|
||||
response += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
# Extract the response text
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
log.error("Ollama generation error", error=str(e), model=self.model_name)
|
||||
raise ErrorAction(
|
||||
message=f"Ollama generation failed: {str(e)}",
|
||||
title="Generation Error"
|
||||
)
|
||||
|
||||
async def abort_generation(self):
|
||||
"""
|
||||
Ollama doesn't have a direct abort endpoint, but we can try to stop the model.
|
||||
"""
|
||||
# This is a no-op for now as Ollama doesn't expose an abort endpoint
|
||||
# in the Python client
|
||||
pass
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
Adjusts temperature and repetition_penalty by random values.
|
||||
"""
|
||||
import random
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config.get("repetition_penalty", 1.0)
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
prompt_config["repetition_penalty"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
"""
|
||||
Reconfigure the client with new settings.
|
||||
"""
|
||||
# Handle model update
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
# Re-initialize client if API URL changed or model changed
|
||||
if "api_url" in kwargs or "model" in kwargs:
|
||||
self.set_client(**kwargs)
|
||||
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
@@ -5,9 +5,14 @@ import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
EndpointOverrideMixin,
|
||||
endpoint_override_extra_fields,
|
||||
)
|
||||
from talemate.config import Client as BaseClientConfig, load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
@@ -79,9 +84,6 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model or "o1" in model or "o3" in model:
|
||||
print(
|
||||
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
|
||||
)
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -102,13 +104,15 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
|
||||
return num_tokens
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gpt-4o"
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
@register()
|
||||
class OpenAIClient(ClientBase):
|
||||
class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
@@ -118,7 +122,8 @@ class OpenAIClient(ClientBase):
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "OpenAI"
|
||||
title: str = "OpenAI"
|
||||
@@ -126,10 +131,11 @@ class OpenAIClient(ClientBase):
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
def __init__(self, model="gpt-4o", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -192,7 +198,7 @@ class OpenAIClient(ClientBase):
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.openai_api_key:
|
||||
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111")
|
||||
log.error("No OpenAI API key set")
|
||||
if self.api_key_status:
|
||||
@@ -209,7 +215,7 @@ class OpenAIClient(ClientBase):
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
if model == "gpt-3.5-turbo":
|
||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||
elif model == "gpt-4":
|
||||
@@ -247,6 +253,7 @@ class OpenAIClient(ClientBase):
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
@@ -278,7 +285,7 @@ class OpenAIClient(ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.openai_api_key:
|
||||
if not self.openai_api_key and not self.endpoint_override_base_url_configured:
|
||||
raise Exception("No OpenAI API key set")
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
@@ -333,13 +340,28 @@ class OpenAIClient(ClientBase):
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
response = response.choices[0].message.content
|
||||
# Iterate over streamed chunks
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if delta and getattr(delta, "content", None):
|
||||
content_piece = delta.content
|
||||
response += content_piece
|
||||
# Incrementally track token usage
|
||||
self.update_request_tokens(self.count_tokens(content_piece))
|
||||
|
||||
#self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
#self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
|
||||
329
src/talemate/client/openrouter.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
import httpx
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"OpenRouterClient",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.client.openrouter")
|
||||
|
||||
# Available models will be populated when first client with API key is initialized
|
||||
AVAILABLE_MODELS = []
|
||||
DEFAULT_MODEL = ""
|
||||
MODELS_FETCHED = False
|
||||
|
||||
async def fetch_available_models(api_key: str = None):
|
||||
"""Fetch available models from OpenRouter API"""
|
||||
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
|
||||
|
||||
if not api_key:
|
||||
return []
|
||||
|
||||
if MODELS_FETCHED:
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
# Only fetch if we haven't already or if explicitly requested
|
||||
if AVAILABLE_MODELS and not api_key:
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
timeout=10.0
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
for model in data.get("data", []):
|
||||
model_id = model.get("id")
|
||||
if model_id:
|
||||
models.append(model_id)
|
||||
AVAILABLE_MODELS = sorted(models)
|
||||
log.debug(f"Fetched {len(AVAILABLE_MODELS)} models from OpenRouter")
|
||||
else:
|
||||
log.warning(f"Failed to fetch models from OpenRouter: {response.status_code}")
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from OpenRouter: {e}")
|
||||
|
||||
MODELS_FETCHED = True
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
|
||||
|
||||
def fetch_models_sync(event):
|
||||
api_key = event.data.get("openrouter", {}).get("api_key")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(fetch_available_models(api_key))
|
||||
|
||||
handlers["config_saved"].connect(fetch_models_sync)
|
||||
handlers["talemate_started"].connect(fetch_models_sync)
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = DEFAULT_MODEL
|
||||
|
||||
|
||||
@register()
|
||||
class OpenRouterClient(ClientBase):
|
||||
"""
|
||||
OpenRouter client for generating text using various models.
|
||||
"""
|
||||
|
||||
client_type = "openrouter"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "OpenRouter"
|
||||
title: str = "OpenRouter"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = pydantic.Field(default_factory=lambda: AVAILABLE_MODELS)
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
self.model_name = model or DEFAULT_MODEL
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
self._models_fetched = False
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def can_be_coerced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def openrouter_api_key(self):
|
||||
return self.config.get("openrouter", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"repetition_penalty",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.openrouter_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"openrouter_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status if self.enabled else "disabled",
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
# Unlike other clients, we don't need to set up a client instance
|
||||
# We'll use httpx directly in the generate method
|
||||
|
||||
if not self.openrouter_api_key:
|
||||
log.error("No OpenRouter API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = DEFAULT_MODEL
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
# Set max token length (default to 16k if not specified)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"openrouter set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
async def status(self):
|
||||
# Fetch models if we have an API key and haven't fetched yet
|
||||
if self.openrouter_api_key and not self._models_fetched:
|
||||
self._models_fetched = True
|
||||
# Update the Meta class with new model choices
|
||||
self.Meta.manual_model_choices = AVAILABLE_MODELS
|
||||
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
"""
|
||||
Open-router handles the prompt template internally, so we just
|
||||
give the prompt as is.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters using OpenRouter API.
|
||||
"""
|
||||
|
||||
if not self.openrouter_api_key:
|
||||
raise Exception("No OpenRouter API key set")
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
|
||||
# Prepare messages for chat completion
|
||||
messages = [
|
||||
{"role": "system", "content": self.get_system_message(kind)},
|
||||
{"role": "user", "content": prompt.strip()}
|
||||
]
|
||||
|
||||
if coercion_prompt:
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
**parameters
|
||||
}
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
response_text = ""
|
||||
buffer = ""
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.openrouter_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
timeout=120.0 # 2 minute timeout for generation
|
||||
) as response:
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
|
||||
while True:
|
||||
# Find the next complete SSE line
|
||||
line_end = buffer.find('\n')
|
||||
if line_end == -1:
|
||||
break
|
||||
|
||||
line = buffer[:line_end].strip()
|
||||
buffer = buffer[line_end + 1:]
|
||||
|
||||
if line.startswith('data: '):
|
||||
data = line[6:]
|
||||
if data == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
data_obj = json.loads(data)
|
||||
content = data_obj["choices"][0]["delta"].get("content")
|
||||
usage = data_obj.get("usage", {})
|
||||
completion_tokens += usage.get("completion_tokens", 0)
|
||||
prompt_tokens += usage.get("prompt_tokens", 0)
|
||||
if content:
|
||||
response_text += content
|
||||
# Update tokens as content streams in
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Extract the response content
|
||||
response_content = response_text
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
return response_content
|
||||
|
||||
except httpx.ConnectTimeout:
|
||||
self.log.error("OpenRouter API timeout")
|
||||
emit("status", message="OpenRouter API: Request timed out", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message=f"OpenRouter API Error: {str(e)}", status="error")
|
||||
raise
|
||||
@@ -1,5 +1,109 @@
|
||||
__all__ = ["RemoteServiceMixin"]
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from .base import FieldGroup, ExtraField
|
||||
|
||||
import talemate.ux.schema as ux_schema
|
||||
|
||||
|
||||
log = structlog.get_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"RemoteServiceMixin",
|
||||
"EndpointOverrideMixin",
|
||||
"endpoint_override_extra_fields",
|
||||
"EndpointOverrideGroup",
|
||||
"EndpointOverrideField",
|
||||
"EndpointOverrideBaseURLField",
|
||||
"EndpointOverrideAPIKeyField",
|
||||
]
|
||||
|
||||
def endpoint_override_extra_fields():
|
||||
return {
|
||||
"override_base_url": EndpointOverrideBaseURLField(),
|
||||
"override_api_key": EndpointOverrideAPIKeyField(),
|
||||
}
|
||||
|
||||
class EndpointOverride(pydantic.BaseModel):
|
||||
override_base_url: str | None = None
|
||||
override_api_key: str | None = None
|
||||
|
||||
class EndpointOverrideGroup(FieldGroup):
|
||||
name: str = "endpoint_override"
|
||||
label: str = "Endpoint Override"
|
||||
description: str = ("Override the default base URL used by this client to access the {client_type} service API.\n\n"
|
||||
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
|
||||
"If the override endpoint requires an API key, enter it below.")
|
||||
icon: str = "mdi-api"
|
||||
|
||||
class EndpointOverrideField(ExtraField):
|
||||
group: EndpointOverrideGroup = pydantic.Field(default_factory=EndpointOverrideGroup)
|
||||
|
||||
class EndpointOverrideBaseURLField(EndpointOverrideField):
|
||||
name: str = "override_base_url"
|
||||
type: str = "text"
|
||||
label: str = "Base URL"
|
||||
required: bool = False
|
||||
description: str = "Override the base URL for the remote service"
|
||||
|
||||
class EndpointOverrideAPIKeyField(EndpointOverrideField):
|
||||
name: str = "override_api_key"
|
||||
type: str = "password"
|
||||
label: str = "API Key"
|
||||
required: bool = False
|
||||
description: str = "Override the API key for the remote service"
|
||||
note: ux_schema.Note = pydantic.Field(default_factory=lambda: ux_schema.Note(
|
||||
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
|
||||
color="warning",
|
||||
))
|
||||
|
||||
|
||||
class EndpointOverrideMixin:
|
||||
override_base_url: str | None = None
|
||||
override_api_key: str | None = None
|
||||
|
||||
def set_client_api_key(self, api_key: str | None):
|
||||
if getattr(self, "client", None):
|
||||
try:
|
||||
self.client.api_key = api_key
|
||||
except Exception as e:
|
||||
log.error("Error setting client API key", error=e, client=self.client_type)
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
if self.endpoint_override_base_url_configured:
|
||||
return self.override_api_key
|
||||
return getattr(self, f"{self.client_type}_api_key", None)
|
||||
|
||||
@property
|
||||
def base_url(self) -> str | None:
|
||||
if self.override_base_url and self.override_base_url.strip():
|
||||
return self.override_base_url
|
||||
return None
|
||||
|
||||
@property
|
||||
def endpoint_override_base_url_configured(self) -> bool:
|
||||
return self.override_base_url and self.override_base_url.strip()
|
||||
|
||||
@property
|
||||
def endpoint_override_api_key_configured(self) -> bool:
|
||||
return self.override_api_key and self.override_api_key.strip()
|
||||
|
||||
@property
|
||||
def endpoint_override_fully_configured(self) -> bool:
|
||||
return self.endpoint_override_base_url_configured and self.endpoint_override_api_key_configured
|
||||
|
||||
def _reconfigure_endpoint_override(self, **kwargs):
|
||||
if "override_base_url" in kwargs:
|
||||
orig = getattr(self, "override_base_url", None)
|
||||
self.override_base_url = kwargs["override_base_url"]
|
||||
if getattr(self, "client", None) and orig != self.override_base_url:
|
||||
log.info("Reconfiguring client base URL", new=self.override_base_url)
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "override_api_key" in kwargs:
|
||||
self.override_api_key = kwargs["override_api_key"]
|
||||
self.set_client_api_key(self.override_api_key)
|
||||
|
||||
class RemoteServiceMixin:
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import random
|
||||
import urllib
|
||||
from typing import Literal
|
||||
import aiohttp
|
||||
import json
|
||||
import httpx
|
||||
import pydantic
|
||||
import structlog
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
from openai import PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ExtraField, CommonDefaults
|
||||
from talemate.client.registry import register
|
||||
@@ -17,61 +17,6 @@ log = structlog.get_logger("talemate.client.tabbyapi")
|
||||
EXPERIMENTAL_DESCRIPTION = """Use this client to use all of TabbyAPI's features"""
|
||||
|
||||
|
||||
class CustomAPIClient:
|
||||
def __init__(self, base_url, api_key):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
async def get_model_name(self):
|
||||
url = urljoin(self.base_url, "model")
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Request failed: {response.status}")
|
||||
response_data = await response.json()
|
||||
model_name = response_data.get("id")
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
return model_name
|
||||
|
||||
async def create_chat_completion(self, model, messages, **parameters):
|
||||
url = urljoin(self.base_url, "chat/completions")
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Request failed: {response.status}")
|
||||
return await response.json()
|
||||
|
||||
async def create_completion(self, model, **parameters):
|
||||
url = urljoin(self.base_url, "completions")
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"model": model,
|
||||
**parameters,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Request failed: {response.status}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:5000/v1"
|
||||
api_key: str = ""
|
||||
@@ -153,7 +98,6 @@ class TabbyAPIClient(ClientBase):
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
self.client = CustomAPIClient(base_url=self.api_url, api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
)
|
||||
@@ -178,49 +122,150 @@ class TabbyAPIClient(ClientBase):
|
||||
return prompt
|
||||
|
||||
async def get_model_name(self):
|
||||
return await self.client.get_model_name()
|
||||
url = urljoin(self.api_url, "model")
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, timeout=10.0)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}")
|
||||
response_data = response.json()
|
||||
model_name = response_data.get("id")
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
Generates text from the given prompt and parameters using streaming responses.
|
||||
"""
|
||||
|
||||
# Determine whether we are using chat or completions endpoint
|
||||
is_chat = self.api_handles_prompt_template
|
||||
|
||||
try:
|
||||
if self.api_handles_prompt_template:
|
||||
# Custom API handles prompt template
|
||||
# Use the chat completions endpoint
|
||||
if is_chat:
|
||||
# Chat completions endpoint
|
||||
self.log.debug(
|
||||
"generate (chat/completions)",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
response = await self.client.create_chat_completion(
|
||||
self.model_name, [human_message], **parameters
|
||||
)
|
||||
response = response["choices"][0]["message"]["content"]
|
||||
return self.process_response_for_indirect_coercion(prompt, response)
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": [human_message],
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
**parameters,
|
||||
}
|
||||
endpoint = "chat/completions"
|
||||
else:
|
||||
# Talemate handles prompt template
|
||||
# Use the completions endpoint
|
||||
# Completions endpoint
|
||||
self.log.debug(
|
||||
"generate (completions)",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
)
|
||||
parameters["prompt"] = prompt
|
||||
response = await self.client.create_completion(
|
||||
self.model_name, **parameters
|
||||
)
|
||||
return response["choices"][0]["text"]
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
**parameters,
|
||||
}
|
||||
endpoint = "completions"
|
||||
|
||||
url = urljoin(self.api_url, endpoint)
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response_text = ""
|
||||
buffer = ""
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120.0
|
||||
) as response:
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
|
||||
while True:
|
||||
line_end = buffer.find('\n')
|
||||
if line_end == -1:
|
||||
break
|
||||
|
||||
line = buffer[:line_end].strip()
|
||||
buffer = buffer[line_end + 1:]
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data_obj = json.loads(data)
|
||||
|
||||
choice = data_obj.get("choices", [{}])[0]
|
||||
|
||||
# Chat completions use delta -> content.
|
||||
delta = choice.get("delta", {})
|
||||
content = (
|
||||
delta.get("content")
|
||||
or delta.get("text")
|
||||
or choice.get("text")
|
||||
)
|
||||
|
||||
usage = data_obj.get("usage", {})
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
if content:
|
||||
response_text += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
except json.JSONDecodeError:
|
||||
# ignore malformed json chunks
|
||||
pass
|
||||
|
||||
# Save token stats for logging
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
if is_chat:
|
||||
# Process indirect coercion
|
||||
response_text = self.process_response_for_indirect_coercion(prompt, response_text)
|
||||
|
||||
return response_text
|
||||
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Client API: Permission Denied", status="error")
|
||||
return ""
|
||||
except httpx.ConnectTimeout:
|
||||
self.log.error("API timeout")
|
||||
emit("status", message="TabbyAPI: Request timed out", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
emit("status", message="Error during generation (check logs)", status="error")
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
|
||||
@@ -195,7 +195,8 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
payload = json.loads(event.data)
|
||||
chunk = payload['choices'][0]['text']
|
||||
response += chunk
|
||||
|
||||
self.update_request_tokens(self.count_tokens(chunk))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,10 @@ class CmdSetEnvironmentToScene(TalemateCommand):
|
||||
player_character = self.scene.get_player_character()
|
||||
|
||||
if not player_character:
|
||||
self.system_message("No player character found")
|
||||
self.system_message("No characters found - cannot switch to gameplay mode.", meta={
|
||||
"icon": "mdi-alert",
|
||||
"color": "warning",
|
||||
})
|
||||
return True
|
||||
|
||||
self.scene.set_environment("scene")
|
||||
|
||||
@@ -93,6 +93,7 @@ class General(BaseModel):
|
||||
auto_save: bool = True
|
||||
auto_progress: bool = True
|
||||
max_backscroll: int = 512
|
||||
add_default_character: bool = True
|
||||
|
||||
|
||||
class StateReinforcementTemplate(BaseModel):
|
||||
@@ -161,6 +162,9 @@ class DeepSeekConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class OpenRouterConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
class RunPodConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
@@ -177,6 +181,7 @@ class CoquiConfig(BaseModel):
|
||||
class GoogleConfig(BaseModel):
|
||||
gcloud_credentials_path: Union[str, None] = None
|
||||
gcloud_location: Union[str, None] = None
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class TTSVoiceSamples(BaseModel):
|
||||
@@ -209,6 +214,7 @@ class EmbeddingFunctionPreset(BaseModel):
|
||||
gpu_recommendation: bool = False
|
||||
local: bool = True
|
||||
custom: bool = False
|
||||
client: str | None = None
|
||||
|
||||
|
||||
|
||||
@@ -506,6 +512,8 @@ class Config(BaseModel):
|
||||
|
||||
anthropic: AnthropicConfig = AnthropicConfig()
|
||||
|
||||
openrouter: OpenRouterConfig = OpenRouterConfig()
|
||||
|
||||
cohere: CohereConfig = CohereConfig()
|
||||
|
||||
groq: GroqConfig = GroqConfig()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
__all__ = [
|
||||
"ArchiveEntry",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchiveEntry:
|
||||
text: str
|
||||
start: int = None
|
||||
end: int = None
|
||||
ts: str = None
|
||||
@@ -25,7 +25,7 @@ class AsyncSignal:
|
||||
async def send(self, emission):
|
||||
for receiver in self.receivers:
|
||||
await receiver(emission)
|
||||
|
||||
|
||||
|
||||
def _register(name: str):
|
||||
"""
|
||||
|
||||
@@ -180,11 +180,11 @@ class Emitter:
|
||||
def setup_emitter(self, scene: Scene = None):
|
||||
self.emit_for_scene = scene
|
||||
|
||||
def emit(self, typ: str, message: str, character: Character = None):
|
||||
emit(typ, message, character=character, scene=self.emit_for_scene)
|
||||
def emit(self, typ: str, message: str, character: Character = None, **kwargs):
|
||||
emit(typ, message, character=character, scene=self.emit_for_scene, **kwargs)
|
||||
|
||||
def system_message(self, message: str):
|
||||
self.emit("system", message)
|
||||
def system_message(self, message: str, **kwargs):
|
||||
self.emit("system", message, **kwargs)
|
||||
|
||||
def narrator_message(self, message: str):
|
||||
self.emit("narrator", message)
|
||||
|
||||
@@ -49,6 +49,8 @@ SpiceApplied = signal("spice_applied")
|
||||
|
||||
WorldSateManager = signal("world_state_manager")
|
||||
|
||||
TalemateStarted = signal("talemate_started")
|
||||
|
||||
handlers = {
|
||||
"system": SystemMessage,
|
||||
"narrator": NarratorMessage,
|
||||
@@ -86,4 +88,5 @@ handlers = {
|
||||
"memory_request": MemoryRequest,
|
||||
"player_choice": PlayerChoiceMessage,
|
||||
"world_state_manager": WorldSateManager,
|
||||
"talemate_started": TalemateStarted,
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
import pydantic
|
||||
|
||||
import talemate.emit.async_signals as async_signals
|
||||
|
||||
@@ -29,10 +28,9 @@ class HistoryEvent(Event):
|
||||
@dataclass
|
||||
class ArchiveEvent(Event):
|
||||
text: str
|
||||
memory_id: str = None
|
||||
memory_id: str
|
||||
ts: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharacterStateEvent(Event):
|
||||
state: str
|
||||
|
||||