Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de16feeed5 | ||
|
|
cdcc804ffa | ||
|
|
9a2bbd78a4 | ||
|
|
ddfbd6891b | ||
|
|
143dd47e02 | ||
|
|
cc7cb773d1 | ||
|
|
02c88f75a1 | ||
|
|
419371e0fb | ||
|
|
6e847bf283 | ||
|
|
ceedd3019f | ||
|
|
a28cf2a029 | ||
|
|
60cb271e30 | ||
|
|
1874234d2c | ||
|
|
ef99539e69 | ||
|
|
39bd02722d | ||
|
|
f0b627b900 | ||
|
|
95ae00e01f | ||
|
|
83027b3a0f | ||
|
|
27eba3bd63 | ||
|
|
ba64050eab | ||
|
|
199ffd1095 | ||
|
|
88b9fcb8bb | ||
|
|
2f5944bc09 | ||
|
|
abdfb1abbf | ||
|
|
2f07248211 | ||
|
|
9ae6fc822b | ||
|
|
5094359c4e | ||
|
|
28801b54bf | ||
|
|
4d69f0e837 | ||
|
|
d91b3f8042 | ||
|
|
03a0ab2fcf | ||
|
|
d860d62972 | ||
|
|
add4893939 | ||
|
|
eb251d6e37 | ||
|
|
4ba635497b | ||
|
|
bdbf14c1ed | ||
|
|
c340fc085c | ||
|
|
94f8d0f242 | ||
|
|
1d8a9b113c | ||
|
|
1837796852 | ||
|
|
c5c53c056e | ||
|
|
f1b1190f0b |
1
.gitignore
vendored
@@ -16,3 +16,4 @@ scenes/
|
||||
!scenes/infinity-quest-dynamic-scenario/infinity-quest.json
|
||||
!scenes/infinity-quest/assets/
|
||||
!scenes/infinity-quest/infinity-quest.json
|
||||
tts_voice_samples/*.wav
|
||||
25
Dockerfile.backend
Normal file
@@ -0,0 +1,25 @@
|
||||
# Use an official Python runtime as a parent image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the current directory contents into the container at /app
|
||||
COPY ./src /app/src
|
||||
|
||||
# Copy poetry files
|
||||
COPY pyproject.toml /app/
|
||||
# If there's a poetry lock file, include the following line
|
||||
COPY poetry.lock /app/
|
||||
|
||||
# Install poetry
|
||||
RUN pip install poetry
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry install --no-dev
|
||||
|
||||
# Make port 5050 available to the world outside this container
|
||||
EXPOSE 5050
|
||||
|
||||
# Run backend server
|
||||
CMD ["poetry", "run", "python", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050"]
|
||||
23
Dockerfile.frontend
Normal file
@@ -0,0 +1,23 @@
|
||||
# Use an official node runtime as a parent image
|
||||
FROM node:20
|
||||
|
||||
# Make sure we are in a development environment (this isn't a production ready Dockerfile)
|
||||
ENV NODE_ENV=development
|
||||
|
||||
# Echo that this isn't a production ready Dockerfile
|
||||
RUN echo "This Dockerfile is not production ready. It is intended for development purposes only."
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the frontend directory contents into the container at /app
|
||||
COPY ./talemate_frontend /app
|
||||
|
||||
# Install all dependencies
|
||||
RUN npm install
|
||||
|
||||
# Make port 8080 available to the world outside this container
|
||||
EXPOSE 8080
|
||||
|
||||
# Run frontend server
|
||||
CMD ["npm", "run", "serve"]
|
||||
275
README.md
@@ -1,70 +1,68 @@
|
||||
# Talemate
|
||||
|
||||
Allows you to play roleplay scenarios with large language models.
|
||||
Roleplay with AI with a focus on strong narration and consistent world and game state tracking.
|
||||
|
||||
|
||||
|||
|
||||
|||
|
||||
|------------------------------------------|------------------------------------------|
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio.**
|
||||
Supported APIs:
|
||||
- [OpenAI](https://platform.openai.com/overview)
|
||||
- [Anthropic](https://www.anthropic.com/)
|
||||
- [mistral.ai](https://mistral.ai/)
|
||||
- [Cohere](https://www.cohere.com/)
|
||||
- [Groq](https://www.groq.com/)
|
||||
- [Google Gemini](https://console.cloud.google.com/)
|
||||
|
||||
This means you need to either have:
|
||||
- an [OpenAI](https://platform.openai.com/overview) api key
|
||||
- OR setup local (or remote via runpod) LLM inference via one of these options:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
Supported self-hosted APIs:
|
||||
- [KoboldCpp](https://koboldai.org/cpp) ([Local](https://koboldai.org/cpp), [Runpod](https://koboldai.org/runpodcpp), [VastAI](https://koboldai.org/vastcpp), also includes image gen support)
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
|
||||
## Current features
|
||||
Generic OpenAI api implementations (tested and confirmed working):
|
||||
- [DeepInfra](https://deepinfra.com/)
|
||||
- [llamacpp](https://github.com/ggerganov/llama.cpp) with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
|
||||
- responive modern ui
|
||||
- agents
|
||||
- conversation: handles character dialogue
|
||||
- narration: handles narrative exposition
|
||||
- summarization: handles summarization to compress context while maintain history
|
||||
- director: can be used to direct the story / characters
|
||||
- editor: improves AI responses (very hit and miss at the moment)
|
||||
- world state: generates world snapshot and handles passage of time (objects and characters)
|
||||
- creator: character / scenario creator
|
||||
- tts: text to speech via elevenlabs, coqui studio, coqui local
|
||||
- multi-client support (agents can be connected to separate APIs)
|
||||
- long term memory
|
||||
- chromadb integration
|
||||
- passage of time
|
||||
- narrative world state
|
||||
- Automatically keep track and reinforce selected character and world truths / states.
|
||||
- narrative tools
|
||||
- creative tools
|
||||
- AI backed character creation with template support (jinja2)
|
||||
- AI backed scenario creation
|
||||
- context managegement
|
||||
- Manage character details and attributes
|
||||
- Manage world information / past events
|
||||
- Pin important information to the context (Manually or conditionally through AI)
|
||||
- runpod integration
|
||||
- overridable templates for all prompts. (jinja2)
|
||||
## Core Features
|
||||
|
||||
## Planned features
|
||||
- Multiple AI agents for dialogue, narration, summarization, direction, editing, world state management, character/scenario creation, text-to-speech, and visual generation
|
||||
- Support for multiple AI clients and APIs
|
||||
- Long-term memory using ChromaDB and passage of time tracking
|
||||
- Narrative world state management to reinforce character and world truths
|
||||
- Creative tools for managing NPCs, AI-assisted character, and scenario creation with template support
|
||||
- Context management for character details, world information, past events, and pinned information
|
||||
- Integration with Runpod
|
||||
- Customizable templates for all prompts using Jinja2
|
||||
- Modern, responsive UI
|
||||
|
||||
Kinda making it up as i go along, but i want to lean more into gameplay through AI, keeping track of gamestates, moving away from simply roleplaying towards a more game-ified experience.
|
||||
# Instructions
|
||||
|
||||
In no particular order:
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
|
||||
|
||||
- Extension support
|
||||
- modular agents and clients
|
||||
- Improved world state
|
||||
- Dynamic player choice generation
|
||||
- Better creative tools
|
||||
- node based scenario / character creation
|
||||
- Improved and consistent long term memory and accurate current state of the world
|
||||
- Improved director agent
|
||||
- Right now this doesn't really work well on anything but GPT-4 (and even there it's debatable). It tends to steer the story in a way that introduces pacing issues. It needs a model that is creative but also reasons really well i think.
|
||||
- Gameplay loop governed by AI
|
||||
- objectives
|
||||
- quests
|
||||
- win / lose conditions
|
||||
- stable-diffusion client for in place visual generation
|
||||
- [Quickstart](#quickstart)
|
||||
- [Installation](#installation)
|
||||
- [Windows](#windows)
|
||||
- [Linux](#linux)
|
||||
- [Docker](#docker)
|
||||
- [Connecting to an LLM](#connecting-to-an-llm)
|
||||
- [OpenAI / mistral.ai / Anthropic](#openai--mistralai--anthropic)
|
||||
- [Text-generation-webui / LMStudio](#text-generation-webui--lmstudio)
|
||||
- [Specifying the correct prompt template](#specifying-the-correct-prompt-template)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [DeepInfra via OpenAI Compatible client](#deepinfra-via-openai-compatible-client)
|
||||
- [Google Gemini](#google-gemini)
|
||||
- [Google Cloud Setup](#google-cloud-setup)
|
||||
- [Ready to go](#ready-to-go)
|
||||
- [Load the introductory scenario "Infinity Quest"](#load-the-introductory-scenario-infinity-quest)
|
||||
- [Loading character cards](#loading-character-cards)
|
||||
- [Configure for hosting](#configure-for-hosting)
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [Visual Generation](docs/visual.md)
|
||||
- [ChromaDB (long term memory) configuration](docs/chromadb.md)
|
||||
- [Runpod Integration](docs/runpod.md)
|
||||
- [Prompt template overrides](docs/templates.md)
|
||||
|
||||
# Quickstart
|
||||
|
||||
@@ -77,7 +75,7 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
### Windows
|
||||
|
||||
1. Download and install Python 3.10 or Python 3.11 from the [official Python website](https://www.python.org/downloads/windows/). :warning: python3.12 is currently not supported.
|
||||
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm.
|
||||
1. Download and install Node.js v20 from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm. :warning: v21 is currently not supported.
|
||||
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases).
|
||||
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine.
|
||||
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`.
|
||||
@@ -87,70 +85,149 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
|
||||
`python 3.10` or `python 3.11` is required. :warning: `python 3.12` not supported yet.
|
||||
|
||||
1. `git clone git@github.com:vegu-ai/talemate`
|
||||
`nodejs v19 or v20` :warning: `v21` not supported yet.
|
||||
|
||||
1. `git clone https://github.com/vegu-ai/talemate.git`
|
||||
1. `cd talemate`
|
||||
1. `source install.sh`
|
||||
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
|
||||
|
||||
## Configuration
|
||||
### Docker
|
||||
|
||||
### OpenAI
|
||||
:warning: Some users currently experience issues with missing dependencies inside the docker container, issue tracked at [#114](https://github.com/vegu-ai/talemate/issues/114)
|
||||
|
||||
To set your openai api key, open `config.yaml` in any text editor and uncomment / add
|
||||
1. `git clone https://github.com/vegu-ai/talemate.git`
|
||||
1. `cd talemate`
|
||||
1. `cp config.example.yaml config.yaml`
|
||||
1. `docker compose up`
|
||||
1. Navigate your browser to http://localhost:8080
|
||||
|
||||
```yaml
|
||||
openai:
|
||||
api_key: sk-my-api-key-goes-here
|
||||
```
|
||||
:warning: When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.
|
||||
|
||||
You will need to restart the backend for this change to take effect.
|
||||
#### To shut down the Docker container
|
||||
|
||||
### RunPod
|
||||
Just closing the terminal window will not stop the Docker container. You need to run `docker compose down` to stop the container.
|
||||
|
||||
To set your runpod api key, open `config.yaml` in any text editor and uncomment / add
|
||||
#### How to install Docker
|
||||
|
||||
```yaml
|
||||
runpod:
|
||||
api_key: my-api-key-goes-here
|
||||
```
|
||||
You will need to restart the backend for this change to take effect.
|
||||
1. Download and install Docker Desktop from the [official Docker website](https://www.docker.com/products/docker-desktop).
|
||||
|
||||
Once the api key is set Pods loaded from text-generation-webui templates (or the bloke's runpod llm template) will be autoamtically added to your client list in talemate.
|
||||
|
||||
**ATTENTION**: Talemate is not a suitable for way for you to determine whether your pod is currently running or not. **Always** check the runpod dashboard to see if your pod is running or not.
|
||||
|
||||
## Recommended Models
|
||||
(as of2023.10.25)
|
||||
|
||||
Any of the top models in any of the size classes here should work well:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/17fhp9k/huge_llm_comparisontest_39_models_tested_7b70b/
|
||||
|
||||
## Connecting to an LLM
|
||||
# Connecting to an LLM
|
||||
|
||||
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
|
||||
|
||||

|
||||
|
||||
### Text-generation-webui
|
||||

|
||||
|
||||
## OpenAI / mistral.ai / Anthropic
|
||||
|
||||
The setup is the same for all three, the example below is for OpenAI.
|
||||
|
||||
If you want to add an OpenAI client, just change the client type and select the appropriate model.
|
||||
|
||||

|
||||
|
||||
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
|
||||
|
||||

|
||||
|
||||
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
|
||||
|
||||

|
||||
|
||||
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## Text-generation-webui / LMStudio
|
||||
|
||||
> :warning: As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
|
||||
In the modal if you're planning to connect to text-generation-webui, you can likely leave everything as is and just click Save.
|
||||
|
||||

|
||||

|
||||
|
||||
### OpenAI
|
||||
### Specifying the correct prompt template
|
||||
|
||||
If you want to add an OpenAI client, just change the client type and select the apropriate model.
|
||||
For good results it is **vital** that the correct prompt template is specified for whichever model you have loaded.
|
||||
|
||||

|
||||
Talemate does come with a set of pre-defined templates for some popular models, but going forward, due to the sheet number of models released every day, understanding and specifying the correct prompt template is something you should familiarize yourself with.
|
||||
|
||||
### Ready to go
|
||||
If the text-gen-webui client shows a yellow triangle next to it, it means that the prompt template is not set, and it is currently using the default `VICUNA` style prompt template.
|
||||
|
||||

|
||||
|
||||
Click the two cogwheels to the right of the triangle to open the client settings.
|
||||
|
||||

|
||||
|
||||
You can first try by clicking the `DETERMINE VIA HUGGINGFACE` button, depending on the model's README file, it may be able to determine the correct prompt template for you. (basically the readme needs to contain an example of the template)
|
||||
|
||||
If that doesn't work, you can manually select the prompt template from the dropdown.
|
||||
|
||||
In the case for `bartowski_Nous-Hermes-2-Mistral-7B-DPO-exl2_8_0` that is `ChatML` - select it from the dropdown and click `Save`.
|
||||
|
||||

|
||||
|
||||
### Recommended Models
|
||||
|
||||
Any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
|
||||
[https://oobabooga.github.io/benchmark.html](https://oobabooga.github.io/benchmark.html)
|
||||
|
||||
## DeepInfra via OpenAI Compatible client
|
||||
|
||||
You can use the OpenAI compatible client to connect to [DeepInfra](https://deepinfra.com/).
|
||||
|
||||

|
||||
|
||||
```
|
||||
API URL: https://api.deepinfra.com/v1/openai
|
||||
```
|
||||
|
||||
Models on DeepInfra that work well with Talemate:
|
||||
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://deepinfra.com/mistralai/Mixtral-8x7B-Instruct-v0.1) (max context 32k, 8k recommended)
|
||||
- [cognitivecomputations/dolphin-2.6-mixtral-8x7b](https://deepinfra.com/cognitivecomputations/dolphin-2.6-mixtral-8x7b) (max context 32k, 8k recommended)
|
||||
- [lizpreciatior/lzlv_70b_fp16_hf](https://deepinfra.com/lizpreciatior/lzlv_70b_fp16_hf) (max context 4k)
|
||||
|
||||
## Google Gemini
|
||||
|
||||
### Google Cloud Setup
|
||||
|
||||
Unlike the other clients the setup for Google Gemini is a bit more involved as you will need to set up a google cloud project and credentials for it.
|
||||
|
||||
Please follow their [instructions for setup](https://cloud.google.com/vertex-ai/docs/start/client-libraries) - which includes setting up a project, enabling the Vertex AI API, creating a service account, and downloading the credentials.
|
||||
|
||||
Once you have downloaded the credentials, copy the JSON file into the talemate directory. You can rename it to something that's easier to remember, like `my-credentials.json`.
|
||||
|
||||
### Add the client
|
||||
|
||||

|
||||
|
||||
The `Disable Safety Settings` option will turn off the google reponse validation for what they consider harmful content. Use at your own risk.
|
||||
|
||||
### Conmplete the google cloud setup in talemate
|
||||
|
||||

|
||||
|
||||
Click the `SETUP GOOGLE API CREDENTIALS` button that will appear on the client.
|
||||
|
||||
The google cloud setup modal will appear, fill in the path to the credentials file and select a location that is close to you.
|
||||
|
||||

|
||||
|
||||
Click save and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## Ready to go
|
||||
|
||||
You will know you are good to go when the client and all the agents have a green dot next to them.
|
||||
|
||||

|
||||

|
||||
|
||||
## Load the introductory scenario "Infinity Quest"
|
||||
|
||||
@@ -172,12 +249,16 @@ Once a character is uploaded, talemate may actually take a moment because it nee
|
||||
|
||||
Make sure you save the scene after the character is loaded as it can then be loaded as normal talemate scenario in the future.
|
||||
|
||||
## Further documentation
|
||||
## Configure for hosting
|
||||
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
By default talemate is configured to run locally. If you want to host it behind a reverse proxy or on a server, you will need create some environment variables in the `talemate_frontend/.env.development.local` file
|
||||
|
||||
- [Prompt template overrides](docs/templates.md)
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [ChromaDB (long term memory)](docs/chromadb.md)
|
||||
- [Runpod Integration](docs/runpod.md)
|
||||
- Creative mode
|
||||
Start by copying `talemate_frontend/example.env.development.local` to `talemate_frontend/.env.development.local`.
|
||||
|
||||
Then open the file and edit the `ALLOWED_HOSTS` and `VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL` variables.
|
||||
|
||||
```sh
|
||||
ALLOWED_HOSTS=example.com
|
||||
# wss if behind ssl, ws if not
|
||||
VUE_APP_TALEMATE_BACKEND_WEBSOCKET_URL=wss://example.com:5050
|
||||
```
|
||||
|
||||
@@ -48,6 +48,7 @@ game:
|
||||
# embeddings: instructor
|
||||
# instructor_device: cuda
|
||||
# instructor_model: hkunlp/instructor-xl
|
||||
# openai_model: text-embedding-3-small
|
||||
|
||||
## Remote LLMs
|
||||
|
||||
|
||||
27
docker-compose.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
talemate-backend:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.backend
|
||||
ports:
|
||||
- "5050:5050"
|
||||
volumes:
|
||||
# can uncomment for dev purposes
|
||||
#- ./src/talemate:/app/src/talemate
|
||||
- ./config.yaml:/app/config.yaml
|
||||
- ./scenes:/app/scenes
|
||||
- ./templates:/app/templates
|
||||
- ./chroma:/app/chroma
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
||||
talemate-frontend:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.frontend
|
||||
ports:
|
||||
- "8080:8080"
|
||||
#volumes:
|
||||
# - ./talemate_frontend:/app
|
||||
@@ -56,6 +56,7 @@ Then add the following to `config.yaml` for chromadb:
|
||||
```yaml
|
||||
chromadb:
|
||||
embeddings: openai
|
||||
openai_model: text-embedding-3-small
|
||||
```
|
||||
|
||||
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. ALSO - if you send super explicit content they may flag / ban your key, so keep that in mind (i hear they usually send warnings first though), and always monitor your usage on their dashboard.
|
||||
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. Always monitor your usage on their dashboard.
|
||||
|
||||
48
docs/dev/agents/example/test/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from talemate.agents.base import Agent, AgentAction
|
||||
from talemate.agents.registry import register
|
||||
from talemate.events import GameLoopEvent
|
||||
import talemate.emit.async_signals
|
||||
from talemate.emit import emit
|
||||
|
||||
@register()
|
||||
class TestAgent(Agent):
|
||||
|
||||
agent_type = "test"
|
||||
verbose_name = "Test"
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"test": AgentAction(
|
||||
enabled=True,
|
||||
label="Test",
|
||||
description="Test",
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called on the beginning of every game loop
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
emit("status", status="info", message="Annoying you with a test message every game loop.")
|
||||
130
docs/dev/client/example/runpod_vllm/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
An attempt to write a client against the runpod serverless vllm worker.
|
||||
|
||||
This is close to functional, but since runpod serverless gpu availability is currently terrible, i have
|
||||
been unable to properly test it.
|
||||
|
||||
Putting it here for now since i think it makes a decent example of how to write a client against a new service.
|
||||
"""
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import runpod
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.runpod_vllm")
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
runpod_id: str = ""
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
runpod_id: str = ""
|
||||
|
||||
@register()
|
||||
class RunPodVLLMClient(ClientBase):
|
||||
client_type = "runpod_vllm"
|
||||
conversation_retries = 5
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title: str = "Runpod VLLM"
|
||||
name_prefix: str = "Runpod VLLM"
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"runpod_id": ExtraField(
|
||||
name="runpod_id",
|
||||
type="text",
|
||||
label="Runpod ID",
|
||||
required=True,
|
||||
description="The Runpod ID to connect to.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, model=None, runpod_id=None, **kwargs):
|
||||
self.model_name = model
|
||||
self.runpod_id = runpod_id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
|
||||
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
|
||||
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
prompt = prompt.strip()
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
|
||||
|
||||
run_request = await endpoint.run({
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
}
|
||||
#"parameters": parameters
|
||||
})
|
||||
|
||||
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
|
||||
status = await run_request.status()
|
||||
log.debug("generate", status=status)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
status = await run_request.status()
|
||||
|
||||
log.debug("generate", status=status)
|
||||
|
||||
response = await run_request.output()
|
||||
|
||||
log.debug("generate", response=response)
|
||||
|
||||
return response["choices"][0]["tokens"][0]
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
if "runpod_id" in kwargs:
|
||||
self.api_auth = kwargs["runpod_id"]
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
self.set_client(**kwargs)
|
||||
67
docs/dev/client/example/test/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pydantic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
@register()
|
||||
class TestClient(ClientBase):
|
||||
client_type = "test"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "test"
|
||||
title: str = "Test"
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
|
||||
"""
|
||||
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
|
||||
|
||||
This method is called before the prompt is sent to the client, and it allows the client to remove
|
||||
any parameters that it doesn't support.
|
||||
"""
|
||||
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
"""
|
||||
This should return the name of the model that is being used.
|
||||
"""
|
||||
|
||||
return "Mock test model"
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
BIN
docs/img/0.18.0/openai-api-key-1.png
Normal file
|
After Width: | Height: | Size: 5.6 KiB |
BIN
docs/img/0.18.0/openai-api-key-2.png
Normal file
|
After Width: | Height: | Size: 24 KiB |
BIN
docs/img/0.18.0/openai-api-key-3.png
Normal file
|
After Width: | Height: | Size: 4.7 KiB |
BIN
docs/img/0.19.0/Screenshot_15.png
Normal file
|
After Width: | Height: | Size: 418 KiB |
BIN
docs/img/0.19.0/Screenshot_16.png
Normal file
|
After Width: | Height: | Size: 413 KiB |
BIN
docs/img/0.19.0/Screenshot_17.png
Normal file
|
After Width: | Height: | Size: 364 KiB |
BIN
docs/img/0.20.0/comfyui-base-workflow.png
Normal file
|
After Width: | Height: | Size: 128 KiB |
BIN
docs/img/0.20.0/visual-config-a1111.png
Normal file
|
After Width: | Height: | Size: 32 KiB |
BIN
docs/img/0.20.0/visual-config-comfyui.png
Normal file
|
After Width: | Height: | Size: 34 KiB |
BIN
docs/img/0.20.0/visual-config-openai.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
docs/img/0.20.0/visual-queue.png
Normal file
|
After Width: | Height: | Size: 933 KiB |
BIN
docs/img/0.20.0/visualize-scene-tools.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
docs/img/0.20.0/visualizer-busy.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
docs/img/0.20.0/visualizer-ready.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
docs/img/0.20.0/visualze-new-images.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
docs/img/0.21.0/deepinfra-setup.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/img/0.21.0/no-clients.png
Normal file
|
After Width: | Height: | Size: 7.1 KiB |
BIN
docs/img/0.21.0/openai-add-api-key.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/img/0.21.0/openai-setup.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
docs/img/0.21.0/prompt-template-default.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
docs/img/0.21.0/ready-to-go.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.21.0/select-prompt-template.png
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
docs/img/0.21.0/selected-prompt-template.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
docs/img/0.21.0/text-gen-webui-setup.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
docs/img/0.25.0/google-add-client.png
Normal file
|
After Width: | Height: | Size: 25 KiB |
BIN
docs/img/0.25.0/google-cloud-setup.png
Normal file
|
After Width: | Height: | Size: 59 KiB |
BIN
docs/img/0.25.0/google-ready.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
docs/img/0.25.0/google-setup-incomplete.png
Normal file
|
After Width: | Height: | Size: 7.5 KiB |
15
docs/tts.md
@@ -17,21 +17,6 @@ elevenlabs:
|
||||
api_key: <YOUR_ELEVENLABS_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Coqui TTS
|
||||
|
||||
To use Coqui TTS with Talemate, follow these steps:
|
||||
|
||||
1. Visit [Coqui](https://app.coqui.ai) and sign up for an account.
|
||||
2. Go to the [account page](https://app.coqui.ai/account) and scroll to the bottom to find your API key.
|
||||
3. In the `config.yaml` file, under the `coqui` section, set the `api_key` field with your Coqui API key.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
coqui:
|
||||
api_key: <YOUR_COQUI_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Local TTS API
|
||||
|
||||
For running a local TTS API, Talemate requires specific dependencies to be installed.
|
||||
|
||||
117
docs/visual.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Visual Agent
|
||||
|
||||
The visual agent currently allows for some bare bones visual generation using various stable-diffusion APIs. This is early development and experimental.
|
||||
|
||||
Its important to note that the visualization agent actually specifies two clients. One is the backend for the visual generation, and the other is the text generation client to use for prompt generation.
|
||||
|
||||
The client for prompt generation can be assigned to the agent as you would for any other agent. The client for visual generation is assigned in the Visualizer config.
|
||||
|
||||
## Index
|
||||
|
||||
- [OpenAI](#openai)
|
||||
- [AUTOMATIC1111](#automatic1111)
|
||||
- [ComfyUI](#comfyui)
|
||||
- [How to use](#how-to-use)
|
||||
|
||||
## OpenAI
|
||||
|
||||
Most straightforward to use, as it runs on the OpenAI API. You will need to have an API key and set it in the application config.
|
||||
|
||||

|
||||
|
||||
Then open the Visualizer config by clicking the agent's name in the agent list and choose `OpenAI` as the backend.
|
||||
|
||||

|
||||
|
||||
Note: `Client` here refers to the text-generation client to use for prompt generation. While `Backend` refers to the visual generation backend. You are **NOT** required to use the OpenAI client for prompt generation even if you are using the OpenAI backend for image generation.
|
||||
|
||||
## AUTOMATIC1111
|
||||
|
||||
This requires you to setup a local instance of the AUTOMATIC1111 API. Follow the instructions from their [GitHub](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to get it running.
|
||||
|
||||
Once you have it running, you will want to adjust the `webui-user.bat` in the AUTOMATIC1111 directory to include the following command arguments:
|
||||
|
||||
```bat
|
||||
set COMMANDLINE_ARGS=--api --listen --port 7861
|
||||
```
|
||||
|
||||
Then run the `webui-user.bat` to start the API.
|
||||
|
||||
Once your AUTOAMTIC1111 API is running (check with your browser) you can set the Visualizer config to use the `AUTOMATIC1111` backend
|
||||
|
||||

|
||||
|
||||
#### Extra Configuration
|
||||
|
||||
- `api url`: the url of the API, usually `http://localhost:7861`
|
||||
- `steps`: render steps
|
||||
- `model type`: sdxl or sd1.5 - this will dictate the resolution of the image generation and actually matters for the quality so make sure this is set to the correct model type for the model you are using.
|
||||
|
||||
## ComfyUI
|
||||
|
||||
This requires you to setup a local instance of the ComfyUI API. Follow the instructions from their [GitHub](https://github.com/comfyanonymous/ComfyUI) to get it running.
|
||||
|
||||
Once you're setup, copy their `start.bat` file to a new `start-listen.bat` file and change the contents to.
|
||||
|
||||
```bat
|
||||
call venv\Scripts\activate
|
||||
call python main.py --port 8188 --listen 0.0.0.0
|
||||
```
|
||||
|
||||
Then run the `start-listen.bat` to start the API.
|
||||
|
||||
Once your ComfyUI API is running (check with your browser) you can set the Visualizer config to use the `ComfyUI` backend.
|
||||
|
||||

|
||||
|
||||
### Extra Configuration
|
||||
|
||||
- `api url`: the url of the API, usually `http://localhost:8188`
|
||||
- `workflow`: the workflow file to use. This is a comfyui api workflow file that needs to exist in `./templates/comfyui-workflows` inside the talemate directory. Talemate provides two very barebones workflows with `default-sdxl.json` and `default-sd15.json`. You can create your own workflows and place them in this directory to use them. :warning: The workflow file must be generated using the API Workflow export not the UI export. Please refer to their documentation for more information.
|
||||
- `checkpoint`: the model to use - this will load a list of all available models in your comfyui instance. Select which one you want to use for the image generation.
|
||||
|
||||
### Custom Workflows
|
||||
|
||||
When creating custom workflows for ideal compatibility with Talemate, ensure the following.
|
||||
|
||||
- A `CheckpointLoaderSimple` node named `Talemate Load Checkpoint`
|
||||
- A `EmptyLatentImage` node name `Talemate Resolution`
|
||||
- A `ClipTextEncode` node named `Talemate Positive Prompt`
|
||||
- A `ClipTextEncode` node named `Talemate Negative Prompt`
|
||||
- A `SaveImage` node at the end of the workflow.
|
||||
|
||||

|
||||
|
||||
## How to use
|
||||
|
||||
Once you're done setting up the visualizer agent should have a green dot next to it and display both the selected image generation backend and the selected prompt generation client.
|
||||
|
||||

|
||||
|
||||
Your hotbar should then also enable the visualization menu for you to use (once you have a scene loaded).
|
||||
|
||||

|
||||
|
||||
Right now you can generate a portrait for any NPC in the scene or a background image for the scene itself.
|
||||
|
||||
Image generation by default will actually happen in the background, allowing you to continue using Talemate while the image is being generated.
|
||||
|
||||
You can tell if an image is being generated by the blueish spinner next to the visualization agent.
|
||||
|
||||

|
||||
|
||||
Once the image is generated, it will be avaible for you to view via the visual queue button on top of the screen.
|
||||
|
||||

|
||||
|
||||
Click it to open the visual queue and view the generated images.
|
||||
|
||||

|
||||
|
||||
### Character Portrait
|
||||
|
||||
For character potraits you can chose whether or not to replace the main portrait for the character (the one being displated in the left sidebar when a talemate scene is active).
|
||||
|
||||
### Background Image
|
||||
|
||||
Right now there is nothing to do with the background image, other than to view it in the visual queue. More functionality will be added in the future.
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# create a virtual environment
|
||||
python -m venv talemate_env
|
||||
python3 -m venv talemate_env
|
||||
|
||||
# activate the virtual environment
|
||||
source talemate_env/bin/activate
|
||||
|
||||
4455
poetry.lock
generated
@@ -4,13 +4,13 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.18.0"
|
||||
version = "0.25.6"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
python = ">=3.10,<3.12"
|
||||
astroid = "^2.8"
|
||||
jedi = "^0.18"
|
||||
black = "*"
|
||||
@@ -18,9 +18,13 @@ rope = "^0.22"
|
||||
isort = "^5.10"
|
||||
jinja2 = "^3.0"
|
||||
openai = ">=1"
|
||||
mistralai = ">=0.1.8"
|
||||
cohere = ">=5.2.2"
|
||||
anthropic = ">=0.19.1"
|
||||
groq = ">=0.5.0"
|
||||
requests = "^2.26"
|
||||
colorama = ">=0.4.6"
|
||||
Pillow = "^9.5"
|
||||
Pillow = ">=9.5"
|
||||
httpx = "<1"
|
||||
piexif = "^1.1"
|
||||
typing-inspect = "0.8.0"
|
||||
@@ -33,18 +37,22 @@ python-dotenv = "^1.0.0"
|
||||
websockets = "^11.0.3"
|
||||
structlog = "^23.1.0"
|
||||
runpod = "^1.2.0"
|
||||
google-cloud-aiplatform = ">=1.50.0"
|
||||
nest_asyncio = "^1.5.7"
|
||||
isodate = ">=0.6.1"
|
||||
thefuzz = ">=0.20.0"
|
||||
tiktoken = ">=0.5.1"
|
||||
nltk = ">=3.8.1"
|
||||
huggingface-hub = ">=0.20.2"
|
||||
RestrictedPython = ">7.1"
|
||||
|
||||
# ChromaDB
|
||||
chromadb = ">=0.4.17,<1"
|
||||
InstructorEmbedding = "^1.0.1"
|
||||
torch = ">=2.1.0"
|
||||
sentence-transformers="^2.2.2"
|
||||
torchaudio = ">=2.3.0"
|
||||
# locked for instructor embeddings
|
||||
sentence-transformers="==2.2.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^6.2"
|
||||
|
||||
|
After Width: | Height: | Size: 1.6 MiB |
@@ -98,6 +98,7 @@
|
||||
}
|
||||
],
|
||||
"immutable_save": true,
|
||||
"experimental": true,
|
||||
"goal": null,
|
||||
"goals": [],
|
||||
"context": "an epic sci-fi adventure aimed at an adult audience.",
|
||||
@@ -109,10 +110,10 @@
|
||||
"variables": {}
|
||||
},
|
||||
"assets": {
|
||||
"cover_image": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
"cover_image": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
"assets": {
|
||||
"52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df": {
|
||||
"id": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
"e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404": {
|
||||
"id": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
"file_type": "png",
|
||||
"media_type": "image/png"
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
{%- set _ = emit_system("warning", "This is a dynamic scenario generation experiment for Infinity Quest. It will likely require a strong LLM to generate something coherent. GPT-4 or 34B+ if local. Temper your expectations.") -%}
|
||||
|
||||
{#- emit status update to the UX -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]", as_scene_message=True) -%}
|
||||
|
||||
{#- thematic tags will be used to randomize generation -#}
|
||||
{%- set tags = thematic_generator.generate("color", "state_of_matter", "scifi_trope") -%}
|
||||
@@ -17,17 +17,17 @@
|
||||
|
||||
|
||||
{#- generate introductory text -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]", as_scene_message=True) -%}
|
||||
{%- set tmpl__scenario_intro = render_template('generate-scenario-intro', premise=instr__premise) %}
|
||||
{%- set instr__intro = "*"+render_and_request(tmpl__scenario_intro)+"*" -%}
|
||||
|
||||
{#- generate win conditions -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]", as_scene_message=True) -%}
|
||||
{%- set tmpl__win_conditions = render_template('generate-win-conditions', premise=instr__premise) %}
|
||||
{%- set instr__win_conditions = render_and_request(tmpl__win_conditions) -%}
|
||||
|
||||
{#- emit status update to the UX -#}
|
||||
{%- set status = emit_status("info", "Scenario ready.") -%}
|
||||
{%- set status = emit_status("success", "Scenario ready.", as_scene_message=True) -%}
|
||||
|
||||
{# set gamestate variables #}
|
||||
{%- set _ = game_state.set_var("instr.premise", instr__premise, commit=True) -%}
|
||||
|
||||
|
After Width: | Height: | Size: 1.7 MiB |
535
scenes/simulation-suite/game.py
Normal file
@@ -0,0 +1,535 @@
|
||||
|
||||
def game(TM):
|
||||
|
||||
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
|
||||
|
||||
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
|
||||
|
||||
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice."
|
||||
|
||||
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
AUTO_NARRATE_INTERVAL = 10
|
||||
|
||||
def parse_sim_call_arguments(call:str) -> str:
|
||||
"""
|
||||
Returns the value between the parentheses of a simulation call
|
||||
|
||||
Example:
|
||||
|
||||
call = 'change_environment("a house")'
|
||||
|
||||
parse_sim_call_arguments(call) -> "a house"
|
||||
"""
|
||||
|
||||
try:
|
||||
return call.split("(", 1)[1].split(")")[0]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
class SimulationSuite:
|
||||
|
||||
def __init__(self):
|
||||
# do we update the world state at the end of the round
|
||||
self.update_world_state = False
|
||||
|
||||
self.simulation_reset = False
|
||||
|
||||
self.added_npcs = []
|
||||
|
||||
TM.log.debug("SIMULATION SUITE INIT...")
|
||||
|
||||
self.player_character = TM.scene.get_player_character()
|
||||
self.player_message = TM.scene.last_player_message()
|
||||
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
|
||||
self.player_message_is_instruction = (
|
||||
self.player_message and
|
||||
self.player_message.raw.lower().startswith("computer") and
|
||||
not self.player_message.hidden and
|
||||
not self.last_processed_call > self.player_message.id
|
||||
)
|
||||
|
||||
|
||||
def run(self):
|
||||
if not TM.game_state.has_var("instr.simulation_stopped"):
|
||||
self.simulation()
|
||||
|
||||
self.finalize_round()
|
||||
|
||||
def simulation(self):
|
||||
|
||||
if not TM.game_state.has_var("instr.simulation_started"):
|
||||
self.startup()
|
||||
else:
|
||||
self.simulation_calls()
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state(force=True)
|
||||
|
||||
|
||||
def startup(self):
|
||||
TM.emit_status("busy", "Simulation suite powering up.", as_scene_message=True)
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_STARTUP,
|
||||
emit_message=False
|
||||
)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="passthrough",
|
||||
narration=MSG_HELP
|
||||
)
|
||||
TM.agents.world_state.manager(
|
||||
action_name="save_world_entry",
|
||||
entry_id="sim.quarantined",
|
||||
text=CTX_PIN_UNAWARE,
|
||||
meta={},
|
||||
pin=True
|
||||
)
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
TM.emit_status("success", "Simulation suite ready", as_scene_message=True)
|
||||
self.update_world_state = True
|
||||
|
||||
def simulation_calls(self):
|
||||
"""
|
||||
Calls the simulation suite main prompt to determine the appropriate
|
||||
simulation calls
|
||||
"""
|
||||
|
||||
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
|
||||
return
|
||||
|
||||
# First instruction?
|
||||
if not TM.game_state.has_var("instr.has_issued_instructions"):
|
||||
|
||||
# determine the context of the simulation
|
||||
|
||||
context_context = TM.agents.creator.determine_content_context_for_description(
|
||||
description=self.player_message.raw,
|
||||
)
|
||||
TM.scene.set_content_context(context_context)
|
||||
|
||||
|
||||
calls = TM.client.render_and_request(
|
||||
"computer",
|
||||
dedupe_enabled=False,
|
||||
player_instruction=self.player_message.raw,
|
||||
scene=TM.scene,
|
||||
)
|
||||
|
||||
self.calls = calls = calls.split("\n")
|
||||
|
||||
calls = self.prepare_calls(calls)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
|
||||
|
||||
# calls that are processed
|
||||
processed = []
|
||||
|
||||
for call in calls:
|
||||
processed_call = self.process_call(call)
|
||||
if processed_call:
|
||||
processed.append(processed_call)
|
||||
|
||||
|
||||
if processed:
|
||||
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
|
||||
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
|
||||
|
||||
TM.emit_status("busy", "Simulation suite altering environment.", as_scene_message=True)
|
||||
compiled = "\n".join(processed)
|
||||
if not self.simulation_reset and compiled:
|
||||
narration = TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
# on the first narration we update the scene description and remove any mention of the computer
|
||||
# or the simulation from the previous narration
|
||||
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
|
||||
if not is_initial_narration:
|
||||
TM.scene.set_description(str(narration))
|
||||
TM.scene.set_intro(str(narration))
|
||||
TM.log.debug("SIMULATION SUITE: initial narration", intro=str(narration))
|
||||
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
|
||||
TM.scene.pop_history(typ="director", all=True, reverse=True)
|
||||
TM.game_state.set_var("instr.intro_narration", True, commit=False)
|
||||
|
||||
self.update_world_state = True
|
||||
|
||||
self.set_simulation_title(compiled)
|
||||
|
||||
def set_simulation_title(self, compiled_calls):
|
||||
|
||||
"""
|
||||
Generates a fitting title for the simulation based on the user's instructions
|
||||
"""
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
|
||||
|
||||
if not compiled_calls:
|
||||
return
|
||||
|
||||
if TM.scene.title != "Simulation Suite":
|
||||
# name already changed, no need to do it again
|
||||
return
|
||||
|
||||
title = TM.agents.creator.contextual_generate_from_args(
|
||||
"scene:simulation title",
|
||||
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
|
||||
length=75
|
||||
)
|
||||
|
||||
title = title.strip('"').strip()
|
||||
|
||||
TM.scene.set_title(title)
|
||||
|
||||
def prepare_calls(self, calls):
|
||||
"""
|
||||
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
|
||||
found, ensure that the `set_player_name` call is processed first by moving it in front of the
|
||||
`set_player_persona` call.
|
||||
"""
|
||||
|
||||
set_player_name_call_exists = -1
|
||||
set_player_persona_call_exists = -1
|
||||
|
||||
i = 0
|
||||
for call in calls:
|
||||
if "set_player_name" in call:
|
||||
set_player_name_call_exists = i
|
||||
elif "set_player_persona" in call:
|
||||
set_player_persona_call_exists = i
|
||||
i = i + 1
|
||||
|
||||
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
|
||||
|
||||
if set_player_name_call_exists > set_player_persona_call_exists:
|
||||
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
|
||||
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
|
||||
|
||||
return calls
|
||||
|
||||
def process_call(self, call:str) -> str:
|
||||
"""
|
||||
Processes a simulation call
|
||||
|
||||
Simulation alls are pseudo functions that are called by the simulation suite
|
||||
|
||||
We grab the function name by splitting against ( and taking the first element
|
||||
if the SimulationSuite has a method with the name _call_{function_name} then we call it
|
||||
|
||||
if a function name could be found but we do not have a method to call we dont do anything
|
||||
but we still return it as procssed as the AI can still interpret it as something later on
|
||||
"""
|
||||
|
||||
if "(" not in call:
|
||||
return None
|
||||
|
||||
function_name = call.split("(")[0]
|
||||
|
||||
if hasattr(self, f"call_{function_name}"):
|
||||
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
|
||||
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
return getattr(self, f"call_{function_name}")(call, inject)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_set_simulation_goal(self, call:str, inject:str) -> str:
|
||||
"""
|
||||
Set's the simulation goal as a permanent pin
|
||||
"""
|
||||
TM.emit_status("busy", "Simulation suite setting goal.", as_scene_message=True)
|
||||
TM.agents.world_state.manager(
|
||||
action_name="save_world_entry",
|
||||
entry_id="sim.goal",
|
||||
text=self.player_message.raw,
|
||||
meta={},
|
||||
pin=True
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer sets the goal for the simulation.",
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def call_change_environment(self, call:str, inject:str) -> str:
|
||||
"""
|
||||
Simulation changes the environment, this is entirely interpreted by the AI
|
||||
and we dont need to do any logic on our end, so we just return the call
|
||||
"""
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer changes the environment of the simulation."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_answer_question(self, call:str, inject:str) -> str:
|
||||
"""
|
||||
The player asked the simulation a query, we need to process this and have
|
||||
the AI produce an answer
|
||||
"""
|
||||
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
|
||||
def call_set_player_persona(self, call:str, inject:str) -> str:
|
||||
|
||||
"""
|
||||
The simulation suite is altering the player persona
|
||||
"""
|
||||
|
||||
TM.emit_status("busy", "Simulation suite altering user persona.", as_scene_message=True)
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=self.player_character.name, text=inject, alteration_instructions=self.player_message.raw
|
||||
)
|
||||
self.player_character.update(base_attributes=character_attributes)
|
||||
|
||||
character_description = TM.agents.creator.determine_character_description(character=self.player_character)
|
||||
self.player_character.update(description=character_description)
|
||||
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer transforms the player persona."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_set_player_name(self, call:str, inject:str) -> str:
|
||||
|
||||
"""
|
||||
The simulation suite is altering the player name
|
||||
"""
|
||||
|
||||
TM.emit_status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
|
||||
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
|
||||
if character_name != self.player_character.name:
|
||||
self.player_character.rename(character_name)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer changes the player's identity to {character_name}."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_add_ai_character(self, call:str, inject:str) -> str:
|
||||
|
||||
# sometimes the AI will call this function an pass an inanimate object as the parameter
|
||||
# we need to determine if this is the case and just ignore it
|
||||
is_inanimate = TM.client.query_text_eval(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
|
||||
|
||||
if is_inanimate:
|
||||
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
|
||||
return
|
||||
|
||||
# sometimes the AI will ask if the function adds a group of characters, we need to
|
||||
# determine if this is the case
|
||||
adds_group = TM.client.query_text_eval(f"does the function `{call}` add MULTIPLE ai characters?", call)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
|
||||
|
||||
TM.emit_status("busy", "Simulation suite adding character.", as_scene_message=True)
|
||||
|
||||
if not adds_group:
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
|
||||
else:
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
|
||||
|
||||
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
|
||||
# the same character, if this happens we need to combine into a single add_ai_character call
|
||||
|
||||
has_change_ai_character_call = TM.client.query_text_eval(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
|
||||
|
||||
if has_change_ai_character_call:
|
||||
|
||||
combined_arg = TM.client.render_and_request(
|
||||
"combine-add-and-alter-ai-character",
|
||||
dedupe_enabled=False,
|
||||
calls="\n".join(self.calls),
|
||||
character_name=character_name,
|
||||
scene=TM.scene,
|
||||
).replace("COMBINED ARGUMENT:", "").strip()
|
||||
|
||||
call = f"add_ai_character({combined_arg})"
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
|
||||
TM.emit_status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
|
||||
|
||||
npc = TM.agents.director.persist_character(name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
|
||||
|
||||
self.added_npcs.append(npc.name)
|
||||
|
||||
TM.agents.world_state.manager(
|
||||
action_name="add_detail_reinforcement",
|
||||
character_name=npc.name,
|
||||
question="Goal",
|
||||
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
|
||||
interval=25,
|
||||
run_immediately=True
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
|
||||
|
||||
TM.agents.visual.generate_character_portrait(character_name=npc.name)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer adds {npc.name} to the simulation."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_remove_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.emit_status("busy", "Simulation suite removing character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names())
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
if npc:
|
||||
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
|
||||
TM.agents.world_state.manager(action_name="deactivate_character", character_name=npc.name)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer removes {npc.name} from the simulation."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def call_change_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.emit_status("busy", "Simulation suite altering character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names())
|
||||
|
||||
if character_name in self.added_npcs:
|
||||
# we dont want to change the character if it was just added
|
||||
return
|
||||
|
||||
character_name_after = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
if npc:
|
||||
TM.emit_status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
|
||||
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(name=npc.name, alteration_instructions=self.player_message.raw)
|
||||
|
||||
npc.update(base_attributes=character_attributes)
|
||||
character_description = TM.agents.creator.determine_character_description(character=npc)
|
||||
|
||||
npc.update(description=character_description)
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
|
||||
|
||||
if character_name_after != character_name:
|
||||
npc.rename(character_name_after)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer transforms {npc.name}."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def call_end_simulation(self, call:str, inject:str) -> str:
|
||||
|
||||
explicit_command = TM.client.query_text_eval("has the player explicitly asked to end the simulation?", self.player_message.raw)
|
||||
|
||||
if explicit_command:
|
||||
TM.emit_status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names())}. The player is also transformed back to their normal, non-descript persona as the form of {self.player_character.name} ceases to exist.",
|
||||
emit_message=True
|
||||
)
|
||||
TM.scene.restore()
|
||||
|
||||
self.simulation_reset = True
|
||||
|
||||
TM.game_state.unset_var("instr.has_issued_instructions")
|
||||
TM.game_state.unset_var("instr.lastprocessed_call")
|
||||
TM.game_state.unset_var("instr.simulation_started")
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer ends the simulation."
|
||||
)
|
||||
|
||||
def finalize_round(self):
|
||||
|
||||
# track rounds
|
||||
rounds = TM.game_state.get_var("instr.rounds", 0)
|
||||
|
||||
# increase rounds
|
||||
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
|
||||
|
||||
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state()
|
||||
|
||||
if self.player_message_is_instruction:
|
||||
self.player_message.hide()
|
||||
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
|
||||
TM.emit_status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
|
||||
|
||||
elif self.player_message and not has_issued_instructions:
|
||||
# simulation started, player message is NOT an instruction, and player has not given
|
||||
# any instructions
|
||||
self.guide_player()
|
||||
|
||||
elif self.player_message and not TM.scene.npc_character_names():
|
||||
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
|
||||
self.narrate_round()
|
||||
|
||||
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names() and has_issued_instructions:
|
||||
# every N rounds, narrate the round
|
||||
self.narrate_round()
|
||||
|
||||
def guide_player(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="paraphrase",
|
||||
narration=MSG_HELP,
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
def narrate_round(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_NARRATE_ROUND,
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
def run_update_world_state(self, force=False):
|
||||
TM.log.debug("SIMULATION SUITE: update world state", force=force)
|
||||
TM.emit_status("busy", "Simulation suite updating world state.", as_scene_message=True)
|
||||
TM.agents.world_state.update_world_state(force=force)
|
||||
TM.emit_status("success", "Simulation suite updated world state.", as_scene_message=True)
|
||||
|
||||
SimulationSuite().run()
|
||||
53
scenes/simulation-suite/simulation-suite.json
Normal file
@@ -0,0 +1,53 @@
|
||||
{
|
||||
"name": "Simulation Suite",
|
||||
"title": "Simulation Suite",
|
||||
"environment": "scene",
|
||||
"immutable_save": true,
|
||||
"restore_from": "simulation-suite.json",
|
||||
"experimental": true,
|
||||
"help": "Address the computer by starting your statements with 'Computer, ' followed by an instruction.\n\nExamples:\n'Computer, i would like to experience an adventure on a derelict space station'\n'Computer, add a horrific alien creature that is chasing me.'",
|
||||
"description": "",
|
||||
"intro": "*You have entered the simulation suite. No simulation is currently active and you are in a non-descript space with paneled walls surrounding you. The control panel next to you is pulsating with a green light, indicating readiness to receive a prompt to start the simulation.*",
|
||||
"archived_history": [],
|
||||
"history": [],
|
||||
"ts": "PT1S",
|
||||
"characters": [
|
||||
{
|
||||
"name": "You",
|
||||
"gender": "unknown",
|
||||
"color": "cornflowerblue",
|
||||
"base_attributes": {},
|
||||
"is_player": true
|
||||
}
|
||||
],
|
||||
"context": "a simulated experience",
|
||||
"game_state": {
|
||||
"ops":{
|
||||
"run_on_start": true,
|
||||
"always_direct": true
|
||||
},
|
||||
"variables": {}
|
||||
},
|
||||
"world_state": {
|
||||
"character_name_mappings": {
|
||||
"You": [
|
||||
"user",
|
||||
"player",
|
||||
"player character",
|
||||
"user character",
|
||||
"the user",
|
||||
"the player"
|
||||
]
|
||||
}
|
||||
},
|
||||
"assets": {
|
||||
"cover_image": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
"assets": {
|
||||
"4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103": {
|
||||
"id": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
"file_type": "png",
|
||||
"media_type": "image/png"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
<|SECTION:EXAMPLES|>
|
||||
combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "Sarah" into a single text string argument to be passed to a single `add_ai_character` function call.
|
||||
```
|
||||
set_simulation_goal("player experiences a rollercoaster ride")
|
||||
change_environment("theme park, riding a rollercoaster")
|
||||
set_player_persona("young female experiencing rollercoaster ride")
|
||||
set_player_name("Susanne")
|
||||
add_ai_character("a female friend of player named Sarah")
|
||||
change_ai_character("Sarah hates rollercoasters")
|
||||
```
|
||||
COMBINED ARGUMENT: "a female friend of player named Sarah, Sarah hates rollercoasters"
|
||||
|
||||
TASK: combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "George" into a single text string argument to be passed to a single `add_ai_character` function call.
|
||||
```
|
||||
change_environment("building on fire")
|
||||
change_ai_character("George is injured")
|
||||
add_ai_character("a firefighter named Stephen")
|
||||
change_ai_character("Stephen is afraid of heights")
|
||||
```
|
||||
COMBINED ARGUMENT: "a firefighter named Stephen, Stephen is afraid of heights"
|
||||
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
TASK: combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "{{ character_name }}" into a single text string argument to be passed to a single `add_ai_character` function call.
|
||||
```
|
||||
{{ calls }}
|
||||
```
|
||||
{{ set_prepared_response("COMBINED ARGUMENT:") }}
|
||||
132
scenes/simulation-suite/templates/computer.jinja2
Normal file
@@ -0,0 +1,132 @@
|
||||
<|SECTION:CONTEXT|>
|
||||
{% set scene_history=scene.context_history(budget=1024) %}
|
||||
{% for scene_context in scene_history -%}
|
||||
{{ loop.index }}. {{ scene_context }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:FUNCTIONS|>
|
||||
The player has instructed the computer to alter the current simulation.
|
||||
|
||||
You have access to the following functions, you can call as many as you want to fulfill the player's requests.
|
||||
|
||||
You must at least call one of the following functions:
|
||||
|
||||
- change_environment
|
||||
- add_ai_character
|
||||
- change_ai_character
|
||||
- remove_ai_character
|
||||
- set_player_persona
|
||||
- set_player_name
|
||||
- end_simulation
|
||||
- answer_question
|
||||
- set_simulation_goal
|
||||
|
||||
`add_ai_character` and `change_ai_character` are exclusive if they are targeting the same character.
|
||||
|
||||
Set the player persona at the beginning of a new simulation or if the player requests a change.
|
||||
|
||||
Only end the simulation if the player requests it explicitly.
|
||||
|
||||
Your response MUST ONLY CONTAIN the new simulation stack.
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:EXAMPLES|>
|
||||
Request: Computer, I want to be on a mountain top
|
||||
```simulation-stack
|
||||
change_environment("mountain top")
|
||||
set_player_persona("mountain climber")
|
||||
set_player_name("Hank")
|
||||
```
|
||||
|
||||
Request: Computer, I want to be more muscular and taller
|
||||
```simulation-stack
|
||||
set_player_persona("make player more muscular and taller")
|
||||
```
|
||||
|
||||
Request: Computer, the building should be on fire
|
||||
```simulation-stack
|
||||
change_environment("building on fire")
|
||||
```
|
||||
|
||||
Request: Computer, a rocket hits the building and George is now injured
|
||||
```simulation-stack
|
||||
change_environment("building on fire")
|
||||
change_ai_character("George is injured")
|
||||
```
|
||||
|
||||
Request: Computer, I want to experience a rollercoaster ride with a friend
|
||||
```simulation-stack
|
||||
set_simulation_goal("player experiences a rollercoaster ride")
|
||||
change_environment("theme park, riding a rollercoaster")
|
||||
set_player_persona("young female experiencing rollercoaster ride")
|
||||
set_player_name("Susanne")
|
||||
add_ai_character("a female friend of player named Sarah")
|
||||
```
|
||||
|
||||
Request: Computer, I want to experience the international space station, to experience the overview effect
|
||||
```simulation-stack
|
||||
set_simulation_goal("player experiences the overview effect")
|
||||
change_environment("international space station")
|
||||
set_player_persona("astronaut experiencing first trip to ISS")
|
||||
set_player_name("George")
|
||||
add_ai_character("astronaut named Henry")
|
||||
```
|
||||
|
||||
Request: Computer, remove the goblin and add an elven woman instead
|
||||
```simulation-stack
|
||||
remove_ai_character("goblin")
|
||||
add_ai_character("elven woman named Elune")
|
||||
```
|
||||
|
||||
Request: Computer, change the skiing instructor to be older.
|
||||
```simulation-stack
|
||||
change_ai_character("make skiing instructor older")
|
||||
```
|
||||
|
||||
Request: Computer, change my grandma to my grandpa
|
||||
```simulation-stack
|
||||
remove_ai_character("grandma")
|
||||
add_ai_character("grandpa named Steven")
|
||||
```
|
||||
|
||||
Request: Computer, remove the skiing instructor and add my friend instead.
|
||||
```simulation-stack
|
||||
remove_ai_character("skiing instructor")
|
||||
add_ai_character("player's friend named Tara")
|
||||
```
|
||||
|
||||
Request: Computer, replace the skiing instructor with my friend.
|
||||
```simulation-stack
|
||||
remove_ai_character("skiing instructor")
|
||||
add_ai_character("player's friend named Lisa")
|
||||
```
|
||||
|
||||
Request: Computer, I want to end the simulation
|
||||
```simulation-stack
|
||||
end_simulation("simulation ended")
|
||||
```
|
||||
|
||||
Request: Computer, shut down the simulation
|
||||
```simulation-stack
|
||||
end_simulation("simulation ended")
|
||||
```
|
||||
|
||||
Request: Computer, what do you know about the game of thrones?
|
||||
```simulation-stack
|
||||
answer_question("what do you know about the game of thrones?")
|
||||
```
|
||||
|
||||
Request: Computer, i want to be a wizard in a dark goblin infested dungeon in a fantasy world, looking for secret treasure and fighting goblins.
|
||||
```simulation-stack
|
||||
set_simulation_goal("player wants to find secret treasure and fight creatures")
|
||||
change_environment("dark dungeon in a fantasy world")
|
||||
set_player_persona("powerful wizard")
|
||||
set_player_name("Lanadel")
|
||||
add_ai_character("a goblin named Gobbo")
|
||||
```
|
||||
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
Respond with the simulation stack for the following request:
|
||||
|
||||
Request: {{ player_instruction }}
|
||||
{{ bot_token }}```simulation-stack
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.18.0"
|
||||
VERSION = "0.25.6"
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from .base import Agent
|
||||
from .creator import CreatorAgent
|
||||
from .conversation import ConversationAgent
|
||||
from .creator import CreatorAgent
|
||||
from .director import DirectorAgent
|
||||
from .editor import EditorAgent
|
||||
from .memory import ChromaDBMemoryAgent, MemoryAgent
|
||||
from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .editor import EditorAgent
|
||||
from .tts import TTSAgent
|
||||
from .visual import VisualAgent
|
||||
from .world_state import WorldStateAgent
|
||||
from .tts import TTSAgent
|
||||
@@ -1,24 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import re
|
||||
from abc import ABC
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
from blinker import signal
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.agents.context import ActiveAgent
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopStartEvent
|
||||
import talemate.emit.async_signals
|
||||
import dataclasses
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentAction",
|
||||
"AgentActionConditional",
|
||||
"AgentActionConfig",
|
||||
"AgentDetail",
|
||||
"AgentEmission",
|
||||
"set_processing",
|
||||
]
|
||||
|
||||
@@ -37,26 +43,41 @@ class AgentActionConfig(pydantic.BaseModel):
|
||||
scope: str = "global"
|
||||
choices: Union[list[dict[str, str]], None] = None
|
||||
note: Union[str, None] = None
|
||||
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
|
||||
class AgentActionConditional(pydantic.BaseModel):
|
||||
attribute: str
|
||||
value: Union[int, float, str, bool, None] = None
|
||||
|
||||
|
||||
class AgentAction(pydantic.BaseModel):
|
||||
enabled: bool = True
|
||||
label: str
|
||||
description: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
|
||||
|
||||
class AgentDetail(pydantic.BaseModel):
|
||||
value: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
icon: Union[str, None] = None
|
||||
color: str = "grey"
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
decorator that emits the agent status as processing while the function
|
||||
is running.
|
||||
|
||||
|
||||
Done via a try - final block to ensure the status is reset even if
|
||||
the function fails.
|
||||
"""
|
||||
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
with ActiveAgent(self, fn):
|
||||
try:
|
||||
@@ -69,9 +90,8 @@ def set_processing(fn):
|
||||
# not sure why this happens
|
||||
# some concurrency error?
|
||||
log.error("error emitting agent status", exc=exc)
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
|
||||
|
||||
wrapper.exposed = True
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -85,6 +105,9 @@ class Agent(ABC):
|
||||
set_processing = set_processing
|
||||
requires_llm_client = True
|
||||
auto_break_repetition = False
|
||||
websocket_handler = None
|
||||
essential = True
|
||||
ready_check_error = None
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
@@ -97,46 +120,51 @@ class Agent(ABC):
|
||||
def verbose_name(self):
|
||||
return self.agent_type.capitalize()
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if not getattr(self.client, "enabled", True):
|
||||
return False
|
||||
|
||||
|
||||
if self.client and self.client.current_status in ["error", "warning"]:
|
||||
return False
|
||||
|
||||
|
||||
return self.client is not None
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.ready:
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
return "idle" if getattr(self, "processing", 0) == 0 else "busy"
|
||||
else:
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
|
||||
if not self.ready:
|
||||
return "uninitialized"
|
||||
|
||||
if getattr(self, "processing", 0) > 0:
|
||||
return "busy"
|
||||
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg"
|
||||
|
||||
return "idle"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
# by default, agents are enabled, an agent class that
|
||||
# is disableable should override this property
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def disable(self):
|
||||
# by default, agents are enabled, an agent class that
|
||||
# is disableable should override this property to
|
||||
# is disableable should override this property to
|
||||
# disable the agent
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
# by default, agents do not have toggles to enable / disable
|
||||
# an agent class that is disableable should override this property
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
# by default, agents are not experimental, an agent class that
|
||||
@@ -153,100 +181,180 @@ class Agent(ABC):
|
||||
"requires_llm_client": cls.requires_llm_client,
|
||||
}
|
||||
actions = getattr(agent, "actions", None)
|
||||
|
||||
|
||||
if actions:
|
||||
config_options["actions"] = {k: v.model_dump() for k, v in actions.items()}
|
||||
else:
|
||||
config_options["actions"] = {}
|
||||
|
||||
|
||||
return config_options
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
@property
|
||||
def meta(self):
|
||||
return {
|
||||
"essential": self.essential,
|
||||
}
|
||||
|
||||
@property
|
||||
def sanitized_action_config(self):
|
||||
if not getattr(self, "actions", None):
|
||||
return {}
|
||||
|
||||
return {k: v.model_dump() for k, v in self.actions.items()}
|
||||
|
||||
async def _handle_ready_check(self, fut: asyncio.Future):
|
||||
callback_failure = getattr(self, "on_ready_check_failure", None)
|
||||
if fut.cancelled():
|
||||
if callback_failure:
|
||||
await callback_failure()
|
||||
return
|
||||
|
||||
if fut.exception():
|
||||
exc = fut.exception()
|
||||
self.ready_check_error = exc
|
||||
log.error("agent ready check error", agent=self.agent_type, exc=exc)
|
||||
if callback_failure:
|
||||
await callback_failure(exc)
|
||||
return
|
||||
|
||||
callback = getattr(self, "on_ready_check_success", None)
|
||||
if callback:
|
||||
await callback()
|
||||
|
||||
async def setup_check(self):
|
||||
return False
|
||||
|
||||
async def ready_check(self, task: asyncio.Task = None):
|
||||
self.ready_check_error = None
|
||||
if task:
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_ready_check(fut))
|
||||
)
|
||||
return
|
||||
return True
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
if self.has_toggle and "enabled" in kwargs:
|
||||
self.is_enabled = kwargs.get("enabled", False)
|
||||
|
||||
|
||||
if not getattr(self, "actions", None):
|
||||
return
|
||||
|
||||
|
||||
for action_key, action in self.actions.items():
|
||||
|
||||
if not kwargs.get("actions"):
|
||||
continue
|
||||
|
||||
action.enabled = kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
|
||||
|
||||
|
||||
action.enabled = (
|
||||
kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
|
||||
)
|
||||
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
|
||||
for config_key, config in action.config.items():
|
||||
try:
|
||||
config.value = kwargs.get("actions", {}).get(action_key, {}).get("config", {}).get(config_key, {}).get("value", config.value)
|
||||
config.value = (
|
||||
kwargs.get("actions", {})
|
||||
.get(action_key, {})
|
||||
.get("config", {})
|
||||
.get(config_key, {})
|
||||
.get("value", config.value)
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
async def on_game_loop_start(self, event:GameLoopStartEvent):
|
||||
|
||||
|
||||
async def on_game_loop_start(self, event: GameLoopStartEvent):
|
||||
"""
|
||||
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values
|
||||
"""
|
||||
|
||||
|
||||
if not getattr(self, "actions", None):
|
||||
return
|
||||
|
||||
|
||||
for _, action in self.actions.items():
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
|
||||
for _, config in action.config.items():
|
||||
if config.scope == "scene":
|
||||
# if default_value is None, just use the `type` of the current
|
||||
# if default_value is None, just use the `type` of the current
|
||||
# value
|
||||
if config.default_value is None:
|
||||
default_value = type(config.value)()
|
||||
else:
|
||||
default_value = config.default_value
|
||||
|
||||
log.debug("resetting config", config=config, default_value=default_value)
|
||||
|
||||
log.debug(
|
||||
"resetting config", config=config, default_value=default_value
|
||||
)
|
||||
config.value = default_value
|
||||
|
||||
|
||||
await self.emit_status()
|
||||
|
||||
|
||||
async def emit_status(self, processing: bool = None):
|
||||
|
||||
# should keep a count of processing requests, and when the
|
||||
# number is 0 status is "idle", if the number is greater than 0
|
||||
# status is "busy"
|
||||
#
|
||||
# increase / decrease based on value of `processing`
|
||||
|
||||
|
||||
if getattr(self, "processing", None) is None:
|
||||
self.processing = 0
|
||||
|
||||
if not processing:
|
||||
|
||||
if processing is False:
|
||||
self.processing -= 1
|
||||
self.processing = max(0, self.processing)
|
||||
else:
|
||||
elif processing is True:
|
||||
self.processing += 1
|
||||
|
||||
status = "busy" if self.processing > 0 else "idle"
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
|
||||
|
||||
emit(
|
||||
"agent_status",
|
||||
message=self.verbose_name or "",
|
||||
id=self.agent_type,
|
||||
status=status,
|
||||
status=self.status,
|
||||
details=self.agent_details,
|
||||
meta=self.meta,
|
||||
data=self.config_options(agent=self),
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _handle_background_processing(self, fut: asyncio.Future):
|
||||
try:
|
||||
if fut.cancelled():
|
||||
return
|
||||
|
||||
if fut.exception():
|
||||
log.error(
|
||||
"background processing error",
|
||||
agent=self.agent_type,
|
||||
exc=fut.exception(),
|
||||
)
|
||||
await self.emit_status()
|
||||
return
|
||||
|
||||
log.info("background processing done", agent=self.agent_type)
|
||||
finally:
|
||||
self.processing_bg -= 1
|
||||
await self.emit_status()
|
||||
|
||||
async def set_background_processing(self, task: asyncio.Task):
|
||||
log.info("set_background_processing", agent=self.agent_type)
|
||||
if not hasattr(self, "processing_bg"):
|
||||
self.processing_bg = 0
|
||||
|
||||
self.processing_bg += 1
|
||||
|
||||
await self.emit_status()
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_background_processing(fut))
|
||||
)
|
||||
|
||||
def connect(self, scene):
|
||||
self.scene = scene
|
||||
talemate.emit.async_signals.get("game_loop_start").connect(self.on_game_loop_start)
|
||||
|
||||
talemate.emit.async_signals.get("game_loop_start").connect(
|
||||
self.on_game_loop_start
|
||||
)
|
||||
|
||||
def clean_result(self, result):
|
||||
if "#" in result:
|
||||
@@ -291,23 +399,28 @@ class Agent(ABC):
|
||||
|
||||
current_memory_context.append(memory)
|
||||
return current_memory_context
|
||||
|
||||
|
||||
# LLM client related methods. These are called during or after the client
|
||||
# sends the prompt to the API.
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param:dict, kind:str, agent_function_name:str):
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
"""
|
||||
Injects prompt parameters before the client sends off the prompt
|
||||
Override as needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def allow_repetition_break(self, kind:str, agent_function_name:str, auto:bool=False):
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
"""
|
||||
Returns True if repetition breaking is allowed, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentEmission:
|
||||
agent: Agent
|
||||
agent: Agent
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""
|
||||
Code has been moved.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
import contextvars
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
@@ -9,25 +9,38 @@ __all__ = [
|
||||
|
||||
active_agent = contextvars.ContextVar("active_agent", default=None)
|
||||
|
||||
|
||||
class ActiveAgentContext(pydantic.BaseModel):
|
||||
agent: object
|
||||
fn: Callable
|
||||
|
||||
agent_stack: list = pydantic.Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed=True
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def action(self):
|
||||
return self.fn.__name__
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.agent.verbose_name}.{self.action}"
|
||||
|
||||
|
||||
class ActiveAgent:
|
||||
|
||||
def __init__(self, agent, fn):
|
||||
self.agent = ActiveAgentContext(agent=agent, fn=fn)
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
previous_agent = active_agent.get()
|
||||
|
||||
if previous_agent:
|
||||
self.agent.agent_stack = previous_agent.agent_stack + [str(self.agent)]
|
||||
else:
|
||||
self.agent.agent_stack = [str(self.agent)]
|
||||
|
||||
self.token = active_agent.set(self.agent)
|
||||
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
active_agent.reset(self.token)
|
||||
return False
|
||||
|
||||
@@ -1,40 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.client as client
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
import structlog
|
||||
from talemate.client.context import (
|
||||
client_context_attribute,
|
||||
set_client_context_attribute,
|
||||
set_conversation_context_attribute,
|
||||
)
|
||||
from talemate.emit import emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.client.context import set_conversation_context_attribute, client_context_attribute, set_client_context_attribute
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
|
||||
from .base import Agent, AgentEmission, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
AgentEmission,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene, Actor
|
||||
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agents.conversation")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConversationAgentEmission(AgentEmission):
|
||||
actor: Actor
|
||||
character: Character
|
||||
generation: list[str]
|
||||
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.conversation.before_generate",
|
||||
"agent.conversation.generated"
|
||||
"agent.conversation.before_generate", "agent.conversation.generated"
|
||||
)
|
||||
|
||||
|
||||
@register()
|
||||
class ConversationAgent(Agent):
|
||||
"""
|
||||
@@ -45,7 +61,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
agent_type = "conversation"
|
||||
verbose_name = "Conversation"
|
||||
|
||||
|
||||
min_dialogue_length = 75
|
||||
|
||||
def __init__(
|
||||
@@ -60,28 +76,37 @@ class ConversationAgent(Agent):
|
||||
self.logging_enabled = logging_enabled
|
||||
self.logging_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
self.current_memory_context = None
|
||||
|
||||
|
||||
# several agents extend this class, but we only want to initialize
|
||||
# these actions for the conversation agent
|
||||
|
||||
|
||||
if self.agent_type != "conversation":
|
||||
return
|
||||
|
||||
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled = True,
|
||||
label = "Generation Override",
|
||||
description = "Override generation parameters",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Generation Settings",
|
||||
config={
|
||||
"format": AgentActionConfig(
|
||||
type="text",
|
||||
label="Format",
|
||||
description="The generation format of the scene context, as seen by the AI.",
|
||||
choices=[
|
||||
{"label": "Screenplay", "value": "movie_script"},
|
||||
{"label": "Chat (legacy)", "value": "chat"},
|
||||
],
|
||||
value="movie_script",
|
||||
),
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
label="Generation Length (tokens)",
|
||||
description="Maximum number of tokens to generate for a conversation response.",
|
||||
value=96,
|
||||
value=128,
|
||||
min=32,
|
||||
max=512,
|
||||
step=32,
|
||||
),#
|
||||
), #
|
||||
"instructions": AgentActionConfig(
|
||||
type="text",
|
||||
label="Instructions",
|
||||
@@ -96,24 +121,24 @@ class ConversationAgent(Agent):
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled = True,
|
||||
label = "Auto Break Repetition",
|
||||
description = "Will attempt to automatically break AI repetition.",
|
||||
enabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"natural_flow": AgentAction(
|
||||
enabled = True,
|
||||
label = "Natural Flow",
|
||||
description = "Will attempt to generate a more natural flow of conversation between multiple characters.",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Natural Flow",
|
||||
description="Will attempt to generate a more natural flow of conversation between multiple characters.",
|
||||
config={
|
||||
"max_auto_turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Max. Auto Turns",
|
||||
description="The maximum number of turns the AI is allowed to generate before it stops and waits for the player to respond.",
|
||||
value=4,
|
||||
value=4,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
@@ -122,72 +147,114 @@ class ConversationAgent(Agent):
|
||||
type="number",
|
||||
label="Max. Idle Turns",
|
||||
description="The maximum number of turns a character can go without speaking before they are considered overdue to speak.",
|
||||
value=8,
|
||||
value=8,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
"use_long_term_memory": AgentAction(
|
||||
enabled = True,
|
||||
label = "Long Term Memory",
|
||||
description = "Will augment the conversation prompt with long term memory.",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Long Term Memory",
|
||||
description="Will augment the conversation prompt with long term memory.",
|
||||
config={
|
||||
"retrieval_method": AgentActionConfig(
|
||||
type="text",
|
||||
label="Context Retrieval Method",
|
||||
description="How relevant context is retrieved from the long term memory.",
|
||||
value="direct",
|
||||
choices=[
|
||||
{"label": "Context queries based on recent dialogue (fast)", "value": "direct"},
|
||||
{"label": "Context queries generated by AI", "value": "queries"},
|
||||
{"label": "AI compiled question and answers (slow)", "value": "questions"},
|
||||
]
|
||||
{
|
||||
"label": "Context queries based on recent dialogue (fast)",
|
||||
"value": "direct",
|
||||
},
|
||||
{
|
||||
"label": "Context queries generated by AI",
|
||||
"value": "queries",
|
||||
},
|
||||
{
|
||||
"label": "AI compiled question and answers (slow)",
|
||||
"value": "questions",
|
||||
},
|
||||
],
|
||||
),
|
||||
}
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def conversation_format(self):
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["format"].value
|
||||
return "movie_script"
|
||||
|
||||
@property
|
||||
def conversation_format_label(self):
|
||||
value = self.conversation_format
|
||||
|
||||
choices = self.actions["generation_override"].config["format"].choices
|
||||
|
||||
for choice in choices:
|
||||
if choice["value"] == value:
|
||||
return choice["label"]
|
||||
|
||||
return value
|
||||
|
||||
@property
|
||||
def agent_details(self) -> dict:
|
||||
|
||||
details = {
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.client.name if self.client else None,
|
||||
description="The client to use for prompt generation",
|
||||
).model_dump(),
|
||||
"format": AgentDetail(
|
||||
icon="mdi-format-float-none",
|
||||
value=self.conversation_format_label,
|
||||
description="Generation format of the scene context, as seen by the AI",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
return details
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
def last_spoken(self):
|
||||
|
||||
"""
|
||||
Returns the last time each character spoke
|
||||
"""
|
||||
|
||||
|
||||
last_turn = {}
|
||||
turns = 0
|
||||
character_names = self.scene.character_names
|
||||
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
|
||||
|
||||
|
||||
for idx in range(len(self.scene.history) - 1, -1, -1):
|
||||
|
||||
if isinstance(self.scene.history[idx], CharacterMessage):
|
||||
|
||||
if turns >= max_idle_turns:
|
||||
break
|
||||
|
||||
|
||||
character = self.scene.history[idx].character_name
|
||||
|
||||
|
||||
if character in character_names:
|
||||
last_turn[character] = turns
|
||||
character_names.remove(character)
|
||||
|
||||
|
||||
if not character_names:
|
||||
break
|
||||
|
||||
|
||||
turns += 1
|
||||
|
||||
|
||||
if character_names and turns >= max_idle_turns:
|
||||
for character in character_names:
|
||||
last_turn[character] = max_idle_turns
|
||||
last_turn[character] = max_idle_turns
|
||||
|
||||
return last_turn
|
||||
|
||||
|
||||
def repeated_speaker(self):
|
||||
"""
|
||||
Counts the amount of times the most recent speaker has spoken in a row
|
||||
@@ -203,125 +270,164 @@ class ConversationAgent(Agent):
|
||||
else:
|
||||
break
|
||||
return count
|
||||
|
||||
async def on_game_loop(self, event:GameLoopEvent):
|
||||
|
||||
async def on_game_loop(self, event: GameLoopEvent):
|
||||
await self.apply_natural_flow()
|
||||
|
||||
|
||||
async def apply_natural_flow(self, force: bool = False, npcs_only: bool = False):
|
||||
"""
|
||||
If the natural flow action is enabled, this will attempt to determine
|
||||
the ideal character to talk next.
|
||||
|
||||
|
||||
This will let the AI pick a character to talk to, but if the AI can't figure
|
||||
it out it will apply rules based on max_idle_turns and max_auto_turns.
|
||||
|
||||
|
||||
If all fails it will just pick a random character.
|
||||
|
||||
|
||||
Repetition is also taken into account, so if a character has spoken twice in a row
|
||||
they will not be picked again until someone else has spoken.
|
||||
"""
|
||||
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not scene.auto_progress and not force:
|
||||
# we only apply natural flow if auto_progress is enabled
|
||||
return
|
||||
|
||||
|
||||
if self.actions["natural_flow"].enabled and len(scene.character_names) > 2:
|
||||
|
||||
# last time each character spoke (turns ago)
|
||||
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
|
||||
max_auto_turns = self.actions["natural_flow"].config["max_auto_turns"].value
|
||||
last_turn = self.last_spoken()
|
||||
player_name = scene.get_player_character().name
|
||||
last_turn_player = last_turn.get(player_name, 0)
|
||||
|
||||
|
||||
if last_turn_player >= max_auto_turns and not npcs_only:
|
||||
self.scene.next_actor = scene.get_player_character().name
|
||||
log.debug("conversation_agent.natural_flow", next_actor="player", overdue=True, player_character=scene.get_player_character().name)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="player",
|
||||
overdue=True,
|
||||
player_character=scene.get_player_character().name,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
log.debug("conversation_agent.natural_flow", last_turn=last_turn)
|
||||
|
||||
|
||||
# determine random character to talk, this will be the fallback in case
|
||||
# the AI can't figure out who should talk next
|
||||
|
||||
|
||||
if scene.prev_actor:
|
||||
|
||||
# we dont want to talk to the same person twice in a row
|
||||
character_names = scene.character_names
|
||||
character_names.remove(scene.prev_actor)
|
||||
|
||||
|
||||
if npcs_only:
|
||||
character_names = [c for c in character_names if c != player_name]
|
||||
|
||||
|
||||
random_character_name = random.choice(character_names)
|
||||
else:
|
||||
character_names = scene.character_names
|
||||
character_names = scene.character_names
|
||||
# no one has talked yet, so we just pick a random character
|
||||
|
||||
|
||||
if npcs_only:
|
||||
character_names = [c for c in character_names if c != player_name]
|
||||
|
||||
|
||||
random_character_name = random.choice(scene.character_names)
|
||||
|
||||
overdue_characters = [character for character, turn in last_turn.items() if turn >= max_idle_turns]
|
||||
|
||||
|
||||
overdue_characters = [
|
||||
character
|
||||
for character, turn in last_turn.items()
|
||||
if turn >= max_idle_turns
|
||||
]
|
||||
|
||||
if npcs_only:
|
||||
overdue_characters = [c for c in overdue_characters if c != player_name]
|
||||
|
||||
|
||||
if overdue_characters and self.scene.history:
|
||||
# Pick a random character from the overdue characters
|
||||
scene.next_actor = random.choice(overdue_characters)
|
||||
elif scene.history:
|
||||
scene.next_actor = None
|
||||
|
||||
|
||||
# AI will attempt to figure out who should talk next
|
||||
next_actor = await self.select_talking_actor(character_names)
|
||||
next_actor = next_actor.strip().strip('"').strip(".")
|
||||
|
||||
next_actor = next_actor.split("\n")[0].strip().strip('"').strip(".")
|
||||
|
||||
for character_name in scene.character_names:
|
||||
if next_actor.lower() in character_name.lower() or character_name.lower() in next_actor.lower():
|
||||
if (
|
||||
next_actor.lower() in character_name.lower()
|
||||
or character_name.lower() in next_actor.lower()
|
||||
):
|
||||
scene.next_actor = character_name
|
||||
break
|
||||
|
||||
|
||||
if not scene.next_actor:
|
||||
# AI couldn't figure out who should talk next, so we just pick a random character
|
||||
log.debug("conversation_agent.natural_flow", next_actor="random", random_character_name=random_character_name)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="random",
|
||||
random_character_name=random_character_name,
|
||||
)
|
||||
scene.next_actor = random_character_name
|
||||
else:
|
||||
log.debug("conversation_agent.natural_flow", next_actor="picked", ai_next_actor=scene.next_actor)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="picked",
|
||||
ai_next_actor=scene.next_actor,
|
||||
)
|
||||
else:
|
||||
# always start with main character (TODO: configurable?)
|
||||
player_character = scene.get_player_character()
|
||||
log.debug("conversation_agent.natural_flow", next_actor="main_character", main_character=player_character)
|
||||
scene.next_actor = player_character.name if player_character else random_character_name
|
||||
|
||||
scene.log.debug("conversation_agent.natural_flow", next_actor=scene.next_actor)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="main_character",
|
||||
main_character=player_character,
|
||||
)
|
||||
scene.next_actor = (
|
||||
player_character.name if player_character else random_character_name
|
||||
)
|
||||
|
||||
scene.log.debug(
|
||||
"conversation_agent.natural_flow", next_actor=scene.next_actor
|
||||
)
|
||||
|
||||
|
||||
# same character cannot go thrice in a row, if this is happening, pick a random character that
|
||||
# isnt the same as the last character
|
||||
|
||||
if self.repeated_speaker() >= 2 and self.scene.prev_actor == self.scene.next_actor:
|
||||
scene.next_actor = random.choice([c for c in scene.character_names if c != scene.prev_actor])
|
||||
scene.log.debug("conversation_agent.natural_flow", next_actor="random (repeated safeguard)", random_character_name=scene.next_actor)
|
||||
|
||||
|
||||
if (
|
||||
self.repeated_speaker() >= 2
|
||||
and self.scene.prev_actor == self.scene.next_actor
|
||||
):
|
||||
scene.next_actor = random.choice(
|
||||
[c for c in scene.character_names if c != scene.prev_actor]
|
||||
)
|
||||
scene.log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="random (repeated safeguard)",
|
||||
random_character_name=scene.next_actor,
|
||||
)
|
||||
|
||||
else:
|
||||
scene.next_actor = None
|
||||
|
||||
|
||||
@set_processing
|
||||
async def select_talking_actor(self, character_names: list[str]=None):
|
||||
result = await Prompt.request("conversation.select-talking-actor", self.client, "conversation_select_talking_actor", vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_names": character_names or self.scene.character_names,
|
||||
"character_names_formatted": ", ".join(character_names or self.scene.character_names),
|
||||
})
|
||||
|
||||
async def select_talking_actor(self, character_names: list[str] = None):
|
||||
result = await Prompt.request(
|
||||
"conversation.select-talking-actor",
|
||||
self.client,
|
||||
"conversation_select_talking_actor",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_names": character_names or self.scene.character_names,
|
||||
"character_names_formatted": ", ".join(
|
||||
character_names or self.scene.character_names
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def build_prompt_default(
|
||||
self,
|
||||
@@ -335,17 +441,17 @@ class ConversationAgent(Agent):
|
||||
# we subtract 200 to account for the response
|
||||
|
||||
scene = character.actor.scene
|
||||
|
||||
|
||||
total_token_budget = self.client.max_token_length - 200
|
||||
scene_and_dialogue_budget = total_token_budget - 500
|
||||
long_term_memory_budget = min(int(total_token_budget * 0.05), 200)
|
||||
|
||||
|
||||
scene_and_dialogue = scene.context_history(
|
||||
budget=scene_and_dialogue_budget,
|
||||
budget=scene_and_dialogue_budget,
|
||||
keep_director=True,
|
||||
sections=False,
|
||||
)
|
||||
|
||||
|
||||
memory = await self.build_prompt_default_memory(character)
|
||||
|
||||
main_character = scene.main_character.character
|
||||
@@ -360,36 +466,41 @@ class ConversationAgent(Agent):
|
||||
)
|
||||
else:
|
||||
formatted_names = character_names[0] if character_names else ""
|
||||
|
||||
|
||||
try:
|
||||
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
|
||||
except IndexError:
|
||||
director_message = False
|
||||
|
||||
|
||||
extra_instructions = ""
|
||||
if self.actions["generation_override"].enabled:
|
||||
extra_instructions = self.actions["generation_override"].config["instructions"].value
|
||||
extra_instructions = (
|
||||
self.actions["generation_override"].config["instructions"].value
|
||||
)
|
||||
|
||||
conversation_format = self.conversation_format
|
||||
prompt = Prompt.get(
|
||||
f"conversation.dialogue-{conversation_format}",
|
||||
vars={
|
||||
"scene": scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene_and_dialogue_budget": scene_and_dialogue_budget,
|
||||
"scene_and_dialogue": scene_and_dialogue,
|
||||
"memory": memory,
|
||||
"characters": list(scene.get_characters()),
|
||||
"main_character": main_character,
|
||||
"formatted_names": formatted_names,
|
||||
"talking_character": character,
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": extra_instructions,
|
||||
"decensor": self.client.decensor_enabled,
|
||||
},
|
||||
)
|
||||
|
||||
prompt = Prompt.get("conversation.dialogue", vars={
|
||||
"scene": scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene_and_dialogue_budget": scene_and_dialogue_budget,
|
||||
"scene_and_dialogue": scene_and_dialogue,
|
||||
"memory": memory,
|
||||
"characters": list(scene.get_characters()),
|
||||
"main_character": main_character,
|
||||
"formatted_names": formatted_names,
|
||||
"talking_character": character,
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": extra_instructions,
|
||||
})
|
||||
|
||||
return str(prompt)
|
||||
|
||||
async def build_prompt_default_memory(
|
||||
self, character: Character
|
||||
):
|
||||
|
||||
async def build_prompt_default_memory(self, character: Character):
|
||||
"""
|
||||
Builds long term memory for the conversation prompt
|
||||
|
||||
@@ -404,39 +515,56 @@ class ConversationAgent(Agent):
|
||||
if not self.actions["use_long_term_memory"].enabled:
|
||||
return []
|
||||
|
||||
|
||||
if self.current_memory_context:
|
||||
return self.current_memory_context
|
||||
|
||||
self.current_memory_context = ""
|
||||
retrieval_method = self.actions["use_long_term_memory"].config["retrieval_method"].value
|
||||
|
||||
retrieval_method = (
|
||||
self.actions["use_long_term_memory"].config["retrieval_method"].value
|
||||
)
|
||||
|
||||
if retrieval_method != "direct":
|
||||
|
||||
world_state = instance.get_agent("world_state")
|
||||
history = self.scene.context_history(min_dialogue=3, max_dialogue=15, keep_director=False, sections=False, add_archieved_history=False)
|
||||
history = self.scene.context_history(
|
||||
min_dialogue=3,
|
||||
max_dialogue=15,
|
||||
keep_director=False,
|
||||
sections=False,
|
||||
add_archieved_history=False,
|
||||
)
|
||||
text = "\n".join(history)
|
||||
log.debug("conversation_agent.build_prompt_default_memory", direct=False, version=retrieval_method)
|
||||
|
||||
log.debug(
|
||||
"conversation_agent.build_prompt_default_memory",
|
||||
direct=False,
|
||||
version=retrieval_method,
|
||||
)
|
||||
|
||||
if retrieval_method == "questions":
|
||||
self.current_memory_context = (await world_state.analyze_text_and_extract_context(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
)).split("\n")
|
||||
self.current_memory_context = (
|
||||
await world_state.analyze_text_and_extract_context(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
)
|
||||
).split("\n")
|
||||
elif retrieval_method == "queries":
|
||||
self.current_memory_context = await world_state.analyze_text_and_extract_context_via_queries(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
self.current_memory_context = (
|
||||
await world_state.analyze_text_and_extract_context_via_queries(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
history = list(map(str, self.scene.collect_messages(max_iterations=3)))
|
||||
log.debug("conversation_agent.build_prompt_default_memory", history=history, direct=True)
|
||||
log.debug(
|
||||
"conversation_agent.build_prompt_default_memory",
|
||||
history=history,
|
||||
direct=True,
|
||||
)
|
||||
memory = instance.get_agent("memory")
|
||||
|
||||
|
||||
context = await memory.multi_query(history, max_tokens=500, iterate=5)
|
||||
|
||||
|
||||
self.current_memory_context = context
|
||||
|
||||
|
||||
return self.current_memory_context
|
||||
|
||||
async def build_prompt(self, character, char_message: str = ""):
|
||||
@@ -445,29 +573,37 @@ class ConversationAgent(Agent):
|
||||
return await fn(character, char_message=char_message)
|
||||
|
||||
def clean_result(self, result, character):
|
||||
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
|
||||
|
||||
if "(Internal" in result:
|
||||
result = result.split("(Internal")[0]
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
result = result.replace("[", "*").replace("]", "*")
|
||||
result = result.replace("(", "*").replace(")", "*")
|
||||
result = result.replace("**", "*")
|
||||
|
||||
result = util.handle_endofline_special_delimiter(result)
|
||||
|
||||
return result
|
||||
|
||||
def set_generation_overrides(self):
|
||||
if not self.actions["generation_override"].enabled:
|
||||
return
|
||||
|
||||
set_conversation_context_attribute("length", self.actions["generation_override"].config["length"].value)
|
||||
|
||||
|
||||
set_conversation_context_attribute(
|
||||
"length", self.actions["generation_override"].config["length"].value
|
||||
)
|
||||
|
||||
if self.actions["generation_override"].config["jiggle"].value > 0.0:
|
||||
nuke_repetition = client_context_attribute("nuke_repetition")
|
||||
if nuke_repetition == 0.0:
|
||||
# we only apply the agent override if some other mechanism isn't already
|
||||
# setting the nuke_repetition value
|
||||
nuke_repetition = self.actions["generation_override"].config["jiggle"].value
|
||||
nuke_repetition = (
|
||||
self.actions["generation_override"].config["jiggle"].value
|
||||
)
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
|
||||
@set_processing
|
||||
@@ -479,10 +615,14 @@ class ConversationAgent(Agent):
|
||||
self.current_memory_context = None
|
||||
|
||||
character = actor.character
|
||||
|
||||
emission = ConversationAgentEmission(agent=self, generation="", actor=actor, character=character)
|
||||
await talemate.emit.async_signals.get("agent.conversation.before_generate").send(emission)
|
||||
|
||||
|
||||
emission = ConversationAgentEmission(
|
||||
agent=self, generation="", actor=actor, character=character
|
||||
)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.conversation.before_generate"
|
||||
).send(emission)
|
||||
|
||||
self.set_generation_overrides()
|
||||
|
||||
result = await self.client.send_prompt(await self.build_prompt(character))
|
||||
@@ -505,7 +645,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
result = self.clean_result(result, character)
|
||||
|
||||
total_result += " "+result
|
||||
total_result += " " + result
|
||||
|
||||
if len(total_result) == 0 and max_loops < 10:
|
||||
max_loops += 1
|
||||
@@ -522,16 +662,34 @@ class ConversationAgent(Agent):
|
||||
empty_result_count += 1
|
||||
if empty_result_count >= 2:
|
||||
break
|
||||
|
||||
|
||||
# if result is empty, raise an error
|
||||
if not total_result:
|
||||
raise LLMAccuracyError("Received empty response from AI")
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
|
||||
total_result = total_result.split("#")[0]
|
||||
total_result = total_result.split("#")[0].strip()
|
||||
|
||||
total_result = util.handle_endofline_special_delimiter(total_result)
|
||||
|
||||
log.info("conversation agent", total_result=total_result)
|
||||
|
||||
if total_result.startswith(":\n") or total_result.startswith(": "):
|
||||
total_result = total_result[2:]
|
||||
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Remove "{character.name}:" - all occurences
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name):
|
||||
@@ -548,13 +706,17 @@ class ConversationAgent(Agent):
|
||||
)
|
||||
|
||||
response_message = util.parse_messages_from_str(total_result, [character.name])
|
||||
|
||||
log.info("conversation agent", result=response_message)
|
||||
|
||||
emission = ConversationAgentEmission(agent=self, generation=response_message, actor=actor, character=character)
|
||||
await talemate.emit.async_signals.get("agent.conversation.generated").send(emission)
|
||||
|
||||
#log.info("conversation agent", generation=emission.generation)
|
||||
log.info("conversation agent", result=response_message)
|
||||
|
||||
emission = ConversationAgentEmission(
|
||||
agent=self, generation=response_message, actor=actor, character=character
|
||||
)
|
||||
await talemate.emit.async_signals.get("agent.conversation.generated").send(
|
||||
emission
|
||||
)
|
||||
|
||||
# log.info("conversation agent", generation=emission.generation)
|
||||
|
||||
messages = [CharacterMessage(message) for message in emission.generation]
|
||||
|
||||
@@ -563,15 +725,17 @@ class ConversationAgent(Agent):
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def allow_repetition_break(self, kind: str, agent_function_name: str, auto: bool = False):
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
|
||||
return agent_function_name == "converse"
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += ['[']
|
||||
prompt_param["extra_stopping_strings"] += ["#"]
|
||||
|
||||
@@ -3,22 +3,23 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
|
||||
import talemate.client as client
|
||||
from talemate.agents.base import Agent, set_processing
|
||||
from talemate.agents.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
import talemate.client as client
|
||||
|
||||
from .assistant import AssistantMixin
|
||||
from .character import CharacterCreatorMixin
|
||||
from .scenario import ScenarioCreatorMixin
|
||||
|
||||
|
||||
@register()
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, AssistantMixin, Agent):
|
||||
"""
|
||||
Creates characters and scenarios and other fun stuff!
|
||||
"""
|
||||
|
||||
|
||||
agent_type = "creator"
|
||||
verbose_name = "Creator"
|
||||
|
||||
@@ -78,12 +79,14 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
# Remove duplicates while preserving the order for list type keys
|
||||
for key, value in merged_data.items():
|
||||
if isinstance(value, list):
|
||||
merged_data[key] = [x for i, x in enumerate(value) if x not in value[:i]]
|
||||
merged_data[key] = [
|
||||
x for i, x in enumerate(value) if x not in value[:i]
|
||||
]
|
||||
|
||||
merged_data["context"] = context
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
def load_templates_old(self, names: list, template_type: str = "character") -> dict:
|
||||
"""
|
||||
Loads multiple character creation templates from ./templates/character and merges them in order.
|
||||
@@ -128,8 +131,10 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
if "context" in template_data["instructions"]:
|
||||
context = template_data["instructions"]["context"]
|
||||
|
||||
merged_instructions[name]["questions"] = [q[0] for q in template_data.get("questions", [])]
|
||||
|
||||
merged_instructions[name]["questions"] = [
|
||||
q[0] for q in template_data.get("questions", [])
|
||||
]
|
||||
|
||||
# Remove duplicates while preserving the order
|
||||
merged_template = [
|
||||
@@ -158,24 +163,33 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_json_list(
|
||||
self,
|
||||
text:str,
|
||||
count:int=20,
|
||||
first_item:str=None,
|
||||
text: str,
|
||||
count: int = 20,
|
||||
first_item: str = None,
|
||||
):
|
||||
_, json_list = await Prompt.request(f"creator.generate-json-list", self.client, "create", vars={
|
||||
"text": text,
|
||||
"first_item": first_item,
|
||||
"count": count,
|
||||
})
|
||||
return json_list.get("items",[])
|
||||
|
||||
_, json_list = await Prompt.request(
|
||||
f"creator.generate-json-list",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"text": text,
|
||||
"first_item": first_item,
|
||||
"count": count,
|
||||
},
|
||||
)
|
||||
return json_list.get("items", [])
|
||||
|
||||
@set_processing
|
||||
async def generate_title(self, text:str):
|
||||
title = await Prompt.request(f"creator.generate-title", self.client, "create_short", vars={
|
||||
"text": text,
|
||||
})
|
||||
return title
|
||||
async def generate_title(self, text: str):
|
||||
title = await Prompt.request(
|
||||
f"creator.generate-title",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
return title
|
||||
|
||||
182
src/talemate/agents/creator/assistant.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import pydantic
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene
|
||||
|
||||
|
||||
class ContentGenerationContext(pydantic.BaseModel):
|
||||
"""
|
||||
A context for generating content.
|
||||
"""
|
||||
|
||||
context: str
|
||||
instructions: str = ""
|
||||
length: int = 100
|
||||
character: Union[str, None] = None
|
||||
original: Union[str, None] = None
|
||||
partial: str = ""
|
||||
|
||||
@property
|
||||
def computed_context(self) -> Tuple[str, str]:
|
||||
typ, context = self.context.split(":", 1)
|
||||
return typ, context
|
||||
|
||||
|
||||
class AssistantMixin:
|
||||
"""
|
||||
Creator mixin that allows quick contextual generation of content.
|
||||
"""
|
||||
|
||||
async def contextual_generate_from_args(
|
||||
self,
|
||||
context: str,
|
||||
instructions: str = "",
|
||||
length: int = 100,
|
||||
character: Union[str, None] = None,
|
||||
original: Union[str, None] = None,
|
||||
partial: str = "",
|
||||
):
|
||||
"""
|
||||
Request content from the assistant.
|
||||
"""
|
||||
|
||||
generation_context = ContentGenerationContext(
|
||||
context=context,
|
||||
instructions=instructions,
|
||||
length=length,
|
||||
character=character,
|
||||
original=original,
|
||||
partial=partial,
|
||||
)
|
||||
|
||||
return await self.contextual_generate(generation_context)
|
||||
|
||||
contextual_generate_from_args.exposed = True
|
||||
|
||||
@set_processing
|
||||
async def contextual_generate(
|
||||
self,
|
||||
generation_context: ContentGenerationContext,
|
||||
):
|
||||
"""
|
||||
Request content from the assistant.
|
||||
"""
|
||||
|
||||
context_typ, context_name = generation_context.computed_context
|
||||
|
||||
if generation_context.length < 100:
|
||||
kind = "create_short"
|
||||
elif generation_context.length < 500:
|
||||
kind = "create_concise"
|
||||
else:
|
||||
kind = "create"
|
||||
|
||||
content = await Prompt.request(
|
||||
f"creator.contextual-generate",
|
||||
self.client,
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"generation_context": generation_context,
|
||||
"context_typ": context_typ,
|
||||
"context_name": context_name,
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
"character": (
|
||||
self.scene.get_character(generation_context.character)
|
||||
if generation_context.character
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
if not generation_context.partial:
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
@set_processing
|
||||
async def autocomplete_dialogue(
|
||||
self,
|
||||
input: str,
|
||||
character: "Character",
|
||||
emit_signal: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Autocomplete dialogue.
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-dialogue",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"input": input.strip(),
|
||||
"character": character,
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
},
|
||||
pad_prepended_response=False,
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
response = util.clean_dialogue(response, character.name)[
|
||||
len(character.name + ":") :
|
||||
].strip()
|
||||
|
||||
if response.startswith(input):
|
||||
response = response[len(input) :]
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_suggestion", suggestion=response, input=input
|
||||
)
|
||||
|
||||
if emit_signal:
|
||||
emit("autocomplete_suggestion", response)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def autocomplete_narrative(
|
||||
self,
|
||||
input: str,
|
||||
emit_signal: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Autocomplete narrative.
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-narrative",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"input": input.strip(),
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
},
|
||||
pad_prepended_response=False,
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
if response.startswith(input):
|
||||
response = response[len(input) :]
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_suggestion", suggestion=response, input=input
|
||||
)
|
||||
|
||||
if emit_signal:
|
||||
emit("autocomplete_suggestion", response)
|
||||
|
||||
return response
|
||||
@@ -1,42 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import random
|
||||
import structlog
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt, LoopedPrompt
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.emit import emit
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
from talemate.prompts import LoopedPrompt, Prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
log = structlog.get_logger("talemate.agents.creator.character")
|
||||
|
||||
def validate(k,v):
|
||||
|
||||
def validate(k, v):
|
||||
if k and k.lower() == "gender":
|
||||
return v.lower().strip()
|
||||
if k and k.lower() == "age":
|
||||
try:
|
||||
return int(v.split("\n")[0].strip())
|
||||
except (ValueError, TypeError):
|
||||
raise LLMAccuracyError("Was unable to get a valid age from the response", model_name=None)
|
||||
|
||||
raise LLMAccuracyError(
|
||||
"Was unable to get a valid age from the response", model_name=None
|
||||
)
|
||||
|
||||
return v.strip().strip("\n")
|
||||
|
||||
DEFAULT_CONTENT_CONTEXT="a fun and engaging adventure aimed at an adult audience."
|
||||
|
||||
DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audience."
|
||||
|
||||
|
||||
class CharacterCreatorMixin:
|
||||
"""
|
||||
Adds character creation functionality to the creator agent
|
||||
"""
|
||||
|
||||
|
||||
## NEW
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_attributes(
|
||||
self,
|
||||
@@ -48,8 +54,6 @@ class CharacterCreatorMixin:
|
||||
custom_attributes: dict[str, str] = dict(),
|
||||
predefined_attributes: dict[str, str] = dict(),
|
||||
):
|
||||
|
||||
|
||||
def spice(prompt, spices):
|
||||
# generate number from 0 to 1 and if its smaller than use_spice
|
||||
# select a random spice from the list and return it formatted
|
||||
@@ -57,69 +61,74 @@ class CharacterCreatorMixin:
|
||||
if random.random() < use_spice:
|
||||
spice = random.choice(spices)
|
||||
return prompt.format(spice=spice)
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
# drop any empty attributes from predefined_attributes
|
||||
|
||||
predefined_attributes = {k:v for k,v in predefined_attributes.items() if v}
|
||||
|
||||
prompt = Prompt.get(f"creator.character-attributes-{template}", vars={
|
||||
"character_prompt": character_prompt,
|
||||
"template": template,
|
||||
"spice": spice,
|
||||
"content_context": content_context,
|
||||
"custom_attributes": custom_attributes,
|
||||
"character_sheet": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=attribute_callback,
|
||||
generated=predefined_attributes,
|
||||
),
|
||||
})
|
||||
|
||||
predefined_attributes = {k: v for k, v in predefined_attributes.items() if v}
|
||||
|
||||
prompt = Prompt.get(
|
||||
f"creator.character-attributes-{template}",
|
||||
vars={
|
||||
"character_prompt": character_prompt,
|
||||
"template": template,
|
||||
"spice": spice,
|
||||
"content_context": content_context,
|
||||
"custom_attributes": custom_attributes,
|
||||
"character_sheet": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=attribute_callback,
|
||||
generated=predefined_attributes,
|
||||
),
|
||||
},
|
||||
)
|
||||
await prompt.loop(self.client, "character_sheet", kind="create_concise")
|
||||
|
||||
|
||||
return prompt.vars["character_sheet"].generated
|
||||
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_description(
|
||||
self,
|
||||
character:Character,
|
||||
self,
|
||||
character: Character,
|
||||
content_context: str = DEFAULT_CONTENT_CONTEXT,
|
||||
):
|
||||
|
||||
description = await Prompt.request(f"creator.character-description", self.client, "create", vars={
|
||||
"character": character,
|
||||
"content_context": content_context,
|
||||
})
|
||||
|
||||
description = await Prompt.request(
|
||||
f"creator.character-description",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
"content_context": content_context,
|
||||
},
|
||||
)
|
||||
|
||||
return description.strip()
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_details(
|
||||
self,
|
||||
self,
|
||||
character: Character,
|
||||
template: str,
|
||||
detail_callback: Callable = lambda question, answer: None,
|
||||
questions: list[str] = None,
|
||||
content_context: str = DEFAULT_CONTENT_CONTEXT,
|
||||
):
|
||||
prompt = Prompt.get(f"creator.character-details-{template}", vars={
|
||||
"character_details": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=detail_callback,
|
||||
),
|
||||
"template": template,
|
||||
"content_context": content_context,
|
||||
"character": character,
|
||||
"custom_questions": questions or [],
|
||||
})
|
||||
prompt = Prompt.get(
|
||||
f"creator.character-details-{template}",
|
||||
vars={
|
||||
"character_details": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=detail_callback,
|
||||
),
|
||||
"template": template,
|
||||
"content_context": content_context,
|
||||
"character": character,
|
||||
"custom_questions": questions or [],
|
||||
},
|
||||
)
|
||||
await prompt.loop(self.client, "character_details", kind="create_concise")
|
||||
return prompt.vars["character_details"].generated
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_example_dialogue(
|
||||
self,
|
||||
@@ -131,97 +140,156 @@ class CharacterCreatorMixin:
|
||||
example_callback: Callable = lambda example: None,
|
||||
rules_callback: Callable = lambda rules: None,
|
||||
):
|
||||
|
||||
dialogue_rules = await Prompt.request(f"creator.character-dialogue-rules", self.client, "create", vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
})
|
||||
|
||||
dialogue_rules = await Prompt.request(
|
||||
f"creator.character-dialogue-rules",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
},
|
||||
)
|
||||
|
||||
log.info("dialogue_rules", dialogue_rules=dialogue_rules)
|
||||
|
||||
|
||||
if rules_callback:
|
||||
rules_callback(dialogue_rules)
|
||||
|
||||
example_dialogue_prompt = Prompt.get(f"creator.character-example-dialogue-{template}", vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
"dialogue_rules": dialogue_rules,
|
||||
"generated_examples": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=example_callback,
|
||||
),
|
||||
})
|
||||
|
||||
await example_dialogue_prompt.loop(self.client, "generated_examples", kind="create")
|
||||
|
||||
|
||||
example_dialogue_prompt = Prompt.get(
|
||||
f"creator.character-example-dialogue-{template}",
|
||||
vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
"dialogue_rules": dialogue_rules,
|
||||
"generated_examples": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=example_callback,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
await example_dialogue_prompt.loop(
|
||||
self.client, "generated_examples", kind="create"
|
||||
)
|
||||
|
||||
return example_dialogue_prompt.vars["generated_examples"].generated
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_content_context_for_character(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
|
||||
content_context = await Prompt.request(f"creator.determine-content-context", self.client, "create", vars={
|
||||
"character": character,
|
||||
})
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
},
|
||||
)
|
||||
return content_context.strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_dialogue_instructions(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
instructions = await Prompt.request(
|
||||
f"creator.determine-character-dialogue-instructions",
|
||||
self.client,
|
||||
"create_concise",
|
||||
vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
r = instructions.strip().split("\n")[0].strip('"').strip()
|
||||
return r
|
||||
|
||||
@set_processing
|
||||
async def determine_character_attributes(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
|
||||
attributes = await Prompt.request(f"creator.determine-character-attributes", self.client, "analyze_long", vars={
|
||||
"character": character,
|
||||
})
|
||||
attributes = await Prompt.request(
|
||||
f"creator.determine-character-attributes",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
"character": character,
|
||||
},
|
||||
)
|
||||
return attributes
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_name(
|
||||
self,
|
||||
character_name: str,
|
||||
allowed_names: list[str] = None,
|
||||
group: bool = False,
|
||||
) -> str:
|
||||
name = await Prompt.request(
|
||||
f"creator.determine-character-name",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_name": character_name,
|
||||
"allowed_names": allowed_names or [],
|
||||
"group": group,
|
||||
},
|
||||
)
|
||||
return name.split('"', 1)[0].strip().strip(".").strip()
|
||||
|
||||
@set_processing
|
||||
async def determine_character_description(
|
||||
self,
|
||||
character: Character,
|
||||
text:str=""
|
||||
self, character: Character, text: str = ""
|
||||
):
|
||||
|
||||
description = await Prompt.request(f"creator.determine-character-description", self.client, "create", vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"text": text,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
})
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-character-description",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"text": text,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
return description.strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_goals(
|
||||
self,
|
||||
character: Character,
|
||||
goal_instructions: str,
|
||||
):
|
||||
|
||||
goals = await Prompt.request(f"creator.determine-character-goals", self.client, "create", vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"goal_instructions": goal_instructions,
|
||||
"npc_name": character.name,
|
||||
"player_name": self.scene.get_player_character().name,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
})
|
||||
|
||||
goals = await Prompt.request(
|
||||
f"creator.determine-character-goals",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"goal_instructions": goal_instructions,
|
||||
"npc_name": character.name,
|
||||
"player_name": self.scene.get_player_character().name,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("determine_character_goals", goals=goals, character=character)
|
||||
await character.set_detail("goals", goals.strip())
|
||||
|
||||
|
||||
return goals.strip()
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_character_from_text(
|
||||
self,
|
||||
@@ -229,11 +297,8 @@ class CharacterCreatorMixin:
|
||||
template: str,
|
||||
content_context: str = DEFAULT_CONTENT_CONTEXT,
|
||||
):
|
||||
|
||||
base_attributes = await self.create_character_attributes(
|
||||
character_prompt=text,
|
||||
template=template,
|
||||
content_context=content_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,36 +1,35 @@
|
||||
from talemate.emit import emit, wait_for_input_yesno
|
||||
import re
|
||||
import random
|
||||
import re
|
||||
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.emit import emit, wait_for_input_yesno
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
|
||||
class ScenarioCreatorMixin:
|
||||
|
||||
"""
|
||||
Adds scenario creation functionality to the creator agent
|
||||
"""
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_scene_description(
|
||||
self,
|
||||
prompt:str,
|
||||
content_context:str,
|
||||
prompt: str,
|
||||
content_context: str,
|
||||
):
|
||||
|
||||
"""
|
||||
Creates a new scene.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
prompt (str): The prompt to use to create the scene.
|
||||
|
||||
|
||||
content_context (str): The content context to use for the scene.
|
||||
|
||||
|
||||
callback (callable): A callback to call when the scene has been created.
|
||||
"""
|
||||
scene = self.scene
|
||||
|
||||
|
||||
description = await Prompt.request(
|
||||
"creator.scenario-description",
|
||||
self.client,
|
||||
@@ -40,35 +39,32 @@ class ScenarioCreatorMixin:
|
||||
"content_context": content_context,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": scene,
|
||||
}
|
||||
},
|
||||
)
|
||||
description = description.strip()
|
||||
|
||||
|
||||
return description
|
||||
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_scene_name(
|
||||
self,
|
||||
prompt:str,
|
||||
content_context:str,
|
||||
description:str,
|
||||
prompt: str,
|
||||
content_context: str,
|
||||
description: str,
|
||||
):
|
||||
|
||||
"""
|
||||
Generates a scene name.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
prompt (str): The prompt to use to generate the scene name.
|
||||
|
||||
|
||||
content_context (str): The content context to use for the scene.
|
||||
|
||||
|
||||
description (str): The description of the scene.
|
||||
"""
|
||||
scene = self.scene
|
||||
|
||||
|
||||
name = await Prompt.request(
|
||||
"creator.scenario-name",
|
||||
self.client,
|
||||
@@ -78,37 +74,35 @@ class ScenarioCreatorMixin:
|
||||
"content_context": content_context,
|
||||
"description": description,
|
||||
"scene": scene,
|
||||
}
|
||||
},
|
||||
)
|
||||
name = name.strip().strip('.!').replace('"','')
|
||||
name = name.strip().strip(".!").replace('"', "")
|
||||
return name
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_scene_intro(
|
||||
self,
|
||||
prompt:str,
|
||||
content_context:str,
|
||||
description:str,
|
||||
name:str,
|
||||
prompt: str,
|
||||
content_context: str,
|
||||
description: str,
|
||||
name: str,
|
||||
):
|
||||
|
||||
"""
|
||||
Generates a scene introduction.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
prompt (str): The prompt to use to generate the scene introduction.
|
||||
|
||||
|
||||
content_context (str): The content context to use for the scene.
|
||||
|
||||
|
||||
description (str): The description of the scene.
|
||||
|
||||
|
||||
name (str): The name of the scene.
|
||||
"""
|
||||
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
intro = await Prompt.request(
|
||||
"creator.scenario-intro",
|
||||
self.client,
|
||||
@@ -119,17 +113,34 @@ class ScenarioCreatorMixin:
|
||||
"description": description,
|
||||
"name": name,
|
||||
"scene": scene,
|
||||
}
|
||||
},
|
||||
)
|
||||
intro = intro.strip()
|
||||
return intro
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_scenario_description(
|
||||
async def determine_scenario_description(self, text: str):
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-scenario-description",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
return description.strip()
|
||||
|
||||
@set_processing
|
||||
async def determine_content_context_for_description(
|
||||
self,
|
||||
text:str
|
||||
description: str,
|
||||
):
|
||||
description = await Prompt.request(f"creator.determine-scenario-description", self.client, "analyze_long", vars={
|
||||
"text": text,
|
||||
})
|
||||
return description
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
return content_context.lstrip().split("\n")[0].strip('"').strip()
|
||||
|
||||
34
src/talemate/agents/custom/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger("talemate.agents.custom")
|
||||
|
||||
# import every submodule in this directory
|
||||
#
|
||||
# each directory in this directory is a submodule
|
||||
|
||||
# get the current directory
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
||||
# get all subdirectories
|
||||
subdirectories = [
|
||||
os.path.join(current_directory, name)
|
||||
for name in os.listdir(current_directory)
|
||||
if os.path.isdir(os.path.join(current_directory, name))
|
||||
]
|
||||
|
||||
# import every submodule
|
||||
|
||||
for subdirectory in subdirectories:
|
||||
# get the name of the submodule
|
||||
submodule_name = os.path.basename(subdirectory)
|
||||
|
||||
if submodule_name.startswith("__"):
|
||||
continue
|
||||
|
||||
log.info("activating custom agent", module=submodule_name)
|
||||
|
||||
# import the submodule
|
||||
importlib.import_module(f".{submodule_name}", __package__)
|
||||
@@ -0,0 +1,5 @@
|
||||
Each agent should be in its own subdirectory.
|
||||
|
||||
The subdirectory itself must be a valid python module.
|
||||
|
||||
Check out docs/dev/agents/example/test for a very simplistic custom agent example.
|
||||
@@ -1,227 +1,361 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import random
|
||||
import structlog
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.emit import wait_for_input, emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import NarratorMessage, DirectorMessage
|
||||
from talemate.automated_action import AutomatedAction
|
||||
import structlog
|
||||
|
||||
import talemate.automated_action as automated_action
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from .registry import register
|
||||
from .base import set_processing, AgentAction, AgentActionConfig, Agent
|
||||
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.automated_action import AutomatedAction
|
||||
from talemate.emit import emit, wait_for_input
|
||||
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
|
||||
from talemate.game.engine import GameInstructionsMixin
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Actor, Character, Player, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agent.director")
|
||||
|
||||
|
||||
@register()
|
||||
class DirectorAgent(Agent):
|
||||
class DirectorAgent(GameInstructionsMixin, Agent):
|
||||
agent_type = "director"
|
||||
verbose_name = "Director"
|
||||
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.is_enabled = True
|
||||
self.client = client
|
||||
self.next_direct_character = {}
|
||||
self.next_direct_scene = 0
|
||||
self.actions = {
|
||||
"direct": AgentAction(enabled=True, label="Direct", description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).", config={
|
||||
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before directing the sceen", value=5, min=1, max=100, step=1),
|
||||
"direct_scene": AgentActionConfig(type="bool", label="Direct Scene", description="If enabled, the scene will be directed through narration", value=True),
|
||||
"direct_actors": AgentActionConfig(type="bool", label="Direct Actors", description="If enabled, direction will be given to actors based on their goals.", value=True),
|
||||
}),
|
||||
"direct": AgentAction(
|
||||
enabled=True,
|
||||
label="Direct",
|
||||
description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).",
|
||||
config={
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
description="Number of turns to wait before directing the sceen",
|
||||
value=5,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
),
|
||||
"direct_scene": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Direct Scene",
|
||||
description="If enabled, the scene will be directed through narration",
|
||||
value=True,
|
||||
),
|
||||
"direct_actors": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Direct Actors",
|
||||
description="If enabled, direction will be given to actors based on their goals.",
|
||||
value=True,
|
||||
),
|
||||
"actor_direction_mode": AgentActionConfig(
|
||||
type="text",
|
||||
label="Actor Direction Mode",
|
||||
description="The mode to use when directing actors",
|
||||
value="direction",
|
||||
choices=[
|
||||
{
|
||||
"label": "Direction",
|
||||
"value": "direction",
|
||||
},
|
||||
{
|
||||
"label": "Inner Monologue",
|
||||
"value": "internal_monologue",
|
||||
},
|
||||
],
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def direct_enabled(self):
|
||||
return self.actions["direct"].enabled
|
||||
|
||||
@property
|
||||
def direct_actors_enabled(self):
|
||||
return self.actions["direct"].config["direct_actors"].value
|
||||
|
||||
@property
|
||||
def direct_scene_enabled(self):
|
||||
return self.actions["direct"].config["direct_scene"].value
|
||||
|
||||
@property
|
||||
def actor_direction_mode(self):
|
||||
return self.actions["direct"].config["actor_direction_mode"].value
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(self.on_conversation_before_generate)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_player_dialog)
|
||||
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(
|
||||
self.on_conversation_before_generate
|
||||
)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
|
||||
self.on_player_dialog
|
||||
)
|
||||
talemate.emit.async_signals.get("scene_init").connect(self.on_scene_init)
|
||||
|
||||
|
||||
async def on_scene_init(self, event: SceneStateEvent):
|
||||
"""
|
||||
If game state instructions specify to be run at the start of the game loop
|
||||
we will run them here.
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
if self.scene.game_state.has_scene_instructions:
|
||||
self.is_enabled = True
|
||||
log.warning("on_scene_init - enabling director", scene=self.scene)
|
||||
if await self.scene_has_instructions(self.scene):
|
||||
self.is_enabled = True
|
||||
log.warning("on_scene_init - enabling director", scene=self.scene)
|
||||
else:
|
||||
return
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
|
||||
if not await self.scene_has_instructions(self.scene):
|
||||
return
|
||||
|
||||
if not self.scene.game_state.ops.run_on_start:
|
||||
return
|
||||
|
||||
|
||||
log.info("on_game_loop_start - running game state instructions")
|
||||
await self.run_gamestate_instructions()
|
||||
|
||||
async def on_conversation_before_generate(self, event:ConversationAgentEmission):
|
||||
|
||||
async def on_conversation_before_generate(self, event: ConversationAgentEmission):
|
||||
log.info("on_conversation_before_generate", director_enabled=self.enabled)
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
await self.direct(event.character)
|
||||
|
||||
async def on_player_dialog(self, event:GameLoopActorIterEvent):
|
||||
|
||||
|
||||
async def on_player_dialog(self, event: GameLoopActorIterEvent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
|
||||
if not await self.scene_has_instructions(self.scene):
|
||||
return
|
||||
|
||||
if not event.actor.character.is_player:
|
||||
return
|
||||
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug("director.on_player_dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
|
||||
log.debug(
|
||||
"director.on_player_dialog",
|
||||
skip=True,
|
||||
had_passive_narration=event.game_loop.had_passive_narration,
|
||||
)
|
||||
return
|
||||
|
||||
event.game_loop.had_passive_narration = await self.direct(None)
|
||||
|
||||
|
||||
async def direct(self, character: Character) -> bool:
|
||||
|
||||
if not self.actions["direct"].enabled:
|
||||
return False
|
||||
|
||||
|
||||
if character:
|
||||
|
||||
if not self.actions["direct"].config["direct_actors"].value:
|
||||
log.info("direct", skip=True, reason="direct_actors disabled", character=character)
|
||||
log.info(
|
||||
"direct",
|
||||
skip=True,
|
||||
reason="direct_actors disabled",
|
||||
character=character,
|
||||
)
|
||||
return False
|
||||
|
||||
# character direction, see if there are character goals
|
||||
|
||||
# character direction, see if there are character goals
|
||||
# defined
|
||||
character_goals = character.get_detail("goals")
|
||||
if not character_goals:
|
||||
log.info("direct", skip=True, reason="no goals", character=character)
|
||||
return False
|
||||
|
||||
|
||||
next_direct = self.next_direct_character.get(character.name, 0)
|
||||
|
||||
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
|
||||
log.info("direct", skip=True, next_direct=next_direct, character=character)
|
||||
|
||||
if (
|
||||
next_direct % self.actions["direct"].config["turns"].value != 0
|
||||
or next_direct == 0
|
||||
):
|
||||
log.info(
|
||||
"direct", skip=True, next_direct=next_direct, character=character
|
||||
)
|
||||
self.next_direct_character[character.name] = next_direct + 1
|
||||
return False
|
||||
|
||||
|
||||
self.next_direct_character[character.name] = 0
|
||||
await self.direct_scene(character, character_goals)
|
||||
return True
|
||||
else:
|
||||
|
||||
if not self.actions["direct"].config["direct_scene"].value:
|
||||
log.info("direct", skip=True, reason="direct_scene disabled")
|
||||
return False
|
||||
|
||||
|
||||
# no character, see if there are NPC characters at all
|
||||
# if not we always want to direct narration
|
||||
always_direct = (not self.scene.npc_character_names)
|
||||
|
||||
always_direct = (
|
||||
not self.scene.npc_character_names
|
||||
or self.scene.game_state.ops.always_direct
|
||||
)
|
||||
|
||||
next_direct = self.next_direct_scene
|
||||
|
||||
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
|
||||
|
||||
if (
|
||||
next_direct % self.actions["direct"].config["turns"].value != 0
|
||||
or next_direct == 0
|
||||
):
|
||||
if not always_direct:
|
||||
log.info("direct", skip=True, next_direct=next_direct)
|
||||
self.next_direct_scene += 1
|
||||
return False
|
||||
|
||||
|
||||
self.next_direct_scene = 0
|
||||
await self.direct_scene(None, None)
|
||||
return True
|
||||
|
||||
|
||||
@set_processing
|
||||
async def run_gamestate_instructions(self):
|
||||
"""
|
||||
Run game state instructions, if they exist.
|
||||
"""
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
|
||||
if not await self.scene_has_instructions(self.scene):
|
||||
return
|
||||
|
||||
|
||||
await self.direct_scene(None, None)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def direct_scene(self, character: Character, prompt:str):
|
||||
|
||||
async def direct_scene(self, character: Character, prompt: str):
|
||||
if not character and self.scene.game_state.game_won:
|
||||
# we are not directing a character, and the game has been won
|
||||
# so we don't need to direct the scene any further
|
||||
return
|
||||
|
||||
|
||||
if character:
|
||||
|
||||
# direct a character
|
||||
|
||||
response = await Prompt.request("director.direct-character", self.client, "director", vars={
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
"prompt": prompt,
|
||||
"character": character,
|
||||
"player_character": self.scene.get_player_character(),
|
||||
"game_state": self.scene.game_state,
|
||||
})
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"director.direct-character",
|
||||
self.client,
|
||||
"director",
|
||||
vars={
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
"prompt": prompt,
|
||||
"character": character,
|
||||
"player_character": self.scene.get_player_character(),
|
||||
"game_state": self.scene.game_state,
|
||||
},
|
||||
)
|
||||
|
||||
if "#" in response:
|
||||
response = response.split("#")[0]
|
||||
|
||||
log.info("direct_character", character=character, prompt=prompt, response=response)
|
||||
|
||||
|
||||
log.info(
|
||||
"direct_character",
|
||||
character=character,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
)
|
||||
|
||||
response = response.strip().split("\n")[0].strip()
|
||||
#response += f" (current story goal: {prompt})"
|
||||
# response += f" (current story goal: {prompt})"
|
||||
message = DirectorMessage(response, source=character.name)
|
||||
emit("director", message, character=character)
|
||||
self.scene.push_history(message)
|
||||
else:
|
||||
# run scene instructions
|
||||
self.scene.game_state.scene_instructions
|
||||
await self.run_scene_instructions(self.scene)
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> List[Character]:
|
||||
log.warning(
|
||||
"persist_characters_from_worldstate",
|
||||
world_state_characters=self.scene.world_state.characters,
|
||||
scene_characters=self.scene.character_names,
|
||||
)
|
||||
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
name:str,
|
||||
content:str = None,
|
||||
attributes:str = None,
|
||||
self,
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
determine_name: bool = True,
|
||||
):
|
||||
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
|
||||
self.scene.log.debug("persist_character", name=name)
|
||||
|
||||
if determine_name:
|
||||
name = await creator.determine_character_name(name)
|
||||
self.scene.log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
character = self.scene.Character(name=name)
|
||||
character.color = random.choice(['#F08080', '#FFD700', '#90EE90', '#ADD8E6', '#DDA0DD', '#FFB6C1', '#FAFAD2', '#D3D3D3', '#B0E0E6', '#FFDEAD'])
|
||||
character.color = random.choice(
|
||||
[
|
||||
"#F08080",
|
||||
"#FFD700",
|
||||
"#90EE90",
|
||||
"#ADD8E6",
|
||||
"#DDA0DD",
|
||||
"#FFB6C1",
|
||||
"#FAFAD2",
|
||||
"#D3D3D3",
|
||||
"#B0E0E6",
|
||||
"#FFDEAD",
|
||||
]
|
||||
)
|
||||
|
||||
if not attributes:
|
||||
attributes = await world_state.extract_character_sheet(name=name, text=content)
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=content
|
||||
)
|
||||
else:
|
||||
attributes = world_state._parse_character_sheet(attributes)
|
||||
|
||||
|
||||
self.scene.log.debug("persist_character", attributes=attributes)
|
||||
|
||||
character.base_attributes = attributes
|
||||
@@ -232,35 +366,71 @@ class DirectorAgent(Agent):
|
||||
|
||||
self.scene.log.debug("persist_character", description=description)
|
||||
|
||||
actor = self.scene.Actor(character=character, agent=instance.get_agent("conversation"))
|
||||
dialogue_instructions = await creator.determine_character_dialogue_instructions(
|
||||
character
|
||||
)
|
||||
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
|
||||
self.scene.log.debug(
|
||||
"persist_character", dialogue_instructions=dialogue_instructions
|
||||
)
|
||||
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
|
||||
await self.scene.add_actor(actor)
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@set_processing
|
||||
async def update_content_context(self, content:str=None, extra_choices:list[str]=None):
|
||||
|
||||
async def update_content_context(
|
||||
self, content: str = None, extra_choices: list[str] = None
|
||||
):
|
||||
if not content:
|
||||
content = "\n".join(self.scene.context_history(sections=False, min_dialogue=25, budget=2048))
|
||||
|
||||
response = await Prompt.request("world_state.determine-content-context", self.client, "analyze_freeform", vars={
|
||||
"content": content,
|
||||
"extra_choices": extra_choices or [],
|
||||
})
|
||||
|
||||
content = "\n".join(
|
||||
self.scene.context_history(sections=False, min_dialogue=25, budget=2048)
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.determine-content-context",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars={
|
||||
"content": content,
|
||||
"extra_choices": extra_choices or [],
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.context = response.strip()
|
||||
self.scene.emit_status()
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
|
||||
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
|
||||
|
||||
async def log_action(self, action: str, action_description: str):
|
||||
message = DirectorMessage(message=action_description, action=action)
|
||||
self.scene.push_history(message)
|
||||
emit("director", message)
|
||||
|
||||
log_action.exposed = True
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
log.debug(
|
||||
"inject_prompt_paramters",
|
||||
prompt_param=prompt_param,
|
||||
kind=kind,
|
||||
agent_function_name=agent_function_name,
|
||||
)
|
||||
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += character_names + ["#"]
|
||||
if agent_function_name == "update_content_context":
|
||||
prompt_param["extra_stopping_strings"] += ["\n"]
|
||||
|
||||
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
|
||||
return True
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
return True
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.data_objects as data_objects
|
||||
import talemate.util as util
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
import re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.agents.narrator import NarratorAgentEmission
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agents.editor")
|
||||
|
||||
|
||||
@register()
|
||||
class EditorAgent(Agent):
|
||||
"""
|
||||
@@ -35,175 +35,281 @@ class EditorAgent(Agent):
|
||||
|
||||
agent_type = "editor"
|
||||
verbose_name = "Editor"
|
||||
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."),
|
||||
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.", config={
|
||||
"narrator": AgentActionConfig(type="bool", label="Fix narrator messages", description="Will attempt to fix exposition issues in narrator messages", value=True),
|
||||
}),
|
||||
"add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.")
|
||||
"fix_exposition": AgentAction(
|
||||
enabled=True,
|
||||
label="Fix exposition",
|
||||
description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.",
|
||||
config={
|
||||
"narrator": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Fix narrator messages",
|
||||
description="Will attempt to fix exposition issues in narrator messages",
|
||||
value=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
"add_detail": AgentAction(
|
||||
enabled=False,
|
||||
label="Add detail",
|
||||
description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
"check_continuity_errors": AgentAction(
|
||||
enabled=False,
|
||||
label="Check continuity errors",
|
||||
description="Will attempt to fix continuity errors in the dialogue. Runs automatically after each AI dialogue. (super experimental)",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
|
||||
talemate.emit.async_signals.get("agent.narrator.generated").connect(self.on_narrator_generated)
|
||||
|
||||
async def on_conversation_generated(self, emission:ConversationAgentEmission):
|
||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(
|
||||
self.on_conversation_generated
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.narrator.generated").connect(
|
||||
self.on_narrator_generated
|
||||
)
|
||||
|
||||
async def on_conversation_generated(self, emission: ConversationAgentEmission):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
log.info("editing conversation", emission=emission)
|
||||
|
||||
|
||||
edited = []
|
||||
for text in emission.generation:
|
||||
edit = await self.add_detail(text, emission.character)
|
||||
|
||||
edit = await self.fix_exposition(edit, emission.character)
|
||||
|
||||
edit = await self.check_continuity_errors(edit, emission.character)
|
||||
|
||||
|
||||
edit = await self.add_detail(
|
||||
text,
|
||||
emission.character
|
||||
)
|
||||
|
||||
edit = await self.edit_conversation(
|
||||
edit,
|
||||
emission.character
|
||||
)
|
||||
|
||||
edit = await self.fix_exposition(
|
||||
edit,
|
||||
emission.character
|
||||
)
|
||||
|
||||
edited.append(edit)
|
||||
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
async def on_narrator_generated(self, emission:NarratorAgentEmission):
|
||||
|
||||
async def on_narrator_generated(self, emission: NarratorAgentEmission):
|
||||
"""
|
||||
Called when a narrator message is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
log.info("editing narrator", emission=emission)
|
||||
|
||||
|
||||
edited = []
|
||||
|
||||
|
||||
for text in emission.generation:
|
||||
edit = await self.fix_exposition_on_narrator(text)
|
||||
edited.append(edit)
|
||||
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def edit_conversation(self, content:str, character:Character):
|
||||
"""
|
||||
Edits a conversation
|
||||
"""
|
||||
|
||||
if not self.actions["edit_dialogue"].enabled:
|
||||
return content
|
||||
|
||||
response = await Prompt.request("editor.edit-dialogue", self.client, "edit_dialogue", vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length
|
||||
})
|
||||
|
||||
response = response.split("[end]")[0]
|
||||
|
||||
response = util.replace_exposition_markers(response)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition(self, content:str, character:Character):
|
||||
async def fix_exposition(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a text to make sure all narrative exposition and emotes is encased in *
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["fix_exposition"].enabled:
|
||||
return content
|
||||
|
||||
# if not content was generated, return it as is
|
||||
if not content:
|
||||
return content
|
||||
|
||||
if not character.is_player:
|
||||
if '"' not in content and '*' not in content:
|
||||
if '"' not in content and "*" not in content:
|
||||
content = util.strip_partial_sentences(content)
|
||||
character_prefix = f"{character.name}: "
|
||||
message = content.split(character_prefix)[1]
|
||||
content = f"{character_prefix}*{message.strip('*')}*"
|
||||
content = f'{character_prefix}"{message.strip()}"'
|
||||
return content
|
||||
elif '"' in content:
|
||||
|
||||
# silly hack to clean up some LLMs that always start with a quote
|
||||
# even though the immediate next thing is a narration (indicated by *)
|
||||
content = content.replace(f"{character.name}: \"*", f"{character.name}: *")
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = content.replace(
|
||||
f'{character.name}: "*', f"{character.name}: *"
|
||||
)
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = util.strip_partial_sentences(content)
|
||||
content = util.ensure_dialog_format(content, talking_character=character.name)
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition_on_narrator(self, content:str):
|
||||
|
||||
async def fix_exposition_on_narrator(self, content: str):
|
||||
if not self.actions["fix_exposition"].enabled:
|
||||
return content
|
||||
|
||||
|
||||
if not self.actions["fix_exposition"].config["narrator"].value:
|
||||
return content
|
||||
|
||||
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
|
||||
if '"' not in content:
|
||||
content = f"*{content.strip('*')}*"
|
||||
else:
|
||||
content = util.ensure_dialog_format(content)
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@set_processing
|
||||
async def add_detail(self, content:str, character:Character):
|
||||
async def add_detail(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a text to increase its length and add extra detail and exposition
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["add_detail"].enabled:
|
||||
return content
|
||||
|
||||
response = await Prompt.request("editor.add-detail", self.client, "edit_add_detail", vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length
|
||||
})
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.add-detail",
|
||||
self.client,
|
||||
"edit_add_detail",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = util.replace_exposition_markers(response)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def check_continuity_errors(
|
||||
self,
|
||||
content: str,
|
||||
character: Character,
|
||||
force: bool = False,
|
||||
fix: bool = True,
|
||||
message_id: int = None,
|
||||
) -> str:
|
||||
"""
|
||||
Edits a text to ensure that it is consistent with the scene
|
||||
so far
|
||||
"""
|
||||
|
||||
if not self.actions["check_continuity_errors"].enabled and not force:
|
||||
return content
|
||||
|
||||
MAX_CONTENT_LENGTH = 255
|
||||
count = util.count_tokens(content)
|
||||
|
||||
if count > MAX_CONTENT_LENGTH:
|
||||
log.warning(
|
||||
"check_continuity_errors content too long",
|
||||
length=count,
|
||||
max=MAX_CONTENT_LENGTH,
|
||||
content=content[:255],
|
||||
)
|
||||
return content
|
||||
|
||||
log.debug(
|
||||
"check_continuity_errors START",
|
||||
content=content,
|
||||
character=character,
|
||||
force=force,
|
||||
fix=fix,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.check-continuity-errors",
|
||||
self.client,
|
||||
"basic_analytical_medium2",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
|
||||
# loop through response line by line, checking for lines beginning
|
||||
# with "ERROR {number}:
|
||||
|
||||
errors = []
|
||||
|
||||
for line in response.split("\n"):
|
||||
if "ERROR" not in line:
|
||||
continue
|
||||
|
||||
errors.append(line)
|
||||
|
||||
if not errors:
|
||||
log.debug("check_continuity_errors NO ERRORS")
|
||||
return content
|
||||
|
||||
log.debug("check_continuity_errors ERRORS", fix=fix, errors=errors)
|
||||
|
||||
if not fix:
|
||||
return content
|
||||
|
||||
state = {}
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.fix-continuity-errors",
|
||||
self.client,
|
||||
"editor_creative_medium2",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"errors": errors,
|
||||
"set_state": lambda k, v: state.update({k: v}),
|
||||
},
|
||||
)
|
||||
|
||||
content_fix_identifer = state.get("content_fix_identifier")
|
||||
|
||||
try:
|
||||
content = response.strip().strip("```").split("```")[0].strip()
|
||||
content = content.replace(content_fix_identifer, "").strip()
|
||||
content = content.strip(":")
|
||||
|
||||
# if content doesnt start with {character_name}: then add it
|
||||
if not content.startswith(f"{character.name}:"):
|
||||
content = f"{character.name}: {content}"
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"check_continuity_errors FAILED",
|
||||
content_fix_identifer=content_fix_identifer,
|
||||
response=response,
|
||||
e=e,
|
||||
)
|
||||
return content
|
||||
|
||||
log.debug("check_continuity_errors FIXED", content=content)
|
||||
|
||||
return content
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
from chromadb.config import Settings
|
||||
|
||||
import talemate.events as events
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.config import load_config
|
||||
from talemate.context import scene_is_loading
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.context import scene_is_loading
|
||||
from talemate.config import load_config
|
||||
from talemate.agents.base import set_processing
|
||||
import structlog
|
||||
import shutil
|
||||
import functools
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
@@ -28,19 +30,20 @@ if not chromadb:
|
||||
log.info("ChromaDB not found, disabling Chroma agent")
|
||||
|
||||
|
||||
from .base import Agent
|
||||
from .base import Agent, AgentDetail
|
||||
|
||||
|
||||
class MemoryDocument(str):
|
||||
|
||||
def __new__(cls, text, meta, id, raw):
|
||||
inst = super().__new__(cls, text)
|
||||
|
||||
inst.meta = meta
|
||||
inst.id = id
|
||||
inst.raw = raw
|
||||
|
||||
|
||||
return inst
|
||||
|
||||
|
||||
class MemoryAgent(Agent):
|
||||
"""
|
||||
An agent that can be used to maintain and access a memory of the world
|
||||
@@ -52,10 +55,11 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def readonly(self):
|
||||
|
||||
if scene_is_loading.get() and not getattr(self.scene, "_memory_never_persisted", False):
|
||||
if scene_is_loading.get() and not getattr(
|
||||
self.scene, "_memory_never_persisted", False
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
@@ -72,9 +76,9 @@ class MemoryAgent(Agent):
|
||||
self.memory_tracker = {}
|
||||
self.config = load_config()
|
||||
self._ready_to_add = False
|
||||
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
|
||||
def on_config_saved(self, event):
|
||||
openai_key = self.openai_api_key
|
||||
self.config = load_config()
|
||||
@@ -92,35 +96,68 @@ class MemoryAgent(Agent):
|
||||
raise NotImplementedError()
|
||||
|
||||
@set_processing
|
||||
async def add(self, text, character=None, uid=None, ts:str=None, **kwargs):
|
||||
async def add(self, text, character=None, uid=None, ts: str = None, **kwargs):
|
||||
if not text:
|
||||
return
|
||||
if self.readonly:
|
||||
log.debug("memory agent", status="readonly")
|
||||
return
|
||||
|
||||
|
||||
while not self._ready_to_add:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
log.debug("memory agent add", text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
|
||||
|
||||
|
||||
log.debug(
|
||||
"memory agent add",
|
||||
text=text[:50],
|
||||
character=character,
|
||||
uid=uid,
|
||||
ts=ts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs),
|
||||
)
|
||||
except AttributeError as e:
|
||||
# not sure how this sometimes happens.
|
||||
# chromadb model None
|
||||
# race condition because we are forcing async context onto it?
|
||||
|
||||
log.error("memory agent", error="failed to add memory", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
|
||||
|
||||
log.error(
|
||||
"memory agent",
|
||||
error="failed to add memory",
|
||||
details=e,
|
||||
text=text[:50],
|
||||
character=character,
|
||||
uid=uid,
|
||||
ts=ts,
|
||||
**kwargs,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
try:
|
||||
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._add, text, character, uid=uid, ts=ts, **kwargs
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("memory agent", error="failed to add memory (retried)", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
|
||||
log.error(
|
||||
"memory agent",
|
||||
error="failed to add memory (retried)",
|
||||
details=e,
|
||||
text=text[:50],
|
||||
character=character,
|
||||
uid=uid,
|
||||
ts=ts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _add(self, text, character=None, ts:str=None, **kwargs):
|
||||
def _add(self, text, character=None, ts: str = None, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@set_processing
|
||||
@@ -131,44 +168,46 @@ class MemoryAgent(Agent):
|
||||
|
||||
while not self._ready_to_add:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
log.debug("memory agent add many", len=len(objects))
|
||||
|
||||
|
||||
log.debug("memory agent add many", len=len(objects))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_many, objects)
|
||||
|
||||
|
||||
def _add_many(self, objects: list[dict]):
|
||||
"""
|
||||
Add multiple objects to the memory
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _delete(self, meta:dict):
|
||||
def _delete(self, meta: dict):
|
||||
"""
|
||||
Delete an object from the memory
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def delete(self, meta:dict):
|
||||
async def delete(self, meta: dict):
|
||||
"""
|
||||
Delete an object from the memory
|
||||
"""
|
||||
if self.readonly:
|
||||
log.debug("memory agent", status="readonly")
|
||||
return
|
||||
|
||||
|
||||
while not self._ready_to_add:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._delete, meta)
|
||||
|
||||
@set_processing
|
||||
async def get(self, text, character=None, **query):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
return await loop.run_in_executor(None, functools.partial(self._get, text, character, **query))
|
||||
|
||||
return await loop.run_in_executor(
|
||||
None, functools.partial(self._get, text, character, **query)
|
||||
)
|
||||
|
||||
def _get(self, text, character=None, **query):
|
||||
raise NotImplementedError()
|
||||
@@ -177,12 +216,14 @@ class MemoryAgent(Agent):
|
||||
async def get_document(self, id):
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, self._get_document, id)
|
||||
|
||||
|
||||
def _get_document(self, id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_archive_add(self, event: events.ArchiveEvent):
|
||||
asyncio.ensure_future(self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history"))
|
||||
asyncio.ensure_future(
|
||||
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
|
||||
)
|
||||
|
||||
def on_character_state(self, event: events.CharacterStateEvent):
|
||||
asyncio.ensure_future(
|
||||
@@ -222,10 +263,10 @@ class MemoryAgent(Agent):
|
||||
"""
|
||||
|
||||
memory_context = []
|
||||
|
||||
|
||||
if not query:
|
||||
return memory_context
|
||||
|
||||
|
||||
for memory in await self.get(query):
|
||||
if memory in memory_context:
|
||||
continue
|
||||
@@ -239,17 +280,26 @@ class MemoryAgent(Agent):
|
||||
break
|
||||
return memory_context
|
||||
|
||||
async def query(self, query:str, max_tokens:int=1000, filter:Callable=lambda x:True, **where):
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
max_tokens: int = 1000,
|
||||
filter: Callable = lambda x: True,
|
||||
**where,
|
||||
):
|
||||
"""
|
||||
Get the character memory context for a given character
|
||||
"""
|
||||
|
||||
try:
|
||||
return (await self.multi_query([query], max_tokens=max_tokens, filter=filter, **where))[0]
|
||||
return (
|
||||
await self.multi_query(
|
||||
[query], max_tokens=max_tokens, filter=filter, **where
|
||||
)
|
||||
)[0]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
async def multi_query(
|
||||
self,
|
||||
queries: list[str],
|
||||
@@ -258,7 +308,7 @@ class MemoryAgent(Agent):
|
||||
filter: Callable = lambda x: True,
|
||||
formatter: Callable = lambda x: x,
|
||||
limit: int = 10,
|
||||
**where
|
||||
**where,
|
||||
):
|
||||
"""
|
||||
Get the character memory context for a given character
|
||||
@@ -266,10 +316,9 @@ class MemoryAgent(Agent):
|
||||
|
||||
memory_context = []
|
||||
for query in queries:
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
|
||||
i = 0
|
||||
for memory in await self.get(formatter(query), limit=limit, **where):
|
||||
if memory in memory_context:
|
||||
@@ -296,15 +345,13 @@ from .registry import register
|
||||
|
||||
@register(condition=lambda: chromadb is not None)
|
||||
class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
requires_llm_client = False
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return False
|
||||
|
||||
|
||||
if getattr(self, "db_client", None):
|
||||
return True
|
||||
return False
|
||||
@@ -313,80 +360,110 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def status(self):
|
||||
if self.ready:
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return "error"
|
||||
|
||||
|
||||
return "waiting"
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value="ChromaDB",
|
||||
description="The backend to use for long-term memory",
|
||||
).model_dump(),
|
||||
"embeddings": AgentDetail(
|
||||
icon="mdi-cube-unfolded",
|
||||
value=self.embeddings,
|
||||
description="The embeddings model.",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return "No OpenAI API key set"
|
||||
|
||||
return f"ChromaDB: {self.embeddings}"
|
||||
|
||||
# return "No OpenAI API key set"
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
"value": "No OpenAI API key set",
|
||||
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""
|
||||
Returns which embeddings to use
|
||||
|
||||
|
||||
will read from TM_CHROMADB_EMBEDDINGS env variable and default to 'default' using
|
||||
the default embeddings specified by chromadb.
|
||||
|
||||
|
||||
other values are
|
||||
|
||||
|
||||
- openai: use openai embeddings
|
||||
- instructor: use instructor embeddings
|
||||
|
||||
|
||||
for `openai`:
|
||||
|
||||
|
||||
you will also need to provide an `OPENAI_API_KEY` env variable
|
||||
|
||||
|
||||
for `instructor`:
|
||||
|
||||
|
||||
you will also need to provide which instructor model to use with the `TM_INSTRUCTOR_MODEL` env variable, which defaults to hkunlp/instructor-xl
|
||||
|
||||
|
||||
additionally you can provide the `TM_INSTRUCTOR_DEVICE` env variable to specify which device to use, which defaults to cpu
|
||||
"""
|
||||
|
||||
|
||||
embeddings = self.config.get("chromadb").get("embeddings")
|
||||
|
||||
assert embeddings in ["default", "openai", "instructor"], f"Unknown embeddings {embeddings}"
|
||||
|
||||
|
||||
assert embeddings in [
|
||||
"default",
|
||||
"openai",
|
||||
"instructor",
|
||||
], f"Unknown embeddings {embeddings}"
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@property
|
||||
def USE_OPENAI(self):
|
||||
return self.embeddings == "openai"
|
||||
|
||||
|
||||
@property
|
||||
def USE_INSTRUCTOR(self):
|
||||
return self.embeddings == "instructor"
|
||||
|
||||
|
||||
@property
|
||||
def db_name(self):
|
||||
return getattr(self, "collection_name", "<unnamed>")
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai",{}).get("api_key")
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
def make_collection_name(self, scene):
|
||||
|
||||
if self.USE_OPENAI:
|
||||
suffix = "-openai"
|
||||
model_name = self.config.get("chromadb").get(
|
||||
"openai_model", "text-embedding-3-small"
|
||||
)
|
||||
if model_name == "text-embedding-ada-002":
|
||||
suffix = "-openai"
|
||||
else:
|
||||
suffix = f"-openai-{model_name}"
|
||||
elif self.USE_INSTRUCTOR:
|
||||
suffix = "-instructor"
|
||||
model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
|
||||
model = self.config.get("chromadb").get(
|
||||
"instructor_model", "hkunlp/instructor-xl"
|
||||
)
|
||||
if "xl" in model:
|
||||
suffix += "-xl"
|
||||
elif "large" in model:
|
||||
suffix += "-large"
|
||||
else:
|
||||
suffix = ""
|
||||
|
||||
|
||||
return f"{scene.memory_id}-tm{suffix}"
|
||||
|
||||
async def count(self):
|
||||
@@ -399,9 +476,8 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
await loop.run_in_executor(None, self._set_db)
|
||||
|
||||
def _set_db(self):
|
||||
|
||||
self._ready_to_add = False
|
||||
|
||||
|
||||
if not getattr(self, "db_client", None):
|
||||
log.info("chromadb agent", status="setting up db client to persistent db")
|
||||
self.db_client = chromadb.PersistentClient(
|
||||
@@ -409,49 +485,67 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
)
|
||||
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
|
||||
self.collection_name = collection_name = self.make_collection_name(self.scene)
|
||||
|
||||
log.info("chromadb agent", status="setting up db", collection_name=collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="setting up db", collection_name=collection_name
|
||||
)
|
||||
|
||||
if self.USE_OPENAI:
|
||||
|
||||
if not openai_key:
|
||||
raise ValueError("You must provide an the openai ai key in the config if you want to use it for chromadb embeddings")
|
||||
|
||||
raise ValueError(
|
||||
"You must provide an the openai ai key in the config if you want to use it for chromadb embeddings"
|
||||
)
|
||||
|
||||
model_name = self.config.get("chromadb").get(
|
||||
"openai_model", "text-embedding-3-small"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"crhomadb", status="using openai", openai_key=openai_key[:5] + "..."
|
||||
"crhomadb",
|
||||
status="using openai",
|
||||
openai_key=openai_key[:5] + "...",
|
||||
model=model_name,
|
||||
)
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key = openai_key,
|
||||
model_name="text-embedding-ada-002",
|
||||
api_key=openai_key,
|
||||
model_name=model_name,
|
||||
)
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef
|
||||
)
|
||||
elif self.USE_INSTRUCTOR:
|
||||
|
||||
instructor_device = self.config.get("chromadb").get("instructor_device", "cpu")
|
||||
instructor_model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
|
||||
|
||||
log.info("chromadb", status="using instructor", model=instructor_model, device=instructor_device)
|
||||
|
||||
instructor_device = self.config.get("chromadb").get(
|
||||
"instructor_device", "cpu"
|
||||
)
|
||||
instructor_model = self.config.get("chromadb").get(
|
||||
"instructor_model", "hkunlp/instructor-xl"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"chromadb",
|
||||
status="using instructor",
|
||||
model=instructor_model,
|
||||
device=instructor_device,
|
||||
)
|
||||
|
||||
# ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2")
|
||||
ef = embedding_functions.InstructorEmbeddingFunction(
|
||||
model_name=instructor_model, device=instructor_device
|
||||
)
|
||||
|
||||
|
||||
log.info("chromadb", status="embedding function ready")
|
||||
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef
|
||||
)
|
||||
|
||||
|
||||
log.info("chromadb", status="instructor db ready")
|
||||
else:
|
||||
log.info("chromadb", status="using default embeddings")
|
||||
self.db = self.db_client.get_or_create_collection(collection_name)
|
||||
|
||||
|
||||
self.scene._memory_never_persisted = self.db.count() == 0
|
||||
log.info("chromadb agent", status="db ready")
|
||||
self._ready_to_add = True
|
||||
@@ -459,17 +553,21 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def clear_db(self):
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="clearing db", collection_name=self.collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="clearing db", collection_name=self.collection_name
|
||||
)
|
||||
|
||||
self.db.delete(where={"source": "talemate"})
|
||||
|
||||
|
||||
def drop_db(self):
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="dropping db", collection_name=self.collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="dropping db", collection_name=self.collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
self.db_client.delete_collection(self.collection_name)
|
||||
except ValueError as exc:
|
||||
@@ -479,31 +577,43 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def close_db(self, scene):
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="closing db", collection_name=self.collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="closing db", collection_name=self.collection_name
|
||||
)
|
||||
|
||||
if not scene.saved and not scene.saved_memory_session_id:
|
||||
# scene was never saved so we can discard the memory
|
||||
collection_name = self.make_collection_name(scene)
|
||||
log.info("chromadb agent", status="discarding memory", collection_name=collection_name)
|
||||
log.info(
|
||||
"chromadb agent",
|
||||
status="discarding memory",
|
||||
collection_name=collection_name,
|
||||
)
|
||||
try:
|
||||
self.db_client.delete_collection(collection_name)
|
||||
except ValueError as exc:
|
||||
log.error("chromadb agent", error="failed to delete collection", details=exc)
|
||||
log.error(
|
||||
"chromadb agent", error="failed to delete collection", details=exc
|
||||
)
|
||||
elif not scene.saved:
|
||||
# scene was saved but memory was never persisted
|
||||
# so we need to remove the memory from the db
|
||||
self._remove_unsaved_memory()
|
||||
|
||||
|
||||
self.db = None
|
||||
|
||||
def _add(self, text, character=None, uid=None, ts:str=None, **kwargs):
|
||||
|
||||
def _add(self, text, character=None, uid=None, ts: str = None, **kwargs):
|
||||
metadatas = []
|
||||
ids = []
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if character:
|
||||
meta = {"character": character.name, "source": "talemate", "session": scene.memory_session_id}
|
||||
meta = {
|
||||
"character": character.name,
|
||||
"source": "talemate",
|
||||
"session": scene.memory_session_id,
|
||||
}
|
||||
if ts:
|
||||
meta["ts"] = ts
|
||||
meta.update(kwargs)
|
||||
@@ -513,7 +623,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
id = uid or f"{character.name}-{self.memory_tracker[character.name]}"
|
||||
ids = [id]
|
||||
else:
|
||||
meta = {"character": "__narrator__", "source": "talemate", "session": scene.memory_session_id}
|
||||
meta = {
|
||||
"character": "__narrator__",
|
||||
"source": "talemate",
|
||||
"session": scene.memory_session_id,
|
||||
}
|
||||
if ts:
|
||||
meta["ts"] = ts
|
||||
meta.update(kwargs)
|
||||
@@ -523,17 +637,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}"
|
||||
ids = [id]
|
||||
|
||||
#log.debug("chromadb agent add", text=text, meta=meta, id=id)
|
||||
# log.debug("chromadb agent add", text=text, meta=meta, id=id)
|
||||
|
||||
self.db.upsert(documents=[text], metadatas=metadatas, ids=ids)
|
||||
|
||||
def _add_many(self, objects: list[dict]):
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not objects:
|
||||
return
|
||||
|
||||
@@ -552,52 +665,50 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
ids.append(uid)
|
||||
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def _delete(self, meta:dict):
|
||||
|
||||
def _delete(self, meta: dict):
|
||||
if "ids" in meta:
|
||||
log.debug("chromadb agent delete", ids=meta["ids"])
|
||||
self.db.delete(ids=meta["ids"])
|
||||
return
|
||||
|
||||
where = {"$and": [{k:v} for k,v in meta.items()]}
|
||||
|
||||
where = {"$and": [{k: v} for k, v in meta.items()]}
|
||||
self.db.delete(where=where)
|
||||
log.debug("chromadb agent delete", meta=meta, where=where)
|
||||
|
||||
def _get(self, text, character=None, limit:int=15, **kwargs):
|
||||
def _get(self, text, character=None, limit: int = 15, **kwargs):
|
||||
where = {}
|
||||
|
||||
|
||||
# this doesn't work because chromadb currently doesn't match
|
||||
# non existing fields with $ne (or so it seems)
|
||||
# where.setdefault("$and", [{"pin_only": {"$ne": True}}])
|
||||
|
||||
|
||||
where.setdefault("$and", [])
|
||||
|
||||
|
||||
character_filtered = False
|
||||
|
||||
for k,v in kwargs.items():
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if k == "character":
|
||||
character_filtered = True
|
||||
where["$and"].append({k: v})
|
||||
|
||||
|
||||
if character and not character_filtered:
|
||||
where["$and"].append({"character": character.name})
|
||||
|
||||
|
||||
if len(where["$and"]) == 1:
|
||||
where = where["$and"][0]
|
||||
elif not where["$and"]:
|
||||
where = None
|
||||
|
||||
#log.debug("crhomadb agent get", text=text, where=where)
|
||||
# log.debug("crhomadb agent get", text=text, where=where)
|
||||
|
||||
_results = self.db.query(query_texts=[text], where=where, n_results=limit)
|
||||
|
||||
#import json
|
||||
#print(json.dumps(_results["ids"], indent=2))
|
||||
#print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
|
||||
# import json
|
||||
# print(json.dumps(_results["ids"], indent=2))
|
||||
# print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
max_distance = 1.5
|
||||
if self.USE_INSTRUCTOR:
|
||||
max_distance = 1
|
||||
@@ -606,24 +717,29 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
for i in range(len(_results["distances"][0])):
|
||||
distance = _results["distances"][0][i]
|
||||
|
||||
|
||||
doc = _results["documents"][0][i]
|
||||
meta = _results["metadatas"][0][i]
|
||||
|
||||
if not meta:
|
||||
log.warning("chromadb agent get", error="no meta", doc=doc)
|
||||
continue
|
||||
|
||||
ts = meta.get("ts")
|
||||
|
||||
|
||||
# skip pin_only entries
|
||||
if meta.get("pin_only", False):
|
||||
continue
|
||||
|
||||
|
||||
if distance < max_distance:
|
||||
date_prefix = self.convert_ts_to_date_prefix(ts)
|
||||
raw = doc
|
||||
|
||||
|
||||
if date_prefix:
|
||||
doc = f"{date_prefix}: {doc}"
|
||||
|
||||
|
||||
doc = MemoryDocument(doc, meta, _results["ids"][0][i], raw)
|
||||
|
||||
|
||||
results.append(doc)
|
||||
else:
|
||||
break
|
||||
@@ -635,45 +751,55 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def convert_ts_to_date_prefix(self, ts):
|
||||
if not ts:
|
||||
return None
|
||||
try:
|
||||
return util.iso8601_diff_to_human(ts, self.scene.ts)
|
||||
except Exception as e:
|
||||
log.error("chromadb agent", error="failed to get date prefix", details=e, ts=ts, scene_ts=self.scene.ts)
|
||||
log.error(
|
||||
"chromadb agent",
|
||||
error="failed to get date prefix",
|
||||
details=e,
|
||||
ts=ts,
|
||||
scene_ts=self.scene.ts,
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_document(self, id) -> dict:
|
||||
result = self.db.get(ids=[id] if isinstance(id, str) else id)
|
||||
documents = {}
|
||||
|
||||
|
||||
for idx, doc in enumerate(result["documents"]):
|
||||
date_prefix = self.convert_ts_to_date_prefix(result["metadatas"][idx].get("ts"))
|
||||
date_prefix = self.convert_ts_to_date_prefix(
|
||||
result["metadatas"][idx].get("ts")
|
||||
)
|
||||
if date_prefix:
|
||||
doc = f"{date_prefix}: {doc}"
|
||||
documents[result["ids"][idx]] = MemoryDocument(doc, result["metadatas"][idx], result["ids"][idx], doc)
|
||||
|
||||
documents[result["ids"][idx]] = MemoryDocument(
|
||||
doc, result["metadatas"][idx], result["ids"][idx], doc
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@set_processing
|
||||
async def remove_unsaved_memory(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._remove_unsaved_memory)
|
||||
|
||||
|
||||
def _remove_unsaved_memory(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not scene.memory_session_id:
|
||||
return
|
||||
|
||||
|
||||
if scene.saved_memory_session_id == self.scene.memory_session_id:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="removing unsaved memory", session_id=scene.memory_session_id)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent",
|
||||
status="removing unsaved memory",
|
||||
session_id=scene.memory_session_id,
|
||||
)
|
||||
|
||||
self._delete({"session": scene.memory_session_id, "source": "talemate"})
|
||||
|
||||
|
||||
|
||||
@@ -1,43 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
import dataclasses
|
||||
import structlog
|
||||
import random
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.agents.base import set_processing as _set_processing, Agent, AgentAction, AgentActionConfig, AgentEmission
|
||||
from talemate.agents.world_state import TimePassageEmission
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.events import GameLoopActorIterEvent
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.client as client
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission
|
||||
from talemate.agents.base import set_processing as _set_processing
|
||||
from talemate.agents.world_state import TimePassageEmission
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopActorIterEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import NarratorMessage
|
||||
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Player, Character
|
||||
from talemate.tale_mate import Actor, Character, Player
|
||||
|
||||
log = structlog.get_logger("talemate.agents.narrator")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NarratorAgentEmission(AgentEmission):
|
||||
generation: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.narrator.generated"
|
||||
)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register("agent.narrator.generated")
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
is running and then emits the result of the function as a NarratorAgentEmission
|
||||
"""
|
||||
|
||||
|
||||
@_set_processing
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
@wraps(fn)
|
||||
async def narration_wrapper(self, *args, **kwargs):
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission = NarratorAgentEmission(
|
||||
agent=self,
|
||||
@@ -45,68 +50,68 @@ def set_processing(fn):
|
||||
)
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
return emission.generation[0]
|
||||
wrapper.__name__ = fn.__name__
|
||||
return wrapper
|
||||
|
||||
return narration_wrapper
|
||||
|
||||
|
||||
@register()
|
||||
class NarratorAgent(Agent):
|
||||
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
|
||||
|
||||
agent_type = "narrator"
|
||||
verbose_name = "Narrator"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.TaleMateClient,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
|
||||
# agent actions
|
||||
|
||||
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled = True,
|
||||
label = "Generation Override",
|
||||
description = "Override generation parameters",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Generation Override",
|
||||
description="Override generation parameters",
|
||||
config={
|
||||
"instructions": AgentActionConfig(
|
||||
type="text",
|
||||
label="Instructions",
|
||||
value="Never wax poetic.",
|
||||
description="Extra instructions to give to the AI for narrative generation.",
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled = True,
|
||||
label = "Auto Break Repetition",
|
||||
description = "Will attempt to automatically break AI repetition.",
|
||||
enabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"narrate_time_passage": AgentAction(
|
||||
enabled=True,
|
||||
label="Narrate Time Passage",
|
||||
enabled=True,
|
||||
label="Narrate Time Passage",
|
||||
description="Whenever you indicate passage of time, narrate right after",
|
||||
config = {
|
||||
config={
|
||||
"ask_for_prompt": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Guide time narration via prompt",
|
||||
label="Guide time narration via prompt",
|
||||
description="Ask the user for a prompt to generate the time passage narration",
|
||||
value=True,
|
||||
)
|
||||
}
|
||||
},
|
||||
),
|
||||
"narrate_dialogue": AgentAction(
|
||||
enabled=False,
|
||||
label="Narrate after Dialogue",
|
||||
enabled=False,
|
||||
label="Narrate after Dialogue",
|
||||
description="Narrator will get a chance to narrate after every line of dialogue",
|
||||
config = {
|
||||
config={
|
||||
"ai_dialog": AgentActionConfig(
|
||||
type="number",
|
||||
label="AI Dialogue",
|
||||
label="AI Dialogue",
|
||||
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||
value=0.0,
|
||||
min=0.0,
|
||||
@@ -115,7 +120,7 @@ class NarratorAgent(Agent):
|
||||
),
|
||||
"player_dialog": AgentActionConfig(
|
||||
type="number",
|
||||
label="Player Dialogue",
|
||||
label="Player Dialogue",
|
||||
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||
value=0.1,
|
||||
min=0.0,
|
||||
@@ -124,34 +129,32 @@ class NarratorAgent(Agent):
|
||||
),
|
||||
"generate_dialogue": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Allow Dialogue in Narration",
|
||||
label="Allow Dialogue in Narration",
|
||||
description="Allow the narrator to generate dialogue in narration",
|
||||
value=False,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def extra_instructions(self):
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["instructions"].value
|
||||
return ""
|
||||
|
||||
|
||||
def clean_result(self, result):
|
||||
|
||||
"""
|
||||
Cleans the result of a narration
|
||||
"""
|
||||
|
||||
|
||||
result = result.strip().strip(":").strip()
|
||||
|
||||
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
|
||||
|
||||
character_names = [c.name for c in self.scene.get_characters()]
|
||||
|
||||
|
||||
|
||||
cleaned = []
|
||||
for line in result.split("\n"):
|
||||
for character_name in character_names:
|
||||
@@ -160,71 +163,83 @@ class NarratorAgent(Agent):
|
||||
cleaned.append(line)
|
||||
|
||||
result = "\n".join(cleaned)
|
||||
#result = util.strip_partial_sentences(result)
|
||||
# result = util.strip_partial_sentences(result)
|
||||
return result
|
||||
|
||||
def connect(self, scene):
|
||||
|
||||
"""
|
||||
Connect to signals
|
||||
"""
|
||||
|
||||
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
|
||||
talemate.emit.async_signals.get("agent.world_state.time").connect(
|
||||
self.on_time_passage
|
||||
)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
|
||||
|
||||
async def on_time_passage(self, event:TimePassageEmission):
|
||||
|
||||
async def on_time_passage(self, event: TimePassageEmission):
|
||||
"""
|
||||
Handles time passage narration, if enabled
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["narrate_time_passage"].enabled:
|
||||
return
|
||||
|
||||
response = await self.narrate_time_passage(event.duration, event.human_duration, event.narrative)
|
||||
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
|
||||
|
||||
response = await self.narrate_time_passage(
|
||||
event.duration, event.human_duration, event.narrative
|
||||
)
|
||||
narrator_message = NarratorMessage(
|
||||
response, source=f"narrate_time_passage:{event.duration};{event.narrative}"
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
async def on_dialog(self, event:GameLoopActorIterEvent):
|
||||
|
||||
|
||||
async def on_dialog(self, event: GameLoopActorIterEvent):
|
||||
"""
|
||||
Handles dialogue narration, if enabled
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["narrate_dialogue"].enabled:
|
||||
return
|
||||
|
||||
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug("narrate on dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
|
||||
log.debug(
|
||||
"narrate on dialog",
|
||||
skip=True,
|
||||
had_passive_narration=event.game_loop.had_passive_narration,
|
||||
)
|
||||
return
|
||||
|
||||
narrate_on_ai_chance = self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
narrate_on_player_chance = self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
|
||||
narrate_on_ai_chance = (
|
||||
self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
)
|
||||
narrate_on_player_chance = (
|
||||
self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
)
|
||||
narrate_on_ai = random.random() < narrate_on_ai_chance
|
||||
narrate_on_player = random.random() < narrate_on_player_chance
|
||||
|
||||
log.debug(
|
||||
"narrate on dialog",
|
||||
narrate_on_ai=narrate_on_ai,
|
||||
narrate_on_ai_chance=narrate_on_ai_chance,
|
||||
"narrate on dialog",
|
||||
narrate_on_ai=narrate_on_ai,
|
||||
narrate_on_ai_chance=narrate_on_ai_chance,
|
||||
narrate_on_player=narrate_on_player,
|
||||
narrate_on_player_chance=narrate_on_player_chance,
|
||||
)
|
||||
|
||||
|
||||
if event.actor.character.is_player and not narrate_on_player:
|
||||
return
|
||||
|
||||
|
||||
if not event.actor.character.is_player and not narrate_on_ai:
|
||||
return
|
||||
|
||||
|
||||
response = await self.narrate_after_dialogue(event.actor.character)
|
||||
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
|
||||
narrator_message = NarratorMessage(
|
||||
response, source=f"narrate_dialogue:{event.actor.character.name}"
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
|
||||
event.game_loop.had_passive_narration = True
|
||||
|
||||
@set_processing
|
||||
@@ -237,22 +252,22 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-scene",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
response = response.strip("*")
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
|
||||
response = f"*{response.strip('*')}*"
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def progress_story(self, narrative_direction:str=None):
|
||||
async def progress_story(self, narrative_direction: str = None):
|
||||
"""
|
||||
Narrate the scene
|
||||
"""
|
||||
@@ -260,18 +275,20 @@ class NarratorAgent(Agent):
|
||||
scene = self.scene
|
||||
pc = scene.get_player_character()
|
||||
npcs = list(scene.get_npc_characters())
|
||||
npc_names= ", ".join([npc.name for npc in npcs])
|
||||
|
||||
npc_names = ", ".join([npc.name for npc in npcs])
|
||||
|
||||
if narrative_direction is None:
|
||||
narrative_direction = "Slightly move the current scene forward."
|
||||
|
||||
self.scene.log.info("narrative_direction", narrative_direction=narrative_direction)
|
||||
|
||||
self.scene.log.info(
|
||||
"narrative_direction", narrative_direction=narrative_direction
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-progress",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"narrative_direction": narrative_direction,
|
||||
@@ -279,7 +296,7 @@ class NarratorAgent(Agent):
|
||||
"npcs": npcs,
|
||||
"npc_names": npc_names,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.info("progress_story", response=response)
|
||||
@@ -291,11 +308,13 @@ class NarratorAgent(Agent):
|
||||
if response.count("*") % 2 != 0:
|
||||
response = response.replace("*", "")
|
||||
response = f"*{response}*"
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def narrate_query(self, query:str, at_the_end:bool=False, as_narrative:bool=True):
|
||||
async def narrate_query(
|
||||
self, query: str, at_the_end: bool = False, as_narrative: bool = True
|
||||
):
|
||||
"""
|
||||
Narrate a specific query
|
||||
"""
|
||||
@@ -303,21 +322,21 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-query",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"query": query,
|
||||
"at_the_end": at_the_end,
|
||||
"as_narrative": as_narrative,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
log.info("narrate_query", response=response)
|
||||
response = self.clean_result(response.strip())
|
||||
log.info("narrate_query (after clean)", response=response)
|
||||
if as_narrative:
|
||||
response = f"*{response}*"
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@@ -330,12 +349,12 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-character",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"character": character,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip())
|
||||
@@ -345,54 +364,55 @@ class NarratorAgent(Agent):
|
||||
|
||||
@set_processing
|
||||
async def augment_context(self):
|
||||
|
||||
"""
|
||||
Takes a context history generated via scene.context_history() and augments it with additional information
|
||||
by asking and answering questions with help from the long term memory.
|
||||
"""
|
||||
memory = self.scene.get_helper("memory").agent
|
||||
|
||||
|
||||
questions = await Prompt.request(
|
||||
"narrator.context-questions",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
self.scene.log.info("context_questions", questions=questions)
|
||||
|
||||
|
||||
questions = [q for q in questions.split("\n") if q.strip()]
|
||||
|
||||
|
||||
memory_context = await memory.multi_query(
|
||||
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
|
||||
)
|
||||
|
||||
|
||||
answers = await Prompt.request(
|
||||
"narrator.context-answers",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"memory": memory_context,
|
||||
"questions": questions,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
self.scene.log.info("context_answers", answers=answers)
|
||||
|
||||
|
||||
answers = [a for a in answers.split("\n") if a.strip()]
|
||||
|
||||
|
||||
# return questions and answers
|
||||
return list(zip(questions, answers))
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_time_passage(self, duration:str, time_passed:str, narrative:str):
|
||||
async def narrate_time_passage(
|
||||
self, duration: str, time_passed: str, narrative: str
|
||||
):
|
||||
"""
|
||||
Narrate a specific character
|
||||
"""
|
||||
@@ -401,26 +421,25 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-time-passage",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"duration": duration,
|
||||
"time_passed": time_passed,
|
||||
"narrative": narrative,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.info("narrate_time_passage", response=response)
|
||||
|
||||
response = self.clean_result(response.strip())
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_after_dialogue(self, character:Character):
|
||||
async def narrate_after_dialogue(self, character: Character):
|
||||
"""
|
||||
Narrate after a line of dialogue
|
||||
"""
|
||||
@@ -429,22 +448,24 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-after-dialogue",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"last_line": str(self.scene.history[-1]),
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.info("narrate_after_dialogue", response=response)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
allow_dialogue = self.actions["narrate_dialogue"].config["generate_dialogue"].value
|
||||
|
||||
|
||||
allow_dialogue = (
|
||||
self.actions["narrate_dialogue"].config["generate_dialogue"].value
|
||||
)
|
||||
|
||||
if not allow_dialogue:
|
||||
response = response.split('"')[0].strip()
|
||||
response = response.replace("*", "")
|
||||
@@ -452,9 +473,11 @@ class NarratorAgent(Agent):
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_character_entry(self, character:Character, direction:str=None):
|
||||
async def narrate_character_entry(
|
||||
self, character: Character, direction: str = None
|
||||
):
|
||||
"""
|
||||
Narrate a character entering the scene
|
||||
"""
|
||||
@@ -463,22 +486,22 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-character-entry",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"direction": direction,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_character_exit(self, character:Character, direction:str=None):
|
||||
async def narrate_character_exit(self, character: Character, direction: str = None):
|
||||
"""
|
||||
Narrate a character exiting the scene
|
||||
"""
|
||||
@@ -487,47 +510,136 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-character-exit",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"direction": direction,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def paraphrase(self, narration: str):
|
||||
"""
|
||||
Paraphrase a narration
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.paraphrase",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars={
|
||||
"text": narration,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
log.info("paraphrase", narration=narration, response=response)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
async def passthrough(self, narration: str) -> str:
|
||||
"""
|
||||
Pass through narration message as is
|
||||
"""
|
||||
narration = narration.replace("*", "")
|
||||
narration = f"*{narration}*"
|
||||
narration = util.ensure_dialog_format(narration)
|
||||
return narration
|
||||
|
||||
def action_to_source(
|
||||
self,
|
||||
action_name: str,
|
||||
parameters: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a source string for a given action and parameters
|
||||
|
||||
The source string is used to identify the source of a NarratorMessage
|
||||
and will also help regenerate the action and parameters from the source string
|
||||
later on
|
||||
"""
|
||||
|
||||
args = []
|
||||
|
||||
if action_name == "paraphrase":
|
||||
args.append(parameters.get("narration"))
|
||||
elif action_name == "narrate_character_entry":
|
||||
args.append(parameters.get("character").name)
|
||||
# args.append(parameters.get("direction"))
|
||||
elif action_name == "narrate_character_exit":
|
||||
args.append(parameters.get("character").name)
|
||||
# args.append(parameters.get("direction"))
|
||||
elif action_name == "narrate_character":
|
||||
args.append(parameters.get("character").name)
|
||||
elif action_name == "narrate_query":
|
||||
args.append(parameters.get("query"))
|
||||
elif action_name == "narrate_time_passage":
|
||||
args.append(parameters.get("duration"))
|
||||
args.append(parameters.get("time_passed"))
|
||||
args.append(parameters.get("narrative"))
|
||||
elif action_name == "progress_story":
|
||||
args.append(parameters.get("narrative_direction"))
|
||||
elif action_name == "narrate_after_dialogue":
|
||||
args.append(parameters.get("character"))
|
||||
|
||||
arg_str = ";".join(args) if args else ""
|
||||
|
||||
return f"{action_name}:{arg_str}".rstrip(":")
|
||||
|
||||
async def action_to_narration(
|
||||
self,
|
||||
action_name: str,
|
||||
*args,
|
||||
emit_message: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# calls self[action_name] and returns the result as a NarratorMessage
|
||||
# that is pushed to the history
|
||||
|
||||
|
||||
fn = getattr(self, action_name)
|
||||
narration = await fn(*args, **kwargs)
|
||||
narrator_message = NarratorMessage(narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":"))
|
||||
narration = await fn(**kwargs)
|
||||
source = self.action_to_source(action_name, kwargs)
|
||||
|
||||
narrator_message = NarratorMessage(narration, source=source)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
if emit_message:
|
||||
emit("narrator", narrator_message)
|
||||
|
||||
return narrator_message
|
||||
|
||||
|
||||
action_to_narration.exposed = True
|
||||
|
||||
# LLM client related methods. These are called during or after the client
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
|
||||
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
log.debug(
|
||||
"inject_prompt_paramters",
|
||||
prompt_param=prompt_param,
|
||||
kind=kind,
|
||||
agent_function_name=agent_function_name,
|
||||
)
|
||||
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += character_names
|
||||
|
||||
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.data_objects as data_objects
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
from talemate.events import GameLoopEvent
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
import re
|
||||
|
||||
log = structlog.get_logger("talemate.agents.summarize")
|
||||
|
||||
|
||||
@register()
|
||||
class SummarizeAgent(Agent):
|
||||
"""
|
||||
@@ -36,7 +36,7 @@ class SummarizeAgent(Agent):
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
|
||||
|
||||
self.actions = {
|
||||
"archive": AgentAction(
|
||||
enabled=True,
|
||||
@@ -61,36 +61,43 @@ class SummarizeAgent(Agent):
|
||||
{"label": "Short & Concise", "value": "short"},
|
||||
{"label": "Balanced", "value": "balanced"},
|
||||
{"label": "Lengthy & Detailed", "value": "long"},
|
||||
{"label": "Factual List", "value": "facts"},
|
||||
],
|
||||
),
|
||||
"include_previous": AgentActionConfig(
|
||||
type="number",
|
||||
label="Use preceeding summaries to strengthen context",
|
||||
description="Number of entries",
|
||||
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
|
||||
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
|
||||
value=3,
|
||||
min=0,
|
||||
max=10,
|
||||
step=1,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def threshold(self):
|
||||
return self.actions["archive"].config["threshold"].value
|
||||
|
||||
@property
|
||||
def estimated_entry_count(self):
|
||||
all_tokens = sum([util.count_tokens(entry) for entry in self.scene.history])
|
||||
return all_tokens // self.threshold
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission:GameLoopEvent):
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
await self.build_archive(self.scene)
|
||||
|
||||
|
||||
|
||||
def clean_result(self, result):
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
@@ -104,10 +111,10 @@ class SummarizeAgent(Agent):
|
||||
@set_processing
|
||||
async def build_archive(self, scene):
|
||||
end = None
|
||||
|
||||
|
||||
if not self.actions["archive"].enabled:
|
||||
return
|
||||
|
||||
|
||||
if not scene.archived_history:
|
||||
start = 0
|
||||
recent_entry = None
|
||||
@@ -118,14 +125,16 @@ class SummarizeAgent(Agent):
|
||||
# meaning we are still at the beginning of the scene
|
||||
start = 0
|
||||
else:
|
||||
start = recent_entry.get("end", 0)+1
|
||||
|
||||
start = recent_entry.get("end", 0) + 1
|
||||
|
||||
# if there is a recent entry we also collect the 3 most recentries
|
||||
# as extra context
|
||||
|
||||
|
||||
num_previous = self.actions["archive"].config["include_previous"].value
|
||||
if recent_entry and num_previous > 0:
|
||||
extra_context = "\n\n".join([entry["text"] for entry in scene.archived_history[-num_previous:]])
|
||||
extra_context = "\n\n".join(
|
||||
[entry["text"] for entry in scene.archived_history[-num_previous:]]
|
||||
)
|
||||
else:
|
||||
extra_context = None
|
||||
|
||||
@@ -133,36 +142,44 @@ class SummarizeAgent(Agent):
|
||||
dialogue_entries = []
|
||||
ts = "PT0S"
|
||||
time_passage_termination = False
|
||||
|
||||
|
||||
token_threshold = self.actions["archive"].config["threshold"].value
|
||||
|
||||
|
||||
log.debug("build_archive", start=start, recent_entry=recent_entry)
|
||||
|
||||
|
||||
if recent_entry:
|
||||
ts = recent_entry.get("ts", ts)
|
||||
|
||||
for i in range(start, len(scene.history)):
|
||||
|
||||
# we ignore the most recent entry, as the user may still chose to
|
||||
# regenerate it
|
||||
for i in range(start, max(start, len(scene.history) - 1)):
|
||||
dialogue = scene.history[i]
|
||||
|
||||
#log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
|
||||
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
if isinstance(dialogue, DirectorMessage):
|
||||
if i == start:
|
||||
start += 1
|
||||
continue
|
||||
|
||||
|
||||
if isinstance(dialogue, TimePassageMessage):
|
||||
log.debug("build_archive", time_passage_message=dialogue)
|
||||
if i == start:
|
||||
ts = util.iso8601_add(ts, dialogue.ts)
|
||||
log.debug("build_archive", time_passage_message=dialogue, start=start, i=i, ts=ts)
|
||||
log.debug(
|
||||
"build_archive",
|
||||
time_passage_message=dialogue,
|
||||
start=start,
|
||||
i=i,
|
||||
ts=ts,
|
||||
)
|
||||
start += 1
|
||||
continue
|
||||
log.debug("build_archive", time_passage_message_termination=dialogue)
|
||||
time_passage_termination = True
|
||||
end = i - 1
|
||||
break
|
||||
|
||||
|
||||
tokens += util.count_tokens(dialogue)
|
||||
dialogue_entries.append(dialogue)
|
||||
if tokens > token_threshold: #
|
||||
@@ -172,39 +189,44 @@ class SummarizeAgent(Agent):
|
||||
if end is None:
|
||||
# nothing to archive yet
|
||||
return
|
||||
|
||||
log.debug("build_archive", start=start, end=end, ts=ts, time_passage_termination=time_passage_termination)
|
||||
|
||||
log.debug(
|
||||
"build_archive",
|
||||
start=start,
|
||||
end=end,
|
||||
ts=ts,
|
||||
time_passage_termination=time_passage_termination,
|
||||
)
|
||||
|
||||
# in order to summarize coherently, we need to determine if there is a favorable
|
||||
# cutoff point (e.g., the scene naturally ends or shifts meaninfully in the middle
|
||||
# of the dialogue)
|
||||
#
|
||||
# One way to do this is to check if the last line is a TimePassageMessage, which
|
||||
# indicates a scene change or a significant pause.
|
||||
#
|
||||
# indicates a scene change or a significant pause.
|
||||
#
|
||||
# If not, we can ask the AI to find a good point of
|
||||
# termination.
|
||||
|
||||
|
||||
if not time_passage_termination:
|
||||
|
||||
# No TimePassageMessage, so we need to ask the AI to find a good point of termination
|
||||
|
||||
|
||||
terminating_line = await self.analyze_dialoge(dialogue_entries)
|
||||
|
||||
if terminating_line:
|
||||
adjusted_dialogue = []
|
||||
for line in dialogue_entries:
|
||||
for line in dialogue_entries:
|
||||
if str(line) in terminating_line:
|
||||
break
|
||||
adjusted_dialogue.append(line)
|
||||
dialogue_entries = adjusted_dialogue
|
||||
end = start + len(dialogue_entries)-1
|
||||
|
||||
end = start + len(dialogue_entries) - 1
|
||||
|
||||
if dialogue_entries:
|
||||
summarized = await self.summarize(
|
||||
"\n".join(map(str, dialogue_entries)), extra_context=extra_context
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# AI has likely identified the first line as a scene change, so we can't summarize
|
||||
# just use the first line
|
||||
@@ -218,15 +240,20 @@ class SummarizeAgent(Agent):
|
||||
|
||||
@set_processing
|
||||
async def analyze_dialoge(self, dialogue):
|
||||
response = await Prompt.request("summarizer.analyze-dialogue", self.client, "analyze_freeform", vars={
|
||||
"dialogue": "\n".join(map(str, dialogue)),
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
})
|
||||
|
||||
response = await Prompt.request(
|
||||
"summarizer.analyze-dialogue",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars={
|
||||
"dialogue": "\n".join(map(str, dialogue)),
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response)
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize(
|
||||
self,
|
||||
@@ -239,33 +266,42 @@ class SummarizeAgent(Agent):
|
||||
Summarize the given text
|
||||
"""
|
||||
|
||||
response = await Prompt.request("summarizer.summarize-dialogue", self.client, "summarize", vars={
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"summarization_method": self.actions["archive"].config["method"].value if method is None else method,
|
||||
"extra_context": extra_context or "",
|
||||
"extra_instructions": extra_instructions or "",
|
||||
})
|
||||
|
||||
self.scene.log.info("summarize", dialogue_length=len(text), summarized_length=len(response))
|
||||
response = await Prompt.request(
|
||||
"summarizer.summarize-dialogue",
|
||||
self.client,
|
||||
"summarize",
|
||||
vars={
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"summarization_method": (
|
||||
self.actions["archive"].config["method"].value
|
||||
if method is None
|
||||
else method
|
||||
),
|
||||
"extra_context": extra_context or "",
|
||||
"extra_instructions": extra_instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.info(
|
||||
"summarize", dialogue_length=len(text), summarized_length=len(response)
|
||||
)
|
||||
|
||||
return self.clean_result(response)
|
||||
|
||||
|
||||
async def build_stepped_archive_for_level(self, level:int):
|
||||
|
||||
|
||||
async def build_stepped_archive_for_level(self, level: int):
|
||||
"""
|
||||
WIP - not yet used
|
||||
|
||||
|
||||
This will iterate over existing archived_history entries
|
||||
and stepped_archived_history entries and summarize based on time duration
|
||||
indicated between the entries.
|
||||
|
||||
|
||||
The lowest level of summarization (based on token threshold and any time passage)
|
||||
happens in build_archive. This method is for summarizing furhter levels based on
|
||||
long time pasages.
|
||||
|
||||
|
||||
Level 0: small timestap summarize (summarizes all token summarizations when time advances +1 day)
|
||||
Level 1: medium timestap summarize (summarizes all small timestep summarizations when time advances +1 week)
|
||||
Level 2: large timestap summarize (summarizes all medium timestep summarizations when time advances +1 month)
|
||||
@@ -273,7 +309,7 @@ class SummarizeAgent(Agent):
|
||||
Level 4: massive timestap summarize (summarizes all huge timestep summarizations when time advances +10 years)
|
||||
Level 5: epic timestap summarize (summarizes all massive timestep summarizations when time advances +100 years)
|
||||
and so on (increasing by a factor of 10 each time)
|
||||
|
||||
|
||||
```
|
||||
@dataclass
|
||||
class ArchiveEntry:
|
||||
@@ -282,35 +318,34 @@ class SummarizeAgent(Agent):
|
||||
end: int = None
|
||||
ts: str = None
|
||||
```
|
||||
|
||||
|
||||
Like token summarization this will use ArchiveEntry and start and end will refer to the entries in the
|
||||
lower level of summarization.
|
||||
|
||||
|
||||
Ts is the iso8601 timestamp of the start of the summarized period.
|
||||
"""
|
||||
|
||||
|
||||
# select the list to use for the entries
|
||||
|
||||
|
||||
if level == 0:
|
||||
entries = self.scene.archived_history
|
||||
else:
|
||||
entries = self.scene.stepped_archived_history[level-1]
|
||||
|
||||
entries = self.scene.stepped_archived_history[level - 1]
|
||||
|
||||
# select the list to summarize new entries to
|
||||
|
||||
|
||||
target = self.scene.stepped_archived_history[level]
|
||||
|
||||
|
||||
if not target:
|
||||
raise ValueError(f"Invalid level {level}")
|
||||
|
||||
|
||||
# determine the start and end of the period to summarize
|
||||
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
|
||||
|
||||
# determine the time threshold for this level
|
||||
|
||||
|
||||
# first calculate all possible thresholds in iso8601 format, starting with 1 day
|
||||
thresholds = [
|
||||
"P1D",
|
||||
@@ -318,61 +353,65 @@ class SummarizeAgent(Agent):
|
||||
"P1M",
|
||||
"P1Y",
|
||||
]
|
||||
|
||||
|
||||
# TODO: auto extend?
|
||||
|
||||
|
||||
time_threshold_in_seconds = util.iso8601_to_seconds(thresholds[level])
|
||||
|
||||
|
||||
if not time_threshold_in_seconds:
|
||||
raise ValueError(f"Invalid level {level}")
|
||||
|
||||
|
||||
# determine the most recent summarized entry time, and then find entries
|
||||
# that are newer than that in the lower list
|
||||
|
||||
|
||||
ts = target[-1].ts if target else entries[0].ts
|
||||
|
||||
|
||||
# determine the most recent entry at the lower level, if its not newer or
|
||||
# the difference is less than the threshold, then we don't need to summarize
|
||||
|
||||
|
||||
recent_entry = entries[-1]
|
||||
|
||||
|
||||
if util.iso8601_diff(recent_entry.ts, ts) < time_threshold_in_seconds:
|
||||
return
|
||||
|
||||
|
||||
log.debug("build_stepped_archive", level=level, ts=ts)
|
||||
|
||||
|
||||
# if target is empty, start is 0
|
||||
# otherwise start is the end of the last entry
|
||||
|
||||
|
||||
start = 0 if not target else target[-1].end
|
||||
|
||||
|
||||
# collect entries starting at start until the combined time duration
|
||||
# exceeds the threshold
|
||||
|
||||
|
||||
entries_to_summarize = []
|
||||
|
||||
|
||||
for entry in entries[start:]:
|
||||
entries_to_summarize.append(entry)
|
||||
if util.iso8601_diff(entry.ts, ts) > time_threshold_in_seconds:
|
||||
break
|
||||
|
||||
|
||||
# summarize the entries
|
||||
# we also collect N entries of previous summaries to use as context
|
||||
|
||||
|
||||
num_previous = self.actions["archive"].config["include_previous"].value
|
||||
if num_previous > 0:
|
||||
extra_context = "\n\n".join([entry["text"] for entry in target[-num_previous:]])
|
||||
extra_context = "\n\n".join(
|
||||
[entry["text"] for entry in target[-num_previous:]]
|
||||
)
|
||||
else:
|
||||
extra_context = None
|
||||
|
||||
|
||||
summarized = await self.summarize(
|
||||
"\n".join(map(str, entries_to_summarize)), extra_context=extra_context
|
||||
)
|
||||
|
||||
|
||||
# push summarized entry to target
|
||||
|
||||
|
||||
ts = entries_to_summarize[-1].ts
|
||||
|
||||
target.append(data_objects.ArchiveEntry(summarized, start, len(entries_to_summarize)-1, ts=ts))
|
||||
|
||||
|
||||
|
||||
target.append(
|
||||
data_objects.ArchiveEntry(
|
||||
summarized, start, len(entries_to_summarize) - 1, ts=ts
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
import asyncio
|
||||
import httpx
|
||||
import base64
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import pydantic
|
||||
import nltk
|
||||
import tempfile
|
||||
import base64
|
||||
import time
|
||||
import uuid
|
||||
import functools
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
import nltk
|
||||
import pydantic
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import talemate.config as config
|
||||
import talemate.emit.async_signals
|
||||
@@ -21,91 +25,91 @@ from talemate.emit.signals import handlers
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts")#
|
||||
log = structlog.get_logger("talemate.agents.tts") #
|
||||
|
||||
if not TTS:
|
||||
# TTS installation is massive and requires a lot of dependencies
|
||||
# so we don't want to require it unless the user wants to use it
|
||||
log.info("TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api")
|
||||
log.info(
|
||||
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
|
||||
)
|
||||
|
||||
|
||||
def parse_chunks(text):
|
||||
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
|
||||
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*","")
|
||||
chunk = chunk.replace("*", "")
|
||||
if not chunk:
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
|
||||
return cleaned_chunks
|
||||
|
||||
def clean_quotes(chunk:str):
|
||||
|
||||
|
||||
def clean_quotes(chunk: str):
|
||||
# if there is an uneven number of quotes, remove the last one if its
|
||||
# at the end of the chunk. If its in the middle, add a quote to the end
|
||||
if chunk.count('"') % 2 == 1:
|
||||
|
||||
if chunk.endswith('"'):
|
||||
chunk = chunk[:-1]
|
||||
else:
|
||||
chunk += '"'
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def rejoin_chunks(chunks:list[str], chunk_size:int=250):
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
|
||||
for chunk in chunks:
|
||||
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
current_chunk = ""
|
||||
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
return joined_chunks
|
||||
|
||||
|
||||
class Voice(pydantic.BaseModel):
|
||||
value:str
|
||||
label:str
|
||||
value: str
|
||||
label: str
|
||||
|
||||
|
||||
class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
api: str
|
||||
voices: list[Voice] = pydantic.Field(default_factory=list)
|
||||
last_synced: float = None
|
||||
@@ -113,51 +117,50 @@ class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
|
||||
essential = False
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
|
||||
if agent:
|
||||
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
|
||||
voice.model_dump() for voice in agent.list_voices_sync()
|
||||
]
|
||||
|
||||
|
||||
return config_options
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
self.is_enabled = False
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"coqui": VoiceLibrary(api="coqui"),
|
||||
"tts": VoiceLibrary(api="tts"),
|
||||
"openai": VoiceLibrary(api="openai"),
|
||||
}
|
||||
self.config = config.load_config()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.preselect_voice = None
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"api": AgentActionConfig(
|
||||
type="text",
|
||||
choices=[
|
||||
# TODO at local TTS support
|
||||
{"value": "tts", "label": "TTS (Local)"},
|
||||
{"value": "elevenlabs", "label": "Eleven Labs"},
|
||||
{"value": "coqui", "label": "Coqui Studio"},
|
||||
{"value": "openai", "label": "OpenAI"},
|
||||
],
|
||||
value="tts",
|
||||
label="API",
|
||||
@@ -169,7 +172,7 @@ class TTSAgent(Agent):
|
||||
value="default",
|
||||
label="Narrator Voice",
|
||||
description="Voice ID/Name to use for TTS",
|
||||
choices=[]
|
||||
choices=[],
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -194,90 +197,125 @@ class TTSAgent(Agent):
|
||||
value=False,
|
||||
label="Split generation",
|
||||
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
),
|
||||
"openai": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.api", value="openai"
|
||||
),
|
||||
label="OpenAI Settings",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="tts-1",
|
||||
choices=[
|
||||
{"value": "tts-1", "label": "TTS 1"},
|
||||
{"value": "tts-1-hd", "label": "TTS 1 HD"},
|
||||
],
|
||||
label="Model",
|
||||
description="TTS model to use",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def not_ready_reason(self) -> str:
|
||||
"""
|
||||
Returns a string explaining why the agent is not ready
|
||||
"""
|
||||
|
||||
|
||||
if self.ready:
|
||||
return ""
|
||||
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "TTS not installed"
|
||||
|
||||
|
||||
elif self.requires_token and not self.token:
|
||||
return "No API token"
|
||||
|
||||
|
||||
elif not self.default_voice_id:
|
||||
return "No voice selected"
|
||||
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
suffix = ""
|
||||
|
||||
if not self.ready:
|
||||
suffix = f" - {self.not_ready_reason}"
|
||||
else:
|
||||
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
|
||||
|
||||
api = self.api
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api_label = api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
api_label = choice["label"]
|
||||
break
|
||||
|
||||
return f"{api_label}{suffix}"
|
||||
|
||||
details = {
|
||||
"api": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.api_label,
|
||||
description="The backend to use for TTS",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if self.ready and self.enabled:
|
||||
details["voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=self.voice_id_to_label(self.default_voice_id) or "",
|
||||
description="The voice to use for TTS",
|
||||
color="info",
|
||||
).model_dump()
|
||||
elif self.enabled:
|
||||
details["error"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=self.not_ready_reason,
|
||||
description=self.not_ready_reason,
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
|
||||
@property
|
||||
def api_label(self):
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api = self.api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
return choice["label"]
|
||||
return api
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
return self.config.get(api,{}).get("api_key")
|
||||
|
||||
return self.config.get(api, {}).get("api_key")
|
||||
|
||||
@property
|
||||
def default_voice_id(self):
|
||||
return self.actions["_config"].config["voice_id"].value
|
||||
|
||||
|
||||
@property
|
||||
def requires_token(self):
|
||||
return self.api != "tts"
|
||||
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
return (not self.requires_token or self.token) and self.default_voice_id
|
||||
|
||||
@property
|
||||
@@ -285,6 +323,8 @@ class TTSAgent(Agent):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
if self.requires_token and not self.token:
|
||||
return "error"
|
||||
@@ -299,106 +339,139 @@ class TTSAgent(Agent):
|
||||
return 1024
|
||||
elif self.api == "coqui":
|
||||
return 250
|
||||
|
||||
|
||||
return 250
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
api = self.api
|
||||
|
||||
api_changed = api != self.api
|
||||
|
||||
log.debug("apply_config", api=api, api_changed=api != self.api, current_api=self.api)
|
||||
|
||||
super().apply_config(*args, **kwargs)
|
||||
|
||||
|
||||
api_changed = api != self.api
|
||||
|
||||
log.debug(
|
||||
"apply_config",
|
||||
api=api,
|
||||
api_changed=api != self.api,
|
||||
current_api=self.api,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
|
||||
"value"
|
||||
]
|
||||
except KeyError:
|
||||
self.preselect_voice = self.default_voice_id
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
self.actions["_config"].config["voice_id"].value = self.voices[api].voices[0].value
|
||||
self.actions["_config"].config["voice_id"].value = (
|
||||
self.voices[api].voices[0].value
|
||||
)
|
||||
except IndexError:
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(self.on_game_loop_new_message)
|
||||
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(
|
||||
self.on_game_loop_new_message
|
||||
)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
instance.emit_agent_status(self.__class__, self)
|
||||
|
||||
async def on_game_loop_new_message(self, emission:GameLoopNewMessageEvent):
|
||||
|
||||
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled or not self.ready:
|
||||
return
|
||||
|
||||
|
||||
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, NarratorMessage) and not self.actions["_config"].config["generate_for_narration"].value:
|
||||
|
||||
if (
|
||||
isinstance(emission.message, NarratorMessage)
|
||||
and not self.actions["_config"].config["generate_for_narration"].value
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
|
||||
if emission.message.source == "player" and not self.actions["_config"].config["generate_for_player"].value:
|
||||
if (
|
||||
emission.message.source == "player"
|
||||
and not self.actions["_config"].config["generate_for_player"].value
|
||||
):
|
||||
return
|
||||
elif emission.message.source == "ai" and not self.actions["_config"].config["generate_for_npc"].value:
|
||||
elif (
|
||||
emission.message.source == "ai"
|
||||
and not self.actions["_config"].config["generate_for_npc"].value
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
else:
|
||||
character_prefix = ""
|
||||
|
||||
log.info("reactive tts", message=emission.message, character_prefix=character_prefix)
|
||||
|
||||
await self.generate(str(emission.message).replace(character_prefix+": ", ""))
|
||||
|
||||
log.info(
|
||||
"reactive tts", message=emission.message, character_prefix=character_prefix
|
||||
)
|
||||
|
||||
def voice(self, voice_id:str) -> Union[Voice, None]:
|
||||
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
|
||||
|
||||
def voice(self, voice_id: str) -> Union[Voice, None]:
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice
|
||||
return None
|
||||
|
||||
def voice_id_to_label(self, voice_id:str):
|
||||
|
||||
def voice_id_to_label(self, voice_id: str):
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice.label
|
||||
return None
|
||||
|
||||
|
||||
def list_voices_sync(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(self.list_voices())
|
||||
|
||||
|
||||
async def list_voices(self):
|
||||
if self.requires_token and not self.token:
|
||||
return []
|
||||
|
||||
|
||||
library = self.voices[self.api]
|
||||
|
||||
|
||||
# TODO: allow re-syncing voices
|
||||
if library.last_synced:
|
||||
return library.voices
|
||||
|
||||
|
||||
list_fn = getattr(self, f"_list_voices_{self.api}")
|
||||
log.info("Listing voices", api=self.api)
|
||||
|
||||
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
|
||||
if self.preselect_voice:
|
||||
if self.voice(self.preselect_voice):
|
||||
self.actions["_config"].config["voice_id"].value = self.preselect_voice
|
||||
self.preselect_voice = None
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
|
||||
# set loading to false
|
||||
return library.voices
|
||||
|
||||
@@ -407,11 +480,10 @@ class TTSAgent(Agent):
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
|
||||
generate_fn = getattr(self, f"_generate_{self.api}")
|
||||
|
||||
|
||||
if self.actions["_config"].config["generate_chunks"].value:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks)
|
||||
@@ -421,65 +493,78 @@ class TTSAgent(Agent):
|
||||
|
||||
# Start generating audio chunks in the background
|
||||
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
|
||||
await self.set_background_processing(generation_task)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
await asyncio.gather(generation_task)
|
||||
# await asyncio.gather(generation_task)
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*","").strip()
|
||||
chunk = chunk.replace("*", "").strip()
|
||||
log.info("Generating audio", api=self.api, chunk=chunk)
|
||||
audio_data = await generate_fn(chunk)
|
||||
self.play_audio(audio_data)
|
||||
|
||||
def play_audio(self, audio_data):
|
||||
# play audio through the python audio player
|
||||
#play(audio_data)
|
||||
|
||||
emit("audio_queue", data={"audio_data": base64.b64encode(audio_data).decode("utf-8")})
|
||||
|
||||
# play(audio_data)
|
||||
|
||||
emit(
|
||||
"audio_queue",
|
||||
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
|
||||
)
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
# LOCAL
|
||||
|
||||
|
||||
async def _generate_tts(self, text: str) -> Union[bytes, None]:
|
||||
|
||||
if not TTS:
|
||||
return
|
||||
|
||||
tts_config = self.config.get("tts",{})
|
||||
|
||||
tts_config = self.config.get("tts", {})
|
||||
model = tts_config.get("model")
|
||||
device = tts_config.get("device", "cpu")
|
||||
|
||||
|
||||
log.debug("tts local", model=model, device=device)
|
||||
|
||||
|
||||
if not hasattr(self, "tts_instance"):
|
||||
self.tts_instance = TTS(model).to(device)
|
||||
|
||||
|
||||
tts = self.tts_instance
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
voice = self.voice(self.default_voice_id)
|
||||
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(None, functools.partial(tts.tts_to_file, text=text, speaker_wav=voice.value, language="en", file_path=file_path))
|
||||
#tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.tts_to_file,
|
||||
text=text,
|
||||
speaker_wav=voice.value,
|
||||
language="en",
|
||||
file_path=file_path,
|
||||
),
|
||||
)
|
||||
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
|
||||
async def _list_voices_tts(self) -> dict[str, str]:
|
||||
return [Voice(**voice) for voice in self.config.get("tts",{}).get("voices",[])]
|
||||
|
||||
return [
|
||||
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
|
||||
]
|
||||
|
||||
# ELEVENLABS
|
||||
|
||||
async def _generate_elevenlabs(self, text: str, chunk_size: int = 1024) -> Union[bytes, None]:
|
||||
async def _generate_elevenlabs(
|
||||
self, text: str, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
@@ -493,11 +578,8 @@ class TTSAgent(Agent):
|
||||
}
|
||||
data = {
|
||||
"text": text,
|
||||
"model_id": self.config.get("elevenlabs",{}).get("model"),
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.5
|
||||
}
|
||||
"model_id": self.config.get("elevenlabs", {}).get("model"),
|
||||
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
||||
}
|
||||
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
@@ -514,104 +596,57 @@ class TTSAgent(Agent):
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _list_voices_elevenlabs(self) -> dict[str, str]:
|
||||
|
||||
url_voices = "https://api.elevenlabs.io/v1/voices"
|
||||
|
||||
|
||||
voices = []
|
||||
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"xi-api-key": self.token,
|
||||
}
|
||||
response = await client.get(url_voices, headers=headers, params={"per_page":1000})
|
||||
response = await client.get(
|
||||
url_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["voices"]
|
||||
voices.extend([Voice(value=speaker["voice_id"], label=speaker["name"]) for speaker in speakers])
|
||||
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["voice_id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
|
||||
# COQUI STUDIO
|
||||
|
||||
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = "https://app.coqui.ai/api/v2/samples/xtts/render/"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
data = {
|
||||
"voice_id": self.default_voice_id,
|
||||
"text": text,
|
||||
"language": "en" # Assuming English language for simplicity; this could be parameterized
|
||||
}
|
||||
return voices
|
||||
|
||||
# Make the POST request to Coqui API
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
if response.status_code in [200, 201]:
|
||||
# Parse the JSON response to get the audio URL
|
||||
response_data = response.json()
|
||||
audio_url = response_data.get('audio_url')
|
||||
if audio_url:
|
||||
# Make a GET request to download the audio file
|
||||
audio_response = await client.get(audio_url)
|
||||
if audio_response.status_code == 200:
|
||||
# delete the sample from Coqui Studio
|
||||
# await self._cleanup_coqui(response_data.get('id'))
|
||||
return audio_response.content
|
||||
else:
|
||||
log.error(f"Error downloading audio: {audio_response.text}")
|
||||
else:
|
||||
log.error("No audio URL in response")
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _cleanup_coqui(self, sample_id: str):
|
||||
api_key = self.token
|
||||
if not api_key or not sample_id:
|
||||
return
|
||||
# OPENAI
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
async def _generate_openai(self, text: str, chunk_size: int = 1024):
|
||||
|
||||
# Make the DELETE request to Coqui API
|
||||
response = await client.delete(url, headers=headers)
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
if response.status_code == 204:
|
||||
log.info(f"Successfully deleted sample with ID: {sample_id}")
|
||||
else:
|
||||
log.error(f"Error deleting sample with ID: {sample_id}: {response.text}")
|
||||
model = self.actions["openai"].config["model"].value
|
||||
|
||||
async def _list_voices_coqui(self) -> dict[str, str]:
|
||||
|
||||
url_speakers = "https://app.coqui.ai/api/v2/speakers"
|
||||
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}"
|
||||
}
|
||||
response = await client.get(url_speakers, headers=headers, params={"per_page":1000})
|
||||
speakers = response.json()["result"]
|
||||
voices.extend([Voice(value=speaker["id"], label=speaker["name"]) for speaker in speakers])
|
||||
|
||||
response = await client.get(url_custom_voices, headers=headers, params={"per_page":1000})
|
||||
custom_voices = response.json()["result"]
|
||||
voices.extend([Voice(value=voice["id"], label=voice["name"]) for voice in custom_voices])
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
response = await client.audio.speech.create(
|
||||
model=model, voice=self.default_voice_id, input=text
|
||||
)
|
||||
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
|
||||
async def _list_voices_openai(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(value="alloy", label="Alloy"),
|
||||
Voice(value="echo", label="Echo"),
|
||||
Voice(value="fable", label="Fable"),
|
||||
Voice(value="onyx", label="Onyx"),
|
||||
Voice(value="nova", label="Nova"),
|
||||
Voice(value="shimmer", label="Shimmer"),
|
||||
]
|
||||
|
||||
483
src/talemate/agents/visual/__init__.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.agents.visual.automatic1111
|
||||
import talemate.agents.visual.comfyui
|
||||
import talemate.agents.visual.openai_image
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers as signal_handlers
|
||||
from talemate.prompts.base import Prompt
|
||||
|
||||
from .commands import * # noqa
|
||||
from .context import VIS_TYPES, VisualContext, visual_context
|
||||
from .handlers import HANDLERS
|
||||
from .schema import RESOLUTION_MAP, RenderSettings
|
||||
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
|
||||
from .websocket_handler import VisualWebsocketHandler
|
||||
|
||||
__all__ = [
|
||||
"VisualAgent",
|
||||
]
|
||||
|
||||
BACKENDS = [
|
||||
{"value": mixin_backend, "label": mixin["label"]}
|
||||
for mixin_backend, mixin in HANDLERS.items()
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual")
|
||||
|
||||
|
||||
class VisualBase(Agent):
|
||||
"""
|
||||
The visual agent
|
||||
"""
|
||||
|
||||
agent_type = "visual"
|
||||
verbose_name = "Visualizer"
|
||||
essential = False
|
||||
websocket_handler = VisualWebsocketHandler
|
||||
|
||||
ACTIONS = {}
|
||||
|
||||
def __init__(self, client: ClientBase, *kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = False
|
||||
self.backend_ready = False
|
||||
self.initialized = False
|
||||
self.config = load_config()
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="Visual agent configuration",
|
||||
config={
|
||||
"backend": AgentActionConfig(
|
||||
type="text",
|
||||
choices=BACKENDS,
|
||||
value="automatic1111",
|
||||
label="Backend",
|
||||
description="The backend to use for visual processing",
|
||||
),
|
||||
"default_style": AgentActionConfig(
|
||||
type="text",
|
||||
value="graphic_novel",
|
||||
choices=MAJOR_STYLES,
|
||||
label="Default Style",
|
||||
description="The default style to use for visual processing",
|
||||
),
|
||||
},
|
||||
),
|
||||
"automatic_setup": AgentAction(
|
||||
enabled=True,
|
||||
label="Automatic Setup",
|
||||
description="Automatically setup the visual agent if the selected client has an implementation of the selected backend. (Like the KoboldCpp Automatic1111 api)",
|
||||
),
|
||||
"automatic_generation": AgentAction(
|
||||
enabled=False,
|
||||
label="Automatic Generation",
|
||||
description="Allow automatic generation of visual content",
|
||||
),
|
||||
"process_in_background": AgentAction(
|
||||
enabled=True,
|
||||
label="Process in Background",
|
||||
description="Process renders in the background",
|
||||
),
|
||||
}
|
||||
|
||||
for action_name, action in self.ACTIONS.items():
|
||||
self.actions[action_name] = action
|
||||
|
||||
signal_handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self.actions["_config"].config["backend"].value
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
key = self.actions["_config"].config["backend"].value
|
||||
|
||||
for backend in BACKENDS:
|
||||
if backend["value"] == key:
|
||||
return backend["label"]
|
||||
|
||||
@property
|
||||
def default_style(self):
|
||||
return STYLE_MAP.get(
|
||||
self.actions["_config"].config["default_style"].value, Style()
|
||||
)
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
return self.backend_ready
|
||||
|
||||
@property
|
||||
def api_url(self):
|
||||
try:
|
||||
return self.actions[self.backend].config["api_url"].value
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.backend_name,
|
||||
description="The backend to use for visual processing",
|
||||
).model_dump(),
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.client.name if self.client else None,
|
||||
description="The client to use for prompt generation",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if not self.ready and self.enabled:
|
||||
details["status"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=f"{self.backend_name} not ready",
|
||||
color="error",
|
||||
description=self.ready_check_error
|
||||
or f"{self.backend_name} is not ready for processing",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def process_in_background(self):
|
||||
return self.actions["process_in_background"].enabled
|
||||
|
||||
@property
|
||||
def allow_automatic_generation(self):
|
||||
return self.actions["automatic_generation"].enabled
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
asyncio.create_task(self.emit_status())
|
||||
|
||||
async def on_ready_check_success(self):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = True
|
||||
if not prev_ready:
|
||||
await self.emit_status()
|
||||
|
||||
async def on_ready_check_failure(self, error):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = False
|
||||
self.ready_check_error = str(error)
|
||||
await self.setup_check()
|
||||
if prev_ready:
|
||||
await self.emit_status()
|
||||
|
||||
|
||||
async def ready_check(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
backend = self.backend
|
||||
fn = getattr(self, f"{backend.lower()}_ready", None)
|
||||
task = asyncio.create_task(fn())
|
||||
await super().ready_check(task)
|
||||
|
||||
async def setup_check(self):
|
||||
|
||||
if not self.actions["automatic_setup"].enabled:
|
||||
return
|
||||
|
||||
backend = self.backend
|
||||
if self.client and hasattr(self.client, f"visual_{backend.lower()}_setup"):
|
||||
await getattr(self.client, f"visual_{backend.lower()}_setup")(self)
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
try:
|
||||
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
|
||||
except KeyError:
|
||||
backend = self.backend
|
||||
|
||||
backend_changed = backend != self.backend
|
||||
was_disabled = not self.enabled
|
||||
|
||||
if backend_changed:
|
||||
self.backend_ready = False
|
||||
|
||||
log.info(
|
||||
"apply_config",
|
||||
backend=backend,
|
||||
backend_changed=backend_changed,
|
||||
old_backend=self.backend,
|
||||
)
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None)
|
||||
if backend_fn:
|
||||
|
||||
if not backend_changed and was_disabled and self.enabled:
|
||||
# If the backend has not changed, but the agent was previously disabled
|
||||
# and is now enabled, we need to trigger the backend apply_config function
|
||||
backend_changed = True
|
||||
|
||||
task = asyncio.create_task(
|
||||
backend_fn(backend_changed=backend_changed, *args, **kwargs)
|
||||
)
|
||||
await self.set_background_processing(task)
|
||||
|
||||
if not self.backend_ready:
|
||||
await self.ready_check()
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def resolution_from_format(self, format: str, model_type: str = "sdxl"):
|
||||
if model_type not in RESOLUTION_MAP:
|
||||
raise ValueError(f"Model type {model_type} not found in resolution map")
|
||||
return RESOLUTION_MAP[model_type].get(
|
||||
format, RESOLUTION_MAP[model_type]["portrait"]
|
||||
)
|
||||
|
||||
def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style:
|
||||
|
||||
prompt_style = Style()
|
||||
prompt_style.load(prompt)
|
||||
|
||||
if styles:
|
||||
prompt_style.prepend(*styles)
|
||||
|
||||
return prompt_style
|
||||
|
||||
def vis_type_styles(self, vis_type: str):
|
||||
if vis_type == VIS_TYPES.CHARACTER:
|
||||
portrait_style = STYLE_MAP["character_portrait"].copy()
|
||||
return portrait_style
|
||||
elif vis_type == VIS_TYPES.ENVIRONMENT:
|
||||
environment_style = STYLE_MAP["environment"].copy()
|
||||
return environment_style
|
||||
return Style()
|
||||
|
||||
async def apply_image(self, image: str):
|
||||
context = visual_context.get()
|
||||
|
||||
log.debug("apply_image", image=image[:100], context=context)
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
await self.apply_image_character(image, context.character_name)
|
||||
|
||||
async def apply_image_character(self, image: str, character_name: str):
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
if not character:
|
||||
log.error("character not found", character_name=character_name)
|
||||
return
|
||||
|
||||
if character.cover_image:
|
||||
log.info("character cover image already set", character_name=character_name)
|
||||
return
|
||||
|
||||
asset = self.scene.assets.add_asset_from_image_data(
|
||||
f"data:image/png;base64,{image}"
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
self.scene.assets.cover_image = asset.id
|
||||
self.scene.emit_status()
|
||||
|
||||
async def emit_image(self, image: str):
|
||||
context = visual_context.get()
|
||||
await self.apply_image(image)
|
||||
emit(
|
||||
"image_generated",
|
||||
websocket_passthrough=True,
|
||||
data={
|
||||
"base64": image,
|
||||
"context": context.model_dump() if context else None,
|
||||
},
|
||||
)
|
||||
|
||||
@set_processing
|
||||
async def generate(
|
||||
self, format: str = "portrait", prompt: str = None, automatic: bool = False
|
||||
):
|
||||
|
||||
context = visual_context.get()
|
||||
|
||||
if not self.enabled:
|
||||
log.warning("generate", skipped="Visual agent not enabled")
|
||||
return
|
||||
|
||||
if automatic and not self.allow_automatic_generation:
|
||||
log.warning(
|
||||
"generate",
|
||||
skipped="Automatic generation disabled",
|
||||
prompt=prompt,
|
||||
format=format,
|
||||
context=context,
|
||||
)
|
||||
return
|
||||
|
||||
if not context and not prompt:
|
||||
log.error("generate", error="No context or prompt provided")
|
||||
return
|
||||
|
||||
# Handle prompt generation based on context
|
||||
|
||||
if not prompt and context.prompt:
|
||||
prompt = context.prompt
|
||||
|
||||
if context.vis_type == VIS_TYPES.ENVIRONMENT and not prompt:
|
||||
prompt = await self.generate_environment_prompt(
|
||||
instructions=context.instructions
|
||||
)
|
||||
elif context.vis_type == VIS_TYPES.CHARACTER and not prompt:
|
||||
prompt = await self.generate_character_prompt(
|
||||
context.character_name, instructions=context.instructions
|
||||
)
|
||||
else:
|
||||
prompt = prompt or context.prompt
|
||||
|
||||
initial_prompt = prompt
|
||||
|
||||
# Augment the prompt with styles based on context
|
||||
|
||||
thematic_style = self.default_style
|
||||
vis_type_styles = self.vis_type_styles(context.vis_type)
|
||||
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
prompt.keywords.append("character portrait")
|
||||
|
||||
if not prompt:
|
||||
log.error(
|
||||
"generate", error="No prompt provided and no context to generate from"
|
||||
)
|
||||
return
|
||||
|
||||
context.prompt = initial_prompt
|
||||
context.prepared_prompt = str(prompt)
|
||||
|
||||
# Handle format (can either come from context or be passed in)
|
||||
|
||||
if not format and context.format:
|
||||
format = context.format
|
||||
elif not format:
|
||||
format = "portrait"
|
||||
|
||||
context.format = format
|
||||
|
||||
# Call the backend specific generate function
|
||||
|
||||
backend = self.backend
|
||||
fn = f"{backend.lower()}_generate"
|
||||
|
||||
log.info(
|
||||
"generate", backend=backend, prompt=prompt, format=format, context=context
|
||||
)
|
||||
|
||||
if not hasattr(self, fn):
|
||||
log.error("generate", error=f"Backend {backend} does not support generate")
|
||||
|
||||
# add the function call to the asyncio task queue
|
||||
|
||||
if self.process_in_background:
|
||||
task = asyncio.create_task(getattr(self, fn)(prompt=prompt, format=format))
|
||||
await self.set_background_processing(task)
|
||||
else:
|
||||
await getattr(self, fn)(prompt=prompt, format=format)
|
||||
|
||||
@set_processing
|
||||
async def generate_environment_prompt(self, instructions: str = None):
|
||||
|
||||
response = await Prompt.request(
|
||||
"visual.generate-environment-prompt",
|
||||
self.client,
|
||||
"visualize",
|
||||
{
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
@set_processing
|
||||
async def generate_character_prompt(
|
||||
self, character_name: str, instructions: str = None
|
||||
):
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
response = await Prompt.request(
|
||||
"visual.generate-character-prompt",
|
||||
self.client,
|
||||
"visualize",
|
||||
{
|
||||
"scene": self.scene,
|
||||
"character_name": character_name,
|
||||
"character": character,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"instructions": instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
async def generate_environment_background(self, instructions: str = None):
|
||||
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions):
|
||||
await self.generate(format="landscape")
|
||||
|
||||
generate_environment_background.exposed = True
|
||||
|
||||
async def generate_character_portrait(
|
||||
self,
|
||||
character_name: str,
|
||||
instructions: str = None,
|
||||
):
|
||||
with VisualContext(
|
||||
vis_type=VIS_TYPES.CHARACTER,
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
):
|
||||
await self.generate(format="portrait")
|
||||
|
||||
generate_character_portrait.exposed = True
|
||||
|
||||
|
||||
# apply mixins to the agent (from HANDLERS dict[str, cls])
|
||||
|
||||
for mixin_backend, mixin in HANDLERS.items():
|
||||
mixin_cls = mixin["cls"]
|
||||
VisualBase = type("VisualAgent", (mixin_cls, VisualBase), {})
|
||||
|
||||
extend_actions = getattr(mixin_cls, "EXTEND_ACTIONS", {})
|
||||
|
||||
for action_name, action in extend_actions.items():
|
||||
VisualBase.ACTIONS[action_name] = action
|
||||
|
||||
|
||||
@register()
|
||||
class VisualAgent(VisualBase):
|
||||
pass
|
||||
117
src/talemate/agents/visual/automatic1111.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.automatic1111")
|
||||
|
||||
|
||||
@register(backend_name="automatic1111", label="AUTOMATIC1111")
|
||||
class Automatic1111Mixin:
|
||||
|
||||
automatic1111_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"automatic1111": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="automatic1111"
|
||||
),
|
||||
label="Automatic1111 Settings",
|
||||
description="Setting overrides for the automatic1111 backend",
|
||||
config={
|
||||
"api_url": AgentActionConfig(
|
||||
type="text",
|
||||
value="http://localhost:7860",
|
||||
label="API URL",
|
||||
description="The URL of the backend API",
|
||||
),
|
||||
"steps": AgentActionConfig(
|
||||
type="number",
|
||||
value=40,
|
||||
label="Steps",
|
||||
min=5,
|
||||
max=150,
|
||||
step=1,
|
||||
description="number of render steps",
|
||||
),
|
||||
"model_type": AgentActionConfig(
|
||||
type="text",
|
||||
value="sdxl",
|
||||
choices=[
|
||||
{"value": "sdxl", "label": "SDXL"},
|
||||
{"value": "sd15", "label": "SD1.5"},
|
||||
],
|
||||
label="Model Type",
|
||||
description="Right now just differentiates between sdxl and sd15 - affect generation resolution",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def automatic1111_render_settings(self):
|
||||
if self.actions["automatic1111"].enabled:
|
||||
return RenderSettings(
|
||||
steps=self.actions["automatic1111"].config["steps"].value,
|
||||
type_model=self.actions["automatic1111"].config["model_type"].value,
|
||||
)
|
||||
else:
|
||||
return self.automatic1111_default_render_settings
|
||||
|
||||
async def automatic1111_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
resolution = self.resolution_from_format(
|
||||
format, self.automatic1111_render_settings.type_model
|
||||
)
|
||||
render_settings = self.automatic1111_render_settings
|
||||
payload = {
|
||||
"prompt": prompt.positive_prompt,
|
||||
"negative_prompt": prompt.negative_prompt,
|
||||
"steps": render_settings.steps,
|
||||
"width": resolution.width,
|
||||
"height": resolution.height,
|
||||
}
|
||||
|
||||
log.info("automatic1111_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=90
|
||||
)
|
||||
|
||||
r = response.json()
|
||||
|
||||
# image = Image.open(io.BytesIO(base64.b64decode(r['images'][0])))
|
||||
# image.save('a1111-test.png')
|
||||
|
||||
#'log.info("automatic1111_generate", saved_to="a1111-test.png")
|
||||
|
||||
for image in r["images"]:
|
||||
await self.emit_image(image)
|
||||
|
||||
async def automatic1111_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /sdapi/v1/memory and on 200 will return True
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=f"{self.api_url}/sdapi/v1/memory", timeout=2
|
||||
)
|
||||
return response.status_code == 200
|
||||
324
src/talemate/agents/visual/comfyui.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.comfyui")
|
||||
|
||||
|
||||
class Workflow(pydantic.BaseModel):
|
||||
nodes: dict
|
||||
|
||||
def set_resolution(self, resolution: Resolution):
|
||||
|
||||
# will collect all latent image nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Resolution"
|
||||
|
||||
# if there is no latent image node with the title "Talemate Resolution"
|
||||
# the first latent image node will be used
|
||||
|
||||
# resolution will be updated on the selected node
|
||||
|
||||
# if no latent image node is found a warning will be logged
|
||||
|
||||
latent_image_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
if node["class_type"] == "EmptyLatentImage":
|
||||
if not latent_image_node:
|
||||
latent_image_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Resolution":
|
||||
latent_image_node = node
|
||||
break
|
||||
|
||||
if not latent_image_node:
|
||||
log.warning("set_resolution", error="No latent image node found")
|
||||
return
|
||||
|
||||
latent_image_node["inputs"]["width"] = resolution.width
|
||||
latent_image_node["inputs"]["height"] = resolution.height
|
||||
|
||||
def set_prompt(self, prompt: str, negative_prompt: str = None):
|
||||
|
||||
# will collect all CLIPTextEncode nodes
|
||||
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Positive Prompt" and "Talemate Negative Prompt"
|
||||
#
|
||||
# if there is no CLIPTextEncode node with the title "Talemate Positive Prompt"
|
||||
# the first CLIPTextEncode node will be used
|
||||
#
|
||||
# if there is no CLIPTextEncode node with the title "Talemate Negative Prompt"
|
||||
# the second CLIPTextEncode node will be used
|
||||
#
|
||||
# prompt will be updated on the selected node
|
||||
|
||||
# if no CLIPTextEncode node is found an exception will be raised for
|
||||
# the positive prompt
|
||||
|
||||
# if no CLIPTextEncode node is found an exception will be raised for
|
||||
# the negative prompt if it is not None
|
||||
|
||||
positive_prompt_node = None
|
||||
negative_prompt_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
|
||||
if node["class_type"] == "CLIPTextEncode":
|
||||
if not positive_prompt_node:
|
||||
positive_prompt_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Positive Prompt":
|
||||
positive_prompt_node = node
|
||||
elif not negative_prompt_node:
|
||||
negative_prompt_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Negative Prompt":
|
||||
negative_prompt_node = node
|
||||
|
||||
if not positive_prompt_node:
|
||||
raise ValueError("No positive prompt node found")
|
||||
|
||||
positive_prompt_node["inputs"]["text"] = prompt
|
||||
|
||||
if negative_prompt and not negative_prompt_node:
|
||||
raise ValueError("No negative prompt node found")
|
||||
|
||||
if negative_prompt:
|
||||
negative_prompt_node["inputs"]["text"] = negative_prompt
|
||||
|
||||
def set_checkpoint(self, checkpoint: str):
|
||||
|
||||
# will collect all CheckpointLoaderSimple nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Load Checkpoint"
|
||||
|
||||
# if there is no CheckpointLoaderSimple node with the title "Talemate Load Checkpoint"
|
||||
# the first CheckpointLoaderSimple node will be used
|
||||
|
||||
# checkpoint will be updated on the selected node
|
||||
|
||||
# if no CheckpointLoaderSimple node is found a warning will be logged
|
||||
|
||||
checkpoint_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
if node["class_type"] == "CheckpointLoaderSimple":
|
||||
if not checkpoint_node:
|
||||
checkpoint_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Load Checkpoint":
|
||||
checkpoint_node = node
|
||||
break
|
||||
|
||||
if not checkpoint_node:
|
||||
log.warning("set_checkpoint", error="No checkpoint node found")
|
||||
return
|
||||
|
||||
checkpoint_node["inputs"]["ckpt_name"] = checkpoint
|
||||
|
||||
def set_seeds(self):
|
||||
for node in self.nodes.values():
|
||||
for field in node.get("inputs", {}).keys():
|
||||
if field == "noise_seed":
|
||||
node["inputs"]["noise_seed"] = random.randint(0, 999999999999999)
|
||||
|
||||
|
||||
@register(backend_name="comfyui", label="ComfyUI")
|
||||
class ComfyUIMixin:
|
||||
|
||||
comfyui_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"comfyui": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="comfyui"
|
||||
),
|
||||
label="ComfyUI Settings",
|
||||
description="Setting overrides for the comfyui backend",
|
||||
config={
|
||||
"api_url": AgentActionConfig(
|
||||
type="text",
|
||||
value="http://localhost:8188",
|
||||
label="API URL",
|
||||
description="The URL of the backend API",
|
||||
),
|
||||
"workflow": AgentActionConfig(
|
||||
type="text",
|
||||
value="default-sdxl.json",
|
||||
label="Workflow",
|
||||
description="The workflow to use for comfyui (workflow file name inside ./templates/comfyui-workflows)",
|
||||
),
|
||||
"checkpoint": AgentActionConfig(
|
||||
type="text",
|
||||
value="default",
|
||||
label="Checkpoint",
|
||||
choices=[],
|
||||
description="The main checkpoint to use.",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def comfyui_workflow_filename(self):
|
||||
base_name = self.actions["comfyui"].config["workflow"].value
|
||||
|
||||
# make absolute path
|
||||
abs_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"templates",
|
||||
"comfyui-workflows",
|
||||
base_name,
|
||||
)
|
||||
|
||||
return abs_path
|
||||
|
||||
@property
|
||||
def comfyui_workflow_is_sdxl(self) -> bool:
|
||||
"""
|
||||
Returns true if `sdxl` is in worhflow file name (case insensitive)
|
||||
"""
|
||||
|
||||
return "sdxl" in self.comfyui_workflow_filename.lower()
|
||||
|
||||
@property
|
||||
def comfyui_workflow(self) -> Workflow:
|
||||
workflow = self.comfyui_workflow_filename
|
||||
if not workflow:
|
||||
raise ValueError("No comfyui workflow file specified")
|
||||
|
||||
with open(workflow, "r") as f:
|
||||
return Workflow(nodes=json.load(f))
|
||||
|
||||
@property
|
||||
async def comfyui_object_info(self):
|
||||
if hasattr(self, "_comfyui_object_info"):
|
||||
return self._comfyui_object_info
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/object_info")
|
||||
self._comfyui_object_info = response.json()
|
||||
|
||||
return self._comfyui_object_info
|
||||
|
||||
@property
|
||||
async def comfyui_checkpoints(self):
|
||||
loader_node = (await self.comfyui_object_info)["CheckpointLoaderSimple"]
|
||||
_checkpoints = loader_node["input"]["required"]["ckpt_name"][0]
|
||||
return [
|
||||
{"label": checkpoint, "value": checkpoint} for checkpoint in _checkpoints
|
||||
]
|
||||
|
||||
async def comfyui_get_image(self, filename: str, subfolder: str, folder_type: str):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/view?{url_values}")
|
||||
return response.content
|
||||
|
||||
async def comfyui_get_history(self, prompt_id: str):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/history/{prompt_id}")
|
||||
return response.json()
|
||||
|
||||
async def comfyui_get_images(self, prompt_id: str, max_wait: int = 60.0):
|
||||
output_images = {}
|
||||
history = {}
|
||||
|
||||
start = time.time()
|
||||
|
||||
while not history:
|
||||
log.info(
|
||||
"comfyui_get_images", waiting_for_history=True, prompt_id=prompt_id
|
||||
)
|
||||
history = await self.comfyui_get_history(prompt_id)
|
||||
await asyncio.sleep(1.0)
|
||||
if time.time() - start > max_wait:
|
||||
raise TimeoutError("Max wait time exceeded")
|
||||
|
||||
for node_id, node_output in history[prompt_id]["outputs"].items():
|
||||
if "images" in node_output:
|
||||
images_output = []
|
||||
for image in node_output["images"]:
|
||||
image_data = await self.comfyui_get_image(
|
||||
image["filename"], image["subfolder"], image["type"]
|
||||
)
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
async def comfyui_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
workflow = self.comfyui_workflow
|
||||
is_sdxl = self.comfyui_workflow_is_sdxl
|
||||
|
||||
resolution = self.resolution_from_format(format, "sdxl" if is_sdxl else "sd15")
|
||||
|
||||
workflow.set_resolution(resolution)
|
||||
workflow.set_prompt(prompt.positive_prompt, prompt.negative_prompt)
|
||||
workflow.set_seeds()
|
||||
workflow.set_checkpoint(self.actions["comfyui"].config["checkpoint"].value)
|
||||
|
||||
payload = {"prompt": workflow.model_dump().get("nodes")}
|
||||
|
||||
log.info("comfyui_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=f"{url}/prompt", json=payload, timeout=90)
|
||||
|
||||
log.info("comfyui_generate", response=response.text)
|
||||
|
||||
r = response.json()
|
||||
|
||||
prompt_id = r["prompt_id"]
|
||||
|
||||
images = await self.comfyui_get_images(prompt_id)
|
||||
for node_id, node_images in images.items():
|
||||
for i, image in enumerate(node_images):
|
||||
await self.emit_image(base64.b64encode(image).decode("utf-8"))
|
||||
# image = Image.open(io.BytesIO(image))
|
||||
# image.save(f'comfyui-test.png')
|
||||
|
||||
async def comfyui_apply_config(
|
||||
self, backend_changed: bool = False, *args, **kwargs
|
||||
):
|
||||
log.debug(
|
||||
"comfyui_apply_config",
|
||||
backend_changed=backend_changed,
|
||||
enabled=self.enabled,
|
||||
)
|
||||
if (not self.initialized or backend_changed) and self.enabled:
|
||||
checkpoints = await self.comfyui_checkpoints
|
||||
selected_checkpoint = self.actions["comfyui"].config["checkpoint"].value
|
||||
self.actions["comfyui"].config["checkpoint"].choices = checkpoints
|
||||
self.actions["comfyui"].config["checkpoint"].value = selected_checkpoint
|
||||
|
||||
async def comfyui_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /system_stats and on 200 will return True
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/system_stats", timeout=2)
|
||||
return response.status_code == 200
|
||||
68
src/talemate/agents/visual/commands.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from talemate.agents.visual.context import VIS_TYPES, VisualContext
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.instance import get_agent
|
||||
|
||||
__all__ = [
|
||||
"CmdVisualizeTestGenerate",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeTestGenerate(TalemateCommand):
|
||||
"""
|
||||
Generates a visual test
|
||||
"""
|
||||
|
||||
name = "visual_test_generate"
|
||||
description = "Will generate a visual test"
|
||||
aliases = ["vis_test", "vtg"]
|
||||
|
||||
label = "Visualize test"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
prompt = self.args[0]
|
||||
with VisualContext(vis_type=VIS_TYPES.UNSPECIFIED):
|
||||
await visual.generate(prompt)
|
||||
return True
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeEnvironment(TalemateCommand):
|
||||
"""
|
||||
Shows the environment
|
||||
"""
|
||||
|
||||
name = "visual_environment"
|
||||
description = "Will show the environment"
|
||||
aliases = ["vis_env"]
|
||||
|
||||
label = "Visualize environment"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
await visual.generate_environment_background(
|
||||
instructions=self.args[0] if len(self.args) > 0 else None
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeCharacter(TalemateCommand):
|
||||
"""
|
||||
Shows a character
|
||||
"""
|
||||
|
||||
name = "visual_character"
|
||||
description = "Will show a character"
|
||||
aliases = ["vis_char"]
|
||||
|
||||
label = "Visualize character"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
character_name = self.args[0]
|
||||
instructions = self.args[1] if len(self.args) > 1 else None
|
||||
await visual.generate_character_portrait(character_name, instructions)
|
||||
return True
|
||||
55
src/talemate/agents/visual/context.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import contextvars
|
||||
import enum
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"VIS_TYPES",
|
||||
"visual_context",
|
||||
"VisualContext",
|
||||
]
|
||||
|
||||
|
||||
class VIS_TYPES(str, enum.Enum):
|
||||
UNSPECIFIED = "UNSPECIFIED"
|
||||
ENVIRONMENT = "ENVIRONMENT"
|
||||
CHARACTER = "CHARACTER"
|
||||
ITEM = "ITEM"
|
||||
|
||||
|
||||
visual_context = contextvars.ContextVar("visual_context", default=None)
|
||||
|
||||
|
||||
class VisualContextState(pydantic.BaseModel):
|
||||
character_name: Union[str, None] = None
|
||||
instructions: Union[str, None] = None
|
||||
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT
|
||||
prompt: Union[str, None] = None
|
||||
prepared_prompt: Union[str, None] = None
|
||||
format: Union[str, None] = None
|
||||
|
||||
|
||||
class VisualContext:
|
||||
def __init__(
|
||||
self,
|
||||
character_name: Union[str, None] = None,
|
||||
instructions: Union[str, None] = None,
|
||||
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT,
|
||||
prompt: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.state = VisualContextState(
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
vis_type=vis_type,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.token = visual_context.set(self.state)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
visual_context.reset(self.token)
|
||||
return False
|
||||
17
src/talemate/agents/visual/handlers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
__all__ = [
|
||||
"HANDLERS",
|
||||
"register",
|
||||
]
|
||||
|
||||
HANDLERS = {}
|
||||
|
||||
|
||||
class register:
|
||||
|
||||
def __init__(self, backend_name: str, label: str):
|
||||
self.backend_name = backend_name
|
||||
self.label = label
|
||||
|
||||
def __call__(self, mixin_cls):
|
||||
HANDLERS[self.backend_name] = {"label": self.label, "cls": mixin_cls}
|
||||
return mixin_cls
|
||||
125
src/talemate/agents/visual/openai_image.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import base64
|
||||
import io
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.openai_image")
|
||||
|
||||
|
||||
@register(backend_name="openai_image", label="OpenAI")
|
||||
class OpenAIImageMixin:
|
||||
|
||||
openai_image_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"openai_image": AgentAction(
|
||||
enabled=False,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="openai_image"
|
||||
),
|
||||
label="OpenAI Image Generation Advanced Settings",
|
||||
description="Setting overrides for the openai backend",
|
||||
config={
|
||||
"model_type": AgentActionConfig(
|
||||
type="text",
|
||||
value="dall-e-3",
|
||||
choices=[
|
||||
{"value": "dall-e-3", "label": "DALL-E 3"},
|
||||
{"value": "dall-e-2", "label": "DALL-E 2"},
|
||||
],
|
||||
label="Model Type",
|
||||
description="Image generation model",
|
||||
),
|
||||
"quality": AgentActionConfig(
|
||||
type="text",
|
||||
value="standard",
|
||||
choices=[
|
||||
{"value": "standard", "label": "Standard"},
|
||||
{"value": "hd", "label": "HD"},
|
||||
],
|
||||
label="Quality",
|
||||
description="Image generation quality",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def openai_model_type(self):
|
||||
return self.actions["openai_image"].config["model_type"].value
|
||||
|
||||
@property
|
||||
def openai_quality(self):
|
||||
return self.actions["openai_image"].config["quality"].value
|
||||
|
||||
async def openai_image_generate(self, prompt: Style, format: str):
|
||||
"""
|
||||
#
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
|
||||
response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt="a white siamese cat",
|
||||
size="1024x1024",
|
||||
quality="standard",
|
||||
n=1,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
"""
|
||||
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
# When using DALL·E 3, images can have a size of 1024x1024, 1024x1792 or 1792x1024 pixels.#
|
||||
|
||||
if format == "portrait":
|
||||
resolution = Resolution(width=1024, height=1792)
|
||||
elif format == "landscape":
|
||||
resolution = Resolution(width=1792, height=1024)
|
||||
else:
|
||||
resolution = Resolution(width=1024, height=1024)
|
||||
|
||||
log.debug("openai_image_generate", resolution=resolution)
|
||||
|
||||
response = await client.images.generate(
|
||||
model=self.openai_model_type,
|
||||
prompt=prompt.positive_prompt,
|
||||
size=f"{resolution.width}x{resolution.height}",
|
||||
quality=self.openai_quality,
|
||||
n=1,
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
await self.emit_image(response.data[0].b64_json)
|
||||
|
||||
async def openai_image_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /sdapi/v1/memory and on 200 will return True
|
||||
"""
|
||||
|
||||
if not self.openai_api_key:
|
||||
raise ValueError("OpenAI API Key not set")
|
||||
|
||||
return True
|
||||
32
src/talemate/agents/visual/schema.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"RenderSettings",
|
||||
"Resolution",
|
||||
"RESOLUTION_MAP",
|
||||
]
|
||||
|
||||
RESOLUTION_MAP = {}
|
||||
|
||||
|
||||
class RenderSettings(pydantic.BaseModel):
|
||||
type_model: str = "sdxl"
|
||||
steps: int = 40
|
||||
|
||||
|
||||
class Resolution(pydantic.BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
RESOLUTION_MAP["sdxl"] = {
|
||||
"portrait": Resolution(width=832, height=1216),
|
||||
"landscape": Resolution(width=1216, height=832),
|
||||
"square": Resolution(width=1024, height=1024),
|
||||
}
|
||||
|
||||
RESOLUTION_MAP["sd15"] = {
|
||||
"portrait": Resolution(width=512, height=768),
|
||||
"landscape": Resolution(width=768, height=512),
|
||||
"square": Resolution(width=768, height=768),
|
||||
}
|
||||
136
src/talemate/agents/visual/style.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
__all__ = [
|
||||
"Style",
|
||||
"STYLE_MAP",
|
||||
"THEME_MAP",
|
||||
"MAJOR_STYLES",
|
||||
"combine_styles",
|
||||
]
|
||||
|
||||
STYLE_MAP = {}
|
||||
THEME_MAP = {}
|
||||
MAJOR_STYLES = {}
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.style")
|
||||
|
||||
|
||||
class Style(pydantic.BaseModel):
|
||||
keywords: list[str] = pydantic.Field(default_factory=list)
|
||||
negative_keywords: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def positive_prompt(self):
|
||||
return ", ".join(self.keywords)
|
||||
|
||||
@property
|
||||
def negative_prompt(self):
|
||||
return ", ".join(self.negative_keywords)
|
||||
|
||||
def __str__(self):
|
||||
return f"POSITIVE: {self.positive_prompt}\nNEGATIVE: {self.negative_prompt}"
|
||||
|
||||
def load(self, prompt: str, negative_prompt: str = ""):
|
||||
self.keywords = prompt.split(", ")
|
||||
self.negative_keywords = negative_prompt.split(", ")
|
||||
|
||||
# loop through keywords and drop any starting with "no " and add to negative_keywords
|
||||
# with "no " removed
|
||||
for kw in self.keywords:
|
||||
kw = kw.strip()
|
||||
log.debug("Checking keyword", keyword=kw)
|
||||
if kw.startswith("no "):
|
||||
log.debug("Transforming negative keyword", keyword=kw, to=kw[3:])
|
||||
self.keywords.remove(kw)
|
||||
self.negative_keywords.append(kw[3:])
|
||||
|
||||
return self
|
||||
|
||||
def prepend(self, *styles):
|
||||
for style in styles:
|
||||
for idx in range(len(style.keywords) - 1, -1, -1):
|
||||
kw = style.keywords[idx]
|
||||
if kw not in self.keywords:
|
||||
self.keywords.insert(0, kw)
|
||||
|
||||
for idx in range(len(style.negative_keywords) - 1, -1, -1):
|
||||
kw = style.negative_keywords[idx]
|
||||
if kw not in self.negative_keywords:
|
||||
self.negative_keywords.insert(0, kw)
|
||||
|
||||
return self
|
||||
|
||||
def append(self, *styles):
|
||||
for style in styles:
|
||||
for kw in style.keywords:
|
||||
if kw not in self.keywords:
|
||||
self.keywords.append(kw)
|
||||
|
||||
for kw in style.negative_keywords:
|
||||
if kw not in self.negative_keywords:
|
||||
self.negative_keywords.append(kw)
|
||||
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
return Style(
|
||||
keywords=self.keywords.copy(),
|
||||
negative_keywords=self.negative_keywords.copy(),
|
||||
)
|
||||
|
||||
|
||||
# Almost taken straight from some of the fooocus style presets, credit goes to the original author
|
||||
|
||||
STYLE_MAP["digital_art"] = Style(
|
||||
keywords="digital artwork, masterpiece, best quality, high detail".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["concept_art"] = Style(
|
||||
keywords="concept art, conceptual sketch, masterpiece, best quality, high detail".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["ink_illustration"] = Style(
|
||||
keywords="ink illustration, painting, masterpiece, best quality".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["anime"] = Style(
|
||||
keywords="anime, masterpiece, best quality, illustration".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["graphic_novel"] = Style(
|
||||
keywords="(stylized by Enki Bilal:0.7), best quality, graphic novels, detailed linework, digital art".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo, 3d, cgi".split(
|
||||
", "
|
||||
),
|
||||
)
|
||||
|
||||
STYLE_MAP["character_portrait"] = Style(keywords="solo, looking at viewer".split(", "))
|
||||
|
||||
STYLE_MAP["environment"] = Style(
|
||||
keywords="scenery, environment, background, postcard".split(", "),
|
||||
negative_keywords="character, portrait, looking at viewer, people".split(", "),
|
||||
)
|
||||
|
||||
MAJOR_STYLES = [
|
||||
{"value": "digital_art", "label": "Digital Art"},
|
||||
{"value": "concept_art", "label": "Concept Art"},
|
||||
{"value": "ink_illustration", "label": "Ink Illustration"},
|
||||
{"value": "anime", "label": "Anime"},
|
||||
{"value": "graphic_novel", "label": "Graphic Novel"},
|
||||
]
|
||||
|
||||
|
||||
def combine_styles(*styles):
|
||||
keywords = []
|
||||
for style in styles:
|
||||
keywords.extend(style.keywords)
|
||||
return Style(keywords=list(set(keywords)))
|
||||
84
src/talemate/agents/visual/websocket_handler.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
|
||||
from .context import VisualContext, VisualContextState
|
||||
|
||||
__all__ = [
|
||||
"VisualWebsocketHandler",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.server.visual")
|
||||
|
||||
|
||||
class SetCoverImagePayload(pydantic.BaseModel):
|
||||
base64: str
|
||||
context: Union[VisualContextState, None] = None
|
||||
|
||||
|
||||
class RegeneratePayload(pydantic.BaseModel):
|
||||
context: Union[VisualContextState, None] = None
|
||||
|
||||
|
||||
class VisualWebsocketHandler(Plugin):
|
||||
router = "visual"
|
||||
|
||||
async def handle_regenerate(self, data: dict):
|
||||
"""
|
||||
Regenerate the image based on the context.
|
||||
"""
|
||||
|
||||
payload = RegeneratePayload(**data)
|
||||
|
||||
context = payload.context
|
||||
|
||||
visual = get_agent("visual")
|
||||
|
||||
with VisualContext(**context.model_dump()):
|
||||
await visual.generate(format="")
|
||||
|
||||
async def handle_cover_image(self, data: dict):
|
||||
"""
|
||||
Sets the cover image for a character and the scene.
|
||||
"""
|
||||
|
||||
payload = SetCoverImagePayload(**data)
|
||||
|
||||
context = payload.context
|
||||
scene = self.scene
|
||||
|
||||
if context and context.character_name:
|
||||
|
||||
character = scene.get_character(context.character_name)
|
||||
|
||||
if not character:
|
||||
log.error("character not found", character_name=context.character_name)
|
||||
return
|
||||
|
||||
asset = scene.assets.add_asset_from_image_data(payload.base64)
|
||||
|
||||
log.info("setting scene cover image", character_name=context.character_name)
|
||||
scene.assets.cover_image = asset.id
|
||||
|
||||
log.info(
|
||||
"setting character cover image", character_name=context.character_name
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
|
||||
scene.emit_status()
|
||||
self.websocket_handler.request_scene_assets([asset.id])
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_asset_character_cover_image",
|
||||
"asset_id": asset.id,
|
||||
"asset": self.scene.assets.get_asset_bytes_as_base64(asset.id),
|
||||
"media_type": asset.media_type,
|
||||
"character": character.name,
|
||||
}
|
||||
)
|
||||
return
|
||||
@@ -1,46 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import isodate
|
||||
import structlog
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.world_state import InsertionMode
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage, ReinforcementMessage
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.instance import get_agent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import (
|
||||
DirectorMessage,
|
||||
ReinforcementMessage,
|
||||
TimePassageMessage,
|
||||
)
|
||||
from talemate.world_state import InsertionMode
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
|
||||
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
import isodate
|
||||
import time
|
||||
|
||||
|
||||
log = structlog.get_logger("talemate.agents.world_state")
|
||||
|
||||
talemate.emit.async_signals.register("agent.world_state.time")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WorldStateAgentEmission(AgentEmission):
|
||||
"""
|
||||
Emission class for world state agent
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TimePassageEmission(WorldStateAgentEmission):
|
||||
"""
|
||||
Emission class for time passage
|
||||
"""
|
||||
|
||||
duration: str
|
||||
narrative: str
|
||||
human_duration: str = None
|
||||
|
||||
|
||||
|
||||
@register()
|
||||
class WorldStateAgent(Agent):
|
||||
@@ -55,26 +63,57 @@ class WorldStateAgent(Agent):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"update_world_state": AgentAction(enabled=True, label="Update world state", description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.", config={
|
||||
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before updating the world state.", value=5, min=1, max=100, step=1)
|
||||
}),
|
||||
"update_reinforcements": AgentAction(enabled=True, label="Update state reinforcements", description="Will attempt to update any due state reinforcements.", config={}),
|
||||
"check_pin_conditions": AgentAction(enabled=True, label="Update conditional context pins", description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.", config={
|
||||
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before checking conditions.", value=2, min=1, max=100, step=1)
|
||||
}),
|
||||
"update_world_state": AgentAction(
|
||||
enabled=True,
|
||||
label="Update world state",
|
||||
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
|
||||
config={
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
description="Number of turns to wait before updating the world state.",
|
||||
value=5,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
)
|
||||
},
|
||||
),
|
||||
"update_reinforcements": AgentAction(
|
||||
enabled=True,
|
||||
label="Update state reinforcements",
|
||||
description="Will attempt to update any due state reinforcements.",
|
||||
config={},
|
||||
),
|
||||
"check_pin_conditions": AgentAction(
|
||||
enabled=True,
|
||||
label="Update conditional context pins",
|
||||
description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.",
|
||||
config={
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
description="Number of turns to wait before checking conditions.",
|
||||
value=2,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
)
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
self.next_update = 0
|
||||
self.next_pin_check = 0
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
@@ -83,110 +122,123 @@ class WorldStateAgent(Agent):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
async def advance_time(self, duration:str, narrative:str=None):
|
||||
async def advance_time(self, duration: str, narrative: str = None):
|
||||
"""
|
||||
Emit a time passage message
|
||||
"""
|
||||
|
||||
|
||||
isodate.parse_duration(duration)
|
||||
human_duration = util.iso8601_duration_to_human(duration, suffix=" later")
|
||||
message = TimePassageMessage(ts=duration, message=human_duration)
|
||||
|
||||
|
||||
log.debug("world_state.advance_time", message=message)
|
||||
self.scene.push_history(message)
|
||||
self.scene.emit_status()
|
||||
|
||||
emit("time", message)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.world_state.time").send(
|
||||
TimePassageEmission(agent=self, duration=duration, narrative=narrative, human_duration=human_duration)
|
||||
)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission:GameLoopEvent):
|
||||
emit("time", message)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.world_state.time").send(
|
||||
TimePassageEmission(
|
||||
agent=self,
|
||||
duration=duration,
|
||||
narrative=narrative,
|
||||
human_duration=human_duration,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
await self.update_world_state()
|
||||
await self.auto_update_reinforcments()
|
||||
await self.auto_check_pin_conditions()
|
||||
|
||||
|
||||
|
||||
async def auto_update_reinforcments(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.actions["update_reinforcements"].enabled:
|
||||
return
|
||||
|
||||
|
||||
await self.update_reinforcements()
|
||||
|
||||
|
||||
async def auto_check_pin_conditions(self):
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.actions["check_pin_conditions"].enabled:
|
||||
return
|
||||
|
||||
if self.next_pin_check % self.actions["check_pin_conditions"].config["turns"].value != 0 or self.next_pin_check == 0:
|
||||
|
||||
|
||||
if (
|
||||
self.next_pin_check
|
||||
% self.actions["check_pin_conditions"].config["turns"].value
|
||||
!= 0
|
||||
or self.next_pin_check == 0
|
||||
):
|
||||
self.next_pin_check += 1
|
||||
return
|
||||
|
||||
self.next_pin_check = 0
|
||||
|
||||
await self.check_pin_conditions()
|
||||
|
||||
|
||||
async def update_world_state(self):
|
||||
self.next_pin_check = 0
|
||||
|
||||
await self.check_pin_conditions()
|
||||
|
||||
async def update_world_state(self, force: bool = False):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.actions["update_world_state"].enabled:
|
||||
return
|
||||
|
||||
log.debug("update_world_state", next_update=self.next_update, turns=self.actions["update_world_state"].config["turns"].value)
|
||||
|
||||
|
||||
log.debug(
|
||||
"update_world_state",
|
||||
next_update=self.next_update,
|
||||
turns=self.actions["update_world_state"].config["turns"].value,
|
||||
)
|
||||
|
||||
scene = self.scene
|
||||
|
||||
if self.next_update % self.actions["update_world_state"].config["turns"].value != 0 or self.next_update == 0:
|
||||
|
||||
if (
|
||||
self.next_update % self.actions["update_world_state"].config["turns"].value
|
||||
!= 0
|
||||
or self.next_update == 0
|
||||
) and not force:
|
||||
self.next_update += 1
|
||||
return
|
||||
|
||||
|
||||
self.next_update = 0
|
||||
await scene.world_state.request_update()
|
||||
|
||||
|
||||
update_world_state.exposed = True
|
||||
|
||||
@set_processing
|
||||
async def request_world_state(self):
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
_, world_state = await Prompt.request(
|
||||
"world_state.request-world-state-v2",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"object_type": "character",
|
||||
"object_type_plural": "characters",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.debug("request_world_state", response=world_state, time=time.time() - t1)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"request_world_state", response=world_state, time=time.time() - t1
|
||||
)
|
||||
|
||||
return world_state
|
||||
|
||||
|
||||
@set_processing
|
||||
async def request_world_state_inline(self):
|
||||
|
||||
"""
|
||||
EXPERIMENTAL, Overall the one shot request seems about as coherent as the inline request, but the inline request is is about twice as slow and would need to run on every dialogue line.
|
||||
"""
|
||||
@@ -199,14 +251,18 @@ class WorldStateAgent(Agent):
|
||||
"world_state.request-world-state-inline-items",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.debug("request_world_state_inline", marked_items=marked_items_response, time=time.time() - t1)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"request_world_state_inline",
|
||||
marked_items=marked_items_response,
|
||||
time=time.time() - t1,
|
||||
)
|
||||
|
||||
return marked_items_response
|
||||
|
||||
@set_processing
|
||||
@@ -214,99 +270,111 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-time-passage",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
duration = response.split("\n")[0].split(" ")[0].strip()
|
||||
|
||||
|
||||
if not duration.startswith("P"):
|
||||
duration = "P"+duration
|
||||
|
||||
duration = "P" + duration
|
||||
|
||||
return duration
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_text_and_extract_context(
|
||||
self,
|
||||
text: str,
|
||||
goal: str,
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-extract-context",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"goal": goal,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("analyze_text_and_extract_context", goal=goal, text=text, response=response)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_text_and_extract_context", goal=goal, text=text, response=response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_text_and_extract_context_via_queries(
|
||||
self,
|
||||
text: str,
|
||||
goal: str,
|
||||
) -> list[str]:
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-generate-rag-queries",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"goal": goal,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
queries = response.split("\n")
|
||||
|
||||
|
||||
memory_agent = get_agent("memory")
|
||||
|
||||
|
||||
context = await memory_agent.multi_query(queries, iterate=3)
|
||||
|
||||
log.debug("analyze_text_and_extract_context_via_queries", goal=goal, text=text, queries=queries, context=context)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_text_and_extract_context_via_queries",
|
||||
goal=goal,
|
||||
text=text,
|
||||
queries=queries,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_and_follow_instruction(
|
||||
self,
|
||||
text: str,
|
||||
instruction: str,
|
||||
short: bool = False,
|
||||
):
|
||||
|
||||
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-follow-instruction",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"instruction": instruction,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("analyze_and_follow_instruction", instruction=instruction, text=text, response=response)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_and_follow_instruction",
|
||||
instruction=instruction,
|
||||
text=text,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@@ -314,51 +382,55 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
query: str,
|
||||
short: bool = False,
|
||||
):
|
||||
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-answer-question",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"query": query,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("analyze_text_and_answer_question", query=query, text=text, response=response)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_text_and_answer_question",
|
||||
query=query,
|
||||
text=text,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def identify_characters(
|
||||
self,
|
||||
text: str = None,
|
||||
):
|
||||
|
||||
"""
|
||||
Attempts to identify characters in the given text.
|
||||
"""
|
||||
|
||||
|
||||
_, data = await Prompt.request(
|
||||
"world_state.identify-characters",
|
||||
self.client,
|
||||
"analyze",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.debug("identify_characters", text=text, data=data)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _parse_character_sheet(self, response):
|
||||
|
||||
data = {}
|
||||
for line in response.split("\n"):
|
||||
if not line.strip():
|
||||
@@ -367,128 +439,148 @@ class WorldStateAgent(Agent):
|
||||
break
|
||||
name, value = line.split(":", 1)
|
||||
data[name.strip()] = value.strip()
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@set_processing
|
||||
async def extract_character_sheet(
|
||||
self,
|
||||
name:str,
|
||||
text:str = None,
|
||||
name: str,
|
||||
text: str = None,
|
||||
alteration_instructions: str = None,
|
||||
):
|
||||
|
||||
"""
|
||||
Attempts to extract a character sheet from the given text.
|
||||
"""
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.extract-character-sheet",
|
||||
self.client,
|
||||
"create",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"name": name,
|
||||
}
|
||||
"character": self.scene.get_character(name),
|
||||
"alteration_instructions": alteration_instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# loop through each line in response and if it contains a : then extract
|
||||
# the left side as an attribute name and the right side as the value
|
||||
#
|
||||
# break as soon as a non-empty line is found that doesn't contain a :
|
||||
|
||||
|
||||
return self._parse_character_sheet(response)
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def match_character_names(self, names:list[str]):
|
||||
|
||||
async def match_character_names(self, names: list[str]):
|
||||
"""
|
||||
Attempts to match character names.
|
||||
"""
|
||||
|
||||
|
||||
_, response = await Prompt.request(
|
||||
"world_state.match-character-names",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"names": names,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.debug("match_character_names", names=names, response=response)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def update_reinforcements(self, force:bool=False):
|
||||
|
||||
async def update_reinforcements(self, force: bool = False):
|
||||
"""
|
||||
Queries due worldstate re-inforcements
|
||||
"""
|
||||
|
||||
|
||||
for reinforcement in self.scene.world_state.reinforce:
|
||||
if reinforcement.due <= 0 or force:
|
||||
await self.update_reinforcement(reinforcement.question, reinforcement.character)
|
||||
await self.update_reinforcement(
|
||||
reinforcement.question, reinforcement.character
|
||||
)
|
||||
else:
|
||||
reinforcement.due -= 1
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def update_reinforcement(self, question:str, character:str=None, reset:bool=False):
|
||||
|
||||
async def update_reinforcement(
|
||||
self, question: str, character: str = None, reset: bool = False
|
||||
):
|
||||
"""
|
||||
Queries a single re-inforcement
|
||||
"""
|
||||
message = None
|
||||
idx, reinforcement = await self.scene.world_state.find_reinforcement(question, character)
|
||||
|
||||
idx, reinforcement = await self.scene.world_state.find_reinforcement(
|
||||
question, character
|
||||
)
|
||||
|
||||
if not reinforcement:
|
||||
return
|
||||
|
||||
|
||||
source = f"{reinforcement.question}:{reinforcement.character if reinforcement.character else ''}"
|
||||
|
||||
|
||||
if reset and reinforcement.insert == "sequential":
|
||||
self.scene.pop_history(typ="reinforcement", source=source, all=True)
|
||||
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
kind = "analyze_freeform_medium_short"
|
||||
else:
|
||||
kind = "analyze_freeform"
|
||||
|
||||
answer = await Prompt.request(
|
||||
"world_state.update-reinforcements",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"question": reinforcement.question,
|
||||
"instructions": reinforcement.instructions or "",
|
||||
"character": self.scene.get_character(reinforcement.character) if reinforcement.character else None,
|
||||
"character": (
|
||||
self.scene.get_character(reinforcement.character)
|
||||
if reinforcement.character
|
||||
else None
|
||||
),
|
||||
"answer": (reinforcement.answer if not reset else None) or "",
|
||||
"reinforcement": reinforcement,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# sequential reinforcment should be single sentence so we
|
||||
# split on line breaks and take the first line in case the
|
||||
# LLM did not understand the request and returned a longer response
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
answer = answer.split("\n")[0]
|
||||
|
||||
reinforcement.answer = answer
|
||||
reinforcement.due = reinforcement.interval
|
||||
|
||||
|
||||
# remove any recent previous reinforcement message with same question
|
||||
# to avoid overloading the near history with reinforcement messages
|
||||
if not reset:
|
||||
self.scene.pop_history(typ="reinforcement", source=source, max_iterations=10)
|
||||
|
||||
self.scene.pop_history(
|
||||
typ="reinforcement", source=source, max_iterations=10
|
||||
)
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
# insert the reinforcement message at the current position
|
||||
message = ReinforcementMessage(message=answer, source=source)
|
||||
log.debug("update_reinforcement", message=message, reset=reset)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
# if reinforcement has a character name set, update the character detail
|
||||
if reinforcement.character:
|
||||
character = self.scene.get_character(reinforcement.character)
|
||||
await character.set_detail(reinforcement.question, answer)
|
||||
|
||||
|
||||
else:
|
||||
# set world entry
|
||||
await self.scene.world_state_manager.save_world_entry(
|
||||
@@ -496,20 +588,19 @@ class WorldStateAgent(Agent):
|
||||
reinforcement.as_context_line,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
self.scene.world_state.emit()
|
||||
|
||||
return message
|
||||
|
||||
|
||||
return message
|
||||
|
||||
@set_processing
|
||||
async def check_pin_conditions(
|
||||
self,
|
||||
):
|
||||
|
||||
"""
|
||||
Checks if any context pin conditions
|
||||
"""
|
||||
|
||||
|
||||
pins_with_condition = {
|
||||
entry_id: {
|
||||
"condition": pin.condition,
|
||||
@@ -518,41 +609,47 @@ class WorldStateAgent(Agent):
|
||||
for entry_id, pin in self.scene.world_state.pins.items()
|
||||
if pin.condition
|
||||
}
|
||||
|
||||
|
||||
if not pins_with_condition:
|
||||
return
|
||||
|
||||
|
||||
first_entry_id = list(pins_with_condition.keys())[0]
|
||||
|
||||
|
||||
_, answers = await Prompt.request(
|
||||
"world_state.check-pin-conditions",
|
||||
self.client,
|
||||
"analyze",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"previous_states": json.dumps(pins_with_condition,indent=2),
|
||||
"coercion": {first_entry_id:{ "condition": "" }},
|
||||
}
|
||||
"previous_states": json.dumps(pins_with_condition, indent=2),
|
||||
"coercion": {first_entry_id: {"condition": ""}},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
world_state = self.scene.world_state
|
||||
state_change = False
|
||||
|
||||
state_change = False
|
||||
|
||||
for entry_id, answer in answers.items():
|
||||
|
||||
if entry_id not in world_state.pins:
|
||||
log.warning("check_pin_conditions", entry_id=entry_id, answer=answer, msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)")
|
||||
log.warning(
|
||||
"check_pin_conditions",
|
||||
entry_id=entry_id,
|
||||
answer=answer,
|
||||
msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)",
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
log.info("check_pin_conditions", entry_id=entry_id, answer=answer)
|
||||
state = answer.get("state")
|
||||
if state is True or (isinstance(state, str) and state.lower() in ["true", "yes", "y"]):
|
||||
if state is True or (
|
||||
isinstance(state, str) and state.lower() in ["true", "yes", "y"]
|
||||
):
|
||||
prev_state = world_state.pins[entry_id].condition_state
|
||||
|
||||
|
||||
world_state.pins[entry_id].condition_state = True
|
||||
world_state.pins[entry_id].active = True
|
||||
|
||||
|
||||
if prev_state != world_state.pins[entry_id].condition_state:
|
||||
state_change = True
|
||||
else:
|
||||
@@ -560,49 +657,50 @@ class WorldStateAgent(Agent):
|
||||
world_state.pins[entry_id].condition_state = False
|
||||
world_state.pins[entry_id].active = False
|
||||
state_change = True
|
||||
|
||||
|
||||
if state_change:
|
||||
await self.scene.load_active_pins()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize_and_pin(self, message_id:int, num_messages:int=3) -> str:
|
||||
|
||||
async def summarize_and_pin(self, message_id: int, num_messages: int = 3) -> str:
|
||||
"""
|
||||
Will take a message index and then walk back N messages
|
||||
summarizing the scene and pinning it to the context.
|
||||
"""
|
||||
|
||||
|
||||
creator = get_agent("creator")
|
||||
summarizer = get_agent("summarizer")
|
||||
|
||||
|
||||
message_index = self.scene.message_index(message_id)
|
||||
|
||||
|
||||
text = self.scene.snapshot(lines=num_messages, start=message_index)
|
||||
|
||||
extra_context = self.scene.snapshot(lines=50, start=message_index-num_messages)
|
||||
|
||||
|
||||
extra_context = self.scene.snapshot(
|
||||
lines=50, start=message_index - num_messages
|
||||
)
|
||||
|
||||
summary = await summarizer.summarize(
|
||||
text,
|
||||
text,
|
||||
extra_context=extra_context,
|
||||
method="short",
|
||||
extra_instructions="Pay particularly close attention to decisions, agreements or promises made.",
|
||||
)
|
||||
|
||||
|
||||
entry_id = util.clean_id(await creator.generate_title(summary))
|
||||
|
||||
|
||||
ts = self.scene.ts
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarize_and_pin",
|
||||
message_id=message_id,
|
||||
message_index=message_index,
|
||||
num_messages=num_messages,
|
||||
num_messages=num_messages,
|
||||
summary=summary,
|
||||
entry_id=entry_id,
|
||||
ts=ts,
|
||||
)
|
||||
|
||||
|
||||
await self.scene.world_state_manager.save_world_entry(
|
||||
entry_id,
|
||||
summary,
|
||||
@@ -610,49 +708,49 @@ class WorldStateAgent(Agent):
|
||||
"ts": ts,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
await self.scene.world_state_manager.set_pin(
|
||||
entry_id,
|
||||
active=True,
|
||||
)
|
||||
|
||||
|
||||
await self.scene.load_active_pins()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def is_character_present(self, character:str) -> bool:
|
||||
async def is_character_present(self, character: str) -> bool:
|
||||
"""
|
||||
Check if a character is present in the scene
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
- `character`: The character to check.
|
||||
"""
|
||||
|
||||
if len(self.scene.history) < 10:
|
||||
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
|
||||
else:
|
||||
text = self.scene.snapshot(lines=50)
|
||||
|
||||
is_present = await self.analyze_text_and_answer_question(
|
||||
text=text,
|
||||
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
|
||||
)
|
||||
|
||||
return is_present.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def is_character_leaving(self, character:str) -> bool:
|
||||
"""
|
||||
Check if a character is leaving the scene
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- `character`: The character to check.
|
||||
"""
|
||||
|
||||
if len(self.scene.history) < 10:
|
||||
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
|
||||
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
|
||||
else:
|
||||
text = self.scene.snapshot(lines=50)
|
||||
|
||||
is_present = await self.analyze_text_and_answer_question(
|
||||
text=text,
|
||||
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
|
||||
)
|
||||
|
||||
return is_present.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def is_character_leaving(self, character: str) -> bool:
|
||||
"""
|
||||
Check if a character is leaving the scene
|
||||
|
||||
Arguments:
|
||||
|
||||
- `character`: The character to check.
|
||||
"""
|
||||
|
||||
if len(self.scene.history) < 10:
|
||||
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
|
||||
else:
|
||||
text = self.scene.snapshot(lines=50)
|
||||
|
||||
@@ -660,5 +758,30 @@ class WorldStateAgent(Agent):
|
||||
text=text,
|
||||
query=f"Is {character} leaving the current scene? Answert with 'yes' or 'no'.",
|
||||
)
|
||||
|
||||
return is_leaving.lower().startswith("y")
|
||||
|
||||
return is_leaving.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def manager(self, action_name: str, *args, **kwargs):
|
||||
"""
|
||||
Executes a world state manager action through self.scene.world_state_manager
|
||||
"""
|
||||
|
||||
manager = self.scene.world_state_manager
|
||||
|
||||
try:
|
||||
fn = getattr(manager, action_name, None)
|
||||
|
||||
if not fn:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
return await fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"worldstate.manager",
|
||||
action_name=action_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
error=e,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Scene
|
||||
|
||||
|
||||
import structlog
|
||||
|
||||
__all__ = ["AutomatedAction", "register", "initialize_for_scene"]
|
||||
@@ -13,50 +14,64 @@ log = structlog.get_logger("talemate.automated_action")
|
||||
|
||||
AUTOMATED_ACTIONS = {}
|
||||
|
||||
def initialize_for_scene(scene:Scene):
|
||||
|
||||
|
||||
def initialize_for_scene(scene: Scene):
|
||||
for uid, config in AUTOMATED_ACTIONS.items():
|
||||
scene.automated_actions[uid] = config.cls(
|
||||
scene,
|
||||
uid=uid,
|
||||
frequency=config.frequency,
|
||||
call_initially=config.call_initially,
|
||||
enabled=config.enabled
|
||||
enabled=config.enabled,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutomatedActionConfig:
|
||||
uid:str
|
||||
cls:AutomatedAction
|
||||
frequency:int=5
|
||||
call_initially:bool=False
|
||||
enabled:bool=True
|
||||
|
||||
uid: str
|
||||
cls: AutomatedAction
|
||||
frequency: int = 5
|
||||
call_initially: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class register:
|
||||
|
||||
def __init__(self, uid:str, frequency:int=5, call_initially:bool=False, enabled:bool=True):
|
||||
def __init__(
|
||||
self,
|
||||
uid: str,
|
||||
frequency: int = 5,
|
||||
call_initially: bool = False,
|
||||
enabled: bool = True,
|
||||
):
|
||||
self.uid = uid
|
||||
self.frequency = frequency
|
||||
self.call_initially = call_initially
|
||||
self.enabled = enabled
|
||||
|
||||
def __call__(self, action:AutomatedAction):
|
||||
|
||||
def __call__(self, action: AutomatedAction):
|
||||
AUTOMATED_ACTIONS[self.uid] = AutomatedActionConfig(
|
||||
self.uid,
|
||||
action,
|
||||
frequency=self.frequency,
|
||||
call_initially=self.call_initially,
|
||||
enabled=self.enabled
|
||||
self.uid,
|
||||
action,
|
||||
frequency=self.frequency,
|
||||
call_initially=self.call_initially,
|
||||
enabled=self.enabled,
|
||||
)
|
||||
return action
|
||||
|
||||
|
||||
|
||||
class AutomatedAction:
|
||||
"""
|
||||
An action that will be executed every n turns
|
||||
"""
|
||||
|
||||
def __init__(self, scene:Scene, frequency:int=5, call_initially:bool=False, uid:str=None, enabled:bool=True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scene: Scene,
|
||||
frequency: int = 5,
|
||||
call_initially: bool = False,
|
||||
uid: str = None,
|
||||
enabled: bool = True,
|
||||
):
|
||||
self.scene = scene
|
||||
self.enabled = enabled
|
||||
self.frequency = frequency
|
||||
@@ -64,14 +79,19 @@ class AutomatedAction:
|
||||
self.uid = uid
|
||||
if call_initially:
|
||||
self.turns = frequency
|
||||
|
||||
|
||||
async def __call__(self):
|
||||
|
||||
log.debug("automated_action", uid=self.uid, enabled=self.enabled, frequency=self.frequency, turns=self.turns)
|
||||
|
||||
log.debug(
|
||||
"automated_action",
|
||||
uid=self.uid,
|
||||
enabled=self.enabled,
|
||||
frequency=self.frequency,
|
||||
turns=self.turns,
|
||||
)
|
||||
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
|
||||
if self.turns % self.frequency == 0:
|
||||
result = await self.action()
|
||||
log.debug("automated_action", result=result)
|
||||
@@ -79,10 +99,9 @@ class AutomatedAction:
|
||||
# action could not be performed at this turn, we will try again next turn
|
||||
return False
|
||||
self.turns += 1
|
||||
|
||||
|
||||
|
||||
async def action(self) -> Any:
|
||||
"""
|
||||
Override this method to implement your action.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
from typing import Union, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from talemate.instance import get_agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene, Character, Actor
|
||||
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
|
||||
|
||||
__all__ = [
|
||||
"deactivate_character",
|
||||
"activate_character",
|
||||
]
|
||||
|
||||
async def deactivate_character(scene:"Scene", character:Union[str, "Character"]):
|
||||
|
||||
async def deactivate_character(scene: "Scene", character: Union[str, "Character"]):
|
||||
"""
|
||||
Deactivates a character
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- `scene`: The scene to deactivate the character from
|
||||
- `character`: The character to deactivate. Can be a string (the character's name) or a Character object
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(character, str):
|
||||
character = scene.get_character(character)
|
||||
|
||||
|
||||
if character.is_player:
|
||||
# can't deactivate the player
|
||||
return False
|
||||
|
||||
|
||||
if character.name in scene.inactive_characters:
|
||||
# already deactivated
|
||||
return False
|
||||
@@ -34,24 +36,24 @@ async def deactivate_character(scene:"Scene", character:Union[str, "Character"])
|
||||
await scene.remove_actor(character.actor)
|
||||
scene.inactive_characters[character.name] = character
|
||||
|
||||
async def activate_character(scene:"Scene", character:Union[str, "Character"]):
|
||||
|
||||
async def activate_character(scene: "Scene", character: Union[str, "Character"]):
|
||||
"""
|
||||
Activates a character
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- `scene`: The scene to activate the character in
|
||||
- `character`: The character to activate. Can be a string (the character's name) or a Character object
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(character, str):
|
||||
character = scene.get_character(character)
|
||||
|
||||
|
||||
if character.name not in scene.inactive_characters:
|
||||
# already activated
|
||||
return False
|
||||
|
||||
|
||||
actor = scene.Actor(character, get_agent("conversation"))
|
||||
await scene.add_actor(actor)
|
||||
del scene.inactive_characters[character.name]
|
||||
|
||||
|
||||
@@ -2,15 +2,13 @@ import argparse
|
||||
import asyncio
|
||||
import glob
|
||||
import os
|
||||
import structlog
|
||||
|
||||
import structlog
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate import Actor, Character, Helper, Player, Scene
|
||||
from talemate.agents import (
|
||||
ConversationAgent,
|
||||
)
|
||||
from talemate.agents import ConversationAgent
|
||||
from talemate.client import OpenAIClient, TextGeneratorWebuiClient
|
||||
from talemate.emit.console import Console
|
||||
from talemate.load import (
|
||||
@@ -129,7 +127,6 @@ async def run_console_session(parser, args):
|
||||
default_client = None
|
||||
|
||||
if "textgenwebui" in clients.values() or args.client == "textgenwebui":
|
||||
|
||||
# Init the TextGeneratorWebuiClient with ConversationAgent and create an actor
|
||||
textgenwebui_api_url = args.textgenwebui_url
|
||||
|
||||
@@ -145,7 +142,6 @@ async def run_console_session(parser, args):
|
||||
clients[client_name] = text_generator_webui_client
|
||||
|
||||
if "openai" in clients.values() or args.client == "openai":
|
||||
|
||||
openai_client = OpenAIClient()
|
||||
|
||||
for client_name, client_typ in clients.items():
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import os
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
|
||||
import talemate.client.runpod
|
||||
from talemate.client.anthropic import AnthropicClient
|
||||
from talemate.client.cohere import CohereClient
|
||||
from talemate.client.google import GoogleClient
|
||||
from talemate.client.groq import GroqClient
|
||||
from talemate.client.koboldcpp import KoboldCppClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.tabbyapi import TabbyAPIClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
226
src/talemate/client/anthropic.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from anthropic import AsyncAnthropic, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"AnthropicClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "claude-3-sonnet-20240229"
|
||||
|
||||
|
||||
@register()
|
||||
class AnthropicClient(ClientBase):
|
||||
"""
|
||||
Anthropic client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "anthropic"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Anthropic"
|
||||
title: str = "Anthropic"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="claude-3-sonnet-20240229", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def anthropic_api_key(self):
|
||||
return self.config.get("anthropic", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.anthropic_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"anthropic_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.anthropic_api_key:
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "claude-3-opus-20240229"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"anthropic set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.output_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.input_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.anthropic_api_key:
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_message,
|
||||
messages=[human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.content)
|
||||
|
||||
response = response.content[0].text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="anthropic API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,6 +1,7 @@
|
||||
import pydantic
|
||||
from enum import Enum
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"ClientType",
|
||||
"ClientBootstrap",
|
||||
@@ -10,8 +11,10 @@ __all__ = [
|
||||
|
||||
LISTS = {}
|
||||
|
||||
|
||||
class ClientType(str, Enum):
|
||||
"""Client type enum."""
|
||||
|
||||
textgen = "textgenwebui"
|
||||
automatic1111 = "automatic1111"
|
||||
|
||||
@@ -20,43 +23,42 @@ class ClientBootstrap(pydantic.BaseModel):
|
||||
"""Client bootstrap model."""
|
||||
|
||||
# client type, currently supports "textgen" and "automatic1111"
|
||||
|
||||
|
||||
client_type: ClientType
|
||||
|
||||
|
||||
# unique client identifier
|
||||
|
||||
|
||||
uid: str
|
||||
|
||||
|
||||
# connection name
|
||||
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
# connection information for the client
|
||||
# REST api url
|
||||
|
||||
|
||||
api_url: str
|
||||
|
||||
|
||||
# service name (for example runpod)
|
||||
|
||||
|
||||
service_name: str
|
||||
|
||||
|
||||
|
||||
class register_list:
|
||||
|
||||
def __init__(self, service_name:str):
|
||||
def __init__(self, service_name: str):
|
||||
self.service_name = service_name
|
||||
|
||||
|
||||
def __call__(self, func):
|
||||
LISTS[self.service_name] = func
|
||||
return func
|
||||
|
||||
|
||||
|
||||
async def list_all(exclude_urls: list[str] = list()):
|
||||
"""
|
||||
Return a list of client bootstrap objects.
|
||||
"""
|
||||
|
||||
|
||||
for service_name, func in LISTS.items():
|
||||
async for item in func():
|
||||
if item.api_url not in exclude_urls:
|
||||
yield item.dict()
|
||||
yield item.dict()
|
||||
|
||||
245
src/talemate/client/cohere.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from cohere import AsyncClient
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
"CohereClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"command",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "command-r-plus"
|
||||
|
||||
|
||||
@register()
|
||||
class CohereClient(ClientBase):
|
||||
"""
|
||||
Cohere client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "cohere"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Cohere"
|
||||
title: str = "Cohere"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="command-r-plus", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def cohere_api_key(self):
|
||||
return self.config.get("cohere", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
ParameterReroute(talemate_parameter="top_p", client_parameter="p"),
|
||||
ParameterReroute(talemate_parameter="top_k", client_parameter="k"),
|
||||
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop_sequences"),
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.cohere_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"cohere_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.cohere_api_key:
|
||||
self.client = AsyncClient("sk-1111")
|
||||
log.error("No cohere API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "command-r-plus"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncClient(self.cohere_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"cohere set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response.text)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
# if temperature is set, it needs to be clamped between 0 and 1.0
|
||||
if "temperature" in parameters:
|
||||
parameters["temperature"] = max(0.0, min(1.0, parameters["temperature"]))
|
||||
|
||||
# if stop_sequences is set, max 5 items
|
||||
if "stop_sequences" in parameters:
|
||||
parameters["stop_sequences"] = parameters["stop_sequences"][:5]
|
||||
|
||||
# if both frequency_penalty and presence_penalty are set, drop frequency_penalty
|
||||
if "presence_penalty" in parameters and "frequency_penalty" in parameters:
|
||||
del parameters["frequency_penalty"]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.cohere_api_key:
|
||||
raise Exception("No cohere API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat(
|
||||
model=self.model_name,
|
||||
preamble=system_message,
|
||||
message=human_message,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.text)
|
||||
|
||||
response = response.text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
# except PermissionDeniedError as e:
|
||||
# self.log.error("generate error", e=e)
|
||||
# emit("status", message="cohere API: Permission Denied", status="error")
|
||||
# return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -3,19 +3,20 @@ Context managers for various client-side operations.
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from pydantic import BaseModel, Field
|
||||
from copy import deepcopy
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = [
|
||||
'context_data',
|
||||
'client_context_attribute',
|
||||
'ContextModel',
|
||||
"context_data",
|
||||
"client_context_attribute",
|
||||
"ContextModel",
|
||||
]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def model_to_dict_without_defaults(model_instance):
|
||||
model_dict = model_instance.dict()
|
||||
for field_name, field in model_instance.__class__.__fields__.items():
|
||||
@@ -23,20 +24,26 @@ def model_to_dict_without_defaults(model_instance):
|
||||
del model_dict[field_name]
|
||||
return model_dict
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
talking_character: str = None
|
||||
other_characters: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ContextModel(BaseModel):
|
||||
"""
|
||||
Pydantic model for the context data.
|
||||
"""
|
||||
|
||||
nuke_repetition: float = Field(0.0, ge=0.0, le=3.0)
|
||||
conversation: ConversationContext = Field(default_factory=ConversationContext)
|
||||
length: int = 96
|
||||
inference_preset: str = None
|
||||
|
||||
|
||||
# Define the context variable as an empty dictionary
|
||||
context_data = ContextVar('context_data', default=ContextModel().model_dump())
|
||||
context_data = ContextVar("context_data", default=ContextModel().model_dump())
|
||||
|
||||
|
||||
def client_context_attribute(name, default=None):
|
||||
"""
|
||||
@@ -47,6 +54,7 @@ def client_context_attribute(name, default=None):
|
||||
# Return the value of the key if it exists, otherwise return the default value
|
||||
return data.get(name, default)
|
||||
|
||||
|
||||
def set_client_context_attribute(name, value):
|
||||
"""
|
||||
Set the value of the context variable `context_data` for the given key.
|
||||
@@ -55,7 +63,8 @@ def set_client_context_attribute(name, value):
|
||||
data = context_data.get()
|
||||
# Set the value of the key
|
||||
data[name] = value
|
||||
|
||||
|
||||
|
||||
def set_conversation_context_attribute(name, value):
|
||||
"""
|
||||
Set the value of the context variable `context_data.conversation` for the given key.
|
||||
@@ -65,6 +74,7 @@ def set_conversation_context_attribute(name, value):
|
||||
# Set the value of the key
|
||||
data["conversation"][name] = value
|
||||
|
||||
|
||||
class ClientContext:
|
||||
"""
|
||||
A context manager to set values to the context variable `context_data`.
|
||||
@@ -82,10 +92,10 @@ class ClientContext:
|
||||
Set the key-value pairs to the context variable `context_data` when entering the context.
|
||||
"""
|
||||
# Get the current context data
|
||||
|
||||
|
||||
data = deepcopy(context_data.get()) if context_data.get() else {}
|
||||
data.update(self.values)
|
||||
|
||||
|
||||
# Update the context data
|
||||
self.token = context_data.set(data)
|
||||
|
||||
@@ -93,5 +103,5 @@ class ClientContext:
|
||||
"""
|
||||
Reset the context variable `context_data` to its previous values when exiting the context.
|
||||
"""
|
||||
|
||||
|
||||
context_data.reset(self.token)
|
||||
|
||||
34
src/talemate/client/custom/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger("talemate.client.custom")
|
||||
|
||||
# import every submodule in this directory
|
||||
#
|
||||
# each directory in this directory is a submodule
|
||||
|
||||
# get the current directory
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
||||
# get all subdirectories
|
||||
subdirectories = [
|
||||
os.path.join(current_directory, name)
|
||||
for name in os.listdir(current_directory)
|
||||
if os.path.isdir(os.path.join(current_directory, name))
|
||||
]
|
||||
|
||||
# import every submodule
|
||||
|
||||
for subdirectory in subdirectories:
|
||||
# get the name of the submodule
|
||||
submodule_name = os.path.basename(subdirectory)
|
||||
|
||||
if submodule_name.startswith("__"):
|
||||
continue
|
||||
|
||||
log.info("activating custom client", module=submodule_name)
|
||||
|
||||
# import the submodule
|
||||
importlib.import_module(f".{submodule_name}", __package__)
|
||||
@@ -0,0 +1,5 @@
|
||||
Each client should be in its own subdirectory.
|
||||
|
||||
The subdirectory itself must be a valid python module.
|
||||
|
||||
Check out docs/dev/client/example/test for a very simplistic custom client example.
|
||||
332
src/talemate/client/google.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import vertexai
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from vertexai.generative_models import (
|
||||
ChatSession,
|
||||
GenerativeModel,
|
||||
ResponseValidationError,
|
||||
SafetySetting,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import RemoteServiceMixin
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
"GoogleClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"gemini-1.0-pro",
|
||||
"gemini-1.5-pro-preview-0409",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gemini-1.0-pro"
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
"""
|
||||
Google client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "google"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Google"
|
||||
title: str = "Google"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"disable_safety_settings": ExtraField(
|
||||
name="disable_safety_settings",
|
||||
type="bool",
|
||||
label="Disable Safety Settings",
|
||||
required=False,
|
||||
description="Disable Google's safety settings for responses generated by the model.",
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, model="gemini-1.0-pro", **kwargs):
|
||||
self.model_name = model
|
||||
self.setup_status = None
|
||||
self.model_instance = None
|
||||
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
|
||||
self.google_credentials_read = False
|
||||
self.google_project_id = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def google_credentials(self):
|
||||
path = self.google_credentials_path
|
||||
if not path:
|
||||
return None
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
@property
|
||||
def google_credentials_path(self):
|
||||
return self.config.get("google").get("gcloud_credentials_path")
|
||||
|
||||
@property
|
||||
def google_location(self):
|
||||
return self.config.get("google").get("gcloud_location")
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
# all google settings must be set
|
||||
return all(
|
||||
[
|
||||
self.google_credentials_path,
|
||||
self.google_location,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def safety_settings(self):
|
||||
if not self.disable_safety_settings:
|
||||
return None
|
||||
|
||||
safety_settings = [
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
]
|
||||
|
||||
return safety_settings
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
ParameterReroute(talemate_parameter="max_tokens", client_parameter="max_output_tokens"),
|
||||
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop_sequences"),
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.ready:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "Setup incomplete"
|
||||
error_action = ErrorAction(
|
||||
title="Setup Google API credentials",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"google_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
}
|
||||
|
||||
self.populate_extra_fields(data)
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None, **kwargs):
|
||||
if not self.ready:
|
||||
log.error("Google cloud setup incomplete")
|
||||
if self.setup_status:
|
||||
self.setup_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gemini-1.0-pro"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
if self.google_credentials_path:
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_credentials_path
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.setup_status:
|
||||
if self.setup_status is False:
|
||||
project_id = self.google_credentials.get("project_id")
|
||||
self.google_project_id = project_id
|
||||
if self.google_credentials_path:
|
||||
vertexai.init(project=project_id, location=self.google_location)
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.setup_status = True
|
||||
|
||||
self.model_instance = GenerativeModel(model_name=model)
|
||||
|
||||
log.info(
|
||||
"google set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response.text)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "disable_safety_settings" in kwargs:
|
||||
self.disable_safety_settings = kwargs["disable_safety_settings"]
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
log.warning("clean_prompt_parameters", parameters=parameters)
|
||||
# if top_k is 0, remove it
|
||||
if "top_k" in parameters and parameters["top_k"] == 0:
|
||||
del parameters["top_k"]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.ready:
|
||||
raise Exception("Google cloud setup incomplete")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
disable_safety_settings=self.disable_safety_settings,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
chat = self.model_instance.start_chat()
|
||||
|
||||
response = await chat.send_message_async(
|
||||
human_message,
|
||||
safety_settings=self.safety_settings,
|
||||
generation_config=parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.text
|
||||
|
||||
log.debug("generated response", response=response)
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
|
||||
# except PermissionDeniedError as e:
|
||||
# self.log.error("generate error", e=e)
|
||||
# emit("status", message="google API: Permission Denied", status="error")
|
||||
# return ""
|
||||
except ResourceExhausted as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="google API: Quota Limit reached", status="error")
|
||||
return ""
|
||||
except ResponseValidationError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status",
|
||||
message="google API: Response Validation Error",
|
||||
status="error",
|
||||
)
|
||||
if not self.disable_safety_settings:
|
||||
return "Failed to generate response. Probably due to safety settings, you can turn them off in the client settings."
|
||||
return "Failed to generate response. Please check logs."
|
||||
except Exception as e:
|
||||
raise
|
||||
238
src/talemate/client/groq.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from groq import AsyncGroq, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"GroqClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"mixtral-8x7b-32768",
|
||||
"llama3-8b-8192",
|
||||
"llama3-70b-8192",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = []
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 8192
|
||||
model: str = "llama3-70b-8192"
|
||||
|
||||
|
||||
@register()
|
||||
class GroqClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "groq"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Groq"
|
||||
title: str = "Groq"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="llama3-70b-8192", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def groq_api_key(self):
|
||||
return self.config.get("groq", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop"),
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.groq_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"groq_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.groq_api_key:
|
||||
self.client = AsyncGroq(api_key="sk-1111")
|
||||
log.error("No groq.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "llama3-70b-8192"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncGroq(api_key=self.groq_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"groq.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.prompt_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.groq_api_key:
|
||||
raise Exception("No groq.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,16 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import requests
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.client.registry import register
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
from talemate.client.textgenwebui import RESTTaleMateClient
|
||||
from talemate.emit import Emission, emit
|
||||
|
||||
# NOT IMPLEMENTED AT THIS POINT
|
||||
293
src/talemate/client/koboldcpp.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import random
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# import urljoin
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ParameterReroute
|
||||
from talemate.client.registry import register
|
||||
import talemate.util as util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.agents.visual import VisualBase
|
||||
|
||||
log = structlog.get_logger("talemate.client.koboldcpp")
|
||||
|
||||
|
||||
class KoboldCppClientDefaults(Defaults):
|
||||
api_url: str = "http://localhost:5001"
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class KoboldCppClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "koboldcpp"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "KoboldCpp"
|
||||
title: str = "KoboldCpp"
|
||||
enable_api_auth: bool = True
|
||||
defaults: KoboldCppClientDefaults = KoboldCppClientDefaults()
|
||||
|
||||
@property
|
||||
def request_headers(self):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
parts = urlparse(self.api_url)
|
||||
return f"{parts.scheme}://{parts.netloc}"
|
||||
|
||||
@property
|
||||
def is_openai(self) -> bool:
|
||||
"""
|
||||
kcpp has two apis
|
||||
|
||||
open-ai implementation at /v1
|
||||
their own implenation at /api/v1
|
||||
"""
|
||||
return "/api/v1" not in self.api_url
|
||||
|
||||
@property
|
||||
def api_url_for_model(self) -> str:
|
||||
if self.is_openai:
|
||||
# join /model to url
|
||||
return urljoin(self.api_url, "models")
|
||||
else:
|
||||
# join /models to url
|
||||
return urljoin(self.api_url, "model")
|
||||
|
||||
@property
|
||||
def api_url_for_generation(self) -> str:
|
||||
if self.is_openai:
|
||||
# join /v1/completions
|
||||
return urljoin(self.api_url, "completions")
|
||||
else:
|
||||
# join /api/v1/generate
|
||||
return urljoin(self.api_url, "generate")
|
||||
|
||||
@property
|
||||
def max_tokens_param_name(self):
|
||||
if self.is_openai:
|
||||
return "max_tokens"
|
||||
else:
|
||||
return "max_length"
|
||||
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
if not self.is_openai:
|
||||
# koboldcpp united api
|
||||
|
||||
return [
|
||||
ParameterReroute(talemate_parameter="max_tokens", client_parameter="max_length"),
|
||||
"max_context_length",
|
||||
ParameterReroute(talemate_parameter="repetition_penalty", client_parameter="rep_pen"),
|
||||
ParameterReroute(talemate_parameter="repetition_penalty_range", client_parameter="rep_pen_range"),
|
||||
"top_p",
|
||||
"top_k",
|
||||
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop_sequence"),
|
||||
"temperature",
|
||||
]
|
||||
|
||||
else:
|
||||
# openai api
|
||||
|
||||
return [
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"top_p",
|
||||
"temperature",
|
||||
]
|
||||
|
||||
def api_endpoint_specified(self, url: str) -> bool:
|
||||
return "/v1" in self.api_url
|
||||
|
||||
def ensure_api_endpoint_specified(self):
|
||||
if not self.api_endpoint_specified(self.api_url):
|
||||
# url doesn't specify the api endpoint
|
||||
# use the koboldcpp united api
|
||||
self.api_url = urljoin(self.api_url.rstrip("/") + "/", "/api/v1/")
|
||||
if not self.api_url.endswith("/"):
|
||||
self.api_url += "/"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.api_key = kwargs.pop("api_key", "")
|
||||
super().__init__(**kwargs)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.api_url_for_model,
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
|
||||
|
||||
response_data = response.json()
|
||||
if self.is_openai:
|
||||
# {"object": "list", "data": [{"id": "koboldcpp/dolphin-2.8-mistral-7b", "object": "model", "created": 1, "owned_by": "koboldcpp", "permission": [], "root": "koboldcpp"}]}
|
||||
model_name = response_data.get("data")[0].get("id")
|
||||
else:
|
||||
# {"result": "koboldcpp/dolphin-2.8-mistral-7b"}
|
||||
model_name = response_data.get("result")
|
||||
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def tokencount(self, content:str) -> int:
|
||||
"""
|
||||
KoboldCpp has a tokencount endpoint we can use to count tokens
|
||||
for the prompt and response
|
||||
|
||||
If the endpoint is not available, we will use the default token count estimate
|
||||
"""
|
||||
|
||||
# extract scheme and host from api url
|
||||
|
||||
parts = urlparse(self.api_url)
|
||||
|
||||
url_tokencount = f"{parts.scheme}://{parts.netloc}/api/extra/tokencount"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url_tokencount,
|
||||
json={"prompt":content},
|
||||
timeout=None,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
# kobold united doesn't have tokencount endpoint
|
||||
return util.count_tokens(content)
|
||||
|
||||
tokencount = len(response.json().get("ids",[]))
|
||||
return tokencount
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
self._returned_prompt_tokens = await self.tokencount(parameters["prompt"] )
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.api_url_for_generation,
|
||||
json=parameters,
|
||||
timeout=None,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
response_data = response.json()
|
||||
try:
|
||||
if self.is_openai:
|
||||
response_text = response_data["choices"][0]["text"]
|
||||
else:
|
||||
response_text = response_data["results"][0]["text"]
|
||||
except (TypeError, KeyError) as exc:
|
||||
log.error("Failed to generate text", exc=exc, response_data=response_data, response_status=response.status_code)
|
||||
response_text = ""
|
||||
|
||||
self._returned_response_tokens = await self.tokencount(response_text)
|
||||
return response_text
|
||||
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
|
||||
if "rep_pen" in prompt_config:
|
||||
rep_pen_key = "rep_pen"
|
||||
elif "presence_penalty" in prompt_config:
|
||||
rep_pen_key = "presence_penalty"
|
||||
else:
|
||||
rep_pen_key = "repetition_penalty"
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
try:
|
||||
if rep_pen_key == "presence_penalty":
|
||||
presence_penalty = prompt_config["presence_penalty"]
|
||||
prompt_config["presence_penalty"] = round(random.uniform(
|
||||
presence_penalty + 0.1, presence_penalty + offset
|
||||
),1)
|
||||
else:
|
||||
rep_pen = prompt_config[rep_pen_key]
|
||||
prompt_config[rep_pen_key] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs.pop("api_key")
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
|
||||
async def visual_automatic1111_setup(self, visual_agent:"VisualBase") -> bool:
|
||||
|
||||
"""
|
||||
Automatically configure the visual agent for automatic1111
|
||||
if the koboldcpp server has a SD model available
|
||||
"""
|
||||
|
||||
if not self.connected:
|
||||
return False
|
||||
|
||||
sd_models_url = urljoin(self.url, "/sdapi/v1/sd-models")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
||||
try:
|
||||
response = await client.get(
|
||||
url=sd_models_url, timeout=2
|
||||
)
|
||||
except Exception as exc:
|
||||
log.error(f"Failed to fetch sd models from {sd_models_url}", exc=exc)
|
||||
return False
|
||||
|
||||
if response.status_code != 200:
|
||||
return False
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
sd_model = response_data[0].get("model_name") if response_data else None
|
||||
|
||||
if not sd_model:
|
||||
return False
|
||||
|
||||
log.info("automatic1111_setup", sd_model=sd_model)
|
||||
|
||||
visual_agent.actions["automatic1111"].config["api_url"].value = self.url
|
||||
visual_agent.is_enabled = True
|
||||
return True
|
||||
|
||||
@@ -1,65 +1,63 @@
|
||||
import pydantic
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import ClientBase, ParameterReroute
|
||||
from talemate.client.registry import register
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url:str = "http://localhost:1234"
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 8192
|
||||
|
||||
|
||||
@register()
|
||||
class LMStudioClient(ClientBase):
|
||||
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "lmstudio"
|
||||
conversation_retries = 5
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix:str = "LMStudio"
|
||||
title:str = "LMStudio"
|
||||
defaults:Defaults = Defaults()
|
||||
name_prefix: str = "LMStudio"
|
||||
title: str = "LMStudio"
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
ParameterReroute(talemate_parameter="stopping_strings", client_parameter="stop"),
|
||||
]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
async def get_model_name(self):
|
||||
model_name = await super().get_model_name()
|
||||
|
||||
# model name comes back as a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
|
||||
if model_name:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
return ""
|
||||
|
||||
259
src/talemate/client/mistral.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from mistralai.async_client import MistralAsyncClient
|
||||
from mistralai.exceptions import MistralAPIStatusException
|
||||
from mistralai.models.chat_completion import ChatMessage
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"MistralAIClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"open-mixtral-8x22b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"open-mixtral-8x22b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x22b"
|
||||
|
||||
|
||||
@register()
|
||||
class MistralAIClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "mistral"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "MistralAI"
|
||||
title: str = "MistralAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x22b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def mistralai_api_key(self):
|
||||
return self.config.get("mistralai", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.mistralai_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"mistralai_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.mistralai_api_key:
|
||||
self.client = MistralAsyncClient(api_key="sk-1111")
|
||||
log.error("No mistral.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "open-mixtral-8x22b"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = MistralAsyncClient(api_key=self.mistralai_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"mistral.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.prompt_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
super().clean_prompt_parameters(parameters)
|
||||
# clamp temperature to 0.1 and 1.0
|
||||
# Unhandled Error: Status: 422. Message: {"object":"error","message":{"detail":[{"type":"less_than_equal","loc":["body","temperature"],"msg":"Input should be less than or equal to 1","input":1.31,"ctx":{"le":1.0},"url":"https://errors.pydantic.dev/2.6/v/less_than_equal"}]},"type":"invalid_request_error","param":null,"code":null}
|
||||
if "temperature" in parameters:
|
||||
parameters["temperature"] = min(1.0, max(0.1, parameters["temperature"]))
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.mistralai_api_key:
|
||||
raise Exception("No mistral.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="system", content=system_message),
|
||||
ChatMessage(role="user", content=prompt.strip()),
|
||||
]
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except MistralAPIStatusException as e:
|
||||
self.log.error("generate error", e=e)
|
||||
if e.http_status in [403, 401]:
|
||||
emit(
|
||||
"status",
|
||||
message="mistral.ai API: Permission Denied",
|
||||
status="error",
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,17 +1,24 @@
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import json
|
||||
import os
|
||||
import structlog
|
||||
import shutil
|
||||
import huggingface_hub
|
||||
import tempfile
|
||||
|
||||
import huggingface_hub
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
__all__ = ["model_prompt"]
|
||||
|
||||
BASE_TEMPLATE_PATH = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "templates", "llm-prompt"
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"templates",
|
||||
"llm-prompt",
|
||||
)
|
||||
|
||||
# holds the default templates
|
||||
# holds the default templates
|
||||
STD_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "std")
|
||||
|
||||
# llm prompt templates provided by talemate
|
||||
@@ -22,137 +29,189 @@ USER_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "user")
|
||||
|
||||
TEMPLATE_IDENTIFIERS = []
|
||||
|
||||
|
||||
def register_template_identifier(cls):
|
||||
TEMPLATE_IDENTIFIERS.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
log = structlog.get_logger("talemate.model_prompts")
|
||||
|
||||
|
||||
class ModelPrompt:
|
||||
|
||||
"""
|
||||
Will attempt to load an LLM prompt template based on the model name
|
||||
|
||||
|
||||
If the model name is not found, it will default to the 'default' template
|
||||
"""
|
||||
|
||||
|
||||
template_map = {}
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
if not hasattr(self, "_env"):
|
||||
log.info("modal prompt", base_template_path=BASE_TEMPLATE_PATH)
|
||||
self._env = Environment(loader=FileSystemLoader([
|
||||
USER_TEMPLATE_PATH,
|
||||
TALEMATE_TEMPLATE_PATH,
|
||||
]))
|
||||
|
||||
self._env = Environment(
|
||||
loader=FileSystemLoader(
|
||||
[
|
||||
USER_TEMPLATE_PATH,
|
||||
TALEMATE_TEMPLATE_PATH,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return self._env
|
||||
|
||||
|
||||
@property
|
||||
def std_templates(self) -> list[str]:
|
||||
env = Environment(loader=FileSystemLoader(STD_TEMPLATE_PATH))
|
||||
return sorted(env.list_templates())
|
||||
|
||||
def __call__(self, model_name:str, system_message:str, prompt:str):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
model_name: str,
|
||||
system_message: str,
|
||||
prompt: str,
|
||||
double_coercion: str = None,
|
||||
):
|
||||
template, template_file = self.get_template(model_name)
|
||||
if not template:
|
||||
template_file = "default.jinja2"
|
||||
template = self.env.get_template(template_file)
|
||||
|
||||
|
||||
if not double_coercion:
|
||||
double_coercion = ""
|
||||
|
||||
if "<|BOT|>" not in prompt and double_coercion:
|
||||
prompt = f"{prompt}<|BOT|>"
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
user_message, coercion_message = prompt.split("<|BOT|>", 1)
|
||||
coercion_message = f"{double_coercion}{coercion_message}"
|
||||
else:
|
||||
user_message = prompt
|
||||
coercion_message = ""
|
||||
|
||||
return template.render({
|
||||
"system_message": system_message,
|
||||
"prompt": prompt,
|
||||
"user_message": user_message,
|
||||
"coercion_message": coercion_message,
|
||||
"set_response" : self.set_response
|
||||
}), template_file
|
||||
|
||||
def set_response(self, prompt:str, response_str:str):
|
||||
|
||||
|
||||
return (
|
||||
template.render(
|
||||
{
|
||||
"system_message": system_message,
|
||||
"prompt": prompt.strip(),
|
||||
"user_message": user_message.strip(),
|
||||
"coercion_message": coercion_message,
|
||||
"set_response": lambda prompt, response_str: self.set_response(
|
||||
prompt, response_str, double_coercion
|
||||
),
|
||||
}
|
||||
),
|
||||
template_file,
|
||||
)
|
||||
|
||||
def set_response(self, prompt: str, response_str: str, double_coercion: str = None):
|
||||
prompt = prompt.strip("\n").strip()
|
||||
|
||||
|
||||
if not double_coercion:
|
||||
double_coercion = ""
|
||||
|
||||
if "<|BOT|>" not in prompt and double_coercion:
|
||||
prompt = f"{prompt}<|BOT|>"
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
|
||||
response_str = f"{double_coercion}{response_str}"
|
||||
|
||||
if "\n<|BOT|>" in prompt:
|
||||
prompt = prompt.replace("\n<|BOT|>", response_str)
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", response_str)
|
||||
else:
|
||||
prompt = prompt.rstrip("\n") + response_str
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
def get_template(self, model_name:str):
|
||||
def get_template(self, model_name: str):
|
||||
"""
|
||||
Will attempt to load an LLM prompt template - this supports
|
||||
partial filename matching on the template file name.
|
||||
"""
|
||||
|
||||
|
||||
matches = []
|
||||
|
||||
cleaned_model_name = model_name.replace("/", "__")
|
||||
|
||||
# Iterate over all templates in the loader's directory
|
||||
for template_name in self.env.list_templates():
|
||||
# strip extension
|
||||
template_name_match = os.path.splitext(template_name)[0]
|
||||
# Check if the model name is in the template filename
|
||||
if template_name_match.lower() in model_name.lower():
|
||||
if template_name_match.lower() in cleaned_model_name.lower():
|
||||
matches.append(template_name)
|
||||
|
||||
|
||||
# If there are no matches, return None
|
||||
if not matches:
|
||||
return None, None
|
||||
|
||||
|
||||
# If there is only one match, return it
|
||||
if len(matches) == 1:
|
||||
return self.env.get_template(matches[0]), matches[0]
|
||||
|
||||
|
||||
# If there are multiple matches, return the one with the longest name
|
||||
sorted_matches = sorted(matches, key=lambda x: len(x), reverse=True)
|
||||
return self.env.get_template(sorted_matches[0]), sorted_matches[0]
|
||||
|
||||
|
||||
def create_user_override(self, template_name:str, model_name:str):
|
||||
|
||||
|
||||
def create_user_override(self, template_name: str, model_name: str):
|
||||
"""
|
||||
Will copy STD_TEMPLATE_PATH/template_name to USER_TEMPLATE_PATH/model_name.jinja2
|
||||
"""
|
||||
|
||||
|
||||
template_name = template_name.split(".jinja2")[0]
|
||||
|
||||
|
||||
cleaned_model_name = model_name.replace("/", "__")
|
||||
|
||||
shutil.copyfile(
|
||||
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),
|
||||
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
|
||||
os.path.join(USER_TEMPLATE_PATH, cleaned_model_name + ".jinja2"),
|
||||
)
|
||||
|
||||
return os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
|
||||
|
||||
def query_hf_for_prompt_template_suggestion(self, model_name:str):
|
||||
print("query_hf_for_prompt_template_suggestion", model_name)
|
||||
|
||||
return os.path.join(USER_TEMPLATE_PATH, cleaned_model_name + ".jinja2")
|
||||
|
||||
def query_hf_for_prompt_template_suggestion(self, model_name: str):
|
||||
api = huggingface_hub.HfApi()
|
||||
|
||||
|
||||
try:
|
||||
author, model_name = model_name.split("_", 1)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
models = list(api.list_models(
|
||||
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
|
||||
))
|
||||
|
||||
|
||||
branch_name = "main"
|
||||
|
||||
# special popular cases
|
||||
|
||||
# bartowski
|
||||
|
||||
if author == "bartowski" and "exl2" in model_name:
|
||||
# split model_name by exl2 and take the first part with "exl2" readded
|
||||
# the second part is the branch name
|
||||
model_name, branch_name = model_name.split("exl2_", 1)
|
||||
model_name = f"{model_name}exl2"
|
||||
|
||||
models = list(api.list_models(model_name=model_name, author=author))
|
||||
|
||||
if not models:
|
||||
return None
|
||||
|
||||
|
||||
model = models[0]
|
||||
|
||||
|
||||
repo_id = f"{author}/{model_name}"
|
||||
|
||||
# Check README.md
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
readme_path = huggingface_hub.hf_hub_download(repo_id=repo_id, filename="README.md", cache_dir=tmpdir)
|
||||
readme_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="README.md",
|
||||
cache_dir=tmpdir,
|
||||
revision=branch_name,
|
||||
)
|
||||
if not readme_path:
|
||||
return None
|
||||
with open(readme_path) as f:
|
||||
@@ -162,25 +221,54 @@ class ModelPrompt:
|
||||
if identifier(readme):
|
||||
return f"{identifier.template_str}.jinja2"
|
||||
|
||||
# Check tokenizer_config.json
|
||||
# "chat_template" key
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="tokenizer_config.json",
|
||||
cache_dir=tmpdir,
|
||||
revision=branch_name,
|
||||
)
|
||||
if not config_path:
|
||||
return None
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
for identifer_cls in TEMPLATE_IDENTIFIERS:
|
||||
identifier = identifer_cls()
|
||||
if identifier(config.get("chat_template", "")):
|
||||
return f"{identifier.template_str}.jinja2"
|
||||
|
||||
|
||||
|
||||
model_prompt = ModelPrompt()
|
||||
|
||||
|
||||
class TemplateIdentifier:
|
||||
def __call__(self, content:str):
|
||||
def __call__(self, content: str):
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class Llama2Identifier(TemplateIdentifier):
|
||||
template_str = "Llama2"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
return "[INST]" in content and "[/INST]" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class Llama3Identifier(TemplateIdentifier):
|
||||
template_str = "Llama3"
|
||||
|
||||
def __call__(self, content: str):
|
||||
return "<|start_header_id|>" in content and "<|end_header_id|>" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class ChatMLIdentifier(TemplateIdentifier):
|
||||
template_str = "ChatML"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<|im_start|>system
|
||||
{{ system_message }}<|im_end|>
|
||||
@@ -189,28 +277,63 @@ class ChatMLIdentifier(TemplateIdentifier):
|
||||
<|im_start|>assistant
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
|
||||
return "<|im_start|>" in content and "<|im_end|>" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class CommandRIdentifier(TemplateIdentifier):
|
||||
template_str = "CommandR"
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ system_message }}
|
||||
{{ user_message }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|>
|
||||
<|CHATBOT_TOKEN|>{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"<|im_start|>system" in content
|
||||
and "<|im_end|>" in content
|
||||
and "<|im_start|>user" in content
|
||||
and "<|im_start|>assistant" in content
|
||||
"<|START_OF_TURN_TOKEN|>" in content
|
||||
and "<|END_OF_TURN_TOKEN|>" in content
|
||||
and "<|SYSTEM_TOKEN|>" not in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class CommandRPlusIdentifier(TemplateIdentifier):
|
||||
template_str = "CommandRPlus"
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ system_message }}
|
||||
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ user_message }}
|
||||
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"<|START_OF_TURN_TOKEN|>" in content
|
||||
and "<|END_OF_TURN_TOKEN|>" in content
|
||||
and "<|SYSTEM_TOKEN|>" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class InstructionInputResponseIdentifier(TemplateIdentifier):
|
||||
template_str = "InstructionInputResponse"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
return (
|
||||
"### Instruction:" in content
|
||||
and "### Input:" in content
|
||||
and "### Response:" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class AlpacaIdentifier(TemplateIdentifier):
|
||||
template_str = "Alpaca"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
{{ system_message }}
|
||||
|
||||
@@ -220,20 +343,19 @@ class AlpacaIdentifier(TemplateIdentifier):
|
||||
### Response:
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"### Instruction:" in content
|
||||
and "### Response:" in content
|
||||
)
|
||||
|
||||
|
||||
return "### Instruction:" in content and "### Response:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class OpenChatIdentifier(TemplateIdentifier):
|
||||
template_str = "OpenChat"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
GPT4 Correct System: {{ system_message }}<|end_of_turn|>GPT4 Correct User: {{ user_message }}<|end_of_turn|>GPT4 Correct Assistant: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
|
||||
return (
|
||||
"<|end_of_turn|>" in content
|
||||
and "GPT4 Correct System:" in content
|
||||
@@ -241,54 +363,51 @@ class OpenChatIdentifier(TemplateIdentifier):
|
||||
and "GPT4 Correct Assistant:" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class VicunaIdentifier(TemplateIdentifier):
|
||||
template_str = "Vicuna"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
SYSTEM: {{ system_message }}
|
||||
USER: {{ user_message }}
|
||||
ASSISTANT: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"SYSTEM:" in content
|
||||
and "USER:" in content
|
||||
and "ASSISTANT:" in content
|
||||
)
|
||||
|
||||
|
||||
return "SYSTEM:" in content and "USER:" in content and "ASSISTANT:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class USER_ASSISTANTIdentifier(TemplateIdentifier):
|
||||
template_str = "USER_ASSISTANT"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
USER: {{ system_message }} {{ user_message }} ASSISTANT: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"USER:" in content
|
||||
and "ASSISTANT:" in content
|
||||
)
|
||||
|
||||
|
||||
return "USER:" in content and "ASSISTANT:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class UserAssistantIdentifier(TemplateIdentifier):
|
||||
template_str = "UserAssistant"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
User: {{ system_message }} {{ user_message }}
|
||||
Assistant: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"User:" in content
|
||||
and "Assistant:" in content
|
||||
)
|
||||
|
||||
return "User:" in content and "Assistant:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class ZephyrIdentifier(TemplateIdentifier):
|
||||
template_str = "Zephyr"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<|system|>
|
||||
{{ system_message }}</s>
|
||||
@@ -297,9 +416,9 @@ class ZephyrIdentifier(TemplateIdentifier):
|
||||
<|assistant|>
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
|
||||
return (
|
||||
"<|system|>" in content
|
||||
and "<|user|>" in content
|
||||
and "<|assistant|>" in content
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,21 +1,46 @@
|
||||
import json
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.config import load_config
|
||||
import structlog
|
||||
import tiktoken
|
||||
|
||||
__all__ = [
|
||||
"OpenAIClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o-2024-05-13",
|
||||
"gpt-4o",
|
||||
]
|
||||
|
||||
# any model starting with gpt-4- is assumed to support 'json_object'
|
||||
# for others we need to explicitly state the model name
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"gpt-4o",
|
||||
"gpt-3.5-turbo-0125",
|
||||
]
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -66,9 +91,11 @@ def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length:int = 16384
|
||||
model:str = "gpt-4-turbo-preview"
|
||||
max_token_length: int = 16384
|
||||
model: str = "gpt-4o"
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAIClient(ClientBase):
|
||||
@@ -79,35 +106,37 @@ class OpenAIClient(ClientBase):
|
||||
client_type = "openai"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix:str = "OpenAI"
|
||||
title:str = "OpenAI"
|
||||
manual_model:bool = True
|
||||
manual_model_choices:list[str] = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
]
|
||||
requires_prompt_template: bool = False
|
||||
defaults:Defaults = Defaults()
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
def __init__(self, model="gpt-4-turbo-preview", **kwargs):
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "OpenAI"
|
||||
title: str = "OpenAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="gpt-4o", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai",{}).get("api_key")
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
@@ -127,13 +156,13 @@ class OpenAIClient(ClientBase):
|
||||
arguments=[
|
||||
"application",
|
||||
"openai_api",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
@@ -145,25 +174,27 @@ class OpenAIClient(ClientBase):
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length:int=None):
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.openai_api_key:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111")
|
||||
log.error("No OpenAI API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit('request_client_status')
|
||||
emit('request_agent_status')
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gpt-3.5-turbo-16k"
|
||||
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
if model == "gpt-3.5-turbo":
|
||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||
@@ -175,16 +206,20 @@ class OpenAIClient(ClientBase):
|
||||
self.max_token_length = min(max_token_length or 128000, 128000)
|
||||
else:
|
||||
self.max_token_length = max_token_length or 2048
|
||||
|
||||
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit('request_client_status')
|
||||
emit('request_agent_status')
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info("openai set client", max_token_length=self.max_token_length, provided_max_token_length=max_token_length, model=model)
|
||||
|
||||
|
||||
log.info(
|
||||
"openai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
@@ -203,69 +238,78 @@ class OpenAIClient(ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
|
||||
def prompt_template(self, system_message:str, prompt:str):
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
# only gpt-4-1106-preview supports json_object response coersion
|
||||
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
|
||||
if not self.openai_api_key:
|
||||
raise Exception("No OpenAI API key set")
|
||||
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
supports_json_object = self.model_name.startswith("gpt-4-")
|
||||
supports_json_object = (
|
||||
self.model_name.startswith("gpt-4-")
|
||||
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
)
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nContinue this response: ")
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[system_message, human_message], **parameters
|
||||
model=self.model_name,
|
||||
messages=[system_message, human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right):].strip()
|
||||
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -1,101 +1,151 @@
|
||||
import urllib
|
||||
import random
|
||||
import pydantic
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
import structlog
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
|
||||
from openai import AsyncOpenAI, PermissionDeniedError, NotFoundError
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
|
||||
log = structlog.get_logger("talemate.client.openai_compat")
|
||||
|
||||
EXPERIMENTAL_DESCRIPTION = """Use this client if you want to connect to a service implementing an OpenAI-compatible API. Success is going to depend on the level of compatibility. Use the actual OpenAI client if you want to connect to OpenAI's API."""
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url:str = "http://localhost:5000"
|
||||
api_key:str = ""
|
||||
max_token_length:int = 4096
|
||||
model:str = ""
|
||||
api_url: str = "http://localhost:5000"
|
||||
api_key: str = ""
|
||||
max_token_length: int = 8192
|
||||
model: str = ""
|
||||
api_handles_prompt_template: bool = False
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
client_type = "openai_compat"
|
||||
conversation_retries = 5
|
||||
conversation_retries = 0
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title:str = "OpenAI Compatible API"
|
||||
name_prefix:str = "OpenAI Compatible API"
|
||||
experimental:str = EXPERIMENTAL_DESCRIPTION
|
||||
enable_api_auth:bool = True
|
||||
manual_model:bool = True
|
||||
defaults:Defaults = Defaults()
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
title: str = "OpenAI Compatible API"
|
||||
name_prefix: str = "OpenAI Compatible API"
|
||||
experimental: str = EXPERIMENTAL_DESCRIPTION
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"api_handles_prompt_template": ExtraField(
|
||||
name="api_handles_prompt_template",
|
||||
type="bool",
|
||||
label="API handles prompt template (chat/completions)",
|
||||
required=False,
|
||||
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored. This is not recommended and should only be used if the API does not support the `completions` andpoint or you don't know which prompt template to use.",
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_key = api_key
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return EXPERIMENTAL_DESCRIPTION
|
||||
|
||||
@property
|
||||
def can_be_coerced(self):
|
||||
"""
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
return not self.api_handles_prompt_template
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key")
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key=self.api_key)
|
||||
self.model_name = kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
url = self.api_url
|
||||
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
)
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
)
|
||||
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
try:
|
||||
model_name = await super().get_model_name()
|
||||
except NotFoundError as e:
|
||||
# api does not implement model listing
|
||||
return self.model_name
|
||||
except Exception as e:
|
||||
self.log.error("get_model_name error", e=e)
|
||||
return self.model_name
|
||||
return self.model_name
|
||||
|
||||
# model name may be a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
is_filepath = "/" in model_name
|
||||
is_filepath_windows = "\\" in model_name
|
||||
|
||||
if is_filepath or is_filepath_windows:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
if self.api_handles_prompt_template:
|
||||
# OpenAI API handles prompt template
|
||||
# Use the chat completions endpoint
|
||||
self.log.debug("generate (chat/completions)", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
return self.process_response_for_indirect_coercion(prompt, response)
|
||||
else:
|
||||
# Talemate handles prompt template
|
||||
# Use the completions endpoint
|
||||
self.log.debug("generate (completions)", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
parameters["prompt"] = prompt
|
||||
response = await self.client.completions.create(
|
||||
model=self.model_name, **parameters
|
||||
)
|
||||
return response.choices[0].text
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Client API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Error during generation (check logs)", status="error")
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
@@ -104,8 +154,40 @@ class OpenAICompatibleClient(ClientBase):
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
|
||||
)
|
||||
if "api_key" in kwargs:
|
||||
self.api_auth = kwargs["api_key"]
|
||||
self.api_key = kwargs["api_key"]
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
# TODO: why isn't this calling super()?
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and presence penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
self.set_client(**kwargs)
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
|
||||
try:
|
||||
presence_penalty = prompt_config["presence_penalty"]
|
||||
prompt_config["presence_penalty"] = round(random.uniform(
|
||||
presence_penalty + 0.1, presence_penalty + offset
|
||||
),1)
|
||||
except KeyError:
|
||||
pass
|
||||