Compare commits

...

30 Commits

Author SHA1 Message Date
veguAI
1874234d2c Prep 0.25.1 (#103)
* remove auto client disable

* 0.25.1
2024-05-05 23:23:30 +03:00
veguAI
ef99539e69 Update README.md 2024-05-05 22:30:24 +03:00
veguAI
39bd02722d 0.25.0 (#100)
* flip title and name in recent scenes

* fix issue where a message could not be regenerated after applying continuity error fixes

* prompt tweaks

* allow json parameters for commands

* autocomplete improvements

* dialogue cleanup fixes

* fix issue with narrate after dialogue and llama3 (and other models that don't have a line break after the user prompt in their prompt template.

* expose ability to auto generate dialogue instructions to wsm character ux

* use b64_json response type

* move tag checks up so they match first

* fix typo

* prompt tweak

* api key support

* prompt tweaks

* editable parameters in prompt debugger / tester

* allow reseting of prompt params

* codemirror for prompt editor

* prompt tweaks

* more prompt debug tool tweaks

* some extra control for `context_history`

* new analytical preset (testing)

* add `join` and `llm_can_be_coerced` to jinja env

* support factual list summaries

* prompt tweaks to continuity check and fix

* new summarization method `facts` exposed to ux

* clamp mistral ai temperature according to their new requirements

* prompt tweaks

* better parsing of fixed dialogue response

* prompt tweaks

* fix intermittent empty meta issue

* history regen status progression and small ux tweaks

* summary entries should always be condensed

* google gemini support

* relock to install google-cloud-aiplatform for vertex ai inference

* fix instruction link

* better error handling of google safety validation and allow disabling of safety validation

* docs

* clarify credentials path requirements

* tweak error line identification

* handle quota limit error

* autocomplete ux wired to assistant plugin instead of command

* autocomplete narrative editing and fixes to autocomplete during dialog edit

* main input autocomplete tweaks

* allow new lines in main input

* 0.25.0 and relock

* fix issue with autocomplete elsewhere locking out main input

* better way to determine remote service

* prompt tweak

* fix rubberbanding issue when editing character attributes

* add open mistral 8x22

* fix continuity error check summary inclusion of target entry

* docs

* default context length to 8192

* linting
2024-05-05 22:16:03 +03:00
veguAI
f0b627b900 Update README.md 2024-04-27 00:46:39 +03:00
veguAI
95ae00e01f 0.24.0 (#97)
* groq client

* adjust max token length

* more openai image download  fixes

* graphic novel style

* dialogue cleanup

* fix issue where auto-break repetition would trigger on empty responses

* reduce default convo retries to 1

* prompt tweaks

* fix some clients not handling autocomplete well

* screenplay dialogue generation tweaks

* message flags

* better cleanup of redundant change_ai_character calls

* super experimental continuity error fix mode for editor agent

* clamp temperature

* tweaks to continuity error fixing and expose to ux

* expose to ux

* allow CmdFixContinuityErrors to work even if editor has check_continuity_errors disabled

* prompt tweak

* support --endofline-- as well

* double coercion client option added

* fix issue with double coercion inserting "None" if not set

* client ux refactor to make room for coercion config

* rest of -- can be treated as *

* disable double coercion when json coercion is active since it kills accuracy

* prompt tweaks

* prompt tweaks

* show coercion status in client list

* change preset for edit_fix_continuity

* interim commit of coninuity error handling progress

* tag based presets

* special tokens to keep trailing whitespace if needed

* fix continuity errors finalized for now

* change double coercion formatting

* 0.24.0 and relock

* add groq and cohere to supported services

* linting
2024-04-27 00:24:53 +03:00
veguAI
83027b3a0f 0.23.0 (#91)
* dockerfiles and docker-compose

* containerization fixes

* docker instructions

* readme

* readme

* dont mount src by default, readme

* hf template determine fixes

* auto determine prompt template

* script to start talemate listening only to 127.0.0.1

* prompt tweaks

* auto narrate round every 3 rounds

* tweaks

* Add return to startscreen button

* Only show return to start screen button if scene is active

* improvements to character creation

* dedicated property for scene title separate fromn the save directory name

* filter out negations into negative keywords

* increase auto narrate delay

* add character portrait keyword

* summarization should ignore most recent message, as it is often regenerated.

* cohere client

* specify python3

* improve viable runpod text gen detection

* fix formatting in template preview

* cohere command-r plus template that i am not sure if correct or not

* mistral client set to decensor

* fix issue with parsing json responses

* command-r prompts updated

* use official mistralai python client

* send max_tokens

* new input autocomplete functionality

* prompt tweeaks

* llama 3 templates

* add <|eot_id|> to stopping strings

* prompt tweak

* tooltip

* llama-3 identifier

* command-r and command-r plus prompt identifiers

* text-gen-webui client tweaks to make llama3 eos tokens work correctly

* better llama-3 detection

* better llama-3 finalizing of parameters

* streamline client prompt finalizers
reduce YY model smoothing factor from 0.3 to 0.1 for text-generation-webui client

* relock

* linting

* set 0.23.0

* add new gpt-4 models

* set 0.23.0

* add note about conecting to text-gen-webui from docker

* fix openai image generation no longer working

* default to concept_art
2024-04-20 01:01:06 +03:00
veguAI
27eba3bd63 0.22.0 2024-03-29 21:41:45 +02:00
veguAI
ba64050eab 0.22.0 (#89)
* linux dev instance shortcuts

* add voice samples to gitignore

* direction mode: inner monologue

* actor direction fixes

* py script support for scene logic

* fix end_simulation call

* port sim suite logic to python

* remove dupe log

* fix typing

* section off the text

* fix end simulation command

* simulation goal, prompt tweaks

* prompt tweaks

* dialogue format improvements

* director action logged with message

* call director action log and other fixes

* generate character dialogue instructions, prompt fixes, director action ux

* fix question / answer call

* generate dialogue instructions when loading from character cards

* more dialogue format improvements

* set scene content context more reliably.

* fix innermonologue perspective

* conversation prompt should honor the client's decensor setting

* fix comfyui checkpoint list not loading

* more dialogue format fixes

* prompt tweaks

* fix sim suite group characters, prompt fixes

* npm relock

* handle inanimate objects, handle player name change issues

* don't rename details if the original name was "You"

* As the conversation goes on, dialogue instructions should be moved backwards further to have a weaker effect on immediate generations.

* add more context to character creation prompt

* fix select next talking actor when natural language flow is turned on and the LLM returns multiple character names

* prompt fixes for dialogue generation

* summarization fixes

* default to script format

* seperate dialogue prompt by formatting style, tweak conversation system prompt

* remove cruft

* add gen format to agent details

* relock

* relock

* prep 0.22.0

* add claude-3-haiku-20240307

* readme
2024-03-29 21:37:28 +02:00
veguAI
199ffd1095 Update README.md 2024-03-17 01:09:59 +02:00
veguAI
88b9fcb8bb Update README.md 2024-03-11 00:42:42 +02:00
vegu-ai-tools
2f5944bc09 remove unnecessary link 2024-03-10 18:05:33 +02:00
veguAI
abdfb1abbf WIP: Prep 0.21.0 (#83)
* cleanup

* refactor clean_dialogue

* prompt fixes

* prompt fixes

* conversation format types - movie script and chat (legacy)

* stopping strings updated

* mistral.ai client

* prompt tweaks

* mistral client return token counts

* anthropic client

* archive history emits whole object so we can inspectr time stamps

* show timestamp in history dialog

* openai compat fixes to stop trying to coerce openai url path schema and to never attempt to retrieve the model name automatically, hopefully improving compatibility with the various openai api implementations across the board

* openai compat client let api control prompt template via config option

* fix custom client configs and implement max backscroll

* fix backscroll limit

* remove debug message

* prep 0.21.0

* include model name in prompt template selection label

* use tabs for side nav in app config modal

* readme / docs

* fix issue where "No API key set" could be persisted as the selected model name to the config

* deepinfra example

* linting
2024-03-10 18:03:12 +02:00
veguAI
2f07248211 Prep 0.20.0 (#77)
* fix issue where recent save cover images would sometimes not load

* paraphrase prompt tweaks

* action_to_narration regenerate compatibility fixes

* sim suite add asnwer question instruction

* more sim suite tweaks

* refactor agent details display in agent bar

* visual agent progres (a1111 support)

* visual gen prompt tweaks

* openai compat client pass max_tokens

* world state sequential reinforcement max tokens tightened

* improve item names

* Improve item names

* attempt to remove "changed from.." notes when altering an existing character sheet

* prompt improvements for single character portraits

* visual agent progress

* fix issue where character.update wouldn't update long-term memory

* remove experimental flag for now

* add better instructions for updating existing character sheet

* background processing for agents, visual and tts

* fix selected voice not saving between restarts for elevenlabs

* lessen timeout

* clean up agent status logic

* conditional agent configs

* comfyui support

* visualization queue

* refactor visual styles, comfyui progress

* regen images
auto cover image assign
websocket handler plugin abstraction
agent websocket handler

* automatic1111 fixes
agent status and ready checks

* tweaks to character portrait prompt

* system prompt for visualize

* textgenwebui use temp smoothing on yi models

* comment out api key for now

* fixes issues with openai compat client for retaining api key and auto fixing urls

* update_reinforcment tweaks

* agent status emit from one place

* emit agent status as asyncio task

* remove debug output

* tts add openai support

* openai img gen support

* fix issue with confyui checkbox list not loading

* tts model selection for openai

* narrate_query include character sheet if character is referenced in query
improve visual character portrit generation prompt

* client implementation extra field support and runpod vllm client example

* relock

* fix issue where changing context length would cause next generation to error

* visual agent tweaks and auto gen character cover image in sim suite

* fix issue with readyness lock when there werent any clients defined

* load scene readiness fixes

* linting

* docs

* notes for the runpod vllm example
2024-02-16 13:57:45 +02:00
veguAI
9ae6fc822b Update README.md 2024-02-12 18:31:49 +02:00
veguAI
5094359c4e Update README.md 2024-02-10 23:07:30 +02:00
veguAI
28801b54bf Update README.md 2024-02-07 03:12:56 +02:00
veguAI
4d69f0e837 Update README.md 2024-02-06 09:15:55 +02:00
veguAI
d91b3f8042 Update README.md 2024-02-06 09:15:11 +02:00
veguAI
03a0ab2fcf Update README.md 2024-02-06 01:01:00 +02:00
veguAI
d860d62972 Update README.md 2024-02-06 01:00:35 +02:00
veguAI
add4893939 Prep 0.19.0 (#67)
* linting

* improve prompt devtools: test changes, show more information

* some more polish for the new promp devtools

* up default conversation gen length to 128

* openai client tweaks, talemate sets max_tokens on gpt-3.5 generations

* support new openai embeddings (and default to text-embedding-3-small)

* ux polish for character sheet and character state ux

* actor instructions

* experiment using # for context / instructions

* fix bug where regenerating history would mess up time stamps

* remove trailing ]

* prevent client ctx from being unset

* fix issue where sometimes you'd need to delete a client twice for it to disappear

* upgrade dependencies

* set 0.19.0

* fix performance degradation caused by circular loading animation

* remove coqui studio support

* fix issue when switching from unsaved creative mode to loading a scene

* third party client / agent support

* edit dialogue examples through character / actor editor

* remove "edit dialogue" action from editor - replaced by character actor instructions

* different icon for delete

* prompt adjustment for acting instructions

* adhoc context generation for character attributes and details

* add adhoc generation for character description

* contextual generation tweaks

* contextual generation for dialogue examples
fix some formatting issues

* contextual generation for world entries

* prepopulate initial recen scenarios with demo scenes
add experimental holodeck scenario

* scene info
scene experimental

* assortment of fixes for holodeck improvements

* more holodeck fixes

* refactor holodeck instructions

* rename holodeck to simulation suite

* better scene status messages

* add new gpt-3.5-turbo model, better json response coercion for older models

* allow exclusion of characters when persisting based on world state

* better error handling of world state response

* better error handling of world state response

* more simulation suite fixes

* progress color

* world state character name mapping support

* if neither quote nor asterisk is in message default to quotes

* fix rerun of new paraphrase op

* sim suite ping that ensure's characters are not aware of sim

* fixes for better character name assessment
simulation suite can now give the player character a proper name

* fix bug with new status notifications

* sim suite adjustments and fixes and tuning

* sim suite tweaks

* impl scene restore from file

* prompting tweaks for reinforcement messages and acting instructions

* more tweaks

* dialogue prompt tweaks for rerun + rewrite

* fix bug with character entry / exit with narration

* linting

* simsuite screenshots

* screenshots
2024-02-06 00:40:55 +02:00
veguAI
eb251d6e37 fix gpt-4 censorship triggered by system message (#74) 2024-02-01 12:15:30 +02:00
veguAI
4ba635497b Prep 0.18.1 (#72)
* prevent client ctx from being unset

* fix issue with LMStudio client ctx size not sticking

* 0.18.1
2024-01-31 09:46:51 +02:00
veguAI
bdbf14c1ed Update README.md 2024-01-31 01:47:52 +02:00
veguAI
c340fc085c Update README.md 2024-01-31 01:47:29 +02:00
veguAI
94f8d0f242 Update README.md 2024-01-31 01:00:59 +02:00
veguAI
1d8a9b113c Update README.md 2024-01-30 08:08:45 +02:00
vegu-ai-tools
1837796852 readme 2024-01-26 14:41:59 +02:00
vegu-ai-tools
c5c53c056e readme updates 2024-01-26 13:29:21 +02:00
veguAI
f1b1190f0b linting (#63) 2024-01-26 12:46:55 +02:00
257 changed files with 21276 additions and 8141 deletions

1
.gitignore vendored
View File

@@ -16,3 +16,4 @@ scenes/
!scenes/infinity-quest-dynamic-scenario/infinity-quest.json
!scenes/infinity-quest/assets/
!scenes/infinity-quest/infinity-quest.json
tts_voice_samples/*.wav

25
Dockerfile.backend Normal file
View File

@@ -0,0 +1,25 @@
# Use an official Python runtime as a parent image
FROM python:3.11-slim
# Set the working directory in the container
WORKDIR /app
# Copy the current directory contents into the container at /app
COPY ./src /app/src
# Copy poetry files
COPY pyproject.toml /app/
# If there's a poetry lock file, include the following line
COPY poetry.lock /app/
# Install poetry
RUN pip install poetry
# Install dependencies
RUN poetry install --no-dev
# Make port 5050 available to the world outside this container
EXPOSE 5050
# Run backend server
CMD ["poetry", "run", "python", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050"]

17
Dockerfile.frontend Normal file
View File

@@ -0,0 +1,17 @@
# Use an official node runtime as a parent image
FROM node:20
# Set the working directory in the container
WORKDIR /app
# Copy the frontend directory contents into the container at /app
COPY ./talemate_frontend /app
# Install any needed packages specified in package.json
RUN npm install
# Make port 8080 available to the world outside this container
EXPOSE 8080
# Run frontend server
CMD ["npm", "run", "serve"]

271
README.md
View File

@@ -1,70 +1,66 @@
# Talemate
Allows you to play roleplay scenarios with large language models.
Roleplay with AI with a focus on strong narration and consistent world and game state tracking.
|![Screenshot 1](docs/img/0.17.0/ss-1.png)|![Screenshot 2](docs/img/0.17.0/ss-2.png)|
|![Screenshot 3](docs/img/0.17.0/ss-1.png)|![Screenshot 3](docs/img/0.17.0/ss-2.png)|
|------------------------------------------|------------------------------------------|
|![Screenshot 1](docs/img/0.17.0/ss-4.png)|![Screenshot 2](docs/img/0.17.0/ss-3.png)|
|![Screenshot 4](docs/img/0.17.0/ss-4.png)|![Screenshot 1](docs/img/0.19.0/Screenshot_15.png)|
|![Screenshot 2](docs/img/0.19.0/Screenshot_16.png)|![Screenshot 3](docs/img/0.19.0/Screenshot_17.png)|
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio.**
Supported APIs:
- [OpenAI](https://platform.openai.com/overview)
- [Anthropic](https://www.anthropic.com/)
- [mistral.ai](https://mistral.ai/)
- [Cohere](https://www.cohere.com/)
- [Groq](https://www.groq.com/)
- [Google Gemini](https://console.cloud.google.com/)
This means you need to either have:
- an [OpenAI](https://platform.openai.com/overview) api key
- OR setup local (or remote via runpod) LLM inference via one of these options:
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
- [LMStudio](https://lmstudio.ai/)
Supported self-hosted APIs:
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
- [LMStudio](https://lmstudio.ai/)
## Current features
Generic OpenAI api implementations (tested and confirmed working):
- [DeepInfra](https://deepinfra.com/)
- [llamacpp](https://github.com/ggerganov/llama.cpp) with the `api_like_OAI.py` wrapper
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
- responive modern ui
- agents
- conversation: handles character dialogue
- narration: handles narrative exposition
- summarization: handles summarization to compress context while maintain history
- director: can be used to direct the story / characters
- editor: improves AI responses (very hit and miss at the moment)
- world state: generates world snapshot and handles passage of time (objects and characters)
- creator: character / scenario creator
- tts: text to speech via elevenlabs, coqui studio, coqui local
- multi-client support (agents can be connected to separate APIs)
- long term memory
- chromadb integration
- passage of time
- narrative world state
- Automatically keep track and reinforce selected character and world truths / states.
- narrative tools
- creative tools
- AI backed character creation with template support (jinja2)
- AI backed scenario creation
- context managegement
- Manage character details and attributes
- Manage world information / past events
- Pin important information to the context (Manually or conditionally through AI)
- runpod integration
- overridable templates for all prompts. (jinja2)
## Core Features
## Planned features
- Multiple AI agents for dialogue, narration, summarization, direction, editing, world state management, character/scenario creation, text-to-speech, and visual generation
- Support for multiple AI clients and APIs
- Long-term memory using ChromaDB and passage of time tracking
- Narrative world state management to reinforce character and world truths
- Creative tools for managing NPCs, AI-assisted character, and scenario creation with template support
- Context management for character details, world information, past events, and pinned information
- Integration with Runpod
- Customizable templates for all prompts using Jinja2
- Modern, responsive UI
Kinda making it up as i go along, but i want to lean more into gameplay through AI, keeping track of gamestates, moving away from simply roleplaying towards a more game-ified experience.
# Instructions
In no particular order:
Please read the documents in the `docs` folder for more advanced configuration and usage.
- Extension support
- modular agents and clients
- Improved world state
- Dynamic player choice generation
- Better creative tools
- node based scenario / character creation
- Improved and consistent long term memory and accurate current state of the world
- Improved director agent
- Right now this doesn't really work well on anything but GPT-4 (and even there it's debatable). It tends to steer the story in a way that introduces pacing issues. It needs a model that is creative but also reasons really well i think.
- Gameplay loop governed by AI
- objectives
- quests
- win / lose conditions
- stable-diffusion client for in place visual generation
- [Quickstart](#quickstart)
- [Installation](#installation)
- [Windows](#windows)
- [Linux](#linux)
- [Docker](#docker)
- [Connecting to an LLM](#connecting-to-an-llm)
- [OpenAI / mistral.ai / Anthropic](#openai--mistralai--anthropic)
- [Text-generation-webui / LMStudio](#text-generation-webui--lmstudio)
- [Specifying the correct prompt template](#specifying-the-correct-prompt-template)
- [Recommended Models](#recommended-models)
- [DeepInfra via OpenAI Compatible client](#deepinfra-via-openai-compatible-client)
- [Google Gemini](#google-gemini)
- [Google Cloud Setup](#google-cloud-setup)
- [Ready to go](#ready-to-go)
- [Load the introductory scenario "Infinity Quest"](#load-the-introductory-scenario-infinity-quest)
- [Loading character cards](#loading-character-cards)
- [Text-to-Speech (TTS)](docs/tts.md)
- [Visual Generation](docs/visual.md)
- [ChromaDB (long term memory) configuration](docs/chromadb.md)
- [Runpod Integration](docs/runpod.md)
- [Prompt template overrides](docs/templates.md)
# Quickstart
@@ -77,7 +73,7 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
### Windows
1. Download and install Python 3.10 or Python 3.11 from the [official Python website](https://www.python.org/downloads/windows/). :warning: python3.12 is currently not supported.
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm.
1. Download and install Node.js v20 from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm. :warning: v21 is currently not supported.
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases).
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine.
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`.
@@ -87,70 +83,153 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
`python 3.10` or `python 3.11` is required. :warning: `python 3.12` not supported yet.
1. `git clone git@github.com:vegu-ai/talemate`
`nodejs v19 or v20` :warning: `v21` not supported yet.
1. `git clone https://github.com/vegu-ai/talemate.git`
1. `cd talemate`
1. `source install.sh`
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
## Configuration
### Docker
### OpenAI
1. `git clone https://github.com/vegu-ai/talemate.git`
1. `cd talemate`
1. `docker-compose up`
1. Navigate your browser to http://localhost:8080
To set your openai api key, open `config.yaml` in any text editor and uncomment / add
:warning: When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.
```yaml
openai:
api_key: sk-my-api-key-goes-here
```
#### To shut down the Docker container
You will need to restart the backend for this change to take effect.
Just closing the terminal window will not stop the Docker container. You need to run `docker-compose down` to stop the container.
### RunPod
#### How to install Docker
To set your runpod api key, open `config.yaml` in any text editor and uncomment / add
1. Download and install Docker Desktop from the [official Docker website](https://www.docker.com/products/docker-desktop).
```yaml
runpod:
api_key: my-api-key-goes-here
```
You will need to restart the backend for this change to take effect.
Once the api key is set Pods loaded from text-generation-webui templates (or the bloke's runpod llm template) will be autoamtically added to your client list in talemate.
**ATTENTION**: Talemate is not a suitable for way for you to determine whether your pod is currently running or not. **Always** check the runpod dashboard to see if your pod is running or not.
## Recommended Models
(as of2023.10.25)
Any of the top models in any of the size classes here should work well:
https://www.reddit.com/r/LocalLLaMA/comments/17fhp9k/huge_llm_comparisontest_39_models_tested_7b70b/
## Connecting to an LLM
# Connecting to an LLM
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
![Client options](docs/img/client-options-toggle.png)
### Text-generation-webui
![No clients](docs/img/0.21.0/no-clients.png)
## OpenAI / mistral.ai / Anthropic
The setup is the same for all three, the example below is for OpenAI.
If you want to add an OpenAI client, just change the client type and select the apropriate model.
![Add client modal](docs/img/0.21.0/openai-setup.png)
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
![OpenAI API Key missing](docs/img/0.18.0/openai-api-key-1.png)
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
![OpenAI API Key missing](docs/img/0.21.0/openai-add-api-key.png)
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
![OpenAI API Key set](docs/img/0.18.0/openai-api-key-3.png)
## Text-generation-webui / LMStudio
> :warning: As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
In the modal if you're planning to connect to text-generation-webui, you can likely leave everything as is and just click Save.
![Add client modal](docs/img/client-setup-0.13.png)
![Add client modal](docs/img/0.21.0/text-gen-webui-setup.png)
### OpenAI
### Specifying the correct prompt template
If you want to add an OpenAI client, just change the client type and select the apropriate model.
For good results it is **vital** that the correct prompt template is specified for whichever model you have loaded.
![Add client modal](docs/img/add-client-modal-openai.png)
Talemate does come with a set of pre-defined templates for some popular models, but going forward, due to the sheet number of models released every day, understanding and specifying the correct prompt template is something you should familiarize yourself with.
### Ready to go
If the text-gen-webui client shows a yellow triangle next to it, it means that the prompt template is not set, and it is currently using the default `VICUNA` style prompt template.
![Default prompt template](docs/img/0.21.0/prompt-template-default.png)
Click the two cogwheels to the right of the triangle to open the client settings.
![Client settings](docs/img/0.21.0/select-prompt-template.png)
You can first try by clicking the `DETERMINE VIA HUGGINGFACE` button, depending on the model's README file, it may be able to determine the correct prompt template for you. (basically the readme needs to contain an example of the template)
If that doesn't work, you can manually select the prompt template from the dropdown.
In the case for `bartowski_Nous-Hermes-2-Mistral-7B-DPO-exl2_8_0` that is `ChatML` - select it from the dropdown and click `Save`.
![Client settings](docs/img/0.21.0/selected-prompt-template.png)
### Recommended Models
As of 2024.05.06 my personal regular drivers (the ones i test with) are:
- meta-llama_Meta-Llama-3-8B-Instruct
- brucethemoose_Yi-34B-200K-RPMerge
- rAIfle_Verdict-8x7B
- meta-llama_Meta-Llama-3-70B-Instruct
That said, any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
[https://oobabooga.github.io/benchmark.html](https://oobabooga.github.io/benchmark.html)
## DeepInfra via OpenAI Compatible client
You can use the OpenAI compatible client to connect to [DeepInfra](https://deepinfra.com/).
![DeepInfra](docs/img/0.21.0/deepinfra-setup.png)
```
API URL: https://api.deepinfra.com/v1/openai
```
Models on DeepInfra that work well with Talemate:
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://deepinfra.com/mistralai/Mixtral-8x7B-Instruct-v0.1) (max context 32k, 8k recommended)
- [cognitivecomputations/dolphin-2.6-mixtral-8x7b](https://deepinfra.com/cognitivecomputations/dolphin-2.6-mixtral-8x7b) (max context 32k, 8k recommended)
- [lizpreciatior/lzlv_70b_fp16_hf](https://deepinfra.com/lizpreciatior/lzlv_70b_fp16_hf) (max context 4k)
## Google Gemini
### Google Cloud Setup
Unlike the other clients the setup for Google Gemini is a bit more involved as you will need to set up a google cloud project and credentials for it.
Please follow their [instructions for setup](https://cloud.google.com/vertex-ai/docs/start/client-libraries) - which includes setting up a project, enabling the Vertex AI API, creating a service account, and downloading the credentials.
Once you have downloaded the credentials, copy the JSON file into the talemate directory. You can rename it to something that's easier to remember, like `my-credentials.json`.
### Add the client
![Google Gemini](docs/img/0.25.0/google-add-client.png)
The `Disable Safety Settings` option will turn off the google reponse validation for what they consider harmful content. Use at your own risk.
### Conmplete the google cloud setup in talemate
![Google Gemini](docs/img/0.25.0/google-setup-incomplete.png)
Click the `SETUP GOOGLE API CREDENTIALS` button that will appear on the client.
The google cloud setup modal will appear, fill in the path to the credentials file and select a location that is close to you.
![Google Gemini](docs/img/0.25.0/google-cloud-setup.png)
Click save and after a moment the client should have a green dot next to it, indicating that it is ready to go.
![Google Gemini](docs/img/0.25.0/google-ready.png)
## Ready to go
You will know you are good to go when the client and all the agents have a green dot next to them.
![Ready to go](docs/img/client-setup-complete.png)
![Ready to go](docs/img/0.21.0/ready-to-go.png)
## Load the introductory scenario "Infinity Quest"
@@ -171,13 +250,3 @@ Expand the "Load" menu in the top left corner and either click on "Upload a char
Once a character is uploaded, talemate may actually take a moment because it needs to convert it to a talemate format and will also run additional LLM prompts to generate character attributes and world state.
Make sure you save the scene after the character is loaded as it can then be loaded as normal talemate scenario in the future.
## Further documentation
Please read the documents in the `docs` folder for more advanced configuration and usage.
- [Prompt template overrides](docs/templates.md)
- [Text-to-Speech (TTS)](docs/tts.md)
- [ChromaDB (long term memory)](docs/chromadb.md)
- [Runpod Integration](docs/runpod.md)
- Creative mode

View File

@@ -48,6 +48,7 @@ game:
# embeddings: instructor
# instructor_device: cuda
# instructor_model: hkunlp/instructor-xl
# openai_model: text-embedding-3-small
## Remote LLMs

27
docker-compose.yml Normal file
View File

@@ -0,0 +1,27 @@
version: '3.8'
services:
talemate-backend:
build:
context: .
dockerfile: Dockerfile.backend
ports:
- "5050:5050"
volumes:
# can uncomment for dev purposes
#- ./src/talemate:/app/src/talemate
- ./config.yaml:/app/config.yaml
- ./scenes:/app/scenes
- ./templates:/app/templates
- ./chroma:/app/chroma
environment:
- PYTHONUNBUFFERED=1
talemate-frontend:
build:
context: .
dockerfile: Dockerfile.frontend
ports:
- "8080:8080"
volumes:
- ./talemate_frontend:/app

View File

@@ -56,6 +56,7 @@ Then add the following to `config.yaml` for chromadb:
```yaml
chromadb:
embeddings: openai
openai_model: text-embedding-3-small
```
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. ALSO - if you send super explicit content they may flag / ban your key, so keep that in mind (i hear they usually send warnings first though), and always monitor your usage on their dashboard.
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. Always monitor your usage on their dashboard.

View File

@@ -0,0 +1,48 @@
from talemate.agents.base import Agent, AgentAction
from talemate.agents.registry import register
from talemate.events import GameLoopEvent
import talemate.emit.async_signals
from talemate.emit import emit
@register()
class TestAgent(Agent):
agent_type = "test"
verbose_name = "Test"
def __init__(self, client):
self.client = client
self.is_enabled = True
self.actions = {
"test": AgentAction(
enabled=True,
label="Test",
description="Test",
),
}
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called on the beginning of every game loop
"""
if not self.enabled:
return
emit("status", status="info", message="Annoying you with a test message every game loop.")

View File

@@ -0,0 +1,130 @@
"""
An attempt to write a client against the runpod serverless vllm worker.
This is close to functional, but since runpod serverless gpu availability is currently terrible, i have
been unable to properly test it.
Putting it here for now since i think it makes a decent example of how to write a client against a new service.
"""
import pydantic
import structlog
import runpod
import asyncio
import aiohttp
from talemate.client.base import ClientBase, ExtraField
from talemate.client.registry import register
from talemate.emit import emit
from talemate.config import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.runpod_vllm")
class Defaults(pydantic.BaseModel):
max_token_length: int = 4096
model: str = ""
runpod_id: str = ""
class ClientConfig(BaseClientConfig):
runpod_id: str = ""
@register()
class RunPodVLLMClient(ClientBase):
client_type = "runpod_vllm"
conversation_retries = 5
config_cls = ClientConfig
class Meta(ClientBase.Meta):
title: str = "Runpod VLLM"
name_prefix: str = "Runpod VLLM"
enable_api_auth: bool = True
manual_model: bool = True
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = {
"runpod_id": ExtraField(
name="runpod_id",
type="text",
label="Runpod ID",
required=True,
description="The Runpod ID to connect to.",
)
}
def __init__(self, model=None, runpod_id=None, **kwargs):
self.model_name = model
self.runpod_id = runpod_id
super().__init__(**kwargs)
@property
def experimental(self):
return False
def set_client(self, **kwargs):
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
return self.model_name
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
prompt = prompt.strip()
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
async with aiohttp.ClientSession() as session:
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
run_request = await endpoint.run({
"input": {
"prompt": prompt,
}
#"parameters": parameters
})
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
status = await run_request.status()
log.debug("generate", status=status)
await asyncio.sleep(0.1)
status = await run_request.status()
log.debug("generate", status=status)
response = await run_request.output()
log.debug("generate", response=response)
return response["choices"][0]["tokens"][0]
except Exception as e:
self.log.error("generate error", e=e)
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
if "runpod_id" in kwargs:
self.api_auth = kwargs["runpod_id"]
log.warning("reconfigure", kwargs=kwargs)
self.set_client(**kwargs)

View File

@@ -0,0 +1,67 @@
import pydantic
from openai import AsyncOpenAI
from talemate.client.base import ClientBase
from talemate.client.registry import register
class Defaults(pydantic.BaseModel):
api_url: str = "http://localhost:1234"
max_token_length: int = 4096
@register()
class TestClient(ClientBase):
client_type = "test"
class Meta(ClientBase.Meta):
name_prefix: str = "test"
title: str = "Test"
defaults: Defaults = Defaults()
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters: dict, kind: str):
"""
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
This method is called before the prompt is sent to the client, and it allows the client to remove
any parameters that it doesn't support.
"""
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
"""
This should return the name of the model that is being used.
"""
return "Mock test model"
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except Exception as e:
self.log.error("generate error", e=e)
return ""

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 418 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 413 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 364 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 933 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.5 KiB

View File

@@ -17,21 +17,6 @@ elevenlabs:
api_key: <YOUR_ELEVENLABS_API_KEY>
```
## Configuring Coqui TTS
To use Coqui TTS with Talemate, follow these steps:
1. Visit [Coqui](https://app.coqui.ai) and sign up for an account.
2. Go to the [account page](https://app.coqui.ai/account) and scroll to the bottom to find your API key.
3. In the `config.yaml` file, under the `coqui` section, set the `api_key` field with your Coqui API key.
Example configuration snippet:
```yaml
coqui:
api_key: <YOUR_COQUI_API_KEY>
```
## Configuring Local TTS API
For running a local TTS API, Talemate requires specific dependencies to be installed.

117
docs/visual.md Normal file
View File

@@ -0,0 +1,117 @@
# Visual Agent
The visual agent currently allows for some bare bones visual generation using various stable-diffusion APIs. This is early development and experimental.
Its important to note that the visualization agent actually specifies two clients. One is the backend for the visual generation, and the other is the text generation client to use for prompt generation.
The client for prompt generation can be assigned to the agent as you would for any other agent. The client for visual generation is assigned in the Visualizer config.
## Index
- [OpenAI](#openai)
- [AUTOMATIC1111](#automatic1111)
- [ComfyUI](#comfyui)
- [How to use](#how-to-use)
## OpenAI
Most straightforward to use, as it runs on the OpenAI API. You will need to have an API key and set it in the application config.
![Set OpenAI Api Key](img/0.18.0/openai-api-key-2.png)
Then open the Visualizer config by clicking the agent's name in the agent list and choose `OpenAI` as the backend.
![OpenAI Visualizer Config](img/0.20.0/visual-config-openai.png)
Note: `Client` here refers to the text-generation client to use for prompt generation. While `Backend` refers to the visual generation backend. You are **NOT** required to use the OpenAI client for prompt generation even if you are using the OpenAI backend for image generation.
## AUTOMATIC1111
This requires you to setup a local instance of the AUTOMATIC1111 API. Follow the instructions from their [GitHub](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to get it running.
Once you have it running, you will want to adjust the `webui-user.bat` in the AUTOMATIC1111 directory to include the following command arguments:
```bat
set COMMANDLINE_ARGS=--api --listen --port 7861
```
Then run the `webui-user.bat` to start the API.
Once your AUTOAMTIC1111 API is running (check with your browser) you can set the Visualizer config to use the `AUTOMATIC1111` backend
![AUTOMATIC1111 Visualizer Config](img/0.20.0/visual-config-a1111.png)
#### Extra Configuration
- `api url`: the url of the API, usually `http://localhost:7861`
- `steps`: render steps
- `model type`: sdxl or sd1.5 - this will dictate the resolution of the image generation and actually matters for the quality so make sure this is set to the correct model type for the model you are using.
## ComfyUI
This requires you to setup a local instance of the ComfyUI API. Follow the instructions from their [GitHub](https://github.com/comfyanonymous/ComfyUI) to get it running.
Once you're setup, copy their `start.bat` file to a new `start-listen.bat` file and change the contents to.
```bat
call venv\Scripts\activate
call python main.py --port 8188 --listen 0.0.0.0
```
Then run the `start-listen.bat` to start the API.
Once your ComfyUI API is running (check with your browser) you can set the Visualizer config to use the `ComfyUI` backend.
![ComfyUI Visualizer Config](img/0.20.0/visual-config-comfyui.png)
### Extra Configuration
- `api url`: the url of the API, usually `http://localhost:8188`
- `workflow`: the workflow file to use. This is a comfyui api workflow file that needs to exist in `./templates/comfyui-workflows` inside the talemate directory. Talemate provides two very barebones workflows with `default-sdxl.json` and `default-sd15.json`. You can create your own workflows and place them in this directory to use them. :warning: The workflow file must be generated using the API Workflow export not the UI export. Please refer to their documentation for more information.
- `checkpoint`: the model to use - this will load a list of all available models in your comfyui instance. Select which one you want to use for the image generation.
### Custom Workflows
When creating custom workflows for ideal compatibility with Talemate, ensure the following.
- A `CheckpointLoaderSimple` node named `Talemate Load Checkpoint`
- A `EmptyLatentImage` node name `Talemate Resolution`
- A `ClipTextEncode` node named `Talemate Positive Prompt`
- A `ClipTextEncode` node named `Talemate Negative Prompt`
- A `SaveImage` node at the end of the workflow.
![ComfyUI Base workflow example](img/0.20.0/comfyui-base-workflow.png)
## How to use
Once you're done setting up the visualizer agent should have a green dot next to it and display both the selected image generation backend and the selected prompt generation client.
![Visualizer ready](img/0.20.0/visualizer-ready.png)
Your hotbar should then also enable the visualization menu for you to use (once you have a scene loaded).
![Visualization actions](img/0.20.0/visualize-scene-tools.png)
Right now you can generate a portrait for any NPC in the scene or a background image for the scene itself.
Image generation by default will actually happen in the background, allowing you to continue using Talemate while the image is being generated.
You can tell if an image is being generated by the blueish spinner next to the visualization agent.
![Visualizer busy](img/0.20.0/visualizer-busy.png)
Once the image is generated, it will be avaible for you to view via the visual queue button on top of the screen.
![Images ready](img/0.20.0/visualze-new-images.png)
Click it to open the visual queue and view the generated images.
![alt text](img/0.20.0/visual-queue.png)
### Character Portrait
For character potraits you can chose whether or not to replace the main portrait for the character (the one being displated in the left sidebar when a talemate scene is active).
### Background Image
Right now there is nothing to do with the background image, other than to view it in the visual queue. More functionality will be added in the future.

View File

@@ -1,7 +1,7 @@
#!/bin/bash
# create a virtual environment
python -m venv talemate_env
python3 -m venv talemate_env
# activate the virtual environment
source talemate_env/bin/activate

3931
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,13 +4,13 @@ build-backend = "poetry.masonry.api"
[tool.poetry]
name = "talemate"
version = "0.18.0"
version = "0.25.1"
description = "AI-backed roleplay and narrative tools"
authors = ["FinalWombat"]
license = "GNU Affero General Public License v3.0"
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
python = ">=3.10,<3.12"
astroid = "^2.8"
jedi = "^0.18"
black = "*"
@@ -18,9 +18,13 @@ rope = "^0.22"
isort = "^5.10"
jinja2 = "^3.0"
openai = ">=1"
mistralai = ">=0.1.8"
cohere = ">=5.2.2"
anthropic = ">=0.19.1"
groq = ">=0.5.0"
requests = "^2.26"
colorama = ">=0.4.6"
Pillow = "^9.5"
Pillow = ">=9.5"
httpx = "<1"
piexif = "^1.1"
typing-inspect = "0.8.0"
@@ -33,17 +37,20 @@ python-dotenv = "^1.0.0"
websockets = "^11.0.3"
structlog = "^23.1.0"
runpod = "^1.2.0"
google-cloud-aiplatform = ">=1.50.0"
nest_asyncio = "^1.5.7"
isodate = ">=0.6.1"
thefuzz = ">=0.20.0"
tiktoken = ">=0.5.1"
nltk = ">=3.8.1"
huggingface-hub = ">=0.20.2"
RestrictedPython = ">7.1"
# ChromaDB
chromadb = ">=0.4.17,<1"
InstructorEmbedding = "^1.0.1"
torch = ">=2.1.0"
torchaudio = ">=2.3.0"
sentence-transformers="^2.2.2"
[tool.poetry.dev-dependencies]

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@@ -98,6 +98,7 @@
}
],
"immutable_save": true,
"experimental": true,
"goal": null,
"goals": [],
"context": "an epic sci-fi adventure aimed at an adult audience.",
@@ -109,10 +110,10 @@
"variables": {}
},
"assets": {
"cover_image": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
"cover_image": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
"assets": {
"52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df": {
"id": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
"e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404": {
"id": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
"file_type": "png",
"media_type": "image/png"
}

View File

@@ -5,7 +5,7 @@
{%- set _ = emit_system("warning", "This is a dynamic scenario generation experiment for Infinity Quest. It will likely require a strong LLM to generate something coherent. GPT-4 or 34B+ if local. Temper your expectations.") -%}
{#- emit status update to the UX -#}
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]") -%}
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]", as_scene_message=True) -%}
{#- thematic tags will be used to randomize generation -#}
{%- set tags = thematic_generator.generate("color", "state_of_matter", "scifi_trope") -%}
@@ -17,17 +17,17 @@
{#- generate introductory text -#}
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]") -%}
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]", as_scene_message=True) -%}
{%- set tmpl__scenario_intro = render_template('generate-scenario-intro', premise=instr__premise) %}
{%- set instr__intro = "*"+render_and_request(tmpl__scenario_intro)+"*" -%}
{#- generate win conditions -#}
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]") -%}
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]", as_scene_message=True) -%}
{%- set tmpl__win_conditions = render_template('generate-win-conditions', premise=instr__premise) %}
{%- set instr__win_conditions = render_and_request(tmpl__win_conditions) -%}
{#- emit status update to the UX -#}
{%- set status = emit_status("info", "Scenario ready.") -%}
{%- set status = emit_status("success", "Scenario ready.", as_scene_message=True) -%}
{# set gamestate variables #}
{%- set _ = game_state.set_var("instr.premise", instr__premise, commit=True) -%}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

@@ -0,0 +1,535 @@
def game(TM):
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice."
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
AUTO_NARRATE_INTERVAL = 10
def parse_sim_call_arguments(call:str) -> str:
"""
Returns the value between the parentheses of a simulation call
Example:
call = 'change_environment("a house")'
parse_sim_call_arguments(call) -> "a house"
"""
try:
return call.split("(", 1)[1].split(")")[0]
except Exception:
return ""
class SimulationSuite:
def __init__(self):
# do we update the world state at the end of the round
self.update_world_state = False
self.simulation_reset = False
self.added_npcs = []
TM.log.debug("SIMULATION SUITE INIT...")
self.player_character = TM.scene.get_player_character()
self.player_message = TM.scene.last_player_message()
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
self.player_message_is_instruction = (
self.player_message and
self.player_message.raw.lower().startswith("computer") and
not self.player_message.hidden and
not self.last_processed_call > self.player_message.id
)
def run(self):
if not TM.game_state.has_var("instr.simulation_stopped"):
self.simulation()
self.finalize_round()
def simulation(self):
if not TM.game_state.has_var("instr.simulation_started"):
self.startup()
else:
self.simulation_calls()
if self.update_world_state:
self.run_update_world_state(force=True)
def startup(self):
TM.emit_status("busy", "Simulation suite powering up.", as_scene_message=True)
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=PROMPT_STARTUP,
emit_message=False
)
TM.agents.narrator.action_to_narration(
action_name="passthrough",
narration=MSG_HELP
)
TM.agents.world_state.manager(
action_name="save_world_entry",
entry_id="sim.quarantined",
text=CTX_PIN_UNAWARE,
meta={},
pin=True
)
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
TM.emit_status("success", "Simulation suite ready", as_scene_message=True)
self.update_world_state = True
def simulation_calls(self):
"""
Calls the simulation suite main prompt to determine the appropriate
simulation calls
"""
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
return
# First instruction?
if not TM.game_state.has_var("instr.has_issued_instructions"):
# determine the context of the simulation
context_context = TM.agents.creator.determine_content_context_for_description(
description=self.player_message.raw,
)
TM.scene.set_content_context(context_context)
calls = TM.client.render_and_request(
"computer",
dedupe_enabled=False,
player_instruction=self.player_message.raw,
scene=TM.scene,
)
self.calls = calls = calls.split("\n")
calls = self.prepare_calls(calls)
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
# calls that are processed
processed = []
for call in calls:
processed_call = self.process_call(call)
if processed_call:
processed.append(processed_call)
if processed:
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
TM.emit_status("busy", "Simulation suite altering environment.", as_scene_message=True)
compiled = "\n".join(processed)
if not self.simulation_reset and compiled:
narration = TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
emit_message=True
)
# on the first narration we update the scene description and remove any mention of the computer
# or the simulation from the previous narration
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
if not is_initial_narration:
TM.scene.set_description(str(narration))
TM.scene.set_intro(str(narration))
TM.log.debug("SIMULATION SUITE: initial narration", intro=str(narration))
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
TM.scene.pop_history(typ="director", all=True, reverse=True)
TM.game_state.set_var("instr.intro_narration", True, commit=False)
self.update_world_state = True
self.set_simulation_title(compiled)
def set_simulation_title(self, compiled_calls):
"""
Generates a fitting title for the simulation based on the user's instructions
"""
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
if not compiled_calls:
return
if TM.scene.title != "Simulation Suite":
# name already changed, no need to do it again
return
title = TM.agents.creator.contextual_generate_from_args(
"scene:simulation title",
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
length=75
)
title = title.strip('"').strip()
TM.scene.set_title(title)
def prepare_calls(self, calls):
"""
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
found, ensure that the `set_player_name` call is processed first by moving it in front of the
`set_player_persona` call.
"""
set_player_name_call_exists = -1
set_player_persona_call_exists = -1
i = 0
for call in calls:
if "set_player_name" in call:
set_player_name_call_exists = i
elif "set_player_persona" in call:
set_player_persona_call_exists = i
i = i + 1
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
if set_player_name_call_exists > set_player_persona_call_exists:
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
return calls
def process_call(self, call:str) -> str:
"""
Processes a simulation call
Simulation alls are pseudo functions that are called by the simulation suite
We grab the function name by splitting against ( and taking the first element
if the SimulationSuite has a method with the name _call_{function_name} then we call it
if a function name could be found but we do not have a method to call we dont do anything
but we still return it as procssed as the AI can still interpret it as something later on
"""
if "(" not in call:
return None
function_name = call.split("(")[0]
if hasattr(self, f"call_{function_name}"):
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
inject = f"The computer executes the function `{call}`"
return getattr(self, f"call_{function_name}")(call, inject)
return call
def call_set_simulation_goal(self, call:str, inject:str) -> str:
"""
Set's the simulation goal as a permanent pin
"""
TM.emit_status("busy", "Simulation suite setting goal.", as_scene_message=True)
TM.agents.world_state.manager(
action_name="save_world_entry",
entry_id="sim.goal",
text=self.player_message.raw,
meta={},
pin=True
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer sets the goal for the simulation.",
)
return call
def call_change_environment(self, call:str, inject:str) -> str:
"""
Simulation changes the environment, this is entirely interpreted by the AI
and we dont need to do any logic on our end, so we just return the call
"""
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer changes the environment of the simulation."
)
return call
def call_answer_question(self, call:str, inject:str) -> str:
"""
The player asked the simulation a query, we need to process this and have
the AI produce an answer
"""
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
emit_message=True
)
def call_set_player_persona(self, call:str, inject:str) -> str:
"""
The simulation suite is altering the player persona
"""
TM.emit_status("busy", "Simulation suite altering user persona.", as_scene_message=True)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=self.player_character.name, text=inject, alteration_instructions=self.player_message.raw
)
self.player_character.update(base_attributes=character_attributes)
character_description = TM.agents.creator.determine_character_description(character=self.player_character)
self.player_character.update(description=character_description)
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer transforms the player persona."
)
return call
def call_set_player_name(self, call:str, inject:str) -> str:
"""
The simulation suite is altering the player name
"""
TM.emit_status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
if character_name != self.player_character.name:
self.player_character.rename(character_name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer changes the player's identity to {character_name}."
)
return call
def call_add_ai_character(self, call:str, inject:str) -> str:
# sometimes the AI will call this function an pass an inanimate object as the parameter
# we need to determine if this is the case and just ignore it
is_inanimate = TM.client.query_text_eval(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
if is_inanimate:
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
return
# sometimes the AI will ask if the function adds a group of characters, we need to
# determine if this is the case
adds_group = TM.client.query_text_eval(f"does the function `{call}` add MULTIPLE ai characters?", call)
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
TM.emit_status("busy", "Simulation suite adding character.", as_scene_message=True)
if not adds_group:
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
else:
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
# the same character, if this happens we need to combine into a single add_ai_character call
has_change_ai_character_call = TM.client.query_text_eval(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
if has_change_ai_character_call:
combined_arg = TM.client.render_and_request(
"combine-add-and-alter-ai-character",
dedupe_enabled=False,
calls="\n".join(self.calls),
character_name=character_name,
scene=TM.scene,
).replace("COMBINED ARGUMENT:", "").strip()
call = f"add_ai_character({combined_arg})"
inject = f"The computer executes the function `{call}`"
TM.emit_status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
npc = TM.agents.director.persist_character(name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
self.added_npcs.append(npc.name)
TM.agents.world_state.manager(
action_name="add_detail_reinforcement",
character_name=npc.name,
question="Goal",
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
interval=25,
run_immediately=True
)
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
TM.agents.visual.generate_character_portrait(character_name=npc.name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer adds {npc.name} to the simulation."
)
return call
def call_remove_ai_character(self, call:str, inject:str) -> str:
TM.emit_status("busy", "Simulation suite removing character.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names())
npc = TM.scene.get_character(character_name)
if npc:
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
TM.agents.world_state.manager(action_name="deactivate_character", character_name=npc.name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer removes {npc.name} from the simulation."
)
return call
def call_change_ai_character(self, call:str, inject:str) -> str:
TM.emit_status("busy", "Simulation suite altering character.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names())
if character_name in self.added_npcs:
# we dont want to change the character if it was just added
return
character_name_after = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
npc = TM.scene.get_character(character_name)
if npc:
TM.emit_status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
character_attributes = TM.agents.world_state.extract_character_sheet(name=npc.name, alteration_instructions=self.player_message.raw)
npc.update(base_attributes=character_attributes)
character_description = TM.agents.creator.determine_character_description(character=npc)
npc.update(description=character_description)
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
if character_name_after != character_name:
npc.rename(character_name_after)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer transforms {npc.name}."
)
return call
def call_end_simulation(self, call:str, inject:str) -> str:
explicit_command = TM.client.query_text_eval("has the player explicitly asked to end the simulation?", self.player_message.raw)
if explicit_command:
TM.emit_status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names())}. The player is also transformed back to their normal, non-descript persona as the form of {self.player_character.name} ceases to exist.",
emit_message=True
)
TM.scene.restore()
self.simulation_reset = True
TM.game_state.unset_var("instr.has_issued_instructions")
TM.game_state.unset_var("instr.lastprocessed_call")
TM.game_state.unset_var("instr.simulation_started")
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer ends the simulation."
)
def finalize_round(self):
# track rounds
rounds = TM.game_state.get_var("instr.rounds", 0)
# increase rounds
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
if self.update_world_state:
self.run_update_world_state()
if self.player_message_is_instruction:
self.player_message.hide()
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
TM.emit_status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
elif self.player_message and not has_issued_instructions:
# simulation started, player message is NOT an instruction, and player has not given
# any instructions
self.guide_player()
elif self.player_message and not TM.scene.npc_character_names():
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
self.narrate_round()
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names() and has_issued_instructions:
# every N rounds, narrate the round
self.narrate_round()
def guide_player(self):
TM.agents.narrator.action_to_narration(
action_name="paraphrase",
narration=MSG_HELP,
emit_message=True
)
def narrate_round(self):
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=PROMPT_NARRATE_ROUND,
emit_message=True
)
def run_update_world_state(self, force=False):
TM.log.debug("SIMULATION SUITE: update world state", force=force)
TM.emit_status("busy", "Simulation suite updating world state.", as_scene_message=True)
TM.agents.world_state.update_world_state(force=force)
TM.emit_status("success", "Simulation suite updated world state.", as_scene_message=True)
SimulationSuite().run()

View File

@@ -0,0 +1,53 @@
{
"name": "Simulation Suite",
"title": "Simulation Suite",
"environment": "scene",
"immutable_save": true,
"restore_from": "simulation-suite.json",
"experimental": true,
"help": "Address the computer by starting your statements with 'Computer, ' followed by an instruction.\n\nExamples:\n'Computer, i would like to experience an adventure on a derelict space station'\n'Computer, add a horrific alien creature that is chasing me.'",
"description": "",
"intro": "*You have entered the simulation suite. No simulation is currently active and you are in a non-descript space with paneled walls surrounding you. The control panel next to you is pulsating with a green light, indicating readiness to receive a prompt to start the simulation.*",
"archived_history": [],
"history": [],
"ts": "PT1S",
"characters": [
{
"name": "You",
"gender": "unknown",
"color": "cornflowerblue",
"base_attributes": {},
"is_player": true
}
],
"context": "a simulated experience",
"game_state": {
"ops":{
"run_on_start": true,
"always_direct": true
},
"variables": {}
},
"world_state": {
"character_name_mappings": {
"You": [
"user",
"player",
"player character",
"user character",
"the user",
"the player"
]
}
},
"assets": {
"cover_image": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
"assets": {
"4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103": {
"id": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
"file_type": "png",
"media_type": "image/png"
}
}
}
}

View File

@@ -0,0 +1,28 @@
<|SECTION:EXAMPLES|>
combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "Sarah" into a single text string argument to be passed to a single `add_ai_character` function call.
```
set_simulation_goal("player experiences a rollercoaster ride")
change_environment("theme park, riding a rollercoaster")
set_player_persona("young female experiencing rollercoaster ride")
set_player_name("Susanne")
add_ai_character("a female friend of player named Sarah")
change_ai_character("Sarah hates rollercoasters")
```
COMBINED ARGUMENT: "a female friend of player named Sarah, Sarah hates rollercoasters"
TASK: combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "George" into a single text string argument to be passed to a single `add_ai_character` function call.
```
change_environment("building on fire")
change_ai_character("George is injured")
add_ai_character("a firefighter named Stephen")
change_ai_character("Stephen is afraid of heights")
```
COMBINED ARGUMENT: "a firefighter named Stephen, Stephen is afraid of heights"
<|CLOSE_SECTION|>
<|SECTION:TASK|>
TASK: combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "{{ character_name }}" into a single text string argument to be passed to a single `add_ai_character` function call.
```
{{ calls }}
```
{{ set_prepared_response("COMBINED ARGUMENT:") }}

View File

@@ -0,0 +1,132 @@
<|SECTION:CONTEXT|>
{% set scene_history=scene.context_history(budget=1024) %}
{% for scene_context in scene_history -%}
{{ loop.index }}. {{ scene_context }}
{% endfor %}
<|CLOSE_SECTION|>
<|SECTION:FUNCTIONS|>
The player has instructed the computer to alter the current simulation.
You have access to the following functions, you can call as many as you want to fulfill the player's requests.
You must at least call one of the following functions:
- change_environment
- add_ai_character
- change_ai_character
- remove_ai_character
- set_player_persona
- set_player_name
- end_simulation
- answer_question
- set_simulation_goal
`add_ai_character` and `change_ai_character` are exclusive if they are targeting the same character.
Set the player persona at the beginning of a new simulation or if the player requests a change.
Only end the simulation if the player requests it explicitly.
Your response MUST ONLY CONTAIN the new simulation stack.
<|CLOSE_SECTION|>
<|SECTION:EXAMPLES|>
Request: Computer, I want to be on a mountain top
```simulation-stack
change_environment("mountain top")
set_player_persona("mountain climber")
set_player_name("Hank")
```
Request: Computer, I want to be more muscular and taller
```simulation-stack
set_player_persona("make player more muscular and taller")
```
Request: Computer, the building should be on fire
```simulation-stack
change_environment("building on fire")
```
Request: Computer, a rocket hits the building and George is now injured
```simulation-stack
change_environment("building on fire")
change_ai_character("George is injured")
```
Request: Computer, I want to experience a rollercoaster ride with a friend
```simulation-stack
set_simulation_goal("player experiences a rollercoaster ride")
change_environment("theme park, riding a rollercoaster")
set_player_persona("young female experiencing rollercoaster ride")
set_player_name("Susanne")
add_ai_character("a female friend of player named Sarah")
```
Request: Computer, I want to experience the international space station, to experience the overview effect
```simulation-stack
set_simulation_goal("player experiences the overview effect")
change_environment("international space station")
set_player_persona("astronaut experiencing first trip to ISS")
set_player_name("George")
add_ai_character("astronaut named Henry")
```
Request: Computer, remove the goblin and add an elven woman instead
```simulation-stack
remove_ai_character("goblin")
add_ai_character("elven woman named Elune")
```
Request: Computer, change the skiing instructor to be older.
```simulation-stack
change_ai_character("make skiing instructor older")
```
Request: Computer, change my grandma to my grandpa
```simulation-stack
remove_ai_character("grandma")
add_ai_character("grandpa named Steven")
```
Request: Computer, remove the skiing instructor and add my friend instead.
```simulation-stack
remove_ai_character("skiing instructor")
add_ai_character("player's friend named Tara")
```
Request: Computer, replace the skiing instructor with my friend.
```simulation-stack
remove_ai_character("skiing instructor")
add_ai_character("player's friend named Lisa")
```
Request: Computer, I want to end the simulation
```simulation-stack
end_simulation("simulation ended")
```
Request: Computer, shut down the simulation
```simulation-stack
end_simulation("simulation ended")
```
Request: Computer, what do you know about the game of thrones?
```simulation-stack
answer_question("what do you know about the game of thrones?")
```
Request: Computer, i want to be a wizard in a dark goblin infested dungeon in a fantasy world, looking for secret treasure and fighting goblins.
```simulation-stack
set_simulation_goal("player wants to find secret treasure and fight creatures")
change_environment("dark dungeon in a fantasy world")
set_player_persona("powerful wizard")
set_player_name("Lanadel")
add_ai_character("a goblin named Gobbo")
```
<|CLOSE_SECTION|>
<|SECTION:TASK|>
Respond with the simulation stack for the following request:
Request: {{ player_instruction }}
{{ bot_token }}```simulation-stack

View File

@@ -2,4 +2,4 @@ from .agents import Agent
from .client import TextGeneratorWebuiClient
from .tale_mate import *
VERSION = "0.18.0"
VERSION = "0.25.1"

View File

@@ -1,11 +1,12 @@
from .base import Agent
from .creator import CreatorAgent
from .conversation import ConversationAgent
from .creator import CreatorAgent
from .director import DirectorAgent
from .editor import EditorAgent
from .memory import ChromaDBMemoryAgent, MemoryAgent
from .narrator import NarratorAgent
from .registry import AGENT_CLASSES, get_agent_class, register
from .summarize import SummarizeAgent
from .editor import EditorAgent
from .tts import TTSAgent
from .visual import VisualAgent
from .world_state import WorldStateAgent
from .tts import TTSAgent

View File

@@ -1,24 +1,30 @@
from __future__ import annotations
import asyncio
import dataclasses
import re
from abc import ABC
from functools import wraps
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import pydantic
import structlog
from blinker import signal
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
from talemate.agents.context import ActiveAgent
from talemate.emit import emit
from talemate.events import GameLoopStartEvent
import talemate.emit.async_signals
import dataclasses
import pydantic
import structlog
__all__ = [
"Agent",
"AgentAction",
"AgentActionConditional",
"AgentActionConfig",
"AgentDetail",
"AgentEmission",
"set_processing",
]
@@ -37,26 +43,41 @@ class AgentActionConfig(pydantic.BaseModel):
scope: str = "global"
choices: Union[list[dict[str, str]], None] = None
note: Union[str, None] = None
class Config:
arbitrary_types_allowed = True
class AgentActionConditional(pydantic.BaseModel):
attribute: str
value: Union[int, float, str, bool, None] = None
class AgentAction(pydantic.BaseModel):
enabled: bool = True
label: str
description: str = ""
config: Union[dict[str, AgentActionConfig], None] = None
condition: Union[AgentActionConditional, None] = None
class AgentDetail(pydantic.BaseModel):
value: Union[str, None] = None
description: Union[str, None] = None
icon: Union[str, None] = None
color: str = "grey"
def set_processing(fn):
"""
decorator that emits the agent status as processing while the function
is running.
Done via a try - final block to ensure the status is reset even if
the function fails.
"""
@wraps(fn)
async def wrapper(self, *args, **kwargs):
with ActiveAgent(self, fn):
try:
@@ -69,9 +90,8 @@ def set_processing(fn):
# not sure why this happens
# some concurrency error?
log.error("error emitting agent status", exc=exc)
wrapper.__name__ = fn.__name__
wrapper.exposed = True
return wrapper
@@ -85,6 +105,9 @@ class Agent(ABC):
set_processing = set_processing
requires_llm_client = True
auto_break_repetition = False
websocket_handler = None
essential = True
ready_check_error = None
@property
def agent_details(self):
@@ -97,46 +120,51 @@ class Agent(ABC):
def verbose_name(self):
return self.agent_type.capitalize()
@property
def ready(self):
if not getattr(self.client, "enabled", True):
return False
if self.client and self.client.current_status in ["error", "warning"]:
return False
return self.client is not None
@property
def status(self):
if self.ready:
if not self.enabled:
return "disabled"
return "idle" if getattr(self, "processing", 0) == 0 else "busy"
else:
if not self.enabled:
return "disabled"
if not self.ready:
return "uninitialized"
if getattr(self, "processing", 0) > 0:
return "busy"
if getattr(self, "processing_bg", 0) > 0:
return "busy_bg"
return "idle"
@property
def enabled(self):
# by default, agents are enabled, an agent class that
# is disableable should override this property
return True
@property
def disable(self):
# by default, agents are enabled, an agent class that
# is disableable should override this property to
# is disableable should override this property to
# disable the agent
pass
@property
def has_toggle(self):
# by default, agents do not have toggles to enable / disable
# an agent class that is disableable should override this property
return False
@property
def experimental(self):
# by default, agents are not experimental, an agent class that
@@ -153,100 +181,177 @@ class Agent(ABC):
"requires_llm_client": cls.requires_llm_client,
}
actions = getattr(agent, "actions", None)
if actions:
config_options["actions"] = {k: v.model_dump() for k, v in actions.items()}
else:
config_options["actions"] = {}
return config_options
def apply_config(self, *args, **kwargs):
@property
def meta(self):
return {
"essential": self.essential,
}
@property
def sanitized_action_config(self):
if not getattr(self, "actions", None):
return {}
return {k: v.model_dump() for k, v in self.actions.items()}
async def _handle_ready_check(self, fut: asyncio.Future):
callback_failure = getattr(self, "on_ready_check_failure", None)
if fut.cancelled():
if callback_failure:
await callback_failure()
return
if fut.exception():
exc = fut.exception()
self.ready_check_error = exc
log.error("agent ready check error", agent=self.agent_type, exc=exc)
if callback_failure:
await callback_failure(exc)
return
callback = getattr(self, "on_ready_check_success", None)
if callback:
await callback()
async def ready_check(self, task: asyncio.Task = None):
self.ready_check_error = None
if task:
task.add_done_callback(
lambda fut: asyncio.create_task(self._handle_ready_check(fut))
)
return
return True
async def apply_config(self, *args, **kwargs):
if self.has_toggle and "enabled" in kwargs:
self.is_enabled = kwargs.get("enabled", False)
if not getattr(self, "actions", None):
return
for action_key, action in self.actions.items():
if not kwargs.get("actions"):
continue
action.enabled = kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
action.enabled = (
kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
)
if not action.config:
continue
for config_key, config in action.config.items():
try:
config.value = kwargs.get("actions", {}).get(action_key, {}).get("config", {}).get(config_key, {}).get("value", config.value)
config.value = (
kwargs.get("actions", {})
.get(action_key, {})
.get("config", {})
.get(config_key, {})
.get("value", config.value)
)
except AttributeError:
pass
async def on_game_loop_start(self, event:GameLoopStartEvent):
async def on_game_loop_start(self, event: GameLoopStartEvent):
"""
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values
"""
if not getattr(self, "actions", None):
return
for _, action in self.actions.items():
if not action.config:
continue
for _, config in action.config.items():
if config.scope == "scene":
# if default_value is None, just use the `type` of the current
# if default_value is None, just use the `type` of the current
# value
if config.default_value is None:
default_value = type(config.value)()
else:
default_value = config.default_value
log.debug("resetting config", config=config, default_value=default_value)
log.debug(
"resetting config", config=config, default_value=default_value
)
config.value = default_value
await self.emit_status()
async def emit_status(self, processing: bool = None):
# should keep a count of processing requests, and when the
# number is 0 status is "idle", if the number is greater than 0
# status is "busy"
#
# increase / decrease based on value of `processing`
if getattr(self, "processing", None) is None:
self.processing = 0
if not processing:
if processing is False:
self.processing -= 1
self.processing = max(0, self.processing)
else:
elif processing is True:
self.processing += 1
status = "busy" if self.processing > 0 else "idle"
if not self.enabled:
status = "disabled"
emit(
"agent_status",
message=self.verbose_name or "",
id=self.agent_type,
status=status,
status=self.status,
details=self.agent_details,
meta=self.meta,
data=self.config_options(agent=self),
)
await asyncio.sleep(0.01)
async def _handle_background_processing(self, fut: asyncio.Future):
try:
if fut.cancelled():
return
if fut.exception():
log.error(
"background processing error",
agent=self.agent_type,
exc=fut.exception(),
)
await self.emit_status()
return
log.info("background processing done", agent=self.agent_type)
finally:
self.processing_bg -= 1
await self.emit_status()
async def set_background_processing(self, task: asyncio.Task):
log.info("set_background_processing", agent=self.agent_type)
if not hasattr(self, "processing_bg"):
self.processing_bg = 0
self.processing_bg += 1
await self.emit_status()
task.add_done_callback(
lambda fut: asyncio.create_task(self._handle_background_processing(fut))
)
def connect(self, scene):
self.scene = scene
talemate.emit.async_signals.get("game_loop_start").connect(self.on_game_loop_start)
talemate.emit.async_signals.get("game_loop_start").connect(
self.on_game_loop_start
)
def clean_result(self, result):
if "#" in result:
@@ -291,23 +396,28 @@ class Agent(ABC):
current_memory_context.append(memory)
return current_memory_context
# LLM client related methods. These are called during or after the client
# sends the prompt to the API.
def inject_prompt_paramters(self, prompt_param:dict, kind:str, agent_function_name:str):
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
"""
Injects prompt parameters before the client sends off the prompt
Override as needed.
"""
pass
def allow_repetition_break(self, kind:str, agent_function_name:str, auto:bool=False):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
"""
Returns True if repetition breaking is allowed, False otherwise.
"""
return False
@dataclasses.dataclass
class AgentEmission:
agent: Agent
agent: Agent

View File

@@ -1,3 +1,3 @@
"""
Code has been moved.
"""
"""

View File

@@ -1,6 +1,6 @@
from typing import Callable, TYPE_CHECKING
import contextvars
from typing import TYPE_CHECKING, Callable
import pydantic
__all__ = [
@@ -9,25 +9,38 @@ __all__ = [
active_agent = contextvars.ContextVar("active_agent", default=None)
class ActiveAgentContext(pydantic.BaseModel):
agent: object
fn: Callable
agent_stack: list = pydantic.Field(default_factory=list)
class Config:
arbitrary_types_allowed=True
arbitrary_types_allowed = True
@property
def action(self):
return self.fn.__name__
def __str__(self):
return f"{self.agent.verbose_name}.{self.action}"
class ActiveAgent:
def __init__(self, agent, fn):
self.agent = ActiveAgentContext(agent=agent, fn=fn)
def __enter__(self):
previous_agent = active_agent.get()
if previous_agent:
self.agent.agent_stack = previous_agent.agent_stack + [str(self.agent)]
else:
self.agent.agent_stack = [str(self.agent)]
self.token = active_agent.set(self.agent)
def __exit__(self, *args, **kwargs):
active_agent.reset(self.token)
return False

View File

@@ -1,40 +1,55 @@
from __future__ import annotations
import dataclasses
import re
import random
import re
from datetime import datetime
from typing import TYPE_CHECKING, Optional, Union
import structlog
import talemate.client as client
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
import structlog
from talemate.client.context import (
client_context_attribute,
set_client_context_attribute,
set_conversation_context_attribute,
)
from talemate.emit import emit
import talemate.emit.async_signals
from talemate.scene_message import CharacterMessage, DirectorMessage
from talemate.prompts import Prompt
from talemate.events import GameLoopEvent
from talemate.client.context import set_conversation_context_attribute, client_context_attribute, set_client_context_attribute
from talemate.prompts import Prompt
from talemate.scene_message import CharacterMessage, DirectorMessage
from .base import Agent, AgentEmission, set_processing, AgentAction, AgentActionConfig
from .base import (
Agent,
AgentAction,
AgentActionConfig,
AgentDetail,
AgentEmission,
set_processing,
)
from .registry import register
if TYPE_CHECKING:
from talemate.tale_mate import Character, Scene, Actor
from talemate.tale_mate import Actor, Character, Scene
log = structlog.get_logger("talemate.agents.conversation")
@dataclasses.dataclass
class ConversationAgentEmission(AgentEmission):
actor: Actor
character: Character
generation: list[str]
talemate.emit.async_signals.register(
"agent.conversation.before_generate",
"agent.conversation.generated"
"agent.conversation.before_generate", "agent.conversation.generated"
)
@register()
class ConversationAgent(Agent):
"""
@@ -45,7 +60,7 @@ class ConversationAgent(Agent):
agent_type = "conversation"
verbose_name = "Conversation"
min_dialogue_length = 75
def __init__(
@@ -60,28 +75,37 @@ class ConversationAgent(Agent):
self.logging_enabled = logging_enabled
self.logging_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.current_memory_context = None
# several agents extend this class, but we only want to initialize
# these actions for the conversation agent
if self.agent_type != "conversation":
return
self.actions = {
"generation_override": AgentAction(
enabled = True,
label = "Generation Override",
description = "Override generation parameters",
config = {
enabled=True,
label="Generation Settings",
config={
"format": AgentActionConfig(
type="text",
label="Format",
description="The generation format of the scene context, as seen by the AI.",
choices=[
{"label": "Screenplay", "value": "movie_script"},
{"label": "Chat (legacy)", "value": "chat"},
],
value="movie_script",
),
"length": AgentActionConfig(
type="number",
label="Generation Length (tokens)",
description="Maximum number of tokens to generate for a conversation response.",
value=96,
value=128,
min=32,
max=512,
step=32,
),#
), #
"instructions": AgentActionConfig(
type="text",
label="Instructions",
@@ -96,24 +120,24 @@ class ConversationAgent(Agent):
min=0.0,
max=1.0,
step=0.1,
)
}
),
},
),
"auto_break_repetition": AgentAction(
enabled = True,
label = "Auto Break Repetition",
description = "Will attempt to automatically break AI repetition.",
enabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"natural_flow": AgentAction(
enabled = True,
label = "Natural Flow",
description = "Will attempt to generate a more natural flow of conversation between multiple characters.",
config = {
enabled=True,
label="Natural Flow",
description="Will attempt to generate a more natural flow of conversation between multiple characters.",
config={
"max_auto_turns": AgentActionConfig(
type="number",
label="Max. Auto Turns",
description="The maximum number of turns the AI is allowed to generate before it stops and waits for the player to respond.",
value=4,
value=4,
min=1,
max=100,
step=1,
@@ -122,72 +146,114 @@ class ConversationAgent(Agent):
type="number",
label="Max. Idle Turns",
description="The maximum number of turns a character can go without speaking before they are considered overdue to speak.",
value=8,
value=8,
min=1,
max=100,
step=1,
),
}
},
),
"use_long_term_memory": AgentAction(
enabled = True,
label = "Long Term Memory",
description = "Will augment the conversation prompt with long term memory.",
config = {
enabled=True,
label="Long Term Memory",
description="Will augment the conversation prompt with long term memory.",
config={
"retrieval_method": AgentActionConfig(
type="text",
label="Context Retrieval Method",
description="How relevant context is retrieved from the long term memory.",
value="direct",
choices=[
{"label": "Context queries based on recent dialogue (fast)", "value": "direct"},
{"label": "Context queries generated by AI", "value": "queries"},
{"label": "AI compiled question and answers (slow)", "value": "questions"},
]
{
"label": "Context queries based on recent dialogue (fast)",
"value": "direct",
},
{
"label": "Context queries generated by AI",
"value": "queries",
},
{
"label": "AI compiled question and answers (slow)",
"value": "questions",
},
],
),
}
),
},
),
}
@property
def conversation_format(self):
if self.actions["generation_override"].enabled:
return self.actions["generation_override"].config["format"].value
return "movie_script"
@property
def conversation_format_label(self):
value = self.conversation_format
choices = self.actions["generation_override"].config["format"].choices
for choice in choices:
if choice["value"] == value:
return choice["label"]
return value
@property
def agent_details(self) -> dict:
details = {
"client": AgentDetail(
icon="mdi-network-outline",
value=self.client.name if self.client else None,
description="The client to use for prompt generation",
).model_dump(),
"format": AgentDetail(
icon="mdi-format-float-none",
value=self.conversation_format_label,
description="Generation format of the scene context, as seen by the AI",
).model_dump(),
}
return details
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
def last_spoken(self):
"""
Returns the last time each character spoke
"""
last_turn = {}
turns = 0
character_names = self.scene.character_names
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
for idx in range(len(self.scene.history) - 1, -1, -1):
if isinstance(self.scene.history[idx], CharacterMessage):
if turns >= max_idle_turns:
break
character = self.scene.history[idx].character_name
if character in character_names:
last_turn[character] = turns
character_names.remove(character)
if not character_names:
break
turns += 1
if character_names and turns >= max_idle_turns:
for character in character_names:
last_turn[character] = max_idle_turns
last_turn[character] = max_idle_turns
return last_turn
def repeated_speaker(self):
"""
Counts the amount of times the most recent speaker has spoken in a row
@@ -203,125 +269,164 @@ class ConversationAgent(Agent):
else:
break
return count
async def on_game_loop(self, event:GameLoopEvent):
async def on_game_loop(self, event: GameLoopEvent):
await self.apply_natural_flow()
async def apply_natural_flow(self, force: bool = False, npcs_only: bool = False):
"""
If the natural flow action is enabled, this will attempt to determine
the ideal character to talk next.
This will let the AI pick a character to talk to, but if the AI can't figure
it out it will apply rules based on max_idle_turns and max_auto_turns.
If all fails it will just pick a random character.
Repetition is also taken into account, so if a character has spoken twice in a row
they will not be picked again until someone else has spoken.
"""
scene = self.scene
if not scene.auto_progress and not force:
# we only apply natural flow if auto_progress is enabled
return
if self.actions["natural_flow"].enabled and len(scene.character_names) > 2:
# last time each character spoke (turns ago)
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
max_auto_turns = self.actions["natural_flow"].config["max_auto_turns"].value
last_turn = self.last_spoken()
player_name = scene.get_player_character().name
last_turn_player = last_turn.get(player_name, 0)
if last_turn_player >= max_auto_turns and not npcs_only:
self.scene.next_actor = scene.get_player_character().name
log.debug("conversation_agent.natural_flow", next_actor="player", overdue=True, player_character=scene.get_player_character().name)
log.debug(
"conversation_agent.natural_flow",
next_actor="player",
overdue=True,
player_character=scene.get_player_character().name,
)
return
log.debug("conversation_agent.natural_flow", last_turn=last_turn)
# determine random character to talk, this will be the fallback in case
# the AI can't figure out who should talk next
if scene.prev_actor:
# we dont want to talk to the same person twice in a row
character_names = scene.character_names
character_names.remove(scene.prev_actor)
if npcs_only:
character_names = [c for c in character_names if c != player_name]
random_character_name = random.choice(character_names)
else:
character_names = scene.character_names
character_names = scene.character_names
# no one has talked yet, so we just pick a random character
if npcs_only:
character_names = [c for c in character_names if c != player_name]
random_character_name = random.choice(scene.character_names)
overdue_characters = [character for character, turn in last_turn.items() if turn >= max_idle_turns]
overdue_characters = [
character
for character, turn in last_turn.items()
if turn >= max_idle_turns
]
if npcs_only:
overdue_characters = [c for c in overdue_characters if c != player_name]
if overdue_characters and self.scene.history:
# Pick a random character from the overdue characters
scene.next_actor = random.choice(overdue_characters)
elif scene.history:
scene.next_actor = None
# AI will attempt to figure out who should talk next
next_actor = await self.select_talking_actor(character_names)
next_actor = next_actor.strip().strip('"').strip(".")
next_actor = next_actor.split("\n")[0].strip().strip('"').strip(".")
for character_name in scene.character_names:
if next_actor.lower() in character_name.lower() or character_name.lower() in next_actor.lower():
if (
next_actor.lower() in character_name.lower()
or character_name.lower() in next_actor.lower()
):
scene.next_actor = character_name
break
if not scene.next_actor:
# AI couldn't figure out who should talk next, so we just pick a random character
log.debug("conversation_agent.natural_flow", next_actor="random", random_character_name=random_character_name)
log.debug(
"conversation_agent.natural_flow",
next_actor="random",
random_character_name=random_character_name,
)
scene.next_actor = random_character_name
else:
log.debug("conversation_agent.natural_flow", next_actor="picked", ai_next_actor=scene.next_actor)
log.debug(
"conversation_agent.natural_flow",
next_actor="picked",
ai_next_actor=scene.next_actor,
)
else:
# always start with main character (TODO: configurable?)
player_character = scene.get_player_character()
log.debug("conversation_agent.natural_flow", next_actor="main_character", main_character=player_character)
scene.next_actor = player_character.name if player_character else random_character_name
scene.log.debug("conversation_agent.natural_flow", next_actor=scene.next_actor)
log.debug(
"conversation_agent.natural_flow",
next_actor="main_character",
main_character=player_character,
)
scene.next_actor = (
player_character.name if player_character else random_character_name
)
scene.log.debug(
"conversation_agent.natural_flow", next_actor=scene.next_actor
)
# same character cannot go thrice in a row, if this is happening, pick a random character that
# isnt the same as the last character
if self.repeated_speaker() >= 2 and self.scene.prev_actor == self.scene.next_actor:
scene.next_actor = random.choice([c for c in scene.character_names if c != scene.prev_actor])
scene.log.debug("conversation_agent.natural_flow", next_actor="random (repeated safeguard)", random_character_name=scene.next_actor)
if (
self.repeated_speaker() >= 2
and self.scene.prev_actor == self.scene.next_actor
):
scene.next_actor = random.choice(
[c for c in scene.character_names if c != scene.prev_actor]
)
scene.log.debug(
"conversation_agent.natural_flow",
next_actor="random (repeated safeguard)",
random_character_name=scene.next_actor,
)
else:
scene.next_actor = None
@set_processing
async def select_talking_actor(self, character_names: list[str]=None):
result = await Prompt.request("conversation.select-talking-actor", self.client, "conversation_select_talking_actor", vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character_names": character_names or self.scene.character_names,
"character_names_formatted": ", ".join(character_names or self.scene.character_names),
})
async def select_talking_actor(self, character_names: list[str] = None):
result = await Prompt.request(
"conversation.select-talking-actor",
self.client,
"conversation_select_talking_actor",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character_names": character_names or self.scene.character_names,
"character_names_formatted": ", ".join(
character_names or self.scene.character_names
),
},
)
return result
async def build_prompt_default(
self,
@@ -335,17 +440,17 @@ class ConversationAgent(Agent):
# we subtract 200 to account for the response
scene = character.actor.scene
total_token_budget = self.client.max_token_length - 200
scene_and_dialogue_budget = total_token_budget - 500
long_term_memory_budget = min(int(total_token_budget * 0.05), 200)
scene_and_dialogue = scene.context_history(
budget=scene_and_dialogue_budget,
budget=scene_and_dialogue_budget,
keep_director=True,
sections=False,
)
memory = await self.build_prompt_default_memory(character)
main_character = scene.main_character.character
@@ -360,36 +465,41 @@ class ConversationAgent(Agent):
)
else:
formatted_names = character_names[0] if character_names else ""
try:
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
except IndexError:
director_message = False
extra_instructions = ""
if self.actions["generation_override"].enabled:
extra_instructions = self.actions["generation_override"].config["instructions"].value
extra_instructions = (
self.actions["generation_override"].config["instructions"].value
)
conversation_format = self.conversation_format
prompt = Prompt.get(
f"conversation.dialogue-{conversation_format}",
vars={
"scene": scene,
"max_tokens": self.client.max_token_length,
"scene_and_dialogue_budget": scene_and_dialogue_budget,
"scene_and_dialogue": scene_and_dialogue,
"memory": memory,
"characters": list(scene.get_characters()),
"main_character": main_character,
"formatted_names": formatted_names,
"talking_character": character,
"partial_message": char_message,
"director_message": director_message,
"extra_instructions": extra_instructions,
"decensor": self.client.decensor_enabled,
},
)
prompt = Prompt.get("conversation.dialogue", vars={
"scene": scene,
"max_tokens": self.client.max_token_length,
"scene_and_dialogue_budget": scene_and_dialogue_budget,
"scene_and_dialogue": scene_and_dialogue,
"memory": memory,
"characters": list(scene.get_characters()),
"main_character": main_character,
"formatted_names": formatted_names,
"talking_character": character,
"partial_message": char_message,
"director_message": director_message,
"extra_instructions": extra_instructions,
})
return str(prompt)
async def build_prompt_default_memory(
self, character: Character
):
async def build_prompt_default_memory(self, character: Character):
"""
Builds long term memory for the conversation prompt
@@ -404,39 +514,56 @@ class ConversationAgent(Agent):
if not self.actions["use_long_term_memory"].enabled:
return []
if self.current_memory_context:
return self.current_memory_context
self.current_memory_context = ""
retrieval_method = self.actions["use_long_term_memory"].config["retrieval_method"].value
retrieval_method = (
self.actions["use_long_term_memory"].config["retrieval_method"].value
)
if retrieval_method != "direct":
world_state = instance.get_agent("world_state")
history = self.scene.context_history(min_dialogue=3, max_dialogue=15, keep_director=False, sections=False, add_archieved_history=False)
history = self.scene.context_history(
min_dialogue=3,
max_dialogue=15,
keep_director=False,
sections=False,
add_archieved_history=False,
)
text = "\n".join(history)
log.debug("conversation_agent.build_prompt_default_memory", direct=False, version=retrieval_method)
log.debug(
"conversation_agent.build_prompt_default_memory",
direct=False,
version=retrieval_method,
)
if retrieval_method == "questions":
self.current_memory_context = (await world_state.analyze_text_and_extract_context(
text, f"continue the conversation as {character.name}"
)).split("\n")
self.current_memory_context = (
await world_state.analyze_text_and_extract_context(
text, f"continue the conversation as {character.name}"
)
).split("\n")
elif retrieval_method == "queries":
self.current_memory_context = await world_state.analyze_text_and_extract_context_via_queries(
text, f"continue the conversation as {character.name}"
self.current_memory_context = (
await world_state.analyze_text_and_extract_context_via_queries(
text, f"continue the conversation as {character.name}"
)
)
else:
history = list(map(str, self.scene.collect_messages(max_iterations=3)))
log.debug("conversation_agent.build_prompt_default_memory", history=history, direct=True)
log.debug(
"conversation_agent.build_prompt_default_memory",
history=history,
direct=True,
)
memory = instance.get_agent("memory")
context = await memory.multi_query(history, max_tokens=500, iterate=5)
self.current_memory_context = context
return self.current_memory_context
async def build_prompt(self, character, char_message: str = ""):
@@ -445,29 +572,37 @@ class ConversationAgent(Agent):
return await fn(character, char_message=char_message)
def clean_result(self, result, character):
if "#" in result:
result = result.split("#")[0]
if "(Internal" in result:
result = result.split("(Internal")[0]
result = result.replace(" :", ":")
result = result.replace("[", "*").replace("]", "*")
result = result.replace("(", "*").replace(")", "*")
result = result.replace("**", "*")
result = util.handle_endofline_special_delimiter(result)
return result
def set_generation_overrides(self):
if not self.actions["generation_override"].enabled:
return
set_conversation_context_attribute("length", self.actions["generation_override"].config["length"].value)
set_conversation_context_attribute(
"length", self.actions["generation_override"].config["length"].value
)
if self.actions["generation_override"].config["jiggle"].value > 0.0:
nuke_repetition = client_context_attribute("nuke_repetition")
if nuke_repetition == 0.0:
# we only apply the agent override if some other mechanism isn't already
# setting the nuke_repetition value
nuke_repetition = self.actions["generation_override"].config["jiggle"].value
nuke_repetition = (
self.actions["generation_override"].config["jiggle"].value
)
set_client_context_attribute("nuke_repetition", nuke_repetition)
@set_processing
@@ -479,10 +614,14 @@ class ConversationAgent(Agent):
self.current_memory_context = None
character = actor.character
emission = ConversationAgentEmission(agent=self, generation="", actor=actor, character=character)
await talemate.emit.async_signals.get("agent.conversation.before_generate").send(emission)
emission = ConversationAgentEmission(
agent=self, generation="", actor=actor, character=character
)
await talemate.emit.async_signals.get(
"agent.conversation.before_generate"
).send(emission)
self.set_generation_overrides()
result = await self.client.send_prompt(await self.build_prompt(character))
@@ -505,7 +644,7 @@ class ConversationAgent(Agent):
result = self.clean_result(result, character)
total_result += " "+result
total_result += " " + result
if len(total_result) == 0 and max_loops < 10:
max_loops += 1
@@ -525,13 +664,24 @@ class ConversationAgent(Agent):
result = result.replace(" :", ":")
total_result = total_result.split("#")[0]
total_result = total_result.split("#")[0].strip()
total_result = util.handle_endofline_special_delimiter(total_result)
if total_result.startswith(":\n"):
total_result = total_result[2:]
# movie script format
# {uppercase character name}
# {dialogue}
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
# chat format
# {character name}: {dialogue}
total_result = total_result.replace(f"{character.name}:", "")
# Removes partial sentence at the end
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Remove "{character.name}:" - all occurences
total_result = total_result.replace(f"{character.name}:", "")
# Check if total_result starts with character name, if not, prepend it
if not total_result.startswith(character.name):
@@ -548,13 +698,17 @@ class ConversationAgent(Agent):
)
response_message = util.parse_messages_from_str(total_result, [character.name])
log.info("conversation agent", result=response_message)
emission = ConversationAgentEmission(agent=self, generation=response_message, actor=actor, character=character)
await talemate.emit.async_signals.get("agent.conversation.generated").send(emission)
#log.info("conversation agent", generation=emission.generation)
log.info("conversation agent", result=response_message)
emission = ConversationAgentEmission(
agent=self, generation=response_message, actor=actor, character=character
)
await talemate.emit.async_signals.get("agent.conversation.generated").send(
emission
)
# log.info("conversation agent", generation=emission.generation)
messages = [CharacterMessage(message) for message in emission.generation]
@@ -563,15 +717,17 @@ class ConversationAgent(Agent):
return messages
def allow_repetition_break(self, kind: str, agent_function_name: str, auto: bool = False):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return agent_function_name == "converse"
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += ['[']
prompt_param["extra_stopping_strings"] += ["#"]

View File

@@ -3,22 +3,23 @@ from __future__ import annotations
import json
import os
import talemate.client as client
from talemate.agents.base import Agent, set_processing
from talemate.agents.registry import register
from talemate.emit import emit
from talemate.prompts import Prompt
import talemate.client as client
from .assistant import AssistantMixin
from .character import CharacterCreatorMixin
from .scenario import ScenarioCreatorMixin
@register()
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, AssistantMixin, Agent):
"""
Creates characters and scenarios and other fun stuff!
"""
agent_type = "creator"
verbose_name = "Creator"
@@ -78,12 +79,14 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
# Remove duplicates while preserving the order for list type keys
for key, value in merged_data.items():
if isinstance(value, list):
merged_data[key] = [x for i, x in enumerate(value) if x not in value[:i]]
merged_data[key] = [
x for i, x in enumerate(value) if x not in value[:i]
]
merged_data["context"] = context
return merged_data
def load_templates_old(self, names: list, template_type: str = "character") -> dict:
"""
Loads multiple character creation templates from ./templates/character and merges them in order.
@@ -128,8 +131,10 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
if "context" in template_data["instructions"]:
context = template_data["instructions"]["context"]
merged_instructions[name]["questions"] = [q[0] for q in template_data.get("questions", [])]
merged_instructions[name]["questions"] = [
q[0] for q in template_data.get("questions", [])
]
# Remove duplicates while preserving the order
merged_template = [
@@ -158,24 +163,33 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
return rv
@set_processing
async def generate_json_list(
self,
text:str,
count:int=20,
first_item:str=None,
text: str,
count: int = 20,
first_item: str = None,
):
_, json_list = await Prompt.request(f"creator.generate-json-list", self.client, "create", vars={
"text": text,
"first_item": first_item,
"count": count,
})
return json_list.get("items",[])
_, json_list = await Prompt.request(
f"creator.generate-json-list",
self.client,
"create",
vars={
"text": text,
"first_item": first_item,
"count": count,
},
)
return json_list.get("items", [])
@set_processing
async def generate_title(self, text:str):
title = await Prompt.request(f"creator.generate-title", self.client, "create_short", vars={
"text": text,
})
return title
async def generate_title(self, text: str):
title = await Prompt.request(
f"creator.generate-title",
self.client,
"create_short",
vars={
"text": text,
},
)
return title

View File

@@ -0,0 +1,182 @@
import asyncio
from typing import TYPE_CHECKING, Tuple, Union
import pydantic
import talemate.util as util
from talemate.agents.base import set_processing
from talemate.emit import emit
from talemate.prompts import Prompt
if TYPE_CHECKING:
from talemate.tale_mate import Character, Scene
class ContentGenerationContext(pydantic.BaseModel):
"""
A context for generating content.
"""
context: str
instructions: str = ""
length: int = 100
character: Union[str, None] = None
original: Union[str, None] = None
partial: str = ""
@property
def computed_context(self) -> Tuple[str, str]:
typ, context = self.context.split(":", 1)
return typ, context
class AssistantMixin:
"""
Creator mixin that allows quick contextual generation of content.
"""
async def contextual_generate_from_args(
self,
context: str,
instructions: str = "",
length: int = 100,
character: Union[str, None] = None,
original: Union[str, None] = None,
partial: str = "",
):
"""
Request content from the assistant.
"""
generation_context = ContentGenerationContext(
context=context,
instructions=instructions,
length=length,
character=character,
original=original,
partial=partial,
)
return await self.contextual_generate(generation_context)
contextual_generate_from_args.exposed = True
@set_processing
async def contextual_generate(
self,
generation_context: ContentGenerationContext,
):
"""
Request content from the assistant.
"""
context_typ, context_name = generation_context.computed_context
if generation_context.length < 100:
kind = "create_short"
elif generation_context.length < 500:
kind = "create_concise"
else:
kind = "create"
content = await Prompt.request(
f"creator.contextual-generate",
self.client,
kind,
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"generation_context": generation_context,
"context_typ": context_typ,
"context_name": context_name,
"can_coerce": self.client.can_be_coerced,
"character": (
self.scene.get_character(generation_context.character)
if generation_context.character
else None
),
},
)
if not generation_context.partial:
content = util.strip_partial_sentences(content)
return content.strip()
@set_processing
async def autocomplete_dialogue(
self,
input: str,
character: "Character",
emit_signal: bool = True,
) -> str:
"""
Autocomplete dialogue.
"""
response = await Prompt.request(
f"creator.autocomplete-dialogue",
self.client,
"create_short",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"input": input.strip(),
"character": character,
"can_coerce": self.client.can_be_coerced,
},
pad_prepended_response=False,
dedupe_enabled=False,
)
response = util.clean_dialogue(response, character.name)[
len(character.name + ":") :
].strip()
if response.startswith(input):
response = response[len(input) :]
self.scene.log.debug(
"autocomplete_suggestion", suggestion=response, input=input
)
if emit_signal:
emit("autocomplete_suggestion", response)
return response
@set_processing
async def autocomplete_narrative(
self,
input: str,
emit_signal: bool = True,
) -> str:
"""
Autocomplete narrative.
"""
response = await Prompt.request(
f"creator.autocomplete-narrative",
self.client,
"create_short",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"input": input.strip(),
"can_coerce": self.client.can_be_coerced,
},
pad_prepended_response=False,
dedupe_enabled=False,
)
if response.startswith(input):
response = response[len(input) :]
self.scene.log.debug(
"autocomplete_suggestion", suggestion=response, input=input
)
if emit_signal:
emit("autocomplete_suggestion", response)
return response

View File

@@ -1,42 +1,48 @@
from __future__ import annotations
import re
import asyncio
import random
import structlog
import re
from typing import TYPE_CHECKING, Callable
import structlog
import talemate.util as util
from talemate.emit import emit
from talemate.prompts import Prompt, LoopedPrompt
from talemate.exceptions import LLMAccuracyError
from talemate.agents.base import set_processing
from talemate.emit import emit
from talemate.exceptions import LLMAccuracyError
from talemate.prompts import LoopedPrompt, Prompt
if TYPE_CHECKING:
from talemate.tale_mate import Character
log = structlog.get_logger("talemate.agents.creator.character")
def validate(k,v):
def validate(k, v):
if k and k.lower() == "gender":
return v.lower().strip()
if k and k.lower() == "age":
try:
return int(v.split("\n")[0].strip())
except (ValueError, TypeError):
raise LLMAccuracyError("Was unable to get a valid age from the response", model_name=None)
raise LLMAccuracyError(
"Was unable to get a valid age from the response", model_name=None
)
return v.strip().strip("\n")
DEFAULT_CONTENT_CONTEXT="a fun and engaging adventure aimed at an adult audience."
DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audience."
class CharacterCreatorMixin:
"""
Adds character creation functionality to the creator agent
"""
## NEW
@set_processing
async def create_character_attributes(
self,
@@ -48,8 +54,6 @@ class CharacterCreatorMixin:
custom_attributes: dict[str, str] = dict(),
predefined_attributes: dict[str, str] = dict(),
):
def spice(prompt, spices):
# generate number from 0 to 1 and if its smaller than use_spice
# select a random spice from the list and return it formatted
@@ -57,69 +61,74 @@ class CharacterCreatorMixin:
if random.random() < use_spice:
spice = random.choice(spices)
return prompt.format(spice=spice)
return ""
return ""
# drop any empty attributes from predefined_attributes
predefined_attributes = {k:v for k,v in predefined_attributes.items() if v}
prompt = Prompt.get(f"creator.character-attributes-{template}", vars={
"character_prompt": character_prompt,
"template": template,
"spice": spice,
"content_context": content_context,
"custom_attributes": custom_attributes,
"character_sheet": LoopedPrompt(
validate_value=validate,
on_update=attribute_callback,
generated=predefined_attributes,
),
})
predefined_attributes = {k: v for k, v in predefined_attributes.items() if v}
prompt = Prompt.get(
f"creator.character-attributes-{template}",
vars={
"character_prompt": character_prompt,
"template": template,
"spice": spice,
"content_context": content_context,
"custom_attributes": custom_attributes,
"character_sheet": LoopedPrompt(
validate_value=validate,
on_update=attribute_callback,
generated=predefined_attributes,
),
},
)
await prompt.loop(self.client, "character_sheet", kind="create_concise")
return prompt.vars["character_sheet"].generated
@set_processing
async def create_character_description(
self,
character:Character,
self,
character: Character,
content_context: str = DEFAULT_CONTENT_CONTEXT,
):
description = await Prompt.request(f"creator.character-description", self.client, "create", vars={
"character": character,
"content_context": content_context,
})
description = await Prompt.request(
f"creator.character-description",
self.client,
"create",
vars={
"character": character,
"content_context": content_context,
},
)
return description.strip()
@set_processing
async def create_character_details(
self,
self,
character: Character,
template: str,
detail_callback: Callable = lambda question, answer: None,
questions: list[str] = None,
content_context: str = DEFAULT_CONTENT_CONTEXT,
):
prompt = Prompt.get(f"creator.character-details-{template}", vars={
"character_details": LoopedPrompt(
validate_value=validate,
on_update=detail_callback,
),
"template": template,
"content_context": content_context,
"character": character,
"custom_questions": questions or [],
})
prompt = Prompt.get(
f"creator.character-details-{template}",
vars={
"character_details": LoopedPrompt(
validate_value=validate,
on_update=detail_callback,
),
"template": template,
"content_context": content_context,
"character": character,
"custom_questions": questions or [],
},
)
await prompt.loop(self.client, "character_details", kind="create_concise")
return prompt.vars["character_details"].generated
@set_processing
async def create_character_example_dialogue(
self,
@@ -131,97 +140,156 @@ class CharacterCreatorMixin:
example_callback: Callable = lambda example: None,
rules_callback: Callable = lambda rules: None,
):
dialogue_rules = await Prompt.request(f"creator.character-dialogue-rules", self.client, "create", vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
})
dialogue_rules = await Prompt.request(
f"creator.character-dialogue-rules",
self.client,
"create",
vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
},
)
log.info("dialogue_rules", dialogue_rules=dialogue_rules)
if rules_callback:
rules_callback(dialogue_rules)
example_dialogue_prompt = Prompt.get(f"creator.character-example-dialogue-{template}", vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
"dialogue_rules": dialogue_rules,
"generated_examples": LoopedPrompt(
validate_value=validate,
on_update=example_callback,
),
})
await example_dialogue_prompt.loop(self.client, "generated_examples", kind="create")
example_dialogue_prompt = Prompt.get(
f"creator.character-example-dialogue-{template}",
vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
"dialogue_rules": dialogue_rules,
"generated_examples": LoopedPrompt(
validate_value=validate,
on_update=example_callback,
),
},
)
await example_dialogue_prompt.loop(
self.client, "generated_examples", kind="create"
)
return example_dialogue_prompt.vars["generated_examples"].generated
@set_processing
async def determine_content_context_for_character(
self,
character: Character,
):
content_context = await Prompt.request(f"creator.determine-content-context", self.client, "create", vars={
"character": character,
})
content_context = await Prompt.request(
f"creator.determine-content-context",
self.client,
"create",
vars={
"character": character,
},
)
return content_context.strip()
@set_processing
async def determine_character_dialogue_instructions(
self,
character: Character,
):
instructions = await Prompt.request(
f"creator.determine-character-dialogue-instructions",
self.client,
"create_concise",
vars={
"character": character,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
},
)
r = instructions.strip().split("\n")[0].strip('"').strip()
return r
@set_processing
async def determine_character_attributes(
self,
character: Character,
):
attributes = await Prompt.request(f"creator.determine-character-attributes", self.client, "analyze_long", vars={
"character": character,
})
attributes = await Prompt.request(
f"creator.determine-character-attributes",
self.client,
"analyze_long",
vars={
"character": character,
},
)
return attributes
@set_processing
async def determine_character_name(
self,
character_name: str,
allowed_names: list[str] = None,
group: bool = False,
) -> str:
name = await Prompt.request(
f"creator.determine-character-name",
self.client,
"analyze_freeform_short",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character_name": character_name,
"allowed_names": allowed_names or [],
"group": group,
},
)
return name.split('"', 1)[0].strip().strip(".").strip()
@set_processing
async def determine_character_description(
self,
character: Character,
text:str=""
self, character: Character, text: str = ""
):
description = await Prompt.request(f"creator.determine-character-description", self.client, "create", vars={
"character": character,
"scene": self.scene,
"text": text,
"max_tokens": self.client.max_token_length,
})
description = await Prompt.request(
f"creator.determine-character-description",
self.client,
"create",
vars={
"character": character,
"scene": self.scene,
"text": text,
"max_tokens": self.client.max_token_length,
},
)
return description.strip()
@set_processing
async def determine_character_goals(
self,
character: Character,
goal_instructions: str,
):
goals = await Prompt.request(f"creator.determine-character-goals", self.client, "create", vars={
"character": character,
"scene": self.scene,
"goal_instructions": goal_instructions,
"npc_name": character.name,
"player_name": self.scene.get_player_character().name,
"max_tokens": self.client.max_token_length,
})
goals = await Prompt.request(
f"creator.determine-character-goals",
self.client,
"create",
vars={
"character": character,
"scene": self.scene,
"goal_instructions": goal_instructions,
"npc_name": character.name,
"player_name": self.scene.get_player_character().name,
"max_tokens": self.client.max_token_length,
},
)
log.debug("determine_character_goals", goals=goals, character=character)
await character.set_detail("goals", goals.strip())
return goals.strip()
@set_processing
async def generate_character_from_text(
self,
@@ -229,11 +297,8 @@ class CharacterCreatorMixin:
template: str,
content_context: str = DEFAULT_CONTENT_CONTEXT,
):
base_attributes = await self.create_character_attributes(
character_prompt=text,
template=template,
content_context=content_context,
)

View File

@@ -1,36 +1,35 @@
from talemate.emit import emit, wait_for_input_yesno
import re
import random
import re
from talemate.prompts import Prompt
from talemate.agents.base import set_processing
from talemate.emit import emit, wait_for_input_yesno
from talemate.prompts import Prompt
class ScenarioCreatorMixin:
"""
Adds scenario creation functionality to the creator agent
"""
@set_processing
async def create_scene_description(
self,
prompt:str,
content_context:str,
prompt: str,
content_context: str,
):
"""
Creates a new scene.
Arguments:
prompt (str): The prompt to use to create the scene.
content_context (str): The content context to use for the scene.
callback (callable): A callback to call when the scene has been created.
"""
scene = self.scene
description = await Prompt.request(
"creator.scenario-description",
self.client,
@@ -40,35 +39,32 @@ class ScenarioCreatorMixin:
"content_context": content_context,
"max_tokens": self.client.max_token_length,
"scene": scene,
}
},
)
description = description.strip()
return description
@set_processing
async def create_scene_name(
self,
prompt:str,
content_context:str,
description:str,
prompt: str,
content_context: str,
description: str,
):
"""
Generates a scene name.
Arguments:
prompt (str): The prompt to use to generate the scene name.
content_context (str): The content context to use for the scene.
description (str): The description of the scene.
"""
scene = self.scene
name = await Prompt.request(
"creator.scenario-name",
self.client,
@@ -78,37 +74,35 @@ class ScenarioCreatorMixin:
"content_context": content_context,
"description": description,
"scene": scene,
}
},
)
name = name.strip().strip('.!').replace('"','')
name = name.strip().strip(".!").replace('"', "")
return name
@set_processing
async def create_scene_intro(
self,
prompt:str,
content_context:str,
description:str,
name:str,
prompt: str,
content_context: str,
description: str,
name: str,
):
"""
Generates a scene introduction.
Arguments:
prompt (str): The prompt to use to generate the scene introduction.
content_context (str): The content context to use for the scene.
description (str): The description of the scene.
name (str): The name of the scene.
"""
scene = self.scene
intro = await Prompt.request(
"creator.scenario-intro",
self.client,
@@ -119,17 +113,34 @@ class ScenarioCreatorMixin:
"description": description,
"name": name,
"scene": scene,
}
},
)
intro = intro.strip()
return intro
@set_processing
async def determine_scenario_description(
async def determine_scenario_description(self, text: str):
description = await Prompt.request(
f"creator.determine-scenario-description",
self.client,
"analyze_long",
vars={
"text": text,
},
)
return description.strip()
@set_processing
async def determine_content_context_for_description(
self,
text:str
description: str,
):
description = await Prompt.request(f"creator.determine-scenario-description", self.client, "analyze_long", vars={
"text": text,
})
return description
content_context = await Prompt.request(
f"creator.determine-content-context",
self.client,
"create_short",
vars={
"description": description,
},
)
return content_context.lstrip().split("\n")[0].strip('"').strip()

View File

@@ -0,0 +1,34 @@
import importlib
import os
import structlog
log = structlog.get_logger("talemate.agents.custom")
# import every submodule in this directory
#
# each directory in this directory is a submodule
# get the current directory
current_directory = os.path.dirname(__file__)
# get all subdirectories
subdirectories = [
os.path.join(current_directory, name)
for name in os.listdir(current_directory)
if os.path.isdir(os.path.join(current_directory, name))
]
# import every submodule
for subdirectory in subdirectories:
# get the name of the submodule
submodule_name = os.path.basename(subdirectory)
if submodule_name.startswith("__"):
continue
log.info("activating custom agent", module=submodule_name)
# import the submodule
importlib.import_module(f".{submodule_name}", __package__)

View File

@@ -0,0 +1,5 @@
Each agent should be in its own subdirectory.
The subdirectory itself must be a valid python module.
Check out docs/dev/agents/example/test for a very simplistic custom agent example.

View File

@@ -1,227 +1,361 @@
from __future__ import annotations
import asyncio
import re
import random
import structlog
import re
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import talemate.util as util
from talemate.emit import wait_for_input, emit
import talemate.emit.async_signals
from talemate.prompts import Prompt
from talemate.scene_message import NarratorMessage, DirectorMessage
from talemate.automated_action import AutomatedAction
import structlog
import talemate.automated_action as automated_action
from talemate.agents.conversation import ConversationAgentEmission
from .registry import register
from .base import set_processing, AgentAction, AgentActionConfig, Agent
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
from talemate.agents.conversation import ConversationAgentEmission
from talemate.automated_action import AutomatedAction
from talemate.emit import emit, wait_for_input
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
from talemate.game.engine import GameInstructionsMixin
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, NarratorMessage
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
if TYPE_CHECKING:
from talemate import Actor, Character, Player, Scene
log = structlog.get_logger("talemate.agent.director")
@register()
class DirectorAgent(Agent):
class DirectorAgent(GameInstructionsMixin, Agent):
agent_type = "director"
verbose_name = "Director"
def __init__(self, client, **kwargs):
self.is_enabled = True
self.client = client
self.next_direct_character = {}
self.next_direct_scene = 0
self.actions = {
"direct": AgentAction(enabled=True, label="Direct", description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).", config={
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before directing the sceen", value=5, min=1, max=100, step=1),
"direct_scene": AgentActionConfig(type="bool", label="Direct Scene", description="If enabled, the scene will be directed through narration", value=True),
"direct_actors": AgentActionConfig(type="bool", label="Direct Actors", description="If enabled, direction will be given to actors based on their goals.", value=True),
}),
"direct": AgentAction(
enabled=True,
label="Direct",
description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).",
config={
"turns": AgentActionConfig(
type="number",
label="Turns",
description="Number of turns to wait before directing the sceen",
value=5,
min=1,
max=100,
step=1,
),
"direct_scene": AgentActionConfig(
type="bool",
label="Direct Scene",
description="If enabled, the scene will be directed through narration",
value=True,
),
"direct_actors": AgentActionConfig(
type="bool",
label="Direct Actors",
description="If enabled, direction will be given to actors based on their goals.",
value=True,
),
"actor_direction_mode": AgentActionConfig(
type="text",
label="Actor Direction Mode",
description="The mode to use when directing actors",
value="direction",
choices=[
{
"label": "Direction",
"value": "direction",
},
{
"label": "Inner Monologue",
"value": "internal_monologue",
},
],
),
},
),
}
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
@property
def direct_enabled(self):
return self.actions["direct"].enabled
@property
def direct_actors_enabled(self):
return self.actions["direct"].config["direct_actors"].value
@property
def direct_scene_enabled(self):
return self.actions["direct"].config["direct_scene"].value
@property
def actor_direction_mode(self):
return self.actions["direct"].config["actor_direction_mode"].value
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(self.on_conversation_before_generate)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_player_dialog)
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(
self.on_conversation_before_generate
)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
self.on_player_dialog
)
talemate.emit.async_signals.get("scene_init").connect(self.on_scene_init)
async def on_scene_init(self, event: SceneStateEvent):
"""
If game state instructions specify to be run at the start of the game loop
we will run them here.
"""
if not self.enabled:
if self.scene.game_state.has_scene_instructions:
self.is_enabled = True
log.warning("on_scene_init - enabling director", scene=self.scene)
if await self.scene_has_instructions(self.scene):
self.is_enabled = True
log.warning("on_scene_init - enabling director", scene=self.scene)
else:
return
if not self.scene.game_state.has_scene_instructions:
if not await self.scene_has_instructions(self.scene):
return
if not self.scene.game_state.ops.run_on_start:
return
log.info("on_game_loop_start - running game state instructions")
await self.run_gamestate_instructions()
async def on_conversation_before_generate(self, event:ConversationAgentEmission):
async def on_conversation_before_generate(self, event: ConversationAgentEmission):
log.info("on_conversation_before_generate", director_enabled=self.enabled)
if not self.enabled:
return
await self.direct(event.character)
async def on_player_dialog(self, event:GameLoopActorIterEvent):
async def on_player_dialog(self, event: GameLoopActorIterEvent):
if not self.enabled:
return
if not self.scene.game_state.has_scene_instructions:
if not await self.scene_has_instructions(self.scene):
return
if not event.actor.character.is_player:
return
if event.game_loop.had_passive_narration:
log.debug("director.on_player_dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
log.debug(
"director.on_player_dialog",
skip=True,
had_passive_narration=event.game_loop.had_passive_narration,
)
return
event.game_loop.had_passive_narration = await self.direct(None)
async def direct(self, character: Character) -> bool:
if not self.actions["direct"].enabled:
return False
if character:
if not self.actions["direct"].config["direct_actors"].value:
log.info("direct", skip=True, reason="direct_actors disabled", character=character)
log.info(
"direct",
skip=True,
reason="direct_actors disabled",
character=character,
)
return False
# character direction, see if there are character goals
# character direction, see if there are character goals
# defined
character_goals = character.get_detail("goals")
if not character_goals:
log.info("direct", skip=True, reason="no goals", character=character)
return False
next_direct = self.next_direct_character.get(character.name, 0)
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
log.info("direct", skip=True, next_direct=next_direct, character=character)
if (
next_direct % self.actions["direct"].config["turns"].value != 0
or next_direct == 0
):
log.info(
"direct", skip=True, next_direct=next_direct, character=character
)
self.next_direct_character[character.name] = next_direct + 1
return False
self.next_direct_character[character.name] = 0
await self.direct_scene(character, character_goals)
return True
else:
if not self.actions["direct"].config["direct_scene"].value:
log.info("direct", skip=True, reason="direct_scene disabled")
return False
# no character, see if there are NPC characters at all
# if not we always want to direct narration
always_direct = (not self.scene.npc_character_names)
always_direct = (
not self.scene.npc_character_names
or self.scene.game_state.ops.always_direct
)
next_direct = self.next_direct_scene
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
if (
next_direct % self.actions["direct"].config["turns"].value != 0
or next_direct == 0
):
if not always_direct:
log.info("direct", skip=True, next_direct=next_direct)
self.next_direct_scene += 1
return False
self.next_direct_scene = 0
await self.direct_scene(None, None)
return True
@set_processing
async def run_gamestate_instructions(self):
"""
Run game state instructions, if they exist.
"""
if not self.scene.game_state.has_scene_instructions:
if not await self.scene_has_instructions(self.scene):
return
await self.direct_scene(None, None)
@set_processing
async def direct_scene(self, character: Character, prompt:str):
async def direct_scene(self, character: Character, prompt: str):
if not character and self.scene.game_state.game_won:
# we are not directing a character, and the game has been won
# so we don't need to direct the scene any further
return
if character:
# direct a character
response = await Prompt.request("director.direct-character", self.client, "director", vars={
"max_tokens": self.client.max_token_length,
"scene": self.scene,
"prompt": prompt,
"character": character,
"player_character": self.scene.get_player_character(),
"game_state": self.scene.game_state,
})
response = await Prompt.request(
"director.direct-character",
self.client,
"director",
vars={
"max_tokens": self.client.max_token_length,
"scene": self.scene,
"prompt": prompt,
"character": character,
"player_character": self.scene.get_player_character(),
"game_state": self.scene.game_state,
},
)
if "#" in response:
response = response.split("#")[0]
log.info("direct_character", character=character, prompt=prompt, response=response)
log.info(
"direct_character",
character=character,
prompt=prompt,
response=response,
)
response = response.strip().split("\n")[0].strip()
#response += f" (current story goal: {prompt})"
# response += f" (current story goal: {prompt})"
message = DirectorMessage(response, source=character.name)
emit("director", message, character=character)
self.scene.push_history(message)
else:
# run scene instructions
self.scene.game_state.scene_instructions
await self.run_scene_instructions(self.scene)
@set_processing
async def persist_characters_from_worldstate(
self, exclude: list[str] = None
) -> List[Character]:
log.warning(
"persist_characters_from_worldstate",
world_state_characters=self.scene.world_state.characters,
scene_characters=self.scene.character_names,
)
created_characters = []
for character_name in self.scene.world_state.characters.keys():
if exclude and character_name.lower() in exclude:
continue
if character_name in self.scene.character_names:
continue
character = await self.persist_character(name=character_name)
created_characters.append(character)
self.scene.emit_status()
return created_characters
@set_processing
async def persist_character(
self,
name:str,
content:str = None,
attributes:str = None,
self,
name: str,
content: str = None,
attributes: str = None,
determine_name: bool = True,
):
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
self.scene.log.debug("persist_character", name=name)
if determine_name:
name = await creator.determine_character_name(name)
self.scene.log.debug("persist_character", adjusted_name=name)
character = self.scene.Character(name=name)
character.color = random.choice(['#F08080', '#FFD700', '#90EE90', '#ADD8E6', '#DDA0DD', '#FFB6C1', '#FAFAD2', '#D3D3D3', '#B0E0E6', '#FFDEAD'])
character.color = random.choice(
[
"#F08080",
"#FFD700",
"#90EE90",
"#ADD8E6",
"#DDA0DD",
"#FFB6C1",
"#FAFAD2",
"#D3D3D3",
"#B0E0E6",
"#FFDEAD",
]
)
if not attributes:
attributes = await world_state.extract_character_sheet(name=name, text=content)
attributes = await world_state.extract_character_sheet(
name=name, text=content
)
else:
attributes = world_state._parse_character_sheet(attributes)
self.scene.log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
@@ -232,35 +366,71 @@ class DirectorAgent(Agent):
self.scene.log.debug("persist_character", description=description)
actor = self.scene.Actor(character=character, agent=instance.get_agent("conversation"))
dialogue_instructions = await creator.determine_character_dialogue_instructions(
character
)
character.dialogue_instructions = dialogue_instructions
self.scene.log.debug(
"persist_character", dialogue_instructions=dialogue_instructions
)
actor = self.scene.Actor(
character=character, agent=instance.get_agent("conversation")
)
await self.scene.add_actor(actor)
self.scene.emit_status()
return character
@set_processing
async def update_content_context(self, content:str=None, extra_choices:list[str]=None):
async def update_content_context(
self, content: str = None, extra_choices: list[str] = None
):
if not content:
content = "\n".join(self.scene.context_history(sections=False, min_dialogue=25, budget=2048))
response = await Prompt.request("world_state.determine-content-context", self.client, "analyze_freeform", vars={
"content": content,
"extra_choices": extra_choices or [],
})
content = "\n".join(
self.scene.context_history(sections=False, min_dialogue=25, budget=2048)
)
response = await Prompt.request(
"world_state.determine-content-context",
self.client,
"analyze_freeform",
vars={
"content": content,
"extra_choices": extra_choices or [],
},
)
self.scene.context = response.strip()
self.scene.emit_status()
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
async def log_action(self, action: str, action_description: str):
message = DirectorMessage(message=action_description, action=action)
self.scene.push_history(message)
emit("director", message)
log_action.exposed = True
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
log.debug(
"inject_prompt_paramters",
prompt_param=prompt_param,
kind=kind,
agent_function_name=agent_function_name,
)
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += character_names + ["#"]
if agent_function_name == "update_content_context":
prompt_param["extra_stopping_strings"] += ["\n"]
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
return True
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
return True

View File

@@ -1,30 +1,30 @@
from __future__ import annotations
import asyncio
import re
import time
import traceback
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import talemate.data_objects as data_objects
import talemate.util as util
import talemate.emit.async_signals
import talemate.util as util
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
import structlog
import time
import re
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Character, Scene
from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission
from talemate.tale_mate import Actor, Character, Scene
log = structlog.get_logger("talemate.agents.editor")
@register()
class EditorAgent(Agent):
"""
@@ -35,175 +35,277 @@ class EditorAgent(Agent):
agent_type = "editor"
verbose_name = "Editor"
def __init__(self, client, **kwargs):
self.client = client
self.is_enabled = True
self.actions = {
"edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."),
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.", config={
"narrator": AgentActionConfig(type="bool", label="Fix narrator messages", description="Will attempt to fix exposition issues in narrator messages", value=True),
}),
"add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.")
"fix_exposition": AgentAction(
enabled=True,
label="Fix exposition",
description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.",
config={
"narrator": AgentActionConfig(
type="bool",
label="Fix narrator messages",
description="Will attempt to fix exposition issues in narrator messages",
value=True,
),
},
),
"add_detail": AgentAction(
enabled=False,
label="Add detail",
description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
),
"check_continuity_errors": AgentAction(
enabled=False,
label="Check continuity errors",
description="Will attempt to fix continuity errors in the dialogue. Runs automatically after each AI dialogue. (super experimental)",
),
}
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
talemate.emit.async_signals.get("agent.narrator.generated").connect(self.on_narrator_generated)
async def on_conversation_generated(self, emission:ConversationAgentEmission):
talemate.emit.async_signals.get("agent.conversation.generated").connect(
self.on_conversation_generated
)
talemate.emit.async_signals.get("agent.narrator.generated").connect(
self.on_narrator_generated
)
async def on_conversation_generated(self, emission: ConversationAgentEmission):
"""
Called when a conversation is generated
"""
if not self.enabled:
return
log.info("editing conversation", emission=emission)
edited = []
for text in emission.generation:
edit = await self.add_detail(text, emission.character)
edit = await self.fix_exposition(edit, emission.character)
edit = await self.check_continuity_errors(edit, emission.character)
edit = await self.add_detail(
text,
emission.character
)
edit = await self.edit_conversation(
edit,
emission.character
)
edit = await self.fix_exposition(
edit,
emission.character
)
edited.append(edit)
emission.generation = edited
async def on_narrator_generated(self, emission:NarratorAgentEmission):
async def on_narrator_generated(self, emission: NarratorAgentEmission):
"""
Called when a narrator message is generated
"""
if not self.enabled:
return
log.info("editing narrator", emission=emission)
edited = []
for text in emission.generation:
edit = await self.fix_exposition_on_narrator(text)
edited.append(edit)
emission.generation = edited
@set_processing
async def edit_conversation(self, content:str, character:Character):
"""
Edits a conversation
"""
if not self.actions["edit_dialogue"].enabled:
return content
response = await Prompt.request("editor.edit-dialogue", self.client, "edit_dialogue", vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length
})
response = response.split("[end]")[0]
response = util.replace_exposition_markers(response)
response = util.clean_dialogue(response, main_name=character.name)
response = util.strip_partial_sentences(response)
return response
@set_processing
async def fix_exposition(self, content:str, character:Character):
async def fix_exposition(self, content: str, character: Character):
"""
Edits a text to make sure all narrative exposition and emotes is encased in *
"""
if not self.actions["fix_exposition"].enabled:
return content
if not character.is_player:
if '"' not in content and '*' not in content:
if '"' not in content and "*" not in content:
content = util.strip_partial_sentences(content)
character_prefix = f"{character.name}: "
message = content.split(character_prefix)[1]
content = f"{character_prefix}*{message.strip('*')}*"
content = f'{character_prefix}"{message.strip()}"'
return content
elif '"' in content:
# silly hack to clean up some LLMs that always start with a quote
# even though the immediate next thing is a narration (indicated by *)
content = content.replace(f"{character.name}: \"*", f"{character.name}: *")
content = util.clean_dialogue(content, main_name=character.name)
content = content.replace(
f'{character.name}: "*', f"{character.name}: *"
)
content = util.clean_dialogue(content, main_name=character.name)
content = util.strip_partial_sentences(content)
content = util.ensure_dialog_format(content, talking_character=character.name)
return content
@set_processing
async def fix_exposition_on_narrator(self, content:str):
async def fix_exposition_on_narrator(self, content: str):
if not self.actions["fix_exposition"].enabled:
return content
if not self.actions["fix_exposition"].config["narrator"].value:
return content
content = util.strip_partial_sentences(content)
if '"' not in content:
content = f"*{content.strip('*')}*"
else:
content = util.ensure_dialog_format(content)
return content
@set_processing
async def add_detail(self, content:str, character:Character):
async def add_detail(self, content: str, character: Character):
"""
Edits a text to increase its length and add extra detail and exposition
"""
if not self.actions["add_detail"].enabled:
return content
response = await Prompt.request("editor.add-detail", self.client, "edit_add_detail", vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length
})
response = await Prompt.request(
"editor.add-detail",
self.client,
"edit_add_detail",
vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length,
},
)
response = util.replace_exposition_markers(response)
response = util.clean_dialogue(response, main_name=character.name)
response = util.clean_dialogue(response, main_name=character.name)
response = util.strip_partial_sentences(response)
return response
return response
@set_processing
async def check_continuity_errors(
self,
content: str,
character: Character,
force: bool = False,
fix: bool = True,
message_id: int = None,
) -> str:
"""
Edits a text to ensure that it is consistent with the scene
so far
"""
if not self.actions["check_continuity_errors"].enabled and not force:
return content
MAX_CONTENT_LENGTH = 255
count = util.count_tokens(content)
if count > MAX_CONTENT_LENGTH:
log.warning(
"check_continuity_errors content too long",
length=count,
max=MAX_CONTENT_LENGTH,
content=content[:255],
)
return content
log.debug(
"check_continuity_errors START",
content=content,
character=character,
force=force,
fix=fix,
message_id=message_id,
)
response = await Prompt.request(
"editor.check-continuity-errors",
self.client,
"basic_analytical_medium2",
vars={
"content": content,
"character": character,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"message_id": message_id,
},
)
# loop through response line by line, checking for lines beginning
# with "ERROR {number}:
errors = []
for line in response.split("\n"):
if "ERROR" not in line:
continue
errors.append(line)
if not errors:
log.debug("check_continuity_errors NO ERRORS")
return content
log.debug("check_continuity_errors ERRORS", fix=fix, errors=errors)
if not fix:
return content
state = {}
response = await Prompt.request(
"editor.fix-continuity-errors",
self.client,
"editor_creative_medium2",
vars={
"content": content,
"character": character,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"errors": errors,
"set_state": lambda k, v: state.update({k: v}),
},
)
content_fix_identifer = state.get("content_fix_identifier")
try:
content = response.strip().strip("```").split("```")[0].strip()
content = content.replace(content_fix_identifer, "").strip()
content = content.strip(":")
# if content doesnt start with {character_name}: then add it
if not content.startswith(f"{character.name}:"):
content = f"{character.name}: {content}"
except Exception as e:
log.error(
"check_continuity_errors FAILED",
content_fix_identifer=content_fix_identifer,
response=response,
e=e,
)
return content
log.debug("check_continuity_errors FIXED", content=content)
return content

View File

@@ -1,19 +1,21 @@
from __future__ import annotations
import asyncio
import functools
import os
import shutil
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
from chromadb.config import Settings
import talemate.events as events
import talemate.util as util
from talemate.agents.base import set_processing
from talemate.config import load_config
from talemate.context import scene_is_loading
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.context import scene_is_loading
from talemate.config import load_config
from talemate.agents.base import set_processing
import structlog
import shutil
import functools
try:
import chromadb
@@ -28,19 +30,20 @@ if not chromadb:
log.info("ChromaDB not found, disabling Chroma agent")
from .base import Agent
from .base import Agent, AgentDetail
class MemoryDocument(str):
def __new__(cls, text, meta, id, raw):
inst = super().__new__(cls, text)
inst.meta = meta
inst.id = id
inst.raw = raw
return inst
class MemoryAgent(Agent):
"""
An agent that can be used to maintain and access a memory of the world
@@ -52,10 +55,11 @@ class MemoryAgent(Agent):
@property
def readonly(self):
if scene_is_loading.get() and not getattr(self.scene, "_memory_never_persisted", False):
if scene_is_loading.get() and not getattr(
self.scene, "_memory_never_persisted", False
):
return True
return False
@property
@@ -72,9 +76,9 @@ class MemoryAgent(Agent):
self.memory_tracker = {}
self.config = load_config()
self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved)
def on_config_saved(self, event):
openai_key = self.openai_api_key
self.config = load_config()
@@ -92,35 +96,68 @@ class MemoryAgent(Agent):
raise NotImplementedError()
@set_processing
async def add(self, text, character=None, uid=None, ts:str=None, **kwargs):
async def add(self, text, character=None, uid=None, ts: str = None, **kwargs):
if not text:
return
if self.readonly:
log.debug("memory agent", status="readonly")
return
while not self._ready_to_add:
await asyncio.sleep(0.1)
log.debug("memory agent add", text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
log.debug(
"memory agent add",
text=text[:50],
character=character,
uid=uid,
ts=ts,
**kwargs,
)
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
await loop.run_in_executor(
None,
functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs),
)
except AttributeError as e:
# not sure how this sometimes happens.
# chromadb model None
# race condition because we are forcing async context onto it?
log.error("memory agent", error="failed to add memory", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
log.error(
"memory agent",
error="failed to add memory",
details=e,
text=text[:50],
character=character,
uid=uid,
ts=ts,
**kwargs,
)
await asyncio.sleep(1.0)
try:
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
await loop.run_in_executor(
None,
functools.partial(
self._add, text, character, uid=uid, ts=ts, **kwargs
),
)
except Exception as e:
log.error("memory agent", error="failed to add memory (retried)", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
log.error(
"memory agent",
error="failed to add memory (retried)",
details=e,
text=text[:50],
character=character,
uid=uid,
ts=ts,
**kwargs,
)
def _add(self, text, character=None, ts:str=None, **kwargs):
def _add(self, text, character=None, ts: str = None, **kwargs):
raise NotImplementedError()
@set_processing
@@ -131,44 +168,46 @@ class MemoryAgent(Agent):
while not self._ready_to_add:
await asyncio.sleep(0.1)
log.debug("memory agent add many", len=len(objects))
log.debug("memory agent add many", len=len(objects))
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._add_many, objects)
def _add_many(self, objects: list[dict]):
"""
Add multiple objects to the memory
"""
raise NotImplementedError()
def _delete(self, meta:dict):
def _delete(self, meta: dict):
"""
Delete an object from the memory
"""
raise NotImplementedError()
@set_processing
async def delete(self, meta:dict):
async def delete(self, meta: dict):
"""
Delete an object from the memory
"""
if self.readonly:
log.debug("memory agent", status="readonly")
return
while not self._ready_to_add:
await asyncio.sleep(0.1)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._delete, meta)
@set_processing
async def get(self, text, character=None, **query):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, functools.partial(self._get, text, character, **query))
return await loop.run_in_executor(
None, functools.partial(self._get, text, character, **query)
)
def _get(self, text, character=None, **query):
raise NotImplementedError()
@@ -177,12 +216,14 @@ class MemoryAgent(Agent):
async def get_document(self, id):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._get_document, id)
def _get_document(self, id):
raise NotImplementedError()
def on_archive_add(self, event: events.ArchiveEvent):
asyncio.ensure_future(self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history"))
asyncio.ensure_future(
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
)
def on_character_state(self, event: events.CharacterStateEvent):
asyncio.ensure_future(
@@ -222,10 +263,10 @@ class MemoryAgent(Agent):
"""
memory_context = []
if not query:
return memory_context
for memory in await self.get(query):
if memory in memory_context:
continue
@@ -239,17 +280,26 @@ class MemoryAgent(Agent):
break
return memory_context
async def query(self, query:str, max_tokens:int=1000, filter:Callable=lambda x:True, **where):
async def query(
self,
query: str,
max_tokens: int = 1000,
filter: Callable = lambda x: True,
**where,
):
"""
Get the character memory context for a given character
"""
try:
return (await self.multi_query([query], max_tokens=max_tokens, filter=filter, **where))[0]
return (
await self.multi_query(
[query], max_tokens=max_tokens, filter=filter, **where
)
)[0]
except IndexError:
return None
async def multi_query(
self,
queries: list[str],
@@ -258,7 +308,7 @@ class MemoryAgent(Agent):
filter: Callable = lambda x: True,
formatter: Callable = lambda x: x,
limit: int = 10,
**where
**where,
):
"""
Get the character memory context for a given character
@@ -266,10 +316,9 @@ class MemoryAgent(Agent):
memory_context = []
for query in queries:
if not query:
continue
i = 0
for memory in await self.get(formatter(query), limit=limit, **where):
if memory in memory_context:
@@ -296,15 +345,13 @@ from .registry import register
@register(condition=lambda: chromadb is not None)
class ChromaDBMemoryAgent(MemoryAgent):
requires_llm_client = False
@property
def ready(self):
if self.embeddings == "openai" and not self.openai_api_key:
return False
if getattr(self, "db_client", None):
return True
return False
@@ -313,80 +360,110 @@ class ChromaDBMemoryAgent(MemoryAgent):
def status(self):
if self.ready:
return "active" if not getattr(self, "processing", False) else "busy"
if self.embeddings == "openai" and not self.openai_api_key:
return "error"
return "waiting"
@property
def agent_details(self):
details = {
"backend": AgentDetail(
icon="mdi-server-outline",
value="ChromaDB",
description="The backend to use for long-term memory",
).model_dump(),
"embeddings": AgentDetail(
icon="mdi-cube-unfolded",
value=self.embeddings,
description="The embeddings model.",
).model_dump(),
}
if self.embeddings == "openai" and not self.openai_api_key:
return "No OpenAI API key set"
return f"ChromaDB: {self.embeddings}"
# return "No OpenAI API key set"
details["error"] = {
"icon": "mdi-alert",
"value": "No OpenAI API key set",
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
"color": "error",
}
return details
@property
def embeddings(self):
"""
Returns which embeddings to use
will read from TM_CHROMADB_EMBEDDINGS env variable and default to 'default' using
the default embeddings specified by chromadb.
other values are
- openai: use openai embeddings
- instructor: use instructor embeddings
for `openai`:
you will also need to provide an `OPENAI_API_KEY` env variable
for `instructor`:
you will also need to provide which instructor model to use with the `TM_INSTRUCTOR_MODEL` env variable, which defaults to hkunlp/instructor-xl
additionally you can provide the `TM_INSTRUCTOR_DEVICE` env variable to specify which device to use, which defaults to cpu
"""
embeddings = self.config.get("chromadb").get("embeddings")
assert embeddings in ["default", "openai", "instructor"], f"Unknown embeddings {embeddings}"
assert embeddings in [
"default",
"openai",
"instructor",
], f"Unknown embeddings {embeddings}"
return embeddings
@property
def USE_OPENAI(self):
return self.embeddings == "openai"
@property
def USE_INSTRUCTOR(self):
return self.embeddings == "instructor"
@property
def db_name(self):
return getattr(self, "collection_name", "<unnamed>")
@property
def openai_api_key(self):
return self.config.get("openai",{}).get("api_key")
return self.config.get("openai", {}).get("api_key")
def make_collection_name(self, scene):
if self.USE_OPENAI:
suffix = "-openai"
model_name = self.config.get("chromadb").get(
"openai_model", "text-embedding-3-small"
)
if model_name == "text-embedding-ada-002":
suffix = "-openai"
else:
suffix = f"-openai-{model_name}"
elif self.USE_INSTRUCTOR:
suffix = "-instructor"
model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
model = self.config.get("chromadb").get(
"instructor_model", "hkunlp/instructor-xl"
)
if "xl" in model:
suffix += "-xl"
elif "large" in model:
suffix += "-large"
else:
suffix = ""
return f"{scene.memory_id}-tm{suffix}"
async def count(self):
@@ -399,9 +476,8 @@ class ChromaDBMemoryAgent(MemoryAgent):
await loop.run_in_executor(None, self._set_db)
def _set_db(self):
self._ready_to_add = False
if not getattr(self, "db_client", None):
log.info("chromadb agent", status="setting up db client to persistent db")
self.db_client = chromadb.PersistentClient(
@@ -409,49 +485,67 @@ class ChromaDBMemoryAgent(MemoryAgent):
)
openai_key = self.openai_api_key
self.collection_name = collection_name = self.make_collection_name(self.scene)
log.info("chromadb agent", status="setting up db", collection_name=collection_name)
log.info(
"chromadb agent", status="setting up db", collection_name=collection_name
)
if self.USE_OPENAI:
if not openai_key:
raise ValueError("You must provide an the openai ai key in the config if you want to use it for chromadb embeddings")
raise ValueError(
"You must provide an the openai ai key in the config if you want to use it for chromadb embeddings"
)
model_name = self.config.get("chromadb").get(
"openai_model", "text-embedding-3-small"
)
log.info(
"crhomadb", status="using openai", openai_key=openai_key[:5] + "..."
"crhomadb",
status="using openai",
openai_key=openai_key[:5] + "...",
model=model_name,
)
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key = openai_key,
model_name="text-embedding-ada-002",
api_key=openai_key,
model_name=model_name,
)
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=openai_ef
)
elif self.USE_INSTRUCTOR:
instructor_device = self.config.get("chromadb").get("instructor_device", "cpu")
instructor_model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
log.info("chromadb", status="using instructor", model=instructor_model, device=instructor_device)
instructor_device = self.config.get("chromadb").get(
"instructor_device", "cpu"
)
instructor_model = self.config.get("chromadb").get(
"instructor_model", "hkunlp/instructor-xl"
)
log.info(
"chromadb",
status="using instructor",
model=instructor_model,
device=instructor_device,
)
# ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2")
ef = embedding_functions.InstructorEmbeddingFunction(
model_name=instructor_model, device=instructor_device
)
log.info("chromadb", status="embedding function ready")
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef
)
log.info("chromadb", status="instructor db ready")
else:
log.info("chromadb", status="using default embeddings")
self.db = self.db_client.get_or_create_collection(collection_name)
self.scene._memory_never_persisted = self.db.count() == 0
log.info("chromadb agent", status="db ready")
self._ready_to_add = True
@@ -459,17 +553,21 @@ class ChromaDBMemoryAgent(MemoryAgent):
def clear_db(self):
if not self.db:
return
log.info("chromadb agent", status="clearing db", collection_name=self.collection_name)
log.info(
"chromadb agent", status="clearing db", collection_name=self.collection_name
)
self.db.delete(where={"source": "talemate"})
def drop_db(self):
if not self.db:
return
log.info("chromadb agent", status="dropping db", collection_name=self.collection_name)
log.info(
"chromadb agent", status="dropping db", collection_name=self.collection_name
)
try:
self.db_client.delete_collection(self.collection_name)
except ValueError as exc:
@@ -479,31 +577,43 @@ class ChromaDBMemoryAgent(MemoryAgent):
def close_db(self, scene):
if not self.db:
return
log.info("chromadb agent", status="closing db", collection_name=self.collection_name)
log.info(
"chromadb agent", status="closing db", collection_name=self.collection_name
)
if not scene.saved and not scene.saved_memory_session_id:
# scene was never saved so we can discard the memory
collection_name = self.make_collection_name(scene)
log.info("chromadb agent", status="discarding memory", collection_name=collection_name)
log.info(
"chromadb agent",
status="discarding memory",
collection_name=collection_name,
)
try:
self.db_client.delete_collection(collection_name)
except ValueError as exc:
log.error("chromadb agent", error="failed to delete collection", details=exc)
log.error(
"chromadb agent", error="failed to delete collection", details=exc
)
elif not scene.saved:
# scene was saved but memory was never persisted
# so we need to remove the memory from the db
self._remove_unsaved_memory()
self.db = None
def _add(self, text, character=None, uid=None, ts:str=None, **kwargs):
def _add(self, text, character=None, uid=None, ts: str = None, **kwargs):
metadatas = []
ids = []
scene = self.scene
if character:
meta = {"character": character.name, "source": "talemate", "session": scene.memory_session_id}
meta = {
"character": character.name,
"source": "talemate",
"session": scene.memory_session_id,
}
if ts:
meta["ts"] = ts
meta.update(kwargs)
@@ -513,7 +623,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
id = uid or f"{character.name}-{self.memory_tracker[character.name]}"
ids = [id]
else:
meta = {"character": "__narrator__", "source": "talemate", "session": scene.memory_session_id}
meta = {
"character": "__narrator__",
"source": "talemate",
"session": scene.memory_session_id,
}
if ts:
meta["ts"] = ts
meta.update(kwargs)
@@ -523,17 +637,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}"
ids = [id]
#log.debug("chromadb agent add", text=text, meta=meta, id=id)
# log.debug("chromadb agent add", text=text, meta=meta, id=id)
self.db.upsert(documents=[text], metadatas=metadatas, ids=ids)
def _add_many(self, objects: list[dict]):
documents = []
metadatas = []
ids = []
scene = self.scene
if not objects:
return
@@ -552,52 +665,50 @@ class ChromaDBMemoryAgent(MemoryAgent):
ids.append(uid)
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
def _delete(self, meta:dict):
def _delete(self, meta: dict):
if "ids" in meta:
log.debug("chromadb agent delete", ids=meta["ids"])
self.db.delete(ids=meta["ids"])
return
where = {"$and": [{k:v} for k,v in meta.items()]}
where = {"$and": [{k: v} for k, v in meta.items()]}
self.db.delete(where=where)
log.debug("chromadb agent delete", meta=meta, where=where)
def _get(self, text, character=None, limit:int=15, **kwargs):
def _get(self, text, character=None, limit: int = 15, **kwargs):
where = {}
# this doesn't work because chromadb currently doesn't match
# non existing fields with $ne (or so it seems)
# where.setdefault("$and", [{"pin_only": {"$ne": True}}])
where.setdefault("$and", [])
character_filtered = False
for k,v in kwargs.items():
for k, v in kwargs.items():
if k == "character":
character_filtered = True
where["$and"].append({k: v})
if character and not character_filtered:
where["$and"].append({"character": character.name})
if len(where["$and"]) == 1:
where = where["$and"][0]
elif not where["$and"]:
where = None
#log.debug("crhomadb agent get", text=text, where=where)
# log.debug("crhomadb agent get", text=text, where=where)
_results = self.db.query(query_texts=[text], where=where, n_results=limit)
#import json
#print(json.dumps(_results["ids"], indent=2))
#print(json.dumps(_results["distances"], indent=2))
# import json
# print(json.dumps(_results["ids"], indent=2))
# print(json.dumps(_results["distances"], indent=2))
results = []
max_distance = 1.5
if self.USE_INSTRUCTOR:
max_distance = 1
@@ -606,24 +717,29 @@ class ChromaDBMemoryAgent(MemoryAgent):
for i in range(len(_results["distances"][0])):
distance = _results["distances"][0][i]
doc = _results["documents"][0][i]
meta = _results["metadatas"][0][i]
if not meta:
log.warning("chromadb agent get", error="no meta", doc=doc)
continue
ts = meta.get("ts")
# skip pin_only entries
if meta.get("pin_only", False):
continue
if distance < max_distance:
date_prefix = self.convert_ts_to_date_prefix(ts)
raw = doc
if date_prefix:
doc = f"{date_prefix}: {doc}"
doc = MemoryDocument(doc, meta, _results["ids"][0][i], raw)
results.append(doc)
else:
break
@@ -635,45 +751,55 @@ class ChromaDBMemoryAgent(MemoryAgent):
return results
def convert_ts_to_date_prefix(self, ts):
if not ts:
return None
try:
return util.iso8601_diff_to_human(ts, self.scene.ts)
except Exception as e:
log.error("chromadb agent", error="failed to get date prefix", details=e, ts=ts, scene_ts=self.scene.ts)
log.error(
"chromadb agent",
error="failed to get date prefix",
details=e,
ts=ts,
scene_ts=self.scene.ts,
)
return None
def _get_document(self, id) -> dict:
result = self.db.get(ids=[id] if isinstance(id, str) else id)
documents = {}
for idx, doc in enumerate(result["documents"]):
date_prefix = self.convert_ts_to_date_prefix(result["metadatas"][idx].get("ts"))
date_prefix = self.convert_ts_to_date_prefix(
result["metadatas"][idx].get("ts")
)
if date_prefix:
doc = f"{date_prefix}: {doc}"
documents[result["ids"][idx]] = MemoryDocument(doc, result["metadatas"][idx], result["ids"][idx], doc)
documents[result["ids"][idx]] = MemoryDocument(
doc, result["metadatas"][idx], result["ids"][idx], doc
)
return documents
@set_processing
async def remove_unsaved_memory(self):
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._remove_unsaved_memory)
def _remove_unsaved_memory(self):
scene = self.scene
if not scene.memory_session_id:
return
if scene.saved_memory_session_id == self.scene.memory_session_id:
return
log.info("chromadb agent", status="removing unsaved memory", session_id=scene.memory_session_id)
log.info(
"chromadb agent",
status="removing unsaved memory",
session_id=scene.memory_session_id,
)
self._delete({"session": scene.memory_session_id, "source": "talemate"})

View File

@@ -1,43 +1,48 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import dataclasses
import structlog
import random
import talemate.util as util
from talemate.emit import emit
import talemate.emit.async_signals
from talemate.prompts import Prompt
from talemate.agents.base import set_processing as _set_processing, Agent, AgentAction, AgentActionConfig, AgentEmission
from talemate.agents.world_state import TimePassageEmission
from talemate.scene_message import NarratorMessage
from talemate.events import GameLoopActorIterEvent
from functools import wraps
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import talemate.client as client
import talemate.emit.async_signals
import talemate.util as util
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission
from talemate.agents.base import set_processing as _set_processing
from talemate.agents.world_state import TimePassageEmission
from talemate.emit import emit
from talemate.events import GameLoopActorIterEvent
from talemate.prompts import Prompt
from talemate.scene_message import NarratorMessage
from .registry import register
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Player, Character
from talemate.tale_mate import Actor, Character, Player
log = structlog.get_logger("talemate.agents.narrator")
@dataclasses.dataclass
class NarratorAgentEmission(AgentEmission):
generation: list[str] = dataclasses.field(default_factory=list)
talemate.emit.async_signals.register(
"agent.narrator.generated"
)
talemate.emit.async_signals.register("agent.narrator.generated")
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
is running and then emits the result of the function as a NarratorAgentEmission
"""
@_set_processing
async def wrapper(self, *args, **kwargs):
@wraps(fn)
async def narration_wrapper(self, *args, **kwargs):
response = await fn(self, *args, **kwargs)
emission = NarratorAgentEmission(
agent=self,
@@ -45,68 +50,68 @@ def set_processing(fn):
)
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
return emission.generation[0]
wrapper.__name__ = fn.__name__
return wrapper
return narration_wrapper
@register()
class NarratorAgent(Agent):
"""
Handles narration of the story
"""
agent_type = "narrator"
verbose_name = "Narrator"
def __init__(
self,
client: client.TaleMateClient,
**kwargs,
):
self.client = client
# agent actions
self.actions = {
"generation_override": AgentAction(
enabled = True,
label = "Generation Override",
description = "Override generation parameters",
config = {
enabled=True,
label="Generation Override",
description="Override generation parameters",
config={
"instructions": AgentActionConfig(
type="text",
label="Instructions",
value="Never wax poetic.",
description="Extra instructions to give to the AI for narrative generation.",
),
}
},
),
"auto_break_repetition": AgentAction(
enabled = True,
label = "Auto Break Repetition",
description = "Will attempt to automatically break AI repetition.",
enabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"narrate_time_passage": AgentAction(
enabled=True,
label="Narrate Time Passage",
enabled=True,
label="Narrate Time Passage",
description="Whenever you indicate passage of time, narrate right after",
config = {
config={
"ask_for_prompt": AgentActionConfig(
type="bool",
label="Guide time narration via prompt",
label="Guide time narration via prompt",
description="Ask the user for a prompt to generate the time passage narration",
value=True,
)
}
},
),
"narrate_dialogue": AgentAction(
enabled=False,
label="Narrate after Dialogue",
enabled=False,
label="Narrate after Dialogue",
description="Narrator will get a chance to narrate after every line of dialogue",
config = {
config={
"ai_dialog": AgentActionConfig(
type="number",
label="AI Dialogue",
label="AI Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.0,
min=0.0,
@@ -115,7 +120,7 @@ class NarratorAgent(Agent):
),
"player_dialog": AgentActionConfig(
type="number",
label="Player Dialogue",
label="Player Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.1,
min=0.0,
@@ -124,34 +129,32 @@ class NarratorAgent(Agent):
),
"generate_dialogue": AgentActionConfig(
type="bool",
label="Allow Dialogue in Narration",
label="Allow Dialogue in Narration",
description="Allow the narrator to generate dialogue in narration",
value=False,
),
}
},
),
}
@property
def extra_instructions(self):
if self.actions["generation_override"].enabled:
return self.actions["generation_override"].config["instructions"].value
return ""
def clean_result(self, result):
"""
Cleans the result of a narration
"""
result = result.strip().strip(":").strip()
if "#" in result:
result = result.split("#")[0]
character_names = [c.name for c in self.scene.get_characters()]
cleaned = []
for line in result.split("\n"):
for character_name in character_names:
@@ -160,71 +163,83 @@ class NarratorAgent(Agent):
cleaned.append(line)
result = "\n".join(cleaned)
#result = util.strip_partial_sentences(result)
# result = util.strip_partial_sentences(result)
return result
def connect(self, scene):
"""
Connect to signals
"""
super().connect(scene)
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
talemate.emit.async_signals.get("agent.world_state.time").connect(
self.on_time_passage
)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
async def on_time_passage(self, event:TimePassageEmission):
async def on_time_passage(self, event: TimePassageEmission):
"""
Handles time passage narration, if enabled
"""
if not self.actions["narrate_time_passage"].enabled:
return
response = await self.narrate_time_passage(event.duration, event.human_duration, event.narrative)
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
response = await self.narrate_time_passage(
event.duration, event.human_duration, event.narrative
)
narrator_message = NarratorMessage(
response, source=f"narrate_time_passage:{event.duration};{event.narrative}"
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
async def on_dialog(self, event:GameLoopActorIterEvent):
async def on_dialog(self, event: GameLoopActorIterEvent):
"""
Handles dialogue narration, if enabled
"""
if not self.actions["narrate_dialogue"].enabled:
return
if event.game_loop.had_passive_narration:
log.debug("narrate on dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
log.debug(
"narrate on dialog",
skip=True,
had_passive_narration=event.game_loop.had_passive_narration,
)
return
narrate_on_ai_chance = self.actions["narrate_dialogue"].config["ai_dialog"].value
narrate_on_player_chance = self.actions["narrate_dialogue"].config["player_dialog"].value
narrate_on_ai_chance = (
self.actions["narrate_dialogue"].config["ai_dialog"].value
)
narrate_on_player_chance = (
self.actions["narrate_dialogue"].config["player_dialog"].value
)
narrate_on_ai = random.random() < narrate_on_ai_chance
narrate_on_player = random.random() < narrate_on_player_chance
log.debug(
"narrate on dialog",
narrate_on_ai=narrate_on_ai,
narrate_on_ai_chance=narrate_on_ai_chance,
"narrate on dialog",
narrate_on_ai=narrate_on_ai,
narrate_on_ai_chance=narrate_on_ai_chance,
narrate_on_player=narrate_on_player,
narrate_on_player_chance=narrate_on_player_chance,
)
if event.actor.character.is_player and not narrate_on_player:
return
if not event.actor.character.is_player and not narrate_on_ai:
return
response = await self.narrate_after_dialogue(event.actor.character)
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
narrator_message = NarratorMessage(
response, source=f"narrate_dialogue:{event.actor.character.name}"
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
event.game_loop.had_passive_narration = True
@set_processing
@@ -237,22 +252,22 @@ class NarratorAgent(Agent):
"narrator.narrate-scene",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
}
},
)
response = response.strip("*")
response = util.strip_partial_sentences(response)
response = f"*{response.strip('*')}*"
return response
@set_processing
async def progress_story(self, narrative_direction:str=None):
async def progress_story(self, narrative_direction: str = None):
"""
Narrate the scene
"""
@@ -260,18 +275,20 @@ class NarratorAgent(Agent):
scene = self.scene
pc = scene.get_player_character()
npcs = list(scene.get_npc_characters())
npc_names= ", ".join([npc.name for npc in npcs])
npc_names = ", ".join([npc.name for npc in npcs])
if narrative_direction is None:
narrative_direction = "Slightly move the current scene forward."
self.scene.log.info("narrative_direction", narrative_direction=narrative_direction)
self.scene.log.info(
"narrative_direction", narrative_direction=narrative_direction
)
response = await Prompt.request(
"narrator.narrate-progress",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"narrative_direction": narrative_direction,
@@ -279,7 +296,7 @@ class NarratorAgent(Agent):
"npcs": npcs,
"npc_names": npc_names,
"extra_instructions": self.extra_instructions,
}
},
)
self.scene.log.info("progress_story", response=response)
@@ -291,11 +308,13 @@ class NarratorAgent(Agent):
if response.count("*") % 2 != 0:
response = response.replace("*", "")
response = f"*{response}*"
return response
@set_processing
async def narrate_query(self, query:str, at_the_end:bool=False, as_narrative:bool=True):
async def narrate_query(
self, query: str, at_the_end: bool = False, as_narrative: bool = True
):
"""
Narrate a specific query
"""
@@ -303,21 +322,21 @@ class NarratorAgent(Agent):
"narrator.narrate-query",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"query": query,
"at_the_end": at_the_end,
"as_narrative": as_narrative,
"extra_instructions": self.extra_instructions,
}
},
)
log.info("narrate_query", response=response)
response = self.clean_result(response.strip())
log.info("narrate_query (after clean)", response=response)
if as_narrative:
response = f"*{response}*"
return response
@set_processing
@@ -330,12 +349,12 @@ class NarratorAgent(Agent):
"narrator.narrate-character",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"character": character,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
}
},
)
response = self.clean_result(response.strip())
@@ -345,54 +364,55 @@ class NarratorAgent(Agent):
@set_processing
async def augment_context(self):
"""
Takes a context history generated via scene.context_history() and augments it with additional information
by asking and answering questions with help from the long term memory.
"""
memory = self.scene.get_helper("memory").agent
questions = await Prompt.request(
"narrator.context-questions",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
}
},
)
self.scene.log.info("context_questions", questions=questions)
questions = [q for q in questions.split("\n") if q.strip()]
memory_context = await memory.multi_query(
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
)
answers = await Prompt.request(
"narrator.context-answers",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"memory": memory_context,
"questions": questions,
"extra_instructions": self.extra_instructions,
}
},
)
self.scene.log.info("context_answers", answers=answers)
answers = [a for a in answers.split("\n") if a.strip()]
# return questions and answers
return list(zip(questions, answers))
@set_processing
async def narrate_time_passage(self, duration:str, time_passed:str, narrative:str):
async def narrate_time_passage(
self, duration: str, time_passed: str, narrative: str
):
"""
Narrate a specific character
"""
@@ -401,26 +421,25 @@ class NarratorAgent(Agent):
"narrator.narrate-time-passage",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"duration": duration,
"time_passed": time_passed,
"narrative": narrative,
"extra_instructions": self.extra_instructions,
}
},
)
log.info("narrate_time_passage", response=response)
response = self.clean_result(response.strip())
response = f"*{response}*"
return response
@set_processing
async def narrate_after_dialogue(self, character:Character):
async def narrate_after_dialogue(self, character: Character):
"""
Narrate after a line of dialogue
"""
@@ -429,22 +448,24 @@ class NarratorAgent(Agent):
"narrator.narrate-after-dialogue",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"last_line": str(self.scene.history[-1]),
"extra_instructions": self.extra_instructions,
}
},
)
log.info("narrate_after_dialogue", response=response)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
allow_dialogue = self.actions["narrate_dialogue"].config["generate_dialogue"].value
allow_dialogue = (
self.actions["narrate_dialogue"].config["generate_dialogue"].value
)
if not allow_dialogue:
response = response.split('"')[0].strip()
response = response.replace("*", "")
@@ -452,9 +473,11 @@ class NarratorAgent(Agent):
response = f"*{response}*"
return response
@set_processing
async def narrate_character_entry(self, character:Character, direction:str=None):
async def narrate_character_entry(
self, character: Character, direction: str = None
):
"""
Narrate a character entering the scene
"""
@@ -463,22 +486,22 @@ class NarratorAgent(Agent):
"narrator.narrate-character-entry",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"direction": direction,
"extra_instructions": self.extra_instructions,
}
},
)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response
@set_processing
async def narrate_character_exit(self, character:Character, direction:str=None):
async def narrate_character_exit(self, character: Character, direction: str = None):
"""
Narrate a character exiting the scene
"""
@@ -487,47 +510,136 @@ class NarratorAgent(Agent):
"narrator.narrate-character-exit",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"direction": direction,
"extra_instructions": self.extra_instructions,
}
},
)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response
@set_processing
async def paraphrase(self, narration: str):
"""
Paraphrase a narration
"""
response = await Prompt.request(
"narrator.paraphrase",
self.client,
"narrate",
vars={
"text": narration,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
},
)
log.info("paraphrase", narration=narration, response=response)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response
async def passthrough(self, narration: str) -> str:
"""
Pass through narration message as is
"""
narration = narration.replace("*", "")
narration = f"*{narration}*"
narration = util.ensure_dialog_format(narration)
return narration
def action_to_source(
self,
action_name: str,
parameters: dict,
) -> str:
"""
Generate a source string for a given action and parameters
The source string is used to identify the source of a NarratorMessage
and will also help regenerate the action and parameters from the source string
later on
"""
args = []
if action_name == "paraphrase":
args.append(parameters.get("narration"))
elif action_name == "narrate_character_entry":
args.append(parameters.get("character").name)
# args.append(parameters.get("direction"))
elif action_name == "narrate_character_exit":
args.append(parameters.get("character").name)
# args.append(parameters.get("direction"))
elif action_name == "narrate_character":
args.append(parameters.get("character").name)
elif action_name == "narrate_query":
args.append(parameters.get("query"))
elif action_name == "narrate_time_passage":
args.append(parameters.get("duration"))
args.append(parameters.get("time_passed"))
args.append(parameters.get("narrative"))
elif action_name == "progress_story":
args.append(parameters.get("narrative_direction"))
elif action_name == "narrate_after_dialogue":
args.append(parameters.get("character"))
arg_str = ";".join(args) if args else ""
return f"{action_name}:{arg_str}".rstrip(":")
async def action_to_narration(
self,
action_name: str,
*args,
emit_message: bool = False,
**kwargs,
):
# calls self[action_name] and returns the result as a NarratorMessage
# that is pushed to the history
fn = getattr(self, action_name)
narration = await fn(*args, **kwargs)
narrator_message = NarratorMessage(narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":"))
narration = await fn(**kwargs)
source = self.action_to_source(action_name, kwargs)
narrator_message = NarratorMessage(narration, source=source)
self.scene.push_history(narrator_message)
if emit_message:
emit("narrator", narrator_message)
return narrator_message
action_to_narration.exposed = True
# LLM client related methods. These are called during or after the client
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
log.debug(
"inject_prompt_paramters",
prompt_param=prompt_param,
kind=kind,
agent_function_name=agent_function_name,
)
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += character_names
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return True

View File

@@ -1,26 +1,26 @@
from __future__ import annotations
import asyncio
import re
import time
import traceback
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import talemate.data_objects as data_objects
import talemate.emit.async_signals
import talemate.util as util
from talemate.events import GameLoopEvent
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage
from talemate.events import GameLoopEvent
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
import structlog
import time
import re
log = structlog.get_logger("talemate.agents.summarize")
@register()
class SummarizeAgent(Agent):
"""
@@ -36,7 +36,7 @@ class SummarizeAgent(Agent):
def __init__(self, client, **kwargs):
self.client = client
self.actions = {
"archive": AgentAction(
enabled=True,
@@ -61,36 +61,43 @@ class SummarizeAgent(Agent):
{"label": "Short & Concise", "value": "short"},
{"label": "Balanced", "value": "balanced"},
{"label": "Lengthy & Detailed", "value": "long"},
{"label": "Factual List", "value": "facts"},
],
),
"include_previous": AgentActionConfig(
type="number",
label="Use preceeding summaries to strengthen context",
description="Number of entries",
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
value=3,
min=0,
max=10,
step=1,
),
}
},
)
}
@property
def threshold(self):
return self.actions["archive"].config["threshold"].value
@property
def estimated_entry_count(self):
all_tokens = sum([util.count_tokens(entry) for entry in self.scene.history])
return all_tokens // self.threshold
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, emission:GameLoopEvent):
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called when a conversation is generated
"""
await self.build_archive(self.scene)
def clean_result(self, result):
if "#" in result:
result = result.split("#")[0]
@@ -104,10 +111,10 @@ class SummarizeAgent(Agent):
@set_processing
async def build_archive(self, scene):
end = None
if not self.actions["archive"].enabled:
return
if not scene.archived_history:
start = 0
recent_entry = None
@@ -118,14 +125,16 @@ class SummarizeAgent(Agent):
# meaning we are still at the beginning of the scene
start = 0
else:
start = recent_entry.get("end", 0)+1
start = recent_entry.get("end", 0) + 1
# if there is a recent entry we also collect the 3 most recentries
# as extra context
num_previous = self.actions["archive"].config["include_previous"].value
if recent_entry and num_previous > 0:
extra_context = "\n\n".join([entry["text"] for entry in scene.archived_history[-num_previous:]])
extra_context = "\n\n".join(
[entry["text"] for entry in scene.archived_history[-num_previous:]]
)
else:
extra_context = None
@@ -133,36 +142,44 @@ class SummarizeAgent(Agent):
dialogue_entries = []
ts = "PT0S"
time_passage_termination = False
token_threshold = self.actions["archive"].config["threshold"].value
log.debug("build_archive", start=start, recent_entry=recent_entry)
if recent_entry:
ts = recent_entry.get("ts", ts)
for i in range(start, len(scene.history)):
# we ignore the most recent entry, as the user may still chose to
# regenerate it
for i in range(start, max(start, len(scene.history) - 1)):
dialogue = scene.history[i]
#log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
if isinstance(dialogue, DirectorMessage):
if i == start:
start += 1
continue
if isinstance(dialogue, TimePassageMessage):
log.debug("build_archive", time_passage_message=dialogue)
if i == start:
ts = util.iso8601_add(ts, dialogue.ts)
log.debug("build_archive", time_passage_message=dialogue, start=start, i=i, ts=ts)
log.debug(
"build_archive",
time_passage_message=dialogue,
start=start,
i=i,
ts=ts,
)
start += 1
continue
log.debug("build_archive", time_passage_message_termination=dialogue)
time_passage_termination = True
end = i - 1
break
tokens += util.count_tokens(dialogue)
dialogue_entries.append(dialogue)
if tokens > token_threshold: #
@@ -172,39 +189,44 @@ class SummarizeAgent(Agent):
if end is None:
# nothing to archive yet
return
log.debug("build_archive", start=start, end=end, ts=ts, time_passage_termination=time_passage_termination)
log.debug(
"build_archive",
start=start,
end=end,
ts=ts,
time_passage_termination=time_passage_termination,
)
# in order to summarize coherently, we need to determine if there is a favorable
# cutoff point (e.g., the scene naturally ends or shifts meaninfully in the middle
# of the dialogue)
#
# One way to do this is to check if the last line is a TimePassageMessage, which
# indicates a scene change or a significant pause.
#
# indicates a scene change or a significant pause.
#
# If not, we can ask the AI to find a good point of
# termination.
if not time_passage_termination:
# No TimePassageMessage, so we need to ask the AI to find a good point of termination
terminating_line = await self.analyze_dialoge(dialogue_entries)
if terminating_line:
adjusted_dialogue = []
for line in dialogue_entries:
for line in dialogue_entries:
if str(line) in terminating_line:
break
adjusted_dialogue.append(line)
dialogue_entries = adjusted_dialogue
end = start + len(dialogue_entries)-1
end = start + len(dialogue_entries) - 1
if dialogue_entries:
summarized = await self.summarize(
"\n".join(map(str, dialogue_entries)), extra_context=extra_context
)
else:
# AI has likely identified the first line as a scene change, so we can't summarize
# just use the first line
@@ -218,15 +240,20 @@ class SummarizeAgent(Agent):
@set_processing
async def analyze_dialoge(self, dialogue):
response = await Prompt.request("summarizer.analyze-dialogue", self.client, "analyze_freeform", vars={
"dialogue": "\n".join(map(str, dialogue)),
"scene": self.scene,
"max_tokens": self.client.max_token_length,
})
response = await Prompt.request(
"summarizer.analyze-dialogue",
self.client,
"analyze_freeform",
vars={
"dialogue": "\n".join(map(str, dialogue)),
"scene": self.scene,
"max_tokens": self.client.max_token_length,
},
)
response = self.clean_result(response)
return response
@set_processing
async def summarize(
self,
@@ -239,33 +266,42 @@ class SummarizeAgent(Agent):
Summarize the given text
"""
response = await Prompt.request("summarizer.summarize-dialogue", self.client, "summarize", vars={
"dialogue": text,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"summarization_method": self.actions["archive"].config["method"].value if method is None else method,
"extra_context": extra_context or "",
"extra_instructions": extra_instructions or "",
})
self.scene.log.info("summarize", dialogue_length=len(text), summarized_length=len(response))
response = await Prompt.request(
"summarizer.summarize-dialogue",
self.client,
"summarize",
vars={
"dialogue": text,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"summarization_method": (
self.actions["archive"].config["method"].value
if method is None
else method
),
"extra_context": extra_context or "",
"extra_instructions": extra_instructions or "",
},
)
self.scene.log.info(
"summarize", dialogue_length=len(text), summarized_length=len(response)
)
return self.clean_result(response)
async def build_stepped_archive_for_level(self, level:int):
async def build_stepped_archive_for_level(self, level: int):
"""
WIP - not yet used
This will iterate over existing archived_history entries
and stepped_archived_history entries and summarize based on time duration
indicated between the entries.
The lowest level of summarization (based on token threshold and any time passage)
happens in build_archive. This method is for summarizing furhter levels based on
long time pasages.
Level 0: small timestap summarize (summarizes all token summarizations when time advances +1 day)
Level 1: medium timestap summarize (summarizes all small timestep summarizations when time advances +1 week)
Level 2: large timestap summarize (summarizes all medium timestep summarizations when time advances +1 month)
@@ -273,7 +309,7 @@ class SummarizeAgent(Agent):
Level 4: massive timestap summarize (summarizes all huge timestep summarizations when time advances +10 years)
Level 5: epic timestap summarize (summarizes all massive timestep summarizations when time advances +100 years)
and so on (increasing by a factor of 10 each time)
```
@dataclass
class ArchiveEntry:
@@ -282,35 +318,34 @@ class SummarizeAgent(Agent):
end: int = None
ts: str = None
```
Like token summarization this will use ArchiveEntry and start and end will refer to the entries in the
lower level of summarization.
Ts is the iso8601 timestamp of the start of the summarized period.
"""
# select the list to use for the entries
if level == 0:
entries = self.scene.archived_history
else:
entries = self.scene.stepped_archived_history[level-1]
entries = self.scene.stepped_archived_history[level - 1]
# select the list to summarize new entries to
target = self.scene.stepped_archived_history[level]
if not target:
raise ValueError(f"Invalid level {level}")
# determine the start and end of the period to summarize
if not entries:
return
# determine the time threshold for this level
# first calculate all possible thresholds in iso8601 format, starting with 1 day
thresholds = [
"P1D",
@@ -318,61 +353,65 @@ class SummarizeAgent(Agent):
"P1M",
"P1Y",
]
# TODO: auto extend?
time_threshold_in_seconds = util.iso8601_to_seconds(thresholds[level])
if not time_threshold_in_seconds:
raise ValueError(f"Invalid level {level}")
# determine the most recent summarized entry time, and then find entries
# that are newer than that in the lower list
ts = target[-1].ts if target else entries[0].ts
# determine the most recent entry at the lower level, if its not newer or
# the difference is less than the threshold, then we don't need to summarize
recent_entry = entries[-1]
if util.iso8601_diff(recent_entry.ts, ts) < time_threshold_in_seconds:
return
log.debug("build_stepped_archive", level=level, ts=ts)
# if target is empty, start is 0
# otherwise start is the end of the last entry
start = 0 if not target else target[-1].end
# collect entries starting at start until the combined time duration
# exceeds the threshold
entries_to_summarize = []
for entry in entries[start:]:
entries_to_summarize.append(entry)
if util.iso8601_diff(entry.ts, ts) > time_threshold_in_seconds:
break
# summarize the entries
# we also collect N entries of previous summaries to use as context
num_previous = self.actions["archive"].config["include_previous"].value
if num_previous > 0:
extra_context = "\n\n".join([entry["text"] for entry in target[-num_previous:]])
extra_context = "\n\n".join(
[entry["text"] for entry in target[-num_previous:]]
)
else:
extra_context = None
summarized = await self.summarize(
"\n".join(map(str, entries_to_summarize)), extra_context=extra_context
)
# push summarized entry to target
ts = entries_to_summarize[-1].ts
target.append(data_objects.ArchiveEntry(summarized, start, len(entries_to_summarize)-1, ts=ts))
target.append(
data_objects.ArchiveEntry(
summarized, start, len(entries_to_summarize) - 1, ts=ts
)
)

View File

@@ -1,17 +1,21 @@
from __future__ import annotations
from typing import Union
import asyncio
import httpx
import base64
import functools
import io
import os
import pydantic
import nltk
import tempfile
import base64
import time
import uuid
import functools
from typing import Union
import httpx
import nltk
import pydantic
import structlog
from nltk.tokenize import sent_tokenize
from openai import AsyncOpenAI
import talemate.config as config
import talemate.emit.async_signals
@@ -21,91 +25,91 @@ from talemate.emit.signals import handlers
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import CharacterMessage, NarratorMessage
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .registry import register
import structlog
import time
try:
from TTS.api import TTS
except ImportError:
TTS = None
log = structlog.get_logger("talemate.agents.tts")#
log = structlog.get_logger("talemate.agents.tts") #
if not TTS:
# TTS installation is massive and requires a lot of dependencies
# so we don't want to require it unless the user wants to use it
log.info("TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api")
log.info(
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
)
def parse_chunks(text):
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
chunk = chunk.replace("*","")
chunk = chunk.replace("*", "")
if not chunk:
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
def clean_quotes(chunk:str):
def clean_quotes(chunk: str):
# if there is an uneven number of quotes, remove the last one if its
# at the end of the chunk. If its in the middle, add a quote to the end
if chunk.count('"') % 2 == 1:
if chunk.endswith('"'):
chunk = chunk[:-1]
else:
chunk += '"'
return chunk
def rejoin_chunks(chunks:list[str], chunk_size:int=250):
return chunk
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(clean_quotes(current_chunk))
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(clean_quotes(current_chunk))
return joined_chunks
class Voice(pydantic.BaseModel):
value:str
label:str
value: str
label: str
class VoiceLibrary(pydantic.BaseModel):
api: str
voices: list[Voice] = pydantic.Field(default_factory=list)
last_synced: float = None
@@ -113,51 +117,50 @@ class VoiceLibrary(pydantic.BaseModel):
@register()
class TTSAgent(Agent):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Voice"
requires_llm_client = False
essential = False
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if agent:
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
voice.model_dump() for voice in agent.list_voices_sync()
]
return config_options
def __init__(self, **kwargs):
self.is_enabled = False
nltk.download("punkt", quiet=True)
self.voices = {
"elevenlabs": VoiceLibrary(api="elevenlabs"),
"coqui": VoiceLibrary(api="coqui"),
"tts": VoiceLibrary(api="tts"),
"openai": VoiceLibrary(api="openai"),
}
self.config = config.load_config()
self.playback_done_event = asyncio.Event()
self.preselect_voice = None
self.actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"api": AgentActionConfig(
type="text",
choices=[
# TODO at local TTS support
{"value": "tts", "label": "TTS (Local)"},
{"value": "elevenlabs", "label": "Eleven Labs"},
{"value": "coqui", "label": "Coqui Studio"},
{"value": "openai", "label": "OpenAI"},
],
value="tts",
label="API",
@@ -169,7 +172,7 @@ class TTSAgent(Agent):
value="default",
label="Narrator Voice",
description="Voice ID/Name to use for TTS",
choices=[]
choices=[],
),
"generate_for_player": AgentActionConfig(
type="bool",
@@ -194,90 +197,125 @@ class TTSAgent(Agent):
value=False,
label="Split generation",
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
)
}
),
},
),
"openai": AgentAction(
enabled=True,
condition=AgentActionConditional(
attribute="_config.config.api", value="openai"
),
label="OpenAI Settings",
config={
"model": AgentActionConfig(
type="text",
value="tts-1",
choices=[
{"value": "tts-1", "label": "TTS 1"},
{"value": "tts-1-hd", "label": "TTS 1 HD"},
],
label="Model",
description="TTS model to use",
),
},
),
}
self.actions["_config"].model_dump()
handlers["config_saved"].connect(self.on_config_saved)
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def not_ready_reason(self) -> str:
"""
Returns a string explaining why the agent is not ready
"""
if self.ready:
return ""
if self.api == "tts":
if not TTS:
return "TTS not installed"
elif self.requires_token and not self.token:
return "No API token"
elif not self.default_voice_id:
return "No voice selected"
@property
def agent_details(self):
suffix = ""
if not self.ready:
suffix = f" - {self.not_ready_reason}"
else:
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
api = self.api
choices = self.actions["_config"].config["api"].choices
api_label = api
for choice in choices:
if choice["value"] == api:
api_label = choice["label"]
break
return f"{api_label}{suffix}"
details = {
"api": AgentDetail(
icon="mdi-server-outline",
value=self.api_label,
description="The backend to use for TTS",
).model_dump(),
}
if self.ready and self.enabled:
details["voice"] = AgentDetail(
icon="mdi-account-voice",
value=self.voice_id_to_label(self.default_voice_id) or "",
description="The voice to use for TTS",
color="info",
).model_dump()
elif self.enabled:
details["error"] = AgentDetail(
icon="mdi-alert",
value=self.not_ready_reason,
description=self.not_ready_reason,
color="error",
).model_dump()
return details
@property
def api(self):
return self.actions["_config"].config["api"].value
@property
def api_label(self):
choices = self.actions["_config"].config["api"].choices
api = self.api
for choice in choices:
if choice["value"] == api:
return choice["label"]
return api
@property
def token(self):
api = self.api
return self.config.get(api,{}).get("api_key")
return self.config.get(api, {}).get("api_key")
@property
def default_voice_id(self):
return self.actions["_config"].config["voice_id"].value
@property
def requires_token(self):
return self.api != "tts"
@property
def ready(self):
if self.api == "tts":
if not TTS:
return False
return True
return (not self.requires_token or self.token) and self.default_voice_id
@property
@@ -285,6 +323,8 @@ class TTSAgent(Agent):
if not self.enabled:
return "disabled"
if self.ready:
if getattr(self, "processing_bg", 0) > 0:
return "busy_bg" if not getattr(self, "processing", False) else "busy"
return "active" if not getattr(self, "processing", False) else "busy"
if self.requires_token and not self.token:
return "error"
@@ -299,106 +339,139 @@ class TTSAgent(Agent):
return 1024
elif self.api == "coqui":
return 250
return 250
def apply_config(self, *args, **kwargs):
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
async def apply_config(self, *args, **kwargs):
try:
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
except KeyError:
api = self.api
api_changed = api != self.api
log.debug("apply_config", api=api, api_changed=api != self.api, current_api=self.api)
super().apply_config(*args, **kwargs)
api_changed = api != self.api
log.debug(
"apply_config",
api=api,
api_changed=api != self.api,
current_api=self.api,
args=args,
kwargs=kwargs,
)
try:
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
"value"
]
except KeyError:
self.preselect_voice = self.default_voice_id
await super().apply_config(*args, **kwargs)
if api_changed:
try:
self.actions["_config"].config["voice_id"].value = self.voices[api].voices[0].value
self.actions["_config"].config["voice_id"].value = (
self.voices[api].voices[0].value
)
except IndexError:
self.actions["_config"].config["voice_id"].value = ""
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_new_message").connect(self.on_game_loop_new_message)
talemate.emit.async_signals.get("game_loop_new_message").connect(
self.on_game_loop_new_message
)
def on_config_saved(self, event):
config = event.data
self.config = config
instance.emit_agent_status(self.__class__, self)
async def on_game_loop_new_message(self, emission:GameLoopNewMessageEvent):
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if not self.enabled or not self.ready:
return
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
return
if isinstance(emission.message, NarratorMessage) and not self.actions["_config"].config["generate_for_narration"].value:
if (
isinstance(emission.message, NarratorMessage)
and not self.actions["_config"].config["generate_for_narration"].value
):
return
if isinstance(emission.message, CharacterMessage):
if emission.message.source == "player" and not self.actions["_config"].config["generate_for_player"].value:
if (
emission.message.source == "player"
and not self.actions["_config"].config["generate_for_player"].value
):
return
elif emission.message.source == "ai" and not self.actions["_config"].config["generate_for_npc"].value:
elif (
emission.message.source == "ai"
and not self.actions["_config"].config["generate_for_npc"].value
):
return
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
else:
character_prefix = ""
log.info("reactive tts", message=emission.message, character_prefix=character_prefix)
await self.generate(str(emission.message).replace(character_prefix+": ", ""))
log.info(
"reactive tts", message=emission.message, character_prefix=character_prefix
)
def voice(self, voice_id:str) -> Union[Voice, None]:
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
def voice(self, voice_id: str) -> Union[Voice, None]:
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice
return None
def voice_id_to_label(self, voice_id:str):
def voice_id_to_label(self, voice_id: str):
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice.label
return None
def list_voices_sync(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.list_voices())
async def list_voices(self):
if self.requires_token and not self.token:
return []
library = self.voices[self.api]
# TODO: allow re-syncing voices
if library.last_synced:
return library.voices
list_fn = getattr(self, f"_list_voices_{self.api}")
log.info("Listing voices", api=self.api)
library.voices = await list_fn()
library.last_synced = time.time()
if self.preselect_voice:
if self.voice(self.preselect_voice):
self.actions["_config"].config["voice_id"].value = self.preselect_voice
self.preselect_voice = None
# if the current voice cannot be found, reset it
if not self.voice(self.default_voice_id):
self.actions["_config"].config["voice_id"].value = ""
# set loading to false
return library.voices
@@ -407,11 +480,10 @@ class TTSAgent(Agent):
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
generate_fn = getattr(self, f"_generate_{self.api}")
if self.actions["_config"].config["generate_chunks"].value:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks)
@@ -421,65 +493,78 @@ class TTSAgent(Agent):
# Start generating audio chunks in the background
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
await self.set_background_processing(generation_task)
# Wait for both tasks to complete
await asyncio.gather(generation_task)
# await asyncio.gather(generation_task)
async def generate_chunks(self, generate_fn, chunks):
for chunk in chunks:
chunk = chunk.replace("*","").strip()
chunk = chunk.replace("*", "").strip()
log.info("Generating audio", api=self.api, chunk=chunk)
audio_data = await generate_fn(chunk)
self.play_audio(audio_data)
def play_audio(self, audio_data):
# play audio through the python audio player
#play(audio_data)
emit("audio_queue", data={"audio_data": base64.b64encode(audio_data).decode("utf-8")})
# play(audio_data)
emit(
"audio_queue",
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
)
self.playback_done_event.set() # Signal that playback is finished
# LOCAL
async def _generate_tts(self, text: str) -> Union[bytes, None]:
if not TTS:
return
tts_config = self.config.get("tts",{})
tts_config = self.config.get("tts", {})
model = tts_config.get("model")
device = tts_config.get("device", "cpu")
log.debug("tts local", model=model, device=device)
if not hasattr(self, "tts_instance"):
self.tts_instance = TTS(model).to(device)
tts = self.tts_instance
loop = asyncio.get_event_loop()
voice = self.voice(self.default_voice_id)
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(None, functools.partial(tts.tts_to_file, text=text, speaker_wav=voice.value, language="en", file_path=file_path))
#tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
await loop.run_in_executor(
None,
functools.partial(
tts.tts_to_file,
text=text,
speaker_wav=voice.value,
language="en",
file_path=file_path,
),
)
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
with open(file_path, "rb") as f:
return f.read()
async def _list_voices_tts(self) -> dict[str, str]:
return [Voice(**voice) for voice in self.config.get("tts",{}).get("voices",[])]
return [
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
]
# ELEVENLABS
async def _generate_elevenlabs(self, text: str, chunk_size: int = 1024) -> Union[bytes, None]:
async def _generate_elevenlabs(
self, text: str, chunk_size: int = 1024
) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
@@ -493,11 +578,8 @@ class TTSAgent(Agent):
}
data = {
"text": text,
"model_id": self.config.get("elevenlabs",{}).get("model"),
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.5
}
"model_id": self.config.get("elevenlabs", {}).get("model"),
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
}
response = await client.post(url, json=data, headers=headers, timeout=300)
@@ -514,104 +596,57 @@ class TTSAgent(Agent):
log.error(f"Error generating audio: {response.text}")
async def _list_voices_elevenlabs(self) -> dict[str, str]:
url_voices = "https://api.elevenlabs.io/v1/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Accept": "application/json",
"xi-api-key": self.token,
}
response = await client.get(url_voices, headers=headers, params={"per_page":1000})
response = await client.get(
url_voices, headers=headers, params={"per_page": 1000}
)
speakers = response.json()["voices"]
voices.extend([Voice(value=speaker["voice_id"], label=speaker["name"]) for speaker in speakers])
voices.extend(
[
Voice(value=speaker["voice_id"], label=speaker["name"])
for speaker in speakers
]
)
# sort by name
voices.sort(key=lambda x: x.label)
return voices
# COQUI STUDIO
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
async with httpx.AsyncClient() as client:
url = "https://app.coqui.ai/api/v2/samples/xtts/render/"
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
data = {
"voice_id": self.default_voice_id,
"text": text,
"language": "en" # Assuming English language for simplicity; this could be parameterized
}
return voices
# Make the POST request to Coqui API
response = await client.post(url, json=data, headers=headers, timeout=300)
if response.status_code in [200, 201]:
# Parse the JSON response to get the audio URL
response_data = response.json()
audio_url = response_data.get('audio_url')
if audio_url:
# Make a GET request to download the audio file
audio_response = await client.get(audio_url)
if audio_response.status_code == 200:
# delete the sample from Coqui Studio
# await self._cleanup_coqui(response_data.get('id'))
return audio_response.content
else:
log.error(f"Error downloading audio: {audio_response.text}")
else:
log.error("No audio URL in response")
else:
log.error(f"Error generating audio: {response.text}")
async def _cleanup_coqui(self, sample_id: str):
api_key = self.token
if not api_key or not sample_id:
return
# OPENAI
async with httpx.AsyncClient() as client:
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
headers = {
"Authorization": f"Bearer {api_key}"
}
async def _generate_openai(self, text: str, chunk_size: int = 1024):
# Make the DELETE request to Coqui API
response = await client.delete(url, headers=headers)
client = AsyncOpenAI(api_key=self.openai_api_key)
if response.status_code == 204:
log.info(f"Successfully deleted sample with ID: {sample_id}")
else:
log.error(f"Error deleting sample with ID: {sample_id}: {response.text}")
model = self.actions["openai"].config["model"].value
async def _list_voices_coqui(self) -> dict[str, str]:
url_speakers = "https://app.coqui.ai/api/v2/speakers"
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"Bearer {self.token}"
}
response = await client.get(url_speakers, headers=headers, params={"per_page":1000})
speakers = response.json()["result"]
voices.extend([Voice(value=speaker["id"], label=speaker["name"]) for speaker in speakers])
response = await client.get(url_custom_voices, headers=headers, params={"per_page":1000})
custom_voices = response.json()["result"]
voices.extend([Voice(value=voice["id"], label=voice["name"]) for voice in custom_voices])
# sort by name
voices.sort(key=lambda x: x.label)
return voices
response = await client.audio.speech.create(
model=model, voice=self.default_voice_id, input=text
)
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()
async def _list_voices_openai(self) -> dict[str, str]:
return [
Voice(value="alloy", label="Alloy"),
Voice(value="echo", label="Echo"),
Voice(value="fable", label="Fable"),
Voice(value="onyx", label="Onyx"),
Voice(value="nova", label="Nova"),
Voice(value="shimmer", label="Shimmer"),
]

View File

@@ -0,0 +1,467 @@
import asyncio
import traceback
import structlog
import talemate.agents.visual.automatic1111
import talemate.agents.visual.comfyui
import talemate.agents.visual.openai_image
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from talemate.agents.registry import register
from talemate.client.base import ClientBase
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers as signal_handlers
from talemate.prompts.base import Prompt
from .commands import * # noqa
from .context import VIS_TYPES, VisualContext, visual_context
from .handlers import HANDLERS
from .schema import RESOLUTION_MAP, RenderSettings
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
from .websocket_handler import VisualWebsocketHandler
__all__ = [
"VisualAgent",
]
BACKENDS = [
{"value": mixin_backend, "label": mixin["label"]}
for mixin_backend, mixin in HANDLERS.items()
]
log = structlog.get_logger("talemate.agents.visual")
class VisualBase(Agent):
"""
The visual agent
"""
agent_type = "visual"
verbose_name = "Visualizer"
essential = False
websocket_handler = VisualWebsocketHandler
ACTIONS = {}
def __init__(self, client: ClientBase, *kwargs):
self.client = client
self.is_enabled = False
self.backend_ready = False
self.initialized = False
self.config = load_config()
self.actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
description="Visual agent configuration",
config={
"backend": AgentActionConfig(
type="text",
choices=BACKENDS,
value="automatic1111",
label="Backend",
description="The backend to use for visual processing",
),
"default_style": AgentActionConfig(
type="text",
value="graphic_novel",
choices=MAJOR_STYLES,
label="Default Style",
description="The default style to use for visual processing",
),
},
),
"automatic_generation": AgentAction(
enabled=False,
label="Automatic Generation",
description="Allow automatic generation of visual content",
),
"process_in_background": AgentAction(
enabled=True,
label="Process in Background",
description="Process renders in the background",
),
}
for action_name, action in self.ACTIONS.items():
self.actions[action_name] = action
signal_handlers["config_saved"].connect(self.on_config_saved)
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def backend(self):
return self.actions["_config"].config["backend"].value
@property
def backend_name(self):
key = self.actions["_config"].config["backend"].value
for backend in BACKENDS:
if backend["value"] == key:
return backend["label"]
@property
def default_style(self):
return STYLE_MAP.get(
self.actions["_config"].config["default_style"].value, Style()
)
@property
def ready(self):
return self.backend_ready
@property
def api_url(self):
try:
return self.actions[self.backend].config["api_url"].value
except KeyError:
return None
@property
def agent_details(self):
details = {
"backend": AgentDetail(
icon="mdi-server-outline",
value=self.backend_name,
description="The backend to use for visual processing",
).model_dump(),
"client": AgentDetail(
icon="mdi-network-outline",
value=self.client.name if self.client else None,
description="The client to use for prompt generation",
).model_dump(),
}
if not self.ready and self.enabled:
details["status"] = AgentDetail(
icon="mdi-alert",
value=f"{self.backend_name} not ready",
color="error",
description=self.ready_check_error
or f"{self.backend_name} is not ready for processing",
).model_dump()
return details
@property
def process_in_background(self):
return self.actions["process_in_background"].enabled
@property
def allow_automatic_generation(self):
return self.actions["automatic_generation"].enabled
def on_config_saved(self, event):
config = event.data
self.config = config
asyncio.create_task(self.emit_status())
async def on_ready_check_success(self):
prev_ready = self.backend_ready
self.backend_ready = True
if not prev_ready:
await self.emit_status()
async def on_ready_check_failure(self, error):
prev_ready = self.backend_ready
self.backend_ready = False
self.ready_check_error = str(error)
if prev_ready:
await self.emit_status()
async def ready_check(self):
if not self.enabled:
return
backend = self.backend
fn = getattr(self, f"{backend.lower()}_ready", None)
task = asyncio.create_task(fn())
await super().ready_check(task)
async def apply_config(self, *args, **kwargs):
try:
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
except KeyError:
backend = self.backend
backend_changed = backend != self.backend
was_disabled = not self.enabled
if backend_changed:
self.backend_ready = False
log.info(
"apply_config",
backend=backend,
backend_changed=backend_changed,
old_backend=self.backend,
)
await super().apply_config(*args, **kwargs)
backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None)
if backend_fn:
if not backend_changed and was_disabled and self.enabled:
# If the backend has not changed, but the agent was previously disabled
# and is now enabled, we need to trigger the backend apply_config function
backend_changed = True
task = asyncio.create_task(
backend_fn(backend_changed=backend_changed, *args, **kwargs)
)
await self.set_background_processing(task)
if not self.backend_ready:
await self.ready_check()
self.initialized = True
def resolution_from_format(self, format: str, model_type: str = "sdxl"):
if model_type not in RESOLUTION_MAP:
raise ValueError(f"Model type {model_type} not found in resolution map")
return RESOLUTION_MAP[model_type].get(
format, RESOLUTION_MAP[model_type]["portrait"]
)
def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style:
prompt_style = Style()
prompt_style.load(prompt)
if styles:
prompt_style.prepend(*styles)
return prompt_style
def vis_type_styles(self, vis_type: str):
if vis_type == VIS_TYPES.CHARACTER:
portrait_style = STYLE_MAP["character_portrait"].copy()
return portrait_style
elif vis_type == VIS_TYPES.ENVIRONMENT:
environment_style = STYLE_MAP["environment"].copy()
return environment_style
return Style()
async def apply_image(self, image: str):
context = visual_context.get()
log.debug("apply_image", image=image[:100], context=context)
if context.vis_type == VIS_TYPES.CHARACTER:
await self.apply_image_character(image, context.character_name)
async def apply_image_character(self, image: str, character_name: str):
character = self.scene.get_character(character_name)
if not character:
log.error("character not found", character_name=character_name)
return
if character.cover_image:
log.info("character cover image already set", character_name=character_name)
return
asset = self.scene.assets.add_asset_from_image_data(
f"data:image/png;base64,{image}"
)
character.cover_image = asset.id
self.scene.assets.cover_image = asset.id
self.scene.emit_status()
async def emit_image(self, image: str):
context = visual_context.get()
await self.apply_image(image)
emit(
"image_generated",
websocket_passthrough=True,
data={
"base64": image,
"context": context.model_dump() if context else None,
},
)
@set_processing
async def generate(
self, format: str = "portrait", prompt: str = None, automatic: bool = False
):
context = visual_context.get()
if not self.enabled:
log.warning("generate", skipped="Visual agent not enabled")
return
if automatic and not self.allow_automatic_generation:
log.warning(
"generate",
skipped="Automatic generation disabled",
prompt=prompt,
format=format,
context=context,
)
return
if not context and not prompt:
log.error("generate", error="No context or prompt provided")
return
# Handle prompt generation based on context
if not prompt and context.prompt:
prompt = context.prompt
if context.vis_type == VIS_TYPES.ENVIRONMENT and not prompt:
prompt = await self.generate_environment_prompt(
instructions=context.instructions
)
elif context.vis_type == VIS_TYPES.CHARACTER and not prompt:
prompt = await self.generate_character_prompt(
context.character_name, instructions=context.instructions
)
else:
prompt = prompt or context.prompt
initial_prompt = prompt
# Augment the prompt with styles based on context
thematic_style = self.default_style
vis_type_styles = self.vis_type_styles(context.vis_type)
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
if context.vis_type == VIS_TYPES.CHARACTER:
prompt.keywords.append("character portrait")
if not prompt:
log.error(
"generate", error="No prompt provided and no context to generate from"
)
return
context.prompt = initial_prompt
context.prepared_prompt = str(prompt)
# Handle format (can either come from context or be passed in)
if not format and context.format:
format = context.format
elif not format:
format = "portrait"
context.format = format
# Call the backend specific generate function
backend = self.backend
fn = f"{backend.lower()}_generate"
log.info(
"generate", backend=backend, prompt=prompt, format=format, context=context
)
if not hasattr(self, fn):
log.error("generate", error=f"Backend {backend} does not support generate")
# add the function call to the asyncio task queue
if self.process_in_background:
task = asyncio.create_task(getattr(self, fn)(prompt=prompt, format=format))
await self.set_background_processing(task)
else:
await getattr(self, fn)(prompt=prompt, format=format)
@set_processing
async def generate_environment_prompt(self, instructions: str = None):
response = await Prompt.request(
"visual.generate-environment-prompt",
self.client,
"visualize",
{
"scene": self.scene,
"max_tokens": self.client.max_token_length,
},
)
return response.strip()
@set_processing
async def generate_character_prompt(
self, character_name: str, instructions: str = None
):
character = self.scene.get_character(character_name)
response = await Prompt.request(
"visual.generate-character-prompt",
self.client,
"visualize",
{
"scene": self.scene,
"character_name": character_name,
"character": character,
"max_tokens": self.client.max_token_length,
"instructions": instructions or "",
},
)
return response.strip()
async def generate_environment_background(self, instructions: str = None):
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions):
await self.generate(format="landscape")
generate_environment_background.exposed = True
async def generate_character_portrait(
self,
character_name: str,
instructions: str = None,
):
with VisualContext(
vis_type=VIS_TYPES.CHARACTER,
character_name=character_name,
instructions=instructions,
):
await self.generate(format="portrait")
generate_character_portrait.exposed = True
# apply mixins to the agent (from HANDLERS dict[str, cls])
for mixin_backend, mixin in HANDLERS.items():
mixin_cls = mixin["cls"]
VisualBase = type("VisualAgent", (mixin_cls, VisualBase), {})
extend_actions = getattr(mixin_cls, "EXTEND_ACTIONS", {})
for action_name, action in extend_actions.items():
VisualBase.ACTIONS[action_name] = action
@register()
class VisualAgent(VisualBase):
pass

View File

@@ -0,0 +1,117 @@
import base64
import io
import httpx
import structlog
from PIL import Image
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .handlers import register
from .schema import RenderSettings, Resolution
from .style import STYLE_MAP, Style
log = structlog.get_logger("talemate.agents.visual.automatic1111")
@register(backend_name="automatic1111", label="AUTOMATIC1111")
class Automatic1111Mixin:
automatic1111_default_render_settings = RenderSettings()
EXTEND_ACTIONS = {
"automatic1111": AgentAction(
enabled=True,
condition=AgentActionConditional(
attribute="_config.config.backend", value="automatic1111"
),
label="Automatic1111 Settings",
description="Setting overrides for the automatic1111 backend",
config={
"api_url": AgentActionConfig(
type="text",
value="http://localhost:7860",
label="API URL",
description="The URL of the backend API",
),
"steps": AgentActionConfig(
type="number",
value=40,
label="Steps",
min=5,
max=150,
step=1,
description="number of render steps",
),
"model_type": AgentActionConfig(
type="text",
value="sdxl",
choices=[
{"value": "sdxl", "label": "SDXL"},
{"value": "sd15", "label": "SD1.5"},
],
label="Model Type",
description="Right now just differentiates between sdxl and sd15 - affect generation resolution",
),
},
)
}
@property
def automatic1111_render_settings(self):
if self.actions["automatic1111"].enabled:
return RenderSettings(
steps=self.actions["automatic1111"].config["steps"].value,
type_model=self.actions["automatic1111"].config["model_type"].value,
)
else:
return self.automatic1111_default_render_settings
async def automatic1111_generate(self, prompt: Style, format: str):
url = self.api_url
resolution = self.resolution_from_format(
format, self.automatic1111_render_settings.type_model
)
render_settings = self.automatic1111_render_settings
payload = {
"prompt": prompt.positive_prompt,
"negative_prompt": prompt.negative_prompt,
"steps": render_settings.steps,
"width": resolution.width,
"height": resolution.height,
}
log.info("automatic1111_generate", payload=payload, url=url)
async with httpx.AsyncClient() as client:
response = await client.post(
url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=90
)
r = response.json()
# image = Image.open(io.BytesIO(base64.b64decode(r['images'][0])))
# image.save('a1111-test.png')
#'log.info("automatic1111_generate", saved_to="a1111-test.png")
for image in r["images"]:
await self.emit_image(image)
async def automatic1111_ready(self) -> bool:
"""
Will send a GET to /sdapi/v1/memory and on 200 will return True
"""
async with httpx.AsyncClient() as client:
response = await client.get(
url=f"{self.api_url}/sdapi/v1/memory", timeout=2
)
return response.status_code == 200

View File

@@ -0,0 +1,324 @@
import asyncio
import base64
import io
import json
import os
import random
import time
import urllib.parse
import httpx
import pydantic
import structlog
from PIL import Image
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
from .handlers import register
from .schema import RenderSettings, Resolution
from .style import STYLE_MAP, Style
log = structlog.get_logger("talemate.agents.visual.comfyui")
class Workflow(pydantic.BaseModel):
nodes: dict
def set_resolution(self, resolution: Resolution):
# will collect all latent image nodes
# if there is multiple will look for the one with the
# title "Talemate Resolution"
# if there is no latent image node with the title "Talemate Resolution"
# the first latent image node will be used
# resolution will be updated on the selected node
# if no latent image node is found a warning will be logged
latent_image_node = None
for node_id, node in self.nodes.items():
if node["class_type"] == "EmptyLatentImage":
if not latent_image_node:
latent_image_node = node
elif node["_meta"]["title"] == "Talemate Resolution":
latent_image_node = node
break
if not latent_image_node:
log.warning("set_resolution", error="No latent image node found")
return
latent_image_node["inputs"]["width"] = resolution.width
latent_image_node["inputs"]["height"] = resolution.height
def set_prompt(self, prompt: str, negative_prompt: str = None):
# will collect all CLIPTextEncode nodes
# if there is multiple will look for the one with the
# title "Talemate Positive Prompt" and "Talemate Negative Prompt"
#
# if there is no CLIPTextEncode node with the title "Talemate Positive Prompt"
# the first CLIPTextEncode node will be used
#
# if there is no CLIPTextEncode node with the title "Talemate Negative Prompt"
# the second CLIPTextEncode node will be used
#
# prompt will be updated on the selected node
# if no CLIPTextEncode node is found an exception will be raised for
# the positive prompt
# if no CLIPTextEncode node is found an exception will be raised for
# the negative prompt if it is not None
positive_prompt_node = None
negative_prompt_node = None
for node_id, node in self.nodes.items():
if node["class_type"] == "CLIPTextEncode":
if not positive_prompt_node:
positive_prompt_node = node
elif node["_meta"]["title"] == "Talemate Positive Prompt":
positive_prompt_node = node
elif not negative_prompt_node:
negative_prompt_node = node
elif node["_meta"]["title"] == "Talemate Negative Prompt":
negative_prompt_node = node
if not positive_prompt_node:
raise ValueError("No positive prompt node found")
positive_prompt_node["inputs"]["text"] = prompt
if negative_prompt and not negative_prompt_node:
raise ValueError("No negative prompt node found")
if negative_prompt:
negative_prompt_node["inputs"]["text"] = negative_prompt
def set_checkpoint(self, checkpoint: str):
# will collect all CheckpointLoaderSimple nodes
# if there is multiple will look for the one with the
# title "Talemate Load Checkpoint"
# if there is no CheckpointLoaderSimple node with the title "Talemate Load Checkpoint"
# the first CheckpointLoaderSimple node will be used
# checkpoint will be updated on the selected node
# if no CheckpointLoaderSimple node is found a warning will be logged
checkpoint_node = None
for node_id, node in self.nodes.items():
if node["class_type"] == "CheckpointLoaderSimple":
if not checkpoint_node:
checkpoint_node = node
elif node["_meta"]["title"] == "Talemate Load Checkpoint":
checkpoint_node = node
break
if not checkpoint_node:
log.warning("set_checkpoint", error="No checkpoint node found")
return
checkpoint_node["inputs"]["ckpt_name"] = checkpoint
def set_seeds(self):
for node in self.nodes.values():
for field in node.get("inputs", {}).keys():
if field == "noise_seed":
node["inputs"]["noise_seed"] = random.randint(0, 999999999999999)
@register(backend_name="comfyui", label="ComfyUI")
class ComfyUIMixin:
comfyui_default_render_settings = RenderSettings()
EXTEND_ACTIONS = {
"comfyui": AgentAction(
enabled=True,
condition=AgentActionConditional(
attribute="_config.config.backend", value="comfyui"
),
label="ComfyUI Settings",
description="Setting overrides for the comfyui backend",
config={
"api_url": AgentActionConfig(
type="text",
value="http://localhost:8188",
label="API URL",
description="The URL of the backend API",
),
"workflow": AgentActionConfig(
type="text",
value="default-sdxl.json",
label="Workflow",
description="The workflow to use for comfyui (workflow file name inside ./templates/comfyui-workflows)",
),
"checkpoint": AgentActionConfig(
type="text",
value="default",
label="Checkpoint",
choices=[],
description="The main checkpoint to use.",
),
},
)
}
@property
def comfyui_workflow_filename(self):
base_name = self.actions["comfyui"].config["workflow"].value
# make absolute path
abs_path = os.path.join(
os.path.dirname(__file__),
"..",
"..",
"..",
"..",
"templates",
"comfyui-workflows",
base_name,
)
return abs_path
@property
def comfyui_workflow_is_sdxl(self) -> bool:
"""
Returns true if `sdxl` is in worhflow file name (case insensitive)
"""
return "sdxl" in self.comfyui_workflow_filename.lower()
@property
def comfyui_workflow(self) -> Workflow:
workflow = self.comfyui_workflow_filename
if not workflow:
raise ValueError("No comfyui workflow file specified")
with open(workflow, "r") as f:
return Workflow(nodes=json.load(f))
@property
async def comfyui_object_info(self):
if hasattr(self, "_comfyui_object_info"):
return self._comfyui_object_info
async with httpx.AsyncClient() as client:
response = await client.get(url=f"{self.api_url}/object_info")
self._comfyui_object_info = response.json()
return self._comfyui_object_info
@property
async def comfyui_checkpoints(self):
loader_node = (await self.comfyui_object_info)["CheckpointLoaderSimple"]
_checkpoints = loader_node["input"]["required"]["ckpt_name"][0]
return [
{"label": checkpoint, "value": checkpoint} for checkpoint in _checkpoints
]
async def comfyui_get_image(self, filename: str, subfolder: str, folder_type: str):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
async with httpx.AsyncClient() as client:
response = await client.get(url=f"{self.api_url}/view?{url_values}")
return response.content
async def comfyui_get_history(self, prompt_id: str):
async with httpx.AsyncClient() as client:
response = await client.get(url=f"{self.api_url}/history/{prompt_id}")
return response.json()
async def comfyui_get_images(self, prompt_id: str, max_wait: int = 60.0):
output_images = {}
history = {}
start = time.time()
while not history:
log.info(
"comfyui_get_images", waiting_for_history=True, prompt_id=prompt_id
)
history = await self.comfyui_get_history(prompt_id)
await asyncio.sleep(1.0)
if time.time() - start > max_wait:
raise TimeoutError("Max wait time exceeded")
for node_id, node_output in history[prompt_id]["outputs"].items():
if "images" in node_output:
images_output = []
for image in node_output["images"]:
image_data = await self.comfyui_get_image(
image["filename"], image["subfolder"], image["type"]
)
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
async def comfyui_generate(self, prompt: Style, format: str):
url = self.api_url
workflow = self.comfyui_workflow
is_sdxl = self.comfyui_workflow_is_sdxl
resolution = self.resolution_from_format(format, "sdxl" if is_sdxl else "sd15")
workflow.set_resolution(resolution)
workflow.set_prompt(prompt.positive_prompt, prompt.negative_prompt)
workflow.set_seeds()
workflow.set_checkpoint(self.actions["comfyui"].config["checkpoint"].value)
payload = {"prompt": workflow.model_dump().get("nodes")}
log.info("comfyui_generate", payload=payload, url=url)
async with httpx.AsyncClient() as client:
response = await client.post(url=f"{url}/prompt", json=payload, timeout=90)
log.info("comfyui_generate", response=response.text)
r = response.json()
prompt_id = r["prompt_id"]
images = await self.comfyui_get_images(prompt_id)
for node_id, node_images in images.items():
for i, image in enumerate(node_images):
await self.emit_image(base64.b64encode(image).decode("utf-8"))
# image = Image.open(io.BytesIO(image))
# image.save(f'comfyui-test.png')
async def comfyui_apply_config(
self, backend_changed: bool = False, *args, **kwargs
):
log.debug(
"comfyui_apply_config",
backend_changed=backend_changed,
enabled=self.enabled,
)
if (not self.initialized or backend_changed) and self.enabled:
checkpoints = await self.comfyui_checkpoints
selected_checkpoint = self.actions["comfyui"].config["checkpoint"].value
self.actions["comfyui"].config["checkpoint"].choices = checkpoints
self.actions["comfyui"].config["checkpoint"].value = selected_checkpoint
async def comfyui_ready(self) -> bool:
"""
Will send a GET to /system_stats and on 200 will return True
"""
async with httpx.AsyncClient() as client:
response = await client.get(url=f"{self.api_url}/system_stats", timeout=2)
return response.status_code == 200

View File

@@ -0,0 +1,68 @@
from talemate.agents.visual.context import VIS_TYPES, VisualContext
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.instance import get_agent
__all__ = [
"CmdVisualizeTestGenerate",
]
@register
class CmdVisualizeTestGenerate(TalemateCommand):
"""
Generates a visual test
"""
name = "visual_test_generate"
description = "Will generate a visual test"
aliases = ["vis_test", "vtg"]
label = "Visualize test"
async def run(self):
visual = get_agent("visual")
prompt = self.args[0]
with VisualContext(vis_type=VIS_TYPES.UNSPECIFIED):
await visual.generate(prompt)
return True
@register
class CmdVisualizeEnvironment(TalemateCommand):
"""
Shows the environment
"""
name = "visual_environment"
description = "Will show the environment"
aliases = ["vis_env"]
label = "Visualize environment"
async def run(self):
visual = get_agent("visual")
await visual.generate_environment_background(
instructions=self.args[0] if len(self.args) > 0 else None
)
return True
@register
class CmdVisualizeCharacter(TalemateCommand):
"""
Shows a character
"""
name = "visual_character"
description = "Will show a character"
aliases = ["vis_char"]
label = "Visualize character"
async def run(self):
visual = get_agent("visual")
character_name = self.args[0]
instructions = self.args[1] if len(self.args) > 1 else None
await visual.generate_character_portrait(character_name, instructions)
return True

View File

@@ -0,0 +1,55 @@
import contextvars
import enum
from typing import Union
import pydantic
__all__ = [
"VIS_TYPES",
"visual_context",
"VisualContext",
]
class VIS_TYPES(str, enum.Enum):
UNSPECIFIED = "UNSPECIFIED"
ENVIRONMENT = "ENVIRONMENT"
CHARACTER = "CHARACTER"
ITEM = "ITEM"
visual_context = contextvars.ContextVar("visual_context", default=None)
class VisualContextState(pydantic.BaseModel):
character_name: Union[str, None] = None
instructions: Union[str, None] = None
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT
prompt: Union[str, None] = None
prepared_prompt: Union[str, None] = None
format: Union[str, None] = None
class VisualContext:
def __init__(
self,
character_name: Union[str, None] = None,
instructions: Union[str, None] = None,
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT,
prompt: Union[str, None] = None,
**kwargs,
):
self.state = VisualContextState(
character_name=character_name,
instructions=instructions,
vis_type=vis_type,
prompt=prompt,
**kwargs,
)
def __enter__(self):
self.token = visual_context.set(self.state)
def __exit__(self, *args, **kwargs):
visual_context.reset(self.token)
return False

View File

@@ -0,0 +1,17 @@
__all__ = [
"HANDLERS",
"register",
]
HANDLERS = {}
class register:
def __init__(self, backend_name: str, label: str):
self.backend_name = backend_name
self.label = label
def __call__(self, mixin_cls):
HANDLERS[self.backend_name] = {"label": self.label, "cls": mixin_cls}
return mixin_cls

View File

@@ -0,0 +1,125 @@
import base64
import io
from urllib.parse import parse_qs, unquote, urlparse
import httpx
import structlog
from openai import AsyncOpenAI
from PIL import Image
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .handlers import register
from .schema import RenderSettings, Resolution
from .style import STYLE_MAP, Style
log = structlog.get_logger("talemate.agents.visual.openai_image")
@register(backend_name="openai_image", label="OpenAI")
class OpenAIImageMixin:
openai_image_default_render_settings = RenderSettings()
EXTEND_ACTIONS = {
"openai_image": AgentAction(
enabled=False,
condition=AgentActionConditional(
attribute="_config.config.backend", value="openai_image"
),
label="OpenAI Image Generation Advanced Settings",
description="Setting overrides for the openai backend",
config={
"model_type": AgentActionConfig(
type="text",
value="dall-e-3",
choices=[
{"value": "dall-e-3", "label": "DALL-E 3"},
{"value": "dall-e-2", "label": "DALL-E 2"},
],
label="Model Type",
description="Image generation model",
),
"quality": AgentActionConfig(
type="text",
value="standard",
choices=[
{"value": "standard", "label": "Standard"},
{"value": "hd", "label": "HD"},
],
label="Quality",
description="Image generation quality",
),
},
)
}
@property
def openai_api_key(self):
return self.config.get("openai", {}).get("api_key")
@property
def openai_model_type(self):
return self.actions["openai_image"].config["model_type"].value
@property
def openai_quality(self):
return self.actions["openai_image"].config["quality"].value
async def openai_image_generate(self, prompt: Style, format: str):
"""
#
from openai import OpenAI
client = OpenAI()
response = client.images.generate(
model="dall-e-3",
prompt="a white siamese cat",
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
"""
client = AsyncOpenAI(api_key=self.openai_api_key)
# When using DALL·E 3, images can have a size of 1024x1024, 1024x1792 or 1792x1024 pixels.#
if format == "portrait":
resolution = Resolution(width=1024, height=1792)
elif format == "landscape":
resolution = Resolution(width=1792, height=1024)
else:
resolution = Resolution(width=1024, height=1024)
log.debug("openai_image_generate", resolution=resolution)
response = await client.images.generate(
model=self.openai_model_type,
prompt=prompt.positive_prompt,
size=f"{resolution.width}x{resolution.height}",
quality=self.openai_quality,
n=1,
response_format="b64_json",
)
await self.emit_image(response.data[0].b64_json)
async def openai_image_ready(self) -> bool:
"""
Will send a GET to /sdapi/v1/memory and on 200 will return True
"""
if not self.openai_api_key:
raise ValueError("OpenAI API Key not set")
return True

View File

@@ -0,0 +1,32 @@
import pydantic
__all__ = [
"RenderSettings",
"Resolution",
"RESOLUTION_MAP",
]
RESOLUTION_MAP = {}
class RenderSettings(pydantic.BaseModel):
type_model: str = "sdxl"
steps: int = 40
class Resolution(pydantic.BaseModel):
width: int
height: int
RESOLUTION_MAP["sdxl"] = {
"portrait": Resolution(width=832, height=1216),
"landscape": Resolution(width=1216, height=832),
"square": Resolution(width=1024, height=1024),
}
RESOLUTION_MAP["sd15"] = {
"portrait": Resolution(width=512, height=768),
"landscape": Resolution(width=768, height=512),
"square": Resolution(width=768, height=768),
}

View File

@@ -0,0 +1,136 @@
import pydantic
import structlog
__all__ = [
"Style",
"STYLE_MAP",
"THEME_MAP",
"MAJOR_STYLES",
"combine_styles",
]
STYLE_MAP = {}
THEME_MAP = {}
MAJOR_STYLES = {}
log = structlog.get_logger("talemate.agents.visual.style")
class Style(pydantic.BaseModel):
keywords: list[str] = pydantic.Field(default_factory=list)
negative_keywords: list[str] = pydantic.Field(default_factory=list)
@property
def positive_prompt(self):
return ", ".join(self.keywords)
@property
def negative_prompt(self):
return ", ".join(self.negative_keywords)
def __str__(self):
return f"POSITIVE: {self.positive_prompt}\nNEGATIVE: {self.negative_prompt}"
def load(self, prompt: str, negative_prompt: str = ""):
self.keywords = prompt.split(", ")
self.negative_keywords = negative_prompt.split(", ")
# loop through keywords and drop any starting with "no " and add to negative_keywords
# with "no " removed
for kw in self.keywords:
kw = kw.strip()
log.debug("Checking keyword", keyword=kw)
if kw.startswith("no "):
log.debug("Transforming negative keyword", keyword=kw, to=kw[3:])
self.keywords.remove(kw)
self.negative_keywords.append(kw[3:])
return self
def prepend(self, *styles):
for style in styles:
for idx in range(len(style.keywords) - 1, -1, -1):
kw = style.keywords[idx]
if kw not in self.keywords:
self.keywords.insert(0, kw)
for idx in range(len(style.negative_keywords) - 1, -1, -1):
kw = style.negative_keywords[idx]
if kw not in self.negative_keywords:
self.negative_keywords.insert(0, kw)
return self
def append(self, *styles):
for style in styles:
for kw in style.keywords:
if kw not in self.keywords:
self.keywords.append(kw)
for kw in style.negative_keywords:
if kw not in self.negative_keywords:
self.negative_keywords.append(kw)
return self
def copy(self):
return Style(
keywords=self.keywords.copy(),
negative_keywords=self.negative_keywords.copy(),
)
# Almost taken straight from some of the fooocus style presets, credit goes to the original author
STYLE_MAP["digital_art"] = Style(
keywords="digital artwork, masterpiece, best quality, high detail".split(", "),
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
)
STYLE_MAP["concept_art"] = Style(
keywords="concept art, conceptual sketch, masterpiece, best quality, high detail".split(
", "
),
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
)
STYLE_MAP["ink_illustration"] = Style(
keywords="ink illustration, painting, masterpiece, best quality".split(", "),
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
)
STYLE_MAP["anime"] = Style(
keywords="anime, masterpiece, best quality, illustration".split(", "),
negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "),
)
STYLE_MAP["graphic_novel"] = Style(
keywords="(stylized by Enki Bilal:0.7), best quality, graphic novels, detailed linework, digital art".split(
", "
),
negative_keywords="text, watermark, low quality, blurry, photo, 3d, cgi".split(
", "
),
)
STYLE_MAP["character_portrait"] = Style(keywords="solo, looking at viewer".split(", "))
STYLE_MAP["environment"] = Style(
keywords="scenery, environment, background, postcard".split(", "),
negative_keywords="character, portrait, looking at viewer, people".split(", "),
)
MAJOR_STYLES = [
{"value": "digital_art", "label": "Digital Art"},
{"value": "concept_art", "label": "Concept Art"},
{"value": "ink_illustration", "label": "Ink Illustration"},
{"value": "anime", "label": "Anime"},
{"value": "graphic_novel", "label": "Graphic Novel"},
]
def combine_styles(*styles):
keywords = []
for style in styles:
keywords.extend(style.keywords)
return Style(keywords=list(set(keywords)))

View File

@@ -0,0 +1,84 @@
from typing import Union
import pydantic
import structlog
from talemate.instance import get_agent
from talemate.server.websocket_plugin import Plugin
from .context import VisualContext, VisualContextState
__all__ = [
"VisualWebsocketHandler",
]
log = structlog.get_logger("talemate.server.visual")
class SetCoverImagePayload(pydantic.BaseModel):
base64: str
context: Union[VisualContextState, None] = None
class RegeneratePayload(pydantic.BaseModel):
context: Union[VisualContextState, None] = None
class VisualWebsocketHandler(Plugin):
router = "visual"
async def handle_regenerate(self, data: dict):
"""
Regenerate the image based on the context.
"""
payload = RegeneratePayload(**data)
context = payload.context
visual = get_agent("visual")
with VisualContext(**context.model_dump()):
await visual.generate(format="")
async def handle_cover_image(self, data: dict):
"""
Sets the cover image for a character and the scene.
"""
payload = SetCoverImagePayload(**data)
context = payload.context
scene = self.scene
if context and context.character_name:
character = scene.get_character(context.character_name)
if not character:
log.error("character not found", character_name=context.character_name)
return
asset = scene.assets.add_asset_from_image_data(payload.base64)
log.info("setting scene cover image", character_name=context.character_name)
scene.assets.cover_image = asset.id
log.info(
"setting character cover image", character_name=context.character_name
)
character.cover_image = asset.id
scene.emit_status()
self.websocket_handler.request_scene_assets([asset.id])
self.websocket_handler.queue_put(
{
"type": "scene_asset_character_cover_image",
"asset_id": asset.id,
"asset": self.scene.assets.get_asset_bytes_as_base64(asset.id),
"media_type": asset.media_type,
"character": character.name,
}
)
return

View File

@@ -1,46 +1,54 @@
from __future__ import annotations
import dataclasses
import json
import time
import uuid
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import isodate
import structlog
import talemate.emit.async_signals
import talemate.util as util
from talemate.world_state import InsertionMode
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage, ReinforcementMessage
from talemate.emit import emit
from talemate.events import GameLoopEvent
from talemate.instance import get_agent
from talemate.prompts import Prompt
from talemate.scene_message import (
DirectorMessage,
ReinforcementMessage,
TimePassageMessage,
)
from talemate.world_state import InsertionMode
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
from .registry import register
import structlog
import isodate
import time
log = structlog.get_logger("talemate.agents.world_state")
talemate.emit.async_signals.register("agent.world_state.time")
@dataclasses.dataclass
class WorldStateAgentEmission(AgentEmission):
"""
Emission class for world state agent
"""
pass
@dataclasses.dataclass
class TimePassageEmission(WorldStateAgentEmission):
"""
Emission class for time passage
"""
duration: str
narrative: str
human_duration: str = None
@register()
class WorldStateAgent(Agent):
@@ -55,26 +63,57 @@ class WorldStateAgent(Agent):
self.client = client
self.is_enabled = True
self.actions = {
"update_world_state": AgentAction(enabled=True, label="Update world state", description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.", config={
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before updating the world state.", value=5, min=1, max=100, step=1)
}),
"update_reinforcements": AgentAction(enabled=True, label="Update state reinforcements", description="Will attempt to update any due state reinforcements.", config={}),
"check_pin_conditions": AgentAction(enabled=True, label="Update conditional context pins", description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.", config={
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before checking conditions.", value=2, min=1, max=100, step=1)
}),
"update_world_state": AgentAction(
enabled=True,
label="Update world state",
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
config={
"turns": AgentActionConfig(
type="number",
label="Turns",
description="Number of turns to wait before updating the world state.",
value=5,
min=1,
max=100,
step=1,
)
},
),
"update_reinforcements": AgentAction(
enabled=True,
label="Update state reinforcements",
description="Will attempt to update any due state reinforcements.",
config={},
),
"check_pin_conditions": AgentAction(
enabled=True,
label="Update conditional context pins",
description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.",
config={
"turns": AgentActionConfig(
type="number",
label="Turns",
description="Number of turns to wait before checking conditions.",
value=2,
min=1,
max=100,
step=1,
)
},
),
}
self.next_update = 0
self.next_pin_check = 0
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
@@ -83,110 +122,123 @@ class WorldStateAgent(Agent):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def advance_time(self, duration:str, narrative:str=None):
async def advance_time(self, duration: str, narrative: str = None):
"""
Emit a time passage message
"""
isodate.parse_duration(duration)
human_duration = util.iso8601_duration_to_human(duration, suffix=" later")
message = TimePassageMessage(ts=duration, message=human_duration)
log.debug("world_state.advance_time", message=message)
self.scene.push_history(message)
self.scene.emit_status()
emit("time", message)
await talemate.emit.async_signals.get("agent.world_state.time").send(
TimePassageEmission(agent=self, duration=duration, narrative=narrative, human_duration=human_duration)
)
async def on_game_loop(self, emission:GameLoopEvent):
emit("time", message)
await talemate.emit.async_signals.get("agent.world_state.time").send(
TimePassageEmission(
agent=self,
duration=duration,
narrative=narrative,
human_duration=human_duration,
)
)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called when a conversation is generated
"""
if not self.enabled:
return
await self.update_world_state()
await self.auto_update_reinforcments()
await self.auto_check_pin_conditions()
async def auto_update_reinforcments(self):
if not self.enabled:
return
if not self.actions["update_reinforcements"].enabled:
return
await self.update_reinforcements()
async def auto_check_pin_conditions(self):
if not self.enabled:
return
if not self.actions["check_pin_conditions"].enabled:
return
if self.next_pin_check % self.actions["check_pin_conditions"].config["turns"].value != 0 or self.next_pin_check == 0:
if (
self.next_pin_check
% self.actions["check_pin_conditions"].config["turns"].value
!= 0
or self.next_pin_check == 0
):
self.next_pin_check += 1
return
self.next_pin_check = 0
await self.check_pin_conditions()
async def update_world_state(self):
self.next_pin_check = 0
await self.check_pin_conditions()
async def update_world_state(self, force: bool = False):
if not self.enabled:
return
if not self.actions["update_world_state"].enabled:
return
log.debug("update_world_state", next_update=self.next_update, turns=self.actions["update_world_state"].config["turns"].value)
log.debug(
"update_world_state",
next_update=self.next_update,
turns=self.actions["update_world_state"].config["turns"].value,
)
scene = self.scene
if self.next_update % self.actions["update_world_state"].config["turns"].value != 0 or self.next_update == 0:
if (
self.next_update % self.actions["update_world_state"].config["turns"].value
!= 0
or self.next_update == 0
) and not force:
self.next_update += 1
return
self.next_update = 0
await scene.world_state.request_update()
update_world_state.exposed = True
@set_processing
async def request_world_state(self):
t1 = time.time()
_, world_state = await Prompt.request(
"world_state.request-world-state-v2",
self.client,
"analyze_long",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"object_type": "character",
"object_type_plural": "characters",
}
},
)
self.scene.log.debug("request_world_state", response=world_state, time=time.time() - t1)
self.scene.log.debug(
"request_world_state", response=world_state, time=time.time() - t1
)
return world_state
@set_processing
async def request_world_state_inline(self):
"""
EXPERIMENTAL, Overall the one shot request seems about as coherent as the inline request, but the inline request is is about twice as slow and would need to run on every dialogue line.
"""
@@ -199,14 +251,18 @@ class WorldStateAgent(Agent):
"world_state.request-world-state-inline-items",
self.client,
"analyze_long",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
}
},
)
self.scene.log.debug("request_world_state_inline", marked_items=marked_items_response, time=time.time() - t1)
self.scene.log.debug(
"request_world_state_inline",
marked_items=marked_items_response,
time=time.time() - t1,
)
return marked_items_response
@set_processing
@@ -214,99 +270,111 @@ class WorldStateAgent(Agent):
self,
text: str,
):
response = await Prompt.request(
"world_state.analyze-time-passage",
self.client,
"analyze_freeform_short",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
}
},
)
duration = response.split("\n")[0].split(" ")[0].strip()
if not duration.startswith("P"):
duration = "P"+duration
duration = "P" + duration
return duration
@set_processing
async def analyze_text_and_extract_context(
self,
text: str,
goal: str,
):
response = await Prompt.request(
"world_state.analyze-text-and-extract-context",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"goal": goal,
}
},
)
log.debug("analyze_text_and_extract_context", goal=goal, text=text, response=response)
log.debug(
"analyze_text_and_extract_context", goal=goal, text=text, response=response
)
return response
@set_processing
async def analyze_text_and_extract_context_via_queries(
self,
text: str,
goal: str,
) -> list[str]:
response = await Prompt.request(
"world_state.analyze-text-and-generate-rag-queries",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"goal": goal,
}
},
)
queries = response.split("\n")
memory_agent = get_agent("memory")
context = await memory_agent.multi_query(queries, iterate=3)
log.debug("analyze_text_and_extract_context_via_queries", goal=goal, text=text, queries=queries, context=context)
log.debug(
"analyze_text_and_extract_context_via_queries",
goal=goal,
text=text,
queries=queries,
context=context,
)
return context
@set_processing
async def analyze_and_follow_instruction(
self,
text: str,
instruction: str,
short: bool = False,
):
kind = "analyze_freeform_short" if short else "analyze_freeform"
response = await Prompt.request(
"world_state.analyze-text-and-follow-instruction",
self.client,
"analyze_freeform",
vars = {
kind,
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"instruction": instruction,
}
},
)
log.debug("analyze_and_follow_instruction", instruction=instruction, text=text, response=response)
log.debug(
"analyze_and_follow_instruction",
instruction=instruction,
text=text,
response=response,
)
return response
@set_processing
@@ -314,51 +382,55 @@ class WorldStateAgent(Agent):
self,
text: str,
query: str,
short: bool = False,
):
kind = "analyze_freeform_short" if short else "analyze_freeform"
response = await Prompt.request(
"world_state.analyze-text-and-answer-question",
self.client,
"analyze_freeform",
vars = {
kind,
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"query": query,
}
},
)
log.debug("analyze_text_and_answer_question", query=query, text=text, response=response)
log.debug(
"analyze_text_and_answer_question",
query=query,
text=text,
response=response,
)
return response
@set_processing
async def identify_characters(
self,
text: str = None,
):
"""
Attempts to identify characters in the given text.
"""
_, data = await Prompt.request(
"world_state.identify-characters",
self.client,
"analyze",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
}
},
)
log.debug("identify_characters", text=text, data=data)
return data
def _parse_character_sheet(self, response):
data = {}
for line in response.split("\n"):
if not line.strip():
@@ -367,128 +439,148 @@ class WorldStateAgent(Agent):
break
name, value = line.split(":", 1)
data[name.strip()] = value.strip()
return data
@set_processing
async def extract_character_sheet(
self,
name:str,
text:str = None,
name: str,
text: str = None,
alteration_instructions: str = None,
):
"""
Attempts to extract a character sheet from the given text.
"""
response = await Prompt.request(
"world_state.extract-character-sheet",
self.client,
"create",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"name": name,
}
"character": self.scene.get_character(name),
"alteration_instructions": alteration_instructions or "",
},
)
# loop through each line in response and if it contains a : then extract
# the left side as an attribute name and the right side as the value
#
# break as soon as a non-empty line is found that doesn't contain a :
return self._parse_character_sheet(response)
@set_processing
async def match_character_names(self, names:list[str]):
async def match_character_names(self, names: list[str]):
"""
Attempts to match character names.
"""
_, response = await Prompt.request(
"world_state.match-character-names",
self.client,
"analyze_long",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"names": names,
}
},
)
log.debug("match_character_names", names=names, response=response)
return response
@set_processing
async def update_reinforcements(self, force:bool=False):
async def update_reinforcements(self, force: bool = False):
"""
Queries due worldstate re-inforcements
"""
for reinforcement in self.scene.world_state.reinforce:
if reinforcement.due <= 0 or force:
await self.update_reinforcement(reinforcement.question, reinforcement.character)
await self.update_reinforcement(
reinforcement.question, reinforcement.character
)
else:
reinforcement.due -= 1
@set_processing
async def update_reinforcement(self, question:str, character:str=None, reset:bool=False):
async def update_reinforcement(
self, question: str, character: str = None, reset: bool = False
):
"""
Queries a single re-inforcement
"""
message = None
idx, reinforcement = await self.scene.world_state.find_reinforcement(question, character)
idx, reinforcement = await self.scene.world_state.find_reinforcement(
question, character
)
if not reinforcement:
return
source = f"{reinforcement.question}:{reinforcement.character if reinforcement.character else ''}"
if reset and reinforcement.insert == "sequential":
self.scene.pop_history(typ="reinforcement", source=source, all=True)
if reinforcement.insert == "sequential":
kind = "analyze_freeform_medium_short"
else:
kind = "analyze_freeform"
answer = await Prompt.request(
"world_state.update-reinforcements",
self.client,
"analyze_freeform",
vars = {
kind,
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"question": reinforcement.question,
"instructions": reinforcement.instructions or "",
"character": self.scene.get_character(reinforcement.character) if reinforcement.character else None,
"character": (
self.scene.get_character(reinforcement.character)
if reinforcement.character
else None
),
"answer": (reinforcement.answer if not reset else None) or "",
"reinforcement": reinforcement,
}
},
)
# sequential reinforcment should be single sentence so we
# split on line breaks and take the first line in case the
# LLM did not understand the request and returned a longer response
if reinforcement.insert == "sequential":
answer = answer.split("\n")[0]
reinforcement.answer = answer
reinforcement.due = reinforcement.interval
# remove any recent previous reinforcement message with same question
# to avoid overloading the near history with reinforcement messages
if not reset:
self.scene.pop_history(typ="reinforcement", source=source, max_iterations=10)
self.scene.pop_history(
typ="reinforcement", source=source, max_iterations=10
)
if reinforcement.insert == "sequential":
# insert the reinforcement message at the current position
message = ReinforcementMessage(message=answer, source=source)
log.debug("update_reinforcement", message=message, reset=reset)
self.scene.push_history(message)
# if reinforcement has a character name set, update the character detail
if reinforcement.character:
character = self.scene.get_character(reinforcement.character)
await character.set_detail(reinforcement.question, answer)
else:
# set world entry
await self.scene.world_state_manager.save_world_entry(
@@ -496,20 +588,19 @@ class WorldStateAgent(Agent):
reinforcement.as_context_line,
{},
)
self.scene.world_state.emit()
return message
return message
@set_processing
async def check_pin_conditions(
self,
):
"""
Checks if any context pin conditions
"""
pins_with_condition = {
entry_id: {
"condition": pin.condition,
@@ -518,41 +609,47 @@ class WorldStateAgent(Agent):
for entry_id, pin in self.scene.world_state.pins.items()
if pin.condition
}
if not pins_with_condition:
return
first_entry_id = list(pins_with_condition.keys())[0]
_, answers = await Prompt.request(
"world_state.check-pin-conditions",
self.client,
"analyze",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"previous_states": json.dumps(pins_with_condition,indent=2),
"coercion": {first_entry_id:{ "condition": "" }},
}
"previous_states": json.dumps(pins_with_condition, indent=2),
"coercion": {first_entry_id: {"condition": ""}},
},
)
world_state = self.scene.world_state
state_change = False
state_change = False
for entry_id, answer in answers.items():
if entry_id not in world_state.pins:
log.warning("check_pin_conditions", entry_id=entry_id, answer=answer, msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)")
log.warning(
"check_pin_conditions",
entry_id=entry_id,
answer=answer,
msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)",
)
continue
log.info("check_pin_conditions", entry_id=entry_id, answer=answer)
state = answer.get("state")
if state is True or (isinstance(state, str) and state.lower() in ["true", "yes", "y"]):
if state is True or (
isinstance(state, str) and state.lower() in ["true", "yes", "y"]
):
prev_state = world_state.pins[entry_id].condition_state
world_state.pins[entry_id].condition_state = True
world_state.pins[entry_id].active = True
if prev_state != world_state.pins[entry_id].condition_state:
state_change = True
else:
@@ -560,49 +657,50 @@ class WorldStateAgent(Agent):
world_state.pins[entry_id].condition_state = False
world_state.pins[entry_id].active = False
state_change = True
if state_change:
await self.scene.load_active_pins()
self.scene.emit_status()
@set_processing
async def summarize_and_pin(self, message_id:int, num_messages:int=3) -> str:
async def summarize_and_pin(self, message_id: int, num_messages: int = 3) -> str:
"""
Will take a message index and then walk back N messages
summarizing the scene and pinning it to the context.
"""
creator = get_agent("creator")
summarizer = get_agent("summarizer")
message_index = self.scene.message_index(message_id)
text = self.scene.snapshot(lines=num_messages, start=message_index)
extra_context = self.scene.snapshot(lines=50, start=message_index-num_messages)
extra_context = self.scene.snapshot(
lines=50, start=message_index - num_messages
)
summary = await summarizer.summarize(
text,
text,
extra_context=extra_context,
method="short",
extra_instructions="Pay particularly close attention to decisions, agreements or promises made.",
)
entry_id = util.clean_id(await creator.generate_title(summary))
ts = self.scene.ts
log.debug(
"summarize_and_pin",
message_id=message_id,
message_index=message_index,
num_messages=num_messages,
num_messages=num_messages,
summary=summary,
entry_id=entry_id,
ts=ts,
)
await self.scene.world_state_manager.save_world_entry(
entry_id,
summary,
@@ -610,49 +708,49 @@ class WorldStateAgent(Agent):
"ts": ts,
},
)
await self.scene.world_state_manager.set_pin(
entry_id,
active=True,
)
await self.scene.load_active_pins()
self.scene.emit_status()
@set_processing
async def is_character_present(self, character:str) -> bool:
async def is_character_present(self, character: str) -> bool:
"""
Check if a character is present in the scene
Arguments:
- `character`: The character to check.
"""
if len(self.scene.history) < 10:
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
else:
text = self.scene.snapshot(lines=50)
is_present = await self.analyze_text_and_answer_question(
text=text,
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
)
return is_present.lower().startswith("y")
@set_processing
async def is_character_leaving(self, character:str) -> bool:
"""
Check if a character is leaving the scene
Arguments:
- `character`: The character to check.
"""
if len(self.scene.history) < 10:
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
else:
text = self.scene.snapshot(lines=50)
is_present = await self.analyze_text_and_answer_question(
text=text,
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
)
return is_present.lower().startswith("y")
@set_processing
async def is_character_leaving(self, character: str) -> bool:
"""
Check if a character is leaving the scene
Arguments:
- `character`: The character to check.
"""
if len(self.scene.history) < 10:
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
else:
text = self.scene.snapshot(lines=50)
@@ -660,5 +758,30 @@ class WorldStateAgent(Agent):
text=text,
query=f"Is {character} leaving the current scene? Answert with 'yes' or 'no'.",
)
return is_leaving.lower().startswith("y")
return is_leaving.lower().startswith("y")
@set_processing
async def manager(self, action_name: str, *args, **kwargs):
"""
Executes a world state manager action through self.scene.world_state_manager
"""
manager = self.scene.world_state_manager
try:
fn = getattr(manager, action_name, None)
if not fn:
raise ValueError(f"Unknown action: {action_name}")
return await fn(*args, **kwargs)
except Exception as e:
log.error(
"worldstate.manager",
action_name=action_name,
args=args,
kwargs=kwargs,
error=e,
)
raise

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import dataclasses
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from talemate import Scene
import structlog
__all__ = ["AutomatedAction", "register", "initialize_for_scene"]
@@ -13,50 +14,64 @@ log = structlog.get_logger("talemate.automated_action")
AUTOMATED_ACTIONS = {}
def initialize_for_scene(scene:Scene):
def initialize_for_scene(scene: Scene):
for uid, config in AUTOMATED_ACTIONS.items():
scene.automated_actions[uid] = config.cls(
scene,
uid=uid,
frequency=config.frequency,
call_initially=config.call_initially,
enabled=config.enabled
enabled=config.enabled,
)
@dataclasses.dataclass
class AutomatedActionConfig:
uid:str
cls:AutomatedAction
frequency:int=5
call_initially:bool=False
enabled:bool=True
uid: str
cls: AutomatedAction
frequency: int = 5
call_initially: bool = False
enabled: bool = True
class register:
def __init__(self, uid:str, frequency:int=5, call_initially:bool=False, enabled:bool=True):
def __init__(
self,
uid: str,
frequency: int = 5,
call_initially: bool = False,
enabled: bool = True,
):
self.uid = uid
self.frequency = frequency
self.call_initially = call_initially
self.enabled = enabled
def __call__(self, action:AutomatedAction):
def __call__(self, action: AutomatedAction):
AUTOMATED_ACTIONS[self.uid] = AutomatedActionConfig(
self.uid,
action,
frequency=self.frequency,
call_initially=self.call_initially,
enabled=self.enabled
self.uid,
action,
frequency=self.frequency,
call_initially=self.call_initially,
enabled=self.enabled,
)
return action
class AutomatedAction:
"""
An action that will be executed every n turns
"""
def __init__(self, scene:Scene, frequency:int=5, call_initially:bool=False, uid:str=None, enabled:bool=True):
def __init__(
self,
scene: Scene,
frequency: int = 5,
call_initially: bool = False,
uid: str = None,
enabled: bool = True,
):
self.scene = scene
self.enabled = enabled
self.frequency = frequency
@@ -64,14 +79,19 @@ class AutomatedAction:
self.uid = uid
if call_initially:
self.turns = frequency
async def __call__(self):
log.debug("automated_action", uid=self.uid, enabled=self.enabled, frequency=self.frequency, turns=self.turns)
log.debug(
"automated_action",
uid=self.uid,
enabled=self.enabled,
frequency=self.frequency,
turns=self.turns,
)
if not self.enabled:
return False
if self.turns % self.frequency == 0:
result = await self.action()
log.debug("automated_action", result=result)
@@ -79,10 +99,9 @@ class AutomatedAction:
# action could not be performed at this turn, we will try again next turn
return False
self.turns += 1
async def action(self) -> Any:
"""
Override this method to implement your action.
"""
raise NotImplementedError()
raise NotImplementedError()

View File

@@ -1,32 +1,34 @@
from typing import Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from talemate.instance import get_agent
if TYPE_CHECKING:
from talemate.tale_mate import Scene, Character, Actor
from talemate.tale_mate import Actor, Character, Scene
__all__ = [
"deactivate_character",
"activate_character",
]
async def deactivate_character(scene:"Scene", character:Union[str, "Character"]):
async def deactivate_character(scene: "Scene", character: Union[str, "Character"]):
"""
Deactivates a character
Arguments:
- `scene`: The scene to deactivate the character from
- `character`: The character to deactivate. Can be a string (the character's name) or a Character object
"""
if isinstance(character, str):
character = scene.get_character(character)
if character.is_player:
# can't deactivate the player
return False
if character.name in scene.inactive_characters:
# already deactivated
return False
@@ -34,24 +36,24 @@ async def deactivate_character(scene:"Scene", character:Union[str, "Character"])
await scene.remove_actor(character.actor)
scene.inactive_characters[character.name] = character
async def activate_character(scene:"Scene", character:Union[str, "Character"]):
async def activate_character(scene: "Scene", character: Union[str, "Character"]):
"""
Activates a character
Arguments:
- `scene`: The scene to activate the character in
- `character`: The character to activate. Can be a string (the character's name) or a Character object
"""
if isinstance(character, str):
character = scene.get_character(character)
if character.name not in scene.inactive_characters:
# already activated
return False
actor = scene.Actor(character, get_agent("conversation"))
await scene.add_actor(actor)
del scene.inactive_characters[character.name]

View File

@@ -2,15 +2,13 @@ import argparse
import asyncio
import glob
import os
import structlog
import structlog
from dotenv import load_dotenv
import talemate.instance as instance
from talemate import Actor, Character, Helper, Player, Scene
from talemate.agents import (
ConversationAgent,
)
from talemate.agents import ConversationAgent
from talemate.client import OpenAIClient, TextGeneratorWebuiClient
from talemate.emit.console import Console
from talemate.load import (
@@ -129,7 +127,6 @@ async def run_console_session(parser, args):
default_client = None
if "textgenwebui" in clients.values() or args.client == "textgenwebui":
# Init the TextGeneratorWebuiClient with ConversationAgent and create an actor
textgenwebui_api_url = args.textgenwebui_url
@@ -145,7 +142,6 @@ async def run_console_session(parser, args):
clients[client_name] = text_generator_webui_client
if "openai" in clients.values() or args.client == "openai":
openai_client = OpenAIClient()
for client_name, client_typ in clients.items():

View File

@@ -1,7 +1,13 @@
import os
import talemate.client.runpod
from talemate.client.anthropic import AnthropicClient
from talemate.client.cohere import CohereClient
from talemate.client.google import GoogleClient
from talemate.client.groq import GroqClient
from talemate.client.lmstudio import LMStudioClient
from talemate.client.mistral import MistralAIClient
from talemate.client.openai import OpenAIClient
from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.textgenwebui import TextGeneratorWebuiClient
from talemate.client.lmstudio import LMStudioClient
from talemate.client.openai_compat import OpenAICompatibleClient
import talemate.client.runpod

View File

@@ -0,0 +1,225 @@
import pydantic
import structlog
from anthropic import AsyncAnthropic, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"AnthropicClient",
]
log = structlog.get_logger("talemate")
# Edit this to add new models / remove old models
SUPPORTED_MODELS = [
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
]
class Defaults(pydantic.BaseModel):
max_token_length: int = 16384
model: str = "claude-3-sonnet-20240229"
@register()
class AnthropicClient(ClientBase):
"""
Anthropic client for generating text.
"""
client_type = "anthropic"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = False
class Meta(ClientBase.Meta):
name_prefix: str = "Anthropic"
title: str = "Anthropic"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="claude-3-sonnet-20240229", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def anthropic_api_key(self):
return self.config.get("anthropic", {}).get("api_key")
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.anthropic_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"anthropic_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
},
)
def set_client(self, max_token_length: int = None):
if not self.anthropic_api_key:
self.client = AsyncAnthropic(api_key="sk-1111")
log.error("No anthropic API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "claude-3-opus-20240229"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"anthropic set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return response.usage.output_tokens
def prompt_tokens(self, response: str):
return response.usage.input_tokens
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.anthropic_api_key:
raise Exception("No anthropic API key set")
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
human_message = {"role": "user", "content": prompt.strip()}
system_message = self.get_system_message(kind)
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
try:
response = await self.client.messages.create(
model=self.model_name,
system=system_message,
messages=[human_message],
**parameters,
)
self._returned_prompt_tokens = self.prompt_tokens(response)
self._returned_response_tokens = self.response_tokens(response)
log.debug("generated response", response=response.content)
response = response.content[0].text
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
emit("status", message="anthropic API: Permission Denied", status="error")
return ""
except Exception as e:
raise

View File

@@ -1,195 +1,306 @@
"""
A unified client base, based on the openai API
"""
import ipaddress
import logging
import random
import time
import pydantic
from typing import Callable, Union
import pydantic
import structlog
import logging
import urllib3
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.emit import emit
import talemate.instance as instance
import talemate.client.presets as presets
import talemate.client.system_prompts as system_prompts
import talemate.instance as instance
import talemate.util as util
from talemate.agents.context import active_agent
from talemate.client.context import client_context_attribute
from talemate.client.model_prompts import model_prompt
from talemate.agents.context import active_agent
from talemate.emit import emit
# Set up logging level for httpx to WARNING to suppress debug logs.
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
REMOTE_SERVICES = [
# TODO: runpod.py should add this to the list
".runpod.net"
]
log = structlog.get_logger("client.base")
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
class PromptData(pydantic.BaseModel):
kind: str
prompt: str
response: str
prompt_tokens: int
response_tokens: int
client_name: str
client_type: str
time: Union[float, int]
agent_stack: list[str] = pydantic.Field(default_factory=list)
generation_parameters: dict = pydantic.Field(default_factory=dict)
class ErrorAction(pydantic.BaseModel):
title:str
action_name:str
icon:str = "mdi-error"
arguments:list = []
title: str
action_name: str
icon: str = "mdi-error"
arguments: list = []
class Defaults(pydantic.BaseModel):
api_url:str = "http://localhost:5000"
max_token_length:int = 4096
api_url: str = "http://localhost:5000"
max_token_length: int = 8192
double_coercion: str = None
class ExtraField(pydantic.BaseModel):
name: str
type: str
label: str
required: bool
description: str
class ClientBase:
api_url: str
model_name: str
api_key: str = None
name:str = None
name: str = None
enabled: bool = True
current_status: str = None
max_token_length: int = 4096
max_token_length: int = 8192
processing: bool = False
connected: bool = False
conversation_retries: int = 2
conversation_retries: int = 0
auto_break_repetition_enabled: bool = True
decensor_enabled: bool = True
auto_determine_prompt_template: bool = False
finalizers: list[str] = []
double_coercion: Union[str, None] = None
client_type = "base"
class Meta(pydantic.BaseModel):
experimental:Union[None,str] = None
defaults:Defaults = Defaults()
title:str = "Client"
name_prefix:str = "Client"
experimental: Union[None, str] = None
defaults: Defaults = Defaults()
title: str = "Client"
name_prefix: str = "Client"
enable_api_auth: bool = False
requires_prompt_template: bool = True
def __init__(
self,
api_url: str = None,
name = None,
name=None,
**kwargs,
):
self.api_url = api_url
self.name = name or self.client_type
self.auto_determine_prompt_template_attempt = None
self.log = structlog.get_logger(f"client.{self.client_type}")
self.double_coercion = kwargs.get("double_coercion", None)
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
self.max_token_length = (
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
)
self.set_client(max_token_length=self.max_token_length)
def __str__(self):
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
@property
def experimental(self):
return False
@property
def can_be_coerced(self):
"""
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
return self.Meta().requires_prompt_template
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def prompt_template(self, sys_msg, prompt):
def prompt_template(self, sys_msg: str, prompt: str):
"""
Applies the appropriate prompt template for the model.
"""
if not self.model_name:
self.log.warning("prompt template not applied", reason="no model loaded")
return f"{sys_msg}\n{prompt}"
return model_prompt(self.model_name, sys_msg, prompt)[0]
# is JSON coercion active?
# Check for <|BOT|>{ in the prompt
json_coercion = "<|BOT|>{" in prompt
if self.can_be_coerced and self.double_coercion and not json_coercion:
double_coercion = self.double_coercion
double_coercion = f"{double_coercion}\n\n"
else:
double_coercion = None
return model_prompt(self.model_name, sys_msg, prompt, double_coercion)[0]
def prompt_template_example(self):
if not getattr(self, "model_name", None):
return None, None
return model_prompt(self.model_name, "sysmsg", "prompt<|BOT|>{LLM coercion}")
return model_prompt(
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
)
def reconfigure(self, **kwargs):
"""
Reconfigures the client.
Keyword Arguments:
- api_url: the API URL to use
- max_token_length: the max token length to use
- enabled: whether the client is enabled
"""
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
if kwargs.get("max_token_length"):
self.max_token_length = int(kwargs["max_token_length"])
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
def host_is_remote(self, url: str) -> bool:
"""
Returns whether or not the host is a remote service.
It checks common local hostnames / ip prefixes.
- localhost
"""
host = urllib3.util.parse_url(url).host
if host.lower() == "localhost":
return False
# use ipaddress module to check for local ip prefixes
try:
ip = ipaddress.ip_address(host)
except ValueError:
return True
if ip.is_loopback or ip.is_private:
return False
return True
def toggle_disabled_if_remote(self):
"""
If the client is targeting a remote recognized service, this
will disable the client.
"""
for service in REMOTE_SERVICES:
if service in self.api_url:
if self.enabled:
self.log.warn("remote service unreachable, disabling client", client=self.name)
self.enabled = False
return True
if not self.api_url:
return False
if self.host_is_remote(self.api_url) and self.enabled:
self.log.warn(
"remote service unreachable, disabling client", client=self.name
)
self.enabled = False
return True
return False
def get_system_message(self, kind: str) -> str:
"""
Returns the appropriate system message for the given kind of generation
Arguments:
- kind: the kind of generation
"""
# TODO: make extensible
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
return system_prompts.CREATOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY
if "conversation" in kind:
return system_prompts.ROLEPLAY
if "editor" in kind:
return system_prompts.EDITOR
if "world_state" in kind:
return system_prompts.WORLD_STATE
if "analyze_freeform" in kind:
return system_prompts.ANALYST_FREEFORM
if "analyst" in kind:
return system_prompts.ANALYST
if "analyze" in kind:
return system_prompts.ANALYST
if "summarize" in kind:
return system_prompts.SUMMARIZE
if self.decensor_enabled:
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
return system_prompts.CREATOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY
if "conversation" in kind:
return system_prompts.ROLEPLAY
if "basic" in kind:
return system_prompts.BASIC
if "editor" in kind:
return system_prompts.EDITOR
if "edit" in kind:
return system_prompts.EDITOR
if "world_state" in kind:
return system_prompts.WORLD_STATE
if "analyze_freeform" in kind:
return system_prompts.ANALYST_FREEFORM
if "analyst" in kind:
return system_prompts.ANALYST
if "analyze" in kind:
return system_prompts.ANALYST
if "summarize" in kind:
return system_prompts.SUMMARIZE
if "visualize" in kind:
return system_prompts.VISUALIZE
else:
if "narrate" in kind:
return system_prompts.NARRATOR_NO_DECENSOR
if "story" in kind:
return system_prompts.NARRATOR_NO_DECENSOR
if "director" in kind:
return system_prompts.DIRECTOR_NO_DECENSOR
if "create" in kind:
return system_prompts.CREATOR_NO_DECENSOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY_NO_DECENSOR
if "conversation" in kind:
return system_prompts.ROLEPLAY_NO_DECENSOR
if "basic" in kind:
return system_prompts.BASIC
if "editor" in kind:
return system_prompts.EDITOR_NO_DECENSOR
if "edit" in kind:
return system_prompts.EDITOR_NO_DECENSOR
if "world_state" in kind:
return system_prompts.WORLD_STATE_NO_DECENSOR
if "analyze_freeform" in kind:
return system_prompts.ANALYST_FREEFORM_NO_DECENSOR
if "analyst" in kind:
return system_prompts.ANALYST_NO_DECENSOR
if "analyze" in kind:
return system_prompts.ANALYST_NO_DECENSOR
if "summarize" in kind:
return system_prompts.SUMMARIZE_NO_DECENSOR
if "visualize" in kind:
return system_prompts.VISUALIZE_NO_DECENSOR
return system_prompts.BASIC
def emit_status(self, processing: bool = None):
"""
Sets and emits the client status.
"""
if processing is not None:
self.processing = processing
@@ -205,39 +316,80 @@ class ClientBase:
else:
model_name = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
prompt_template_example, prompt_template_file = self.prompt_template_example()
has_prompt_template = (
prompt_template_file and prompt_template_file != "default.jinja2"
)
if not has_prompt_template and self.auto_determine_prompt_template:
# only attempt to determine the prompt template once per model and
# only if the model does not already have a prompt template
if self.auto_determine_prompt_template_attempt != self.model_name:
log.info("auto_determine_prompt_template", model_name=self.model_name)
self.auto_determine_prompt_template_attempt = self.model_name
self.determine_prompt_template()
prompt_template_example, prompt_template_file = (
self.prompt_template_example()
)
has_prompt_template = (
prompt_template_file and prompt_template_file != "default.jinja2"
)
data = {
"api_key": self.api_key,
"prompt_template_example": prompt_template_example,
"has_prompt_template": has_prompt_template,
"template_file": prompt_template_file,
"meta": self.Meta().model_dump(),
"error_action": None,
"double_coercion": self.double_coercion,
}
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
data[field_name] = getattr(self, field_name, None)
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
data={
"api_key": self.api_key,
"prompt_template_example": prompt_template_example,
"has_prompt_template": (prompt_template_file and prompt_template_file != "default.jinja2"),
"template_file": prompt_template_file,
"meta": self.Meta().model_dump(),
"error_action": None,
}
data=data,
)
if status_change:
instance.emit_agent_status_by_client(self)
def populate_extra_fields(self, data: dict):
"""
Updates data with the extra fields from the client's Meta
"""
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
data[field_name] = getattr(self, field_name, None)
def determine_prompt_template(self):
if not self.model_name:
return
template = model_prompt.query_hf_for_prompt_template_suggestion(self.model_name)
if template:
model_prompt.create_user_override(template, self.model_name)
async def get_model_name(self):
models = await self.client.models.list()
try:
return models.data[0].id
except IndexError:
return None
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
@@ -246,77 +398,95 @@ class ClientBase:
"""
if self.processing:
return
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
self.model_name = await self.get_model_name()
except Exception as e:
self.log.warning("client status error", e=e, client=self.name)
self.model_name = None
self.connected = False
self.toggle_disabled_if_remote()
self.emit_status()
return
self.connected = True
if not self.model_name or self.model_name == "None":
self.log.warning("client model not loaded", client=self)
self.emit_status()
return
self.emit_status()
def generate_prompt_parameters(self, kind:str):
def generate_prompt_parameters(self, kind: str):
parameters = {}
self.tune_prompt_parameters(
presets.configure(parameters, kind, self.max_token_length),
kind
presets.configure(parameters, kind, self.max_token_length), kind
)
return parameters
def tune_prompt_parameters(self, parameters:dict, kind:str):
def tune_prompt_parameters(self, parameters: dict, kind: str):
parameters["stream"] = False
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
if client_context_attribute(
"nuke_repetition"
) > 0.0 and self.jiggle_enabled_for(kind):
self.jiggle_randomness(
parameters, offset=client_context_attribute("nuke_repetition")
)
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
if fn_tune_kind:
fn_tune_kind(parameters)
agent_context = active_agent.get()
if agent_context.agent:
agent_context.agent.inject_prompt_paramters(parameters, kind, agent_context.action)
def tune_prompt_parameters_conversation(self, parameters:dict):
agent_context.agent.inject_prompt_paramters(
parameters, kind, agent_context.action
)
def tune_prompt_parameters_conversation(self, parameters: dict):
conversation_context = client_context_attribute("conversation")
parameters["max_tokens"] = conversation_context.get("length", 96)
dialog_stopping_strings = [
f"{character}:" for character in conversation_context["other_characters"]
]
dialog_stopping_strings += [
f"{character.upper()}\n"
for character in conversation_context["other_characters"]
]
if "extra_stopping_strings" in parameters:
parameters["extra_stopping_strings"] += dialog_stopping_strings
else:
parameters["extra_stopping_strings"] = dialog_stopping_strings
async def generate(self, prompt:str, parameters:dict, kind:str):
def finalize(self, parameters: dict, prompt: str):
prompt = util.replace_special_tokens(prompt)
for finalizer in self.finalizers:
fn = getattr(self, finalizer, None)
prompt, applied = fn(parameters, prompt)
if applied:
return prompt
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.completions.create(prompt=prompt.strip(" "), **parameters)
response = await self.client.completions.create(
prompt=prompt.strip(" "), **parameters
)
return response.get("choices", [{}])[0].get("text", "")
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
@@ -324,85 +494,111 @@ class ClientBase:
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message="Error during generation (check logs)", status="error")
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x, retries:int=2
self,
prompt: str,
kind: str = "conversation",
finalize: Callable = lambda x: x,
retries: int = 2,
) -> str:
"""
Send a prompt to the AI and return its response.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
try:
self._returned_prompt_tokens = None
self._returned_response_tokens = None
self.emit_status(processing=True)
await self.status()
prompt_param = self.generate_prompt_parameters(kind)
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip(" ")
finalized_prompt = self.prompt_template(
self.get_system_message(kind), prompt
).strip(" ")
finalized_prompt = self.finalize(prompt_param, finalized_prompt)
prompt_param = finalize(prompt_param)
token_length = self.count_tokens(finalized_prompt)
time_start = time.time()
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
response = await self.generate(
self.repetition_adjustment(finalized_prompt),
prompt_param,
kind
self.log.debug(
"send_prompt",
token_length=token_length,
max_token_length=self.max_token_length,
parameters=prompt_param,
)
response, finalized_prompt = await self.auto_break_repetition(finalized_prompt, prompt_param, response, kind, retries)
prompt_sent = self.repetition_adjustment(finalized_prompt)
response = await self.generate(prompt_sent, prompt_param, kind)
response, finalized_prompt = await self.auto_break_repetition(
finalized_prompt, prompt_param, response, kind, retries
)
time_end = time.time()
# stopping strings sometimes get appended to the end of the response anyways
# split the response by the first stopping string and take the first part
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
if stopping_string in response:
response = response.split(stopping_string)[0]
break
emit("prompt_sent", data={
"kind": kind,
"prompt": finalized_prompt,
"response": response,
"prompt_tokens": token_length,
"response_tokens": self.count_tokens(response),
"time": time_end - time_start,
})
agent_context = active_agent.get()
emit(
"prompt_sent",
data=PromptData(
kind=kind,
prompt=prompt_sent,
response=response,
prompt_tokens=self._returned_prompt_tokens or token_length,
response_tokens=self._returned_response_tokens
or self.count_tokens(response),
agent_stack=agent_context.agent_stack if agent_context else [],
client_name=self.name,
client_type=self.client_type,
time=time_end - time_start,
generation_parameters=prompt_param,
).model_dump(),
)
return response
finally:
self.emit_status(processing=False)
self._returned_prompt_tokens = None
self._returned_response_tokens = None
async def auto_break_repetition(
self,
finalized_prompt:str,
prompt_param:dict,
response:str,
kind:str,
retries:int,
pad_max_tokens:int=32,
self,
finalized_prompt: str,
prompt_param: dict,
response: str,
kind: str,
retries: int,
pad_max_tokens: int = 32,
) -> str:
"""
If repetition breaking is enabled, this will retry the prompt if its
response is too similar to other messages in the prompt
This requires the agent to have the allow_repetition_break method
and the jiggle_enabled_for method and the client to have the
auto_break_repetition_enabled attribute set to True
Arguments:
- finalized_prompt: the prompt that was sent
@@ -411,47 +607,46 @@ class ClientBase:
- kind: the kind of generation
- retries: the number of retries left
- pad_max_tokens: increase response max_tokens by this amount per iteration
Returns:
- the response
"""
if not self.auto_break_repetition_enabled:
if not self.auto_break_repetition_enabled or not response.strip():
return response, finalized_prompt
agent_context = active_agent.get()
if self.jiggle_enabled_for(kind, auto=True):
# check if the response is a repetition
# using the default similarity threshold of 98, meaning it needs
# to be really similar to be considered a repetition
is_repetition, similarity_score, matched_line = util.similarity_score(
response,
finalized_prompt.split("\n"),
similarity_threshold=80
response, finalized_prompt.split("\n"), similarity_threshold=80
)
if not is_repetition:
# not a repetition, return the response
self.log.debug("send_prompt no similarity", similarity_score=similarity_score)
finalized_prompt = self.repetition_adjustment(finalized_prompt, is_repetitive=False)
return response, finalized_prompt
while is_repetition and retries > 0:
# it's a repetition, retry the prompt with adjusted parameters
self.log.warn(
"send_prompt similarity retry",
agent=agent_context.agent.agent_type,
similarity_score=similarity_score,
retries=retries
self.log.debug(
"send_prompt no similarity", similarity_score=similarity_score
)
finalized_prompt = self.repetition_adjustment(
finalized_prompt, is_repetitive=False
)
return response, finalized_prompt
while is_repetition and retries > 0:
# it's a repetition, retry the prompt with adjusted parameters
self.log.warn(
"send_prompt similarity retry",
agent=agent_context.agent.agent_type,
similarity_score=similarity_score,
retries=retries,
)
# first we apply the client's randomness jiggle which will adjust
# parameters like temperature and repetition_penalty, depending
# on the client
@@ -459,90 +654,93 @@ class ClientBase:
# this is a cumulative adjustment, so it will add to the previous
# iteration's adjustment, this also means retries should be kept low
# otherwise it will get out of hand and start generating nonsense
self.jiggle_randomness(prompt_param, offset=0.5)
# then we pad the max_tokens by the pad_max_tokens amount
prompt_param["max_tokens"] += pad_max_tokens
# send the prompt again
# we use the repetition_adjustment method to further encourage
# the AI to break the repetition on its own as well.
finalized_prompt = self.repetition_adjustment(finalized_prompt, is_repetitive=True)
response = retried_response = await self.generate(
finalized_prompt,
prompt_param,
kind
finalized_prompt = self.repetition_adjustment(
finalized_prompt, is_repetitive=True
)
self.log.debug("send_prompt dedupe sentences", response=response, matched_line=matched_line)
response = retried_response = await self.generate(
finalized_prompt, prompt_param, kind
)
self.log.debug(
"send_prompt dedupe sentences",
response=response,
matched_line=matched_line,
)
# a lot of the times the response will now contain the repetition + something new
# so we dedupe the response to remove the repetition on sentences level
response = util.dedupe_sentences(response, matched_line, similarity_threshold=85, debug=True)
self.log.debug("send_prompt dedupe sentences (after)", response=response)
response = util.dedupe_sentences(
response, matched_line, similarity_threshold=85, debug=True
)
self.log.debug(
"send_prompt dedupe sentences (after)", response=response
)
# deduping may have removed the entire response, so we check for that
if not util.strip_partial_sentences(response).strip():
# if the response is empty, we set the response to the original
# and try again next loop
response = retried_response
# check if the response is a repetition again
is_repetition, similarity_score, matched_line = util.similarity_score(
response,
finalized_prompt.split("\n"),
similarity_threshold=80
response, finalized_prompt.split("\n"), similarity_threshold=80
)
retries -= 1
return response, finalized_prompt
def count_tokens(self, content:str):
def count_tokens(self, content: str):
return util.count_tokens(content)
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
def jiggle_enabled_for(self, kind:str, auto:bool=False) -> bool:
def jiggle_enabled_for(self, kind: str, auto: bool = False) -> bool:
agent_context = active_agent.get()
agent = agent_context.agent
if not agent:
return False
return agent.allow_repetition_break(kind, agent_context.action, auto=auto)
def repetition_adjustment(self, prompt:str, is_repetitive:bool=False):
def repetition_adjustment(self, prompt: str, is_repetitive: bool = False):
"""
Breaks the prompt into lines and checkse each line for a match with
[$REPETITION|{repetition_adjustment}].
On match and if is_repetitive is True, the line is removed from the prompt and
replaced with the repetition_adjustment.
On match and if is_repetitive is False, the line is removed from the prompt.
On match and if is_repetitive is False, the line is removed from the prompt.
"""
lines = prompt.split("\n")
new_lines = []
for line in lines:
if line.startswith("[$REPETITION|"):
if is_repetitive:
@@ -551,5 +749,5 @@ class ClientBase:
new_lines.append("")
else:
new_lines.append(line)
return "\n".join(new_lines)
return "\n".join(new_lines)

View File

@@ -1,6 +1,7 @@
import pydantic
from enum import Enum
import pydantic
__all__ = [
"ClientType",
"ClientBootstrap",
@@ -10,8 +11,10 @@ __all__ = [
LISTS = {}
class ClientType(str, Enum):
"""Client type enum."""
textgen = "textgenwebui"
automatic1111 = "automatic1111"
@@ -20,43 +23,42 @@ class ClientBootstrap(pydantic.BaseModel):
"""Client bootstrap model."""
# client type, currently supports "textgen" and "automatic1111"
client_type: ClientType
# unique client identifier
uid: str
# connection name
name: str
# connection information for the client
# REST api url
api_url: str
# service name (for example runpod)
service_name: str
class register_list:
def __init__(self, service_name:str):
def __init__(self, service_name: str):
self.service_name = service_name
def __call__(self, func):
LISTS[self.service_name] = func
return func
async def list_all(exclude_urls: list[str] = list()):
"""
Return a list of client bootstrap objects.
"""
for service_name, func in LISTS.items():
async for item in func():
if item.api_url not in exclude_urls:
yield item.dict()
yield item.dict()

View File

@@ -0,0 +1,229 @@
import pydantic
import structlog
from cohere import AsyncClient
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.util import count_tokens
__all__ = [
"CohereClient",
]
log = structlog.get_logger("talemate")
# Edit this to add new models / remove old models
SUPPORTED_MODELS = [
"command",
"command-r",
"command-r-plus",
]
class Defaults(pydantic.BaseModel):
max_token_length: int = 16384
model: str = "command-r-plus"
@register()
class CohereClient(ClientBase):
"""
Cohere client for generating text.
"""
client_type = "cohere"
conversation_retries = 0
auto_break_repetition_enabled = False
decensor_enabled = True
class Meta(ClientBase.Meta):
name_prefix: str = "Cohere"
title: str = "Cohere"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="command-r-plus", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def cohere_api_key(self):
return self.config.get("cohere", {}).get("api_key")
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.cohere_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"cohere_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
},
)
def set_client(self, max_token_length: int = None):
if not self.cohere_api_key:
self.client = AsyncClient("sk-1111")
log.error("No cohere API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "command-r-plus"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncClient(self.cohere_api_key)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"cohere set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return count_tokens(response.text)
def prompt_tokens(self, prompt: str):
return count_tokens(prompt)
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
# if temperature is set, it needs to be clamped between 0 and 1.0
if "temperature" in parameters:
parameters["temperature"] = max(0.0, min(1.0, parameters["temperature"]))
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.cohere_api_key:
raise Exception("No cohere API key set")
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
human_message = prompt.strip()
system_message = self.get_system_message(kind)
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
try:
response = await self.client.chat(
model=self.model_name,
preamble=system_message,
message=human_message,
**parameters,
)
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
log.debug("generated response", response=response.text)
response = response.text
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
# except PermissionDeniedError as e:
# self.log.error("generate error", e=e)
# emit("status", message="cohere API: Permission Denied", status="error")
# return ""
except Exception as e:
raise

View File

@@ -3,19 +3,20 @@ Context managers for various client-side operations.
"""
from contextvars import ContextVar
from pydantic import BaseModel, Field
from copy import deepcopy
import structlog
from pydantic import BaseModel, Field
__all__ = [
'context_data',
'client_context_attribute',
'ContextModel',
"context_data",
"client_context_attribute",
"ContextModel",
]
log = structlog.get_logger()
def model_to_dict_without_defaults(model_instance):
model_dict = model_instance.dict()
for field_name, field in model_instance.__class__.__fields__.items():
@@ -23,20 +24,25 @@ def model_to_dict_without_defaults(model_instance):
del model_dict[field_name]
return model_dict
class ConversationContext(BaseModel):
talking_character: str = None
other_characters: list[str] = Field(default_factory=list)
class ContextModel(BaseModel):
"""
Pydantic model for the context data.
"""
nuke_repetition: float = Field(0.0, ge=0.0, le=3.0)
conversation: ConversationContext = Field(default_factory=ConversationContext)
length: int = 96
# Define the context variable as an empty dictionary
context_data = ContextVar('context_data', default=ContextModel().model_dump())
context_data = ContextVar("context_data", default=ContextModel().model_dump())
def client_context_attribute(name, default=None):
"""
@@ -47,6 +53,7 @@ def client_context_attribute(name, default=None):
# Return the value of the key if it exists, otherwise return the default value
return data.get(name, default)
def set_client_context_attribute(name, value):
"""
Set the value of the context variable `context_data` for the given key.
@@ -55,7 +62,8 @@ def set_client_context_attribute(name, value):
data = context_data.get()
# Set the value of the key
data[name] = value
def set_conversation_context_attribute(name, value):
"""
Set the value of the context variable `context_data.conversation` for the given key.
@@ -65,6 +73,7 @@ def set_conversation_context_attribute(name, value):
# Set the value of the key
data["conversation"][name] = value
class ClientContext:
"""
A context manager to set values to the context variable `context_data`.
@@ -82,10 +91,10 @@ class ClientContext:
Set the key-value pairs to the context variable `context_data` when entering the context.
"""
# Get the current context data
data = deepcopy(context_data.get()) if context_data.get() else {}
data.update(self.values)
# Update the context data
self.token = context_data.set(data)
@@ -93,5 +102,5 @@ class ClientContext:
"""
Reset the context variable `context_data` to its previous values when exiting the context.
"""
context_data.reset(self.token)

View File

@@ -0,0 +1,34 @@
import importlib
import os
import structlog
log = structlog.get_logger("talemate.client.custom")
# import every submodule in this directory
#
# each directory in this directory is a submodule
# get the current directory
current_directory = os.path.dirname(__file__)
# get all subdirectories
subdirectories = [
os.path.join(current_directory, name)
for name in os.listdir(current_directory)
if os.path.isdir(os.path.join(current_directory, name))
]
# import every submodule
for subdirectory in subdirectories:
# get the name of the submodule
submodule_name = os.path.basename(subdirectory)
if submodule_name.startswith("__"):
continue
log.info("activating custom client", module=submodule_name)
# import the submodule
importlib.import_module(f".{submodule_name}", __package__)

View File

@@ -0,0 +1,5 @@
Each client should be in its own subdirectory.
The subdirectory itself must be a valid python module.
Check out docs/dev/client/example/test for a very simplistic custom client example.

View File

@@ -0,0 +1,312 @@
import json
import os
import pydantic
import structlog
import vertexai
from google.api_core.exceptions import ResourceExhausted
from vertexai.generative_models import (
ChatSession,
GenerativeModel,
ResponseValidationError,
SafetySetting,
)
from talemate.client.base import ClientBase, ErrorAction, ExtraField
from talemate.client.registry import register
from talemate.client.remote import RemoteServiceMixin
from talemate.config import Client as BaseClientConfig
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.util import count_tokens
__all__ = [
"GoogleClient",
]
log = structlog.get_logger("talemate")
# Edit this to add new models / remove old models
SUPPORTED_MODELS = [
"gemini-1.0-pro",
"gemini-1.5-pro-preview-0409",
]
class Defaults(pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gemini-1.0-pro"
disable_safety_settings: bool = False
class ClientConfig(BaseClientConfig):
disable_safety_settings: bool = False
@register()
class GoogleClient(RemoteServiceMixin, ClientBase):
"""
Google client for generating text.
"""
client_type = "google"
conversation_retries = 0
auto_break_repetition_enabled = False
decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Google"
title: str = "Google"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = {
"disable_safety_settings": ExtraField(
name="disable_safety_settings",
type="bool",
label="Disable Safety Settings",
required=False,
description="Disable Google's safety settings for responses generated by the model.",
),
}
def __init__(self, model="gemini-1.0-pro", **kwargs):
self.model_name = model
self.setup_status = None
self.model_instance = None
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
self.google_credentials_read = False
self.google_project_id = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def google_credentials(self):
path = self.google_credentials_path
if not path:
return None
with open(path) as f:
return json.load(f)
@property
def google_credentials_path(self):
return self.config.get("google").get("gcloud_credentials_path")
@property
def google_location(self):
return self.config.get("google").get("gcloud_location")
@property
def ready(self):
# all google settings must be set
return all(
[
self.google_credentials_path,
self.google_location,
]
)
@property
def safety_settings(self):
if not self.disable_safety_settings:
return None
safety_settings = [
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
),
]
return safety_settings
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.ready:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "Setup incomplete"
error_action = ErrorAction(
title="Setup Google API credentials",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"google_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
}
self.populate_extra_fields(data)
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
data=data,
)
def set_client(self, max_token_length: int = None, **kwargs):
if not self.ready:
log.error("Google cloud setup incomplete")
if self.setup_status:
self.setup_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "gemini-1.0-pro"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
if self.google_credentials_path:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_credentials_path
model = self.model_name
self.max_token_length = max_token_length or 16384
if not self.setup_status:
if self.setup_status is False:
project_id = self.google_credentials.get("project_id")
self.google_project_id = project_id
if self.google_credentials_path:
vertexai.init(project=project_id, location=self.google_location)
emit("request_client_status")
emit("request_agent_status")
self.setup_status = True
self.model_instance = GenerativeModel(model_name=model)
log.info(
"google set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def response_tokens(self, response: str):
return count_tokens(response.text)
def prompt_tokens(self, prompt: str):
return count_tokens(prompt)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
if "disable_safety_settings" in kwargs:
self.disable_safety_settings = kwargs["disable_safety_settings"]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.ready:
raise Exception("Google cloud setup incomplete")
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
except (IndexError, ValueError):
pass
human_message = prompt.strip()
system_message = self.get_system_message(kind)
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
disable_safety_settings=self.disable_safety_settings,
safety_settings=self.safety_settings,
)
try:
chat = self.model_instance.start_chat()
response = await chat.send_message_async(
human_message,
safety_settings=self.safety_settings,
)
self._returned_prompt_tokens = self.prompt_tokens(prompt)
self._returned_response_tokens = self.response_tokens(response)
response = response.text
log.debug("generated response", response=response)
if expected_response and expected_response.startswith("{"):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
# except PermissionDeniedError as e:
# self.log.error("generate error", e=e)
# emit("status", message="google API: Permission Denied", status="error")
# return ""
except ResourceExhausted as e:
self.log.error("generate error", e=e)
emit("status", message="google API: Quota Limit reached", status="error")
return ""
except ResponseValidationError as e:
self.log.error("generate error", e=e)
emit(
"status",
message="google API: Response Validation Error",
status="error",
)
if not self.disable_safety_settings:
return "Failed to generate response. Probably due to safety settings, you can turn them off in the client settings."
return "Failed to generate response. Please check logs."
except Exception as e:
raise

235
src/talemate/client/groq.py Normal file
View File

@@ -0,0 +1,235 @@
import pydantic
import structlog
from groq import AsyncGroq, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"GroqClient",
]
log = structlog.get_logger("talemate")
# Edit this to add new models / remove old models
SUPPORTED_MODELS = [
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
]
JSON_OBJECT_RESPONSE_MODELS = []
class Defaults(pydantic.BaseModel):
max_token_length: int = 8192
model: str = "llama3-70b-8192"
@register()
class GroqClient(ClientBase):
"""
OpenAI client for generating text.
"""
client_type = "groq"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = True
class Meta(ClientBase.Meta):
name_prefix: str = "Groq"
title: str = "Groq"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="llama3-70b-8192", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def groq_api_key(self):
return self.config.get("groq", {}).get("api_key")
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.groq_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"groq_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
},
)
def set_client(self, max_token_length: int = None):
if not self.groq_api_key:
self.client = AsyncGroq(api_key="sk-1111")
log.error("No groq.ai API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "llama3-70b-8192"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncGroq(api_key=self.groq_api_key)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"groq.ai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return response.usage.completion_tokens
def prompt_tokens(self, response: str):
return response.usage.prompt_tokens
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.groq_api_key:
raise Exception("No groq.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
system_message = self.get_system_message(kind)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": prompt},
]
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
try:
response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**parameters,
)
response = response.choices[0].message.content
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
emit("status", message="OpenAI API: Permission Denied", status="error")
return ""
except Exception as e:
raise

View File

@@ -1,16 +1,16 @@
import asyncio
import random
import json
import logging
import random
from abc import ABC, abstractmethod
from typing import Callable, Union
import requests
import talemate.client.system_prompts as system_prompts
import talemate.util as util
from talemate.client.registry import register
import talemate.client.system_prompts as system_prompts
from talemate.client.textgenwebui import RESTTaleMateClient
from talemate.emit import Emission, emit
# NOT IMPLEMENTED AT THIS POINT
# NOT IMPLEMENTED AT THIS POINT

View File

@@ -1,65 +1,64 @@
import pydantic
from openai import AsyncOpenAI
from talemate.client.base import ClientBase
from talemate.client.registry import register
from openai import AsyncOpenAI
class Defaults(pydantic.BaseModel):
api_url:str = "http://localhost:1234"
api_url: str = "http://localhost:1234"
max_token_length: int = 8192
@register()
class LMStudioClient(ClientBase):
auto_determine_prompt_template: bool = True
client_type = "lmstudio"
conversation_retries = 5
class Meta(ClientBase.Meta):
name_prefix:str = "LMStudio"
title:str = "LMStudio"
defaults:Defaults = Defaults()
name_prefix: str = "LMStudio"
title: str = "LMStudio"
defaults: Defaults = Defaults()
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters:dict, kind:str):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
model_name = await super().get_model_name()
# model name comes back as a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
if model_name:
model_name = model_name.replace("\\", "/").split("/")[-1]
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {'role': 'user', 'content': prompt.strip()}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except Exception as e:
self.log.error("generate error", e=e)
return ""
return ""

View File

@@ -0,0 +1,254 @@
import pydantic
import structlog
from mistralai.async_client import MistralAsyncClient
from mistralai.exceptions import MistralAPIStatusException
from mistralai.models.chat_completion import ChatMessage
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
__all__ = [
"MistralAIClient",
]
log = structlog.get_logger("talemate")
# Edit this to add new models / remove old models
SUPPORTED_MODELS = [
"open-mistral-7b",
"open-mixtral-8x7b",
"open-mixtral-8x22b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
]
JSON_OBJECT_RESPONSE_MODELS = SUPPORTED_MODELS
class Defaults(pydantic.BaseModel):
max_token_length: int = 16384
model: str = "open-mixtral-8x7b"
@register()
class MistralAIClient(ClientBase):
"""
OpenAI client for generating text.
"""
client_type = "mistral"
conversation_retries = 0
auto_break_repetition_enabled = False
# TODO: make this configurable?
decensor_enabled = True
class Meta(ClientBase.Meta):
name_prefix: str = "MistralAI"
title: str = "MistralAI"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="open-mixtral-8x7b", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def mistralai_api_key(self):
return self.config.get("mistralai", {}).get("api_key")
def emit_status(self, processing: bool = None):
error_action = None
if processing is not None:
self.processing = processing
if self.mistralai_api_key:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
status = "error"
model_name = "No API key set"
error_action = ErrorAction(
title="Set API Key",
action_name="openAppConfig",
icon="mdi-key-variant",
arguments=[
"application",
"mistralai_api",
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
},
)
def set_client(self, max_token_length: int = None):
if not self.mistralai_api_key:
self.client = MistralAsyncClient(api_key="sk-1111")
log.error("No mistral.ai API key set")
if self.api_key_status:
self.api_key_status = False
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "open-mixtral-8x7b"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = MistralAsyncClient(api_key=self.mistralai_api_key)
self.max_token_length = max_token_length or 16384
if not self.api_key_status:
if self.api_key_status is False:
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info(
"mistral.ai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
def on_config_saved(self, event):
config = event.data
self.config = config
self.set_client(max_token_length=self.max_token_length)
def response_tokens(self, response: str):
return response.usage.completion_tokens
def prompt_tokens(self, response: str):
return response.usage.prompt_tokens
async def status(self):
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
# clamp temperature to 0.1 and 1.0
# Unhandled Error: Status: 422. Message: {"object":"error","message":{"detail":[{"type":"less_than_equal","loc":["body","temperature"],"msg":"Input should be less than or equal to 1","input":1.31,"ctx":{"le":1.0},"url":"https://errors.pydantic.dev/2.6/v/less_than_equal"}]},"type":"invalid_request_error","param":null,"code":null}
if "temperature" in parameters:
parameters["temperature"] = min(1.0, max(0.1, parameters["temperature"]))
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.mistralai_api_key:
raise Exception("No mistral.ai API key set")
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
right = None
expected_response = None
try:
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
system_message = self.get_system_message(kind)
messages = [
ChatMessage(role="system", content=system_message),
ChatMessage(role="user", content=prompt.strip()),
]
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
try:
response = await self.client.chat(
model=self.model_name,
messages=messages,
**parameters,
)
self._returned_prompt_tokens = self.prompt_tokens(response)
self._returned_response_tokens = self.response_tokens(response)
response = response.choices[0].message.content
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right) :].strip()
return response
except MistralAPIStatusException as e:
self.log.error("generate error", e=e)
if e.http_status in [403, 401]:
emit(
"status",
message="mistral.ai API: Permission Denied",
status="error",
)
return ""
except Exception as e:
raise

View File

@@ -1,17 +1,24 @@
from jinja2 import Environment, FileSystemLoader
import json
import os
import structlog
import shutil
import huggingface_hub
import tempfile
import huggingface_hub
import structlog
from jinja2 import Environment, FileSystemLoader
__all__ = ["model_prompt"]
BASE_TEMPLATE_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "templates", "llm-prompt"
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
"..",
"templates",
"llm-prompt",
)
# holds the default templates
# holds the default templates
STD_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "std")
# llm prompt templates provided by talemate
@@ -22,80 +29,114 @@ USER_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "user")
TEMPLATE_IDENTIFIERS = []
def register_template_identifier(cls):
TEMPLATE_IDENTIFIERS.append(cls)
return cls
log = structlog.get_logger("talemate.model_prompts")
class ModelPrompt:
"""
Will attempt to load an LLM prompt template based on the model name
If the model name is not found, it will default to the 'default' template
"""
template_map = {}
@property
def env(self):
if not hasattr(self, "_env"):
log.info("modal prompt", base_template_path=BASE_TEMPLATE_PATH)
self._env = Environment(loader=FileSystemLoader([
USER_TEMPLATE_PATH,
TALEMATE_TEMPLATE_PATH,
]))
self._env = Environment(
loader=FileSystemLoader(
[
USER_TEMPLATE_PATH,
TALEMATE_TEMPLATE_PATH,
]
)
)
return self._env
@property
def std_templates(self) -> list[str]:
env = Environment(loader=FileSystemLoader(STD_TEMPLATE_PATH))
return sorted(env.list_templates())
def __call__(self, model_name:str, system_message:str, prompt:str):
def __call__(
self,
model_name: str,
system_message: str,
prompt: str,
double_coercion: str = None,
):
template, template_file = self.get_template(model_name)
if not template:
template_file = "default.jinja2"
template = self.env.get_template(template_file)
if not double_coercion:
double_coercion = ""
if "<|BOT|>" not in prompt and double_coercion:
prompt = f"{prompt}<|BOT|>"
if "<|BOT|>" in prompt:
user_message, coercion_message = prompt.split("<|BOT|>", 1)
coercion_message = f"{double_coercion}{coercion_message}"
else:
user_message = prompt
coercion_message = ""
return template.render({
"system_message": system_message,
"prompt": prompt,
"user_message": user_message,
"coercion_message": coercion_message,
"set_response" : self.set_response
}), template_file
def set_response(self, prompt:str, response_str:str):
return (
template.render(
{
"system_message": system_message,
"prompt": prompt.strip(),
"user_message": user_message.strip(),
"coercion_message": coercion_message,
"set_response": lambda prompt, response_str: self.set_response(
prompt, response_str, double_coercion
),
}
),
template_file,
)
def set_response(self, prompt: str, response_str: str, double_coercion: str = None):
prompt = prompt.strip("\n").strip()
if not double_coercion:
double_coercion = ""
if "<|BOT|>" not in prompt and double_coercion:
prompt = f"{prompt}<|BOT|>"
if "<|BOT|>" in prompt:
response_str = f"{double_coercion}{response_str}"
if "\n<|BOT|>" in prompt:
prompt = prompt.replace("\n<|BOT|>", response_str)
else:
prompt = prompt.replace("<|BOT|>", response_str)
else:
prompt = prompt.rstrip("\n") + response_str
return prompt
def get_template(self, model_name:str):
def get_template(self, model_name: str):
"""
Will attempt to load an LLM prompt template - this supports
partial filename matching on the template file name.
"""
matches = []
# Iterate over all templates in the loader's directory
for template_name in self.env.list_templates():
# strip extension
@@ -103,56 +144,71 @@ class ModelPrompt:
# Check if the model name is in the template filename
if template_name_match.lower() in model_name.lower():
matches.append(template_name)
# If there are no matches, return None
if not matches:
return None, None
# If there is only one match, return it
if len(matches) == 1:
return self.env.get_template(matches[0]), matches[0]
# If there are multiple matches, return the one with the longest name
sorted_matches = sorted(matches, key=lambda x: len(x), reverse=True)
return self.env.get_template(sorted_matches[0]), sorted_matches[0]
def create_user_override(self, template_name:str, model_name:str):
def create_user_override(self, template_name: str, model_name: str):
"""
Will copy STD_TEMPLATE_PATH/template_name to USER_TEMPLATE_PATH/model_name.jinja2
"""
template_name = template_name.split(".jinja2")[0]
shutil.copyfile(
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2"),
)
return os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
def query_hf_for_prompt_template_suggestion(self, model_name:str):
def query_hf_for_prompt_template_suggestion(self, model_name: str):
print("query_hf_for_prompt_template_suggestion", model_name)
api = huggingface_hub.HfApi()
try:
author, model_name = model_name.split("_", 1)
except ValueError:
return None
models = list(api.list_models(
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
))
branch_name = "main"
# special popular cases
# bartowski
if author == "bartowski" and "exl2" in model_name:
# split model_name by exl2 and take the first part with "exl2" readded
# the second part is the branch name
model_name, branch_name = model_name.split("exl2_", 1)
model_name = f"{model_name}exl2"
models = list(api.list_models(model_name=model_name, author=author))
if not models:
return None
model = models[0]
repo_id = f"{author}/{model_name}"
# Check README.md
with tempfile.TemporaryDirectory() as tmpdir:
readme_path = huggingface_hub.hf_hub_download(repo_id=repo_id, filename="README.md", cache_dir=tmpdir)
readme_path = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename="README.md",
cache_dir=tmpdir,
revision=branch_name,
)
if not readme_path:
return None
with open(readme_path) as f:
@@ -162,25 +218,54 @@ class ModelPrompt:
if identifier(readme):
return f"{identifier.template_str}.jinja2"
# Check tokenizer_config.json
# "chat_template" key
with tempfile.TemporaryDirectory() as tmpdir:
config_path = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename="tokenizer_config.json",
cache_dir=tmpdir,
revision=branch_name,
)
if not config_path:
return None
with open(config_path) as f:
config = json.load(f)
for identifer_cls in TEMPLATE_IDENTIFIERS:
identifier = identifer_cls()
if identifier(config.get("chat_template", "")):
return f"{identifier.template_str}.jinja2"
model_prompt = ModelPrompt()
class TemplateIdentifier:
def __call__(self, content:str):
def __call__(self, content: str):
return False
@register_template_identifier
class Llama2Identifier(TemplateIdentifier):
template_str = "Llama2"
def __call__(self, content:str):
def __call__(self, content: str):
return "[INST]" in content and "[/INST]" in content
@register_template_identifier
class Llama3Identifier(TemplateIdentifier):
template_str = "Llama3"
def __call__(self, content: str):
return "<|start_header_id|>" in content and "<|end_header_id|>" in content
@register_template_identifier
class ChatMLIdentifier(TemplateIdentifier):
template_str = "ChatML"
def __call__(self, content:str):
def __call__(self, content: str):
"""
<|im_start|>system
{{ system_message }}<|im_end|>
@@ -189,28 +274,63 @@ class ChatMLIdentifier(TemplateIdentifier):
<|im_start|>assistant
{{ coercion_message }}
"""
return "<|im_start|>" in content and "<|im_end|>" in content
@register_template_identifier
class CommandRIdentifier(TemplateIdentifier):
template_str = "CommandR"
def __call__(self, content: str):
"""
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ system_message }}
{{ user_message }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|>
<|CHATBOT_TOKEN|>{{ coercion_message }}
"""
return (
"<|im_start|>system" in content
and "<|im_end|>" in content
and "<|im_start|>user" in content
and "<|im_start|>assistant" in content
"<|START_OF_TURN_TOKEN|>" in content
and "<|END_OF_TURN_TOKEN|>" in content
and "<|SYSTEM_TOKEN|>" not in content
)
@register_template_identifier
class CommandRPlusIdentifier(TemplateIdentifier):
template_str = "CommandRPlus"
def __call__(self, content: str):
"""
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ system_message }}
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ user_message }}
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{ coercion_message }}
"""
return (
"<|START_OF_TURN_TOKEN|>" in content
and "<|END_OF_TURN_TOKEN|>" in content
and "<|SYSTEM_TOKEN|>" in content
)
@register_template_identifier
class InstructionInputResponseIdentifier(TemplateIdentifier):
template_str = "InstructionInputResponse"
def __call__(self, content:str):
def __call__(self, content: str):
return (
"### Instruction:" in content
and "### Input:" in content
and "### Response:" in content
)
@register_template_identifier
class AlpacaIdentifier(TemplateIdentifier):
template_str = "Alpaca"
def __call__(self, content:str):
def __call__(self, content: str):
"""
{{ system_message }}
@@ -220,20 +340,19 @@ class AlpacaIdentifier(TemplateIdentifier):
### Response:
{{ coercion_message }}
"""
return (
"### Instruction:" in content
and "### Response:" in content
)
return "### Instruction:" in content and "### Response:" in content
@register_template_identifier
class OpenChatIdentifier(TemplateIdentifier):
template_str = "OpenChat"
def __call__(self, content:str):
def __call__(self, content: str):
"""
GPT4 Correct System: {{ system_message }}<|end_of_turn|>GPT4 Correct User: {{ user_message }}<|end_of_turn|>GPT4 Correct Assistant: {{ coercion_message }}
"""
return (
"<|end_of_turn|>" in content
and "GPT4 Correct System:" in content
@@ -241,54 +360,51 @@ class OpenChatIdentifier(TemplateIdentifier):
and "GPT4 Correct Assistant:" in content
)
@register_template_identifier
class VicunaIdentifier(TemplateIdentifier):
template_str = "Vicuna"
def __call__(self, content:str):
def __call__(self, content: str):
"""
SYSTEM: {{ system_message }}
USER: {{ user_message }}
ASSISTANT: {{ coercion_message }}
"""
return (
"SYSTEM:" in content
and "USER:" in content
and "ASSISTANT:" in content
)
return "SYSTEM:" in content and "USER:" in content and "ASSISTANT:" in content
@register_template_identifier
class USER_ASSISTANTIdentifier(TemplateIdentifier):
template_str = "USER_ASSISTANT"
def __call__(self, content:str):
def __call__(self, content: str):
"""
USER: {{ system_message }} {{ user_message }} ASSISTANT: {{ coercion_message }}
"""
return (
"USER:" in content
and "ASSISTANT:" in content
)
return "USER:" in content and "ASSISTANT:" in content
@register_template_identifier
class UserAssistantIdentifier(TemplateIdentifier):
template_str = "UserAssistant"
def __call__(self, content:str):
def __call__(self, content: str):
"""
User: {{ system_message }} {{ user_message }}
Assistant: {{ coercion_message }}
"""
return (
"User:" in content
and "Assistant:" in content
)
return "User:" in content and "Assistant:" in content
@register_template_identifier
class ZephyrIdentifier(TemplateIdentifier):
template_str = "Zephyr"
def __call__(self, content:str):
def __call__(self, content: str):
"""
<|system|>
{{ system_message }}</s>
@@ -297,9 +413,9 @@ class ZephyrIdentifier(TemplateIdentifier):
<|assistant|>
{{ coercion_message }}
"""
return (
"<|system|>" in content
and "<|user|>" in content
and "<|assistant|>" in content
)
)

View File

@@ -1,21 +1,44 @@
import json
import pydantic
import structlog
import tiktoken
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.config import load_config
import structlog
import tiktoken
__all__ = [
"OpenAIClient",
]
log = structlog.get_logger("talemate")
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
# Edit this to add new models / remove old models
SUPPORTED_MODELS = [
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo",
]
JSON_OBJECT_RESPONSE_MODELS = [
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview",
"gpt-3.5-turbo-0125",
]
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
@@ -66,9 +89,11 @@ def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
class Defaults(pydantic.BaseModel):
max_token_length:int = 16384
model:str = "gpt-4-turbo-preview"
max_token_length: int = 16384
model: str = "gpt-4-turbo"
@register()
class OpenAIClient(ClientBase):
@@ -79,35 +104,28 @@ class OpenAIClient(ClientBase):
client_type = "openai"
conversation_retries = 0
auto_break_repetition_enabled = False
class Meta(ClientBase.Meta):
name_prefix:str = "OpenAI"
title:str = "OpenAI"
manual_model:bool = True
manual_model_choices:list[str] = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-turbo-preview",
]
requires_prompt_template: bool = False
defaults:Defaults = Defaults()
# TODO: make this configurable?
decensor_enabled = False
def __init__(self, model="gpt-4-turbo-preview", **kwargs):
class Meta(ClientBase.Meta):
name_prefix: str = "OpenAI"
title: str = "OpenAI"
manual_model: bool = True
manual_model_choices: list[str] = SUPPORTED_MODELS
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
def __init__(self, model="gpt-4-turbo", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def openai_api_key(self):
return self.config.get("openai",{}).get("api_key")
return self.config.get("openai", {}).get("api_key")
def emit_status(self, processing: bool = None):
error_action = None
@@ -127,13 +145,13 @@ class OpenAIClient(ClientBase):
arguments=[
"application",
"openai_api",
]
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
emit(
@@ -145,25 +163,27 @@ class OpenAIClient(ClientBase):
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
}
},
)
def set_client(self, max_token_length:int=None):
def set_client(self, max_token_length: int = None):
if not self.openai_api_key:
self.client = AsyncOpenAI(api_key="sk-1111")
log.error("No OpenAI API key set")
if self.api_key_status:
self.api_key_status = False
emit('request_client_status')
emit('request_agent_status')
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "gpt-3.5-turbo-16k"
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
model = self.model_name
self.client = AsyncOpenAI(api_key=self.openai_api_key)
if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096)
@@ -175,16 +195,20 @@ class OpenAIClient(ClientBase):
self.max_token_length = min(max_token_length or 128000, 128000)
else:
self.max_token_length = max_token_length or 2048
if not self.api_key_status:
if self.api_key_status is False:
emit('request_client_status')
emit('request_agent_status')
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info("openai set client", max_token_length=self.max_token_length, provided_max_token_length=max_token_length, model=model)
log.info(
"openai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
@@ -203,69 +227,98 @@ class OpenAIClient(ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message:str, prompt:str):
def prompt_template(self, system_message: str, prompt: str):
# only gpt-4-1106-preview supports json_object response coersion
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters:dict, kind:str):
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
# GPT-3.5 models tend to run away with the generated
# response size so we allow talemate to set the max_tokens
#
# GPT-4 on the other hand seems to benefit from letting it
# decide the generation length naturally and it will generally
# produce reasonably sized responses
if self.model_name.startswith("gpt-3.5-"):
valid_keys.append("max_tokens")
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.openai_api_key:
raise Exception("No OpenAI API key set")
# only gpt-4-* supports enforcing json object
supports_json_object = self.model_name.startswith("gpt-4-")
supports_json_object = (
self.model_name.startswith("gpt-4-")
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
)
right = None
expected_response = None
try:
_, right = prompt.split("\nContinue this response: ")
_, right = prompt.split("\nStart your response with: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
human_message = {'role': 'user', 'content': prompt.strip()}
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
system_message = {"role": "system", "content": self.get_system_message(kind)}
self.log.debug(
"generate",
prompt=prompt[:128] + " ...",
parameters=parameters,
system_message=system_message,
)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[system_message, human_message], **parameters
model=self.model_name,
messages=[system_message, human_message],
**parameters,
)
response = response.choices[0].message.content
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
# so we strip that out if the expected response is a json object
if (
not supports_json_object
and expected_response
and expected_response.startswith("{")
):
if response.startswith("```json") and response.endswith("```"):
response = response[7:-3].strip()
if right and response.startswith(right):
response = response[len(right):].strip()
response = response[len(right) :].strip()
return response
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
emit("status", message="OpenAI API: Permission Denied", status="error")
return ""
except Exception as e:
raise
raise

View File

@@ -1,93 +1,131 @@
import pydantic
from talemate.client.base import ClientBase
from talemate.client.registry import register
import urllib
from openai import AsyncOpenAI, PermissionDeniedError, NotFoundError
import pydantic
import structlog
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
from talemate.client.base import ClientBase, ExtraField
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
from talemate.emit import emit
log = structlog.get_logger("talemate.client.openai_compat")
EXPERIMENTAL_DESCRIPTION = """Use this client if you want to connect to a service implementing an OpenAI-compatible API. Success is going to depend on the level of compatibility. Use the actual OpenAI client if you want to connect to OpenAI's API."""
class Defaults(pydantic.BaseModel):
api_url:str = "http://localhost:5000"
api_key:str = ""
max_token_length:int = 4096
model:str = ""
api_url: str = "http://localhost:5000"
api_key: str = ""
max_token_length: int = 8192
model: str = ""
api_handles_prompt_template: bool = False
class ClientConfig(BaseClientConfig):
api_handles_prompt_template: bool = False
@register()
class OpenAICompatibleClient(ClientBase):
client_type = "openai_compat"
conversation_retries = 5
conversation_retries = 0
config_cls = ClientConfig
class Meta(ClientBase.Meta):
title:str = "OpenAI Compatible API"
name_prefix:str = "OpenAI Compatible API"
experimental:str = EXPERIMENTAL_DESCRIPTION
enable_api_auth:bool = True
manual_model:bool = True
defaults:Defaults = Defaults()
def __init__(self, model=None, **kwargs):
title: str = "OpenAI Compatible API"
name_prefix: str = "OpenAI Compatible API"
experimental: str = EXPERIMENTAL_DESCRIPTION
enable_api_auth: bool = True
manual_model: bool = True
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = {
"api_handles_prompt_template": ExtraField(
name="api_handles_prompt_template",
type="bool",
label="API Handles Prompt Template",
required=False,
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored.",
)
}
def __init__(
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
):
self.model_name = model
self.api_key = api_key
self.api_handles_prompt_template = api_handles_prompt_template
super().__init__(**kwargs)
@property
def experimental(self):
return EXPERIMENTAL_DESCRIPTION
@property
def can_be_coerced(self):
"""
Determines whether or not his client can pass LLM coercion. (e.g., is able
to predefine partial LLM output in the prompt)
"""
return not self.api_handles_prompt_template
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key")
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key=self.api_key)
self.model_name = kwargs.get("model") or kwargs.get("model_name") or self.model_name
def tune_prompt_parameters(self, parameters:dict, kind:str):
self.api_key = kwargs.get("api_key", self.api_key)
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
url = self.api_url
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
self.model_name = (
kwargs.get("model") or kwargs.get("model_name") or self.model_name
)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
def prompt_template(self, system_message: str, prompt: str):
log.debug(
"IS API HANDLING PROMPT TEMPLATE",
api_handles_prompt_template=self.api_handles_prompt_template,
)
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
async def get_model_name(self):
try:
model_name = await super().get_model_name()
except NotFoundError as e:
# api does not implement model listing
return self.model_name
except Exception as e:
self.log.error("get_model_name error", e=e)
return self.model_name
return self.model_name
# model name may be a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
is_filepath = "/" in model_name
is_filepath_windows = "\\" in model_name
if is_filepath or is_filepath_windows:
model_name = model_name.replace("\\", "/").split("/")[-1]
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {'role': 'user', 'content': prompt.strip()}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
@@ -95,7 +133,9 @@ class OpenAICompatibleClient(ClientBase):
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message="Error during generation (check logs)", status="error")
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
def reconfigure(self, **kwargs):
@@ -104,8 +144,14 @@ class OpenAICompatibleClient(ClientBase):
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
self.max_token_length = (
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
)
if "api_key" in kwargs:
self.api_auth = kwargs["api_key"]
self.set_client(**kwargs)
self.api_key = kwargs["api_key"]
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
log.warning("reconfigure", kwargs=kwargs)
self.set_client(**kwargs)

View File

@@ -28,18 +28,25 @@ PRESET_TALEMATE_CREATOR = {
}
PRESET_LLAMA_PRECISE = {
'temperature': 0.7,
'top_p': 0.1,
'top_k': 40,
'repetition_penalty': 1.18,
"temperature": 0.7,
"top_p": 0.1,
"top_k": 40,
"repetition_penalty": 1.18,
}
PRESET_DETERMINISTIC = {
"temperature": 0.1,
"top_p": 1,
"top_k": 0,
"repetition_penalty": 1.0,
}
PRESET_DIVINE_INTELLECT = {
'temperature': 1.31,
'top_p': 0.14,
'top_k': 49,
"temperature": 1.31,
"top_p": 0.14,
"top_k": 49,
"repetition_penalty_range": 1024,
'repetition_penalty': 1.17,
"repetition_penalty": 1.17,
}
PRESET_SIMPLE_1 = {
@@ -49,7 +56,14 @@ PRESET_SIMPLE_1 = {
"repetition_penalty": 1.15,
}
def configure(config:dict, kind:str, total_budget:int):
PRESET_ANALYTICAL = {
"temperature": 0.1,
"top_p": 0.9,
"top_k": 20,
}
def configure(config: dict, kind: str, total_budget: int):
"""
Sets the config based on the kind of text to generate.
"""
@@ -57,21 +71,34 @@ def configure(config:dict, kind:str, total_budget:int):
set_max_tokens(config, kind, total_budget)
return config
def set_max_tokens(config:dict, kind:str, total_budget:int):
def set_max_tokens(config: dict, kind: str, total_budget: int):
"""
Sets the max_tokens in the config based on the kind of text to generate.
"""
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
return config
def set_preset(config:dict, kind:str):
def set_preset(config: dict, kind: str):
"""
Sets the preset in the config based on the kind of text to generate.
"""
config.update(preset_for_kind(kind))
def preset_for_kind(kind: str):
if kind == "conversation":
# tag based
if "deterministic" in kind:
return PRESET_DETERMINISTIC
elif "creative" in kind:
return PRESET_DIVINE_INTELLECT
elif "simple" in kind:
return PRESET_SIMPLE_1
elif "analytical" in kind:
return PRESET_ANALYTICAL
elif kind == "conversation":
return PRESET_TALEMATE_CONVERSATION
elif kind == "conversation_old":
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
@@ -104,62 +131,99 @@ def preset_for_kind(kind: str):
elif kind == "director":
return PRESET_SIMPLE_1
elif kind == "director_short":
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
return (
PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
)
elif kind == "director_yesno":
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
return (
PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
)
elif kind == "edit_dialogue":
return PRESET_DIVINE_INTELLECT
elif kind == "edit_add_detail":
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
elif kind == "edit_fix_exposition":
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
return PRESET_DETERMINISTIC # Assuming fixing exposition uses the same preset as divine intellect
elif kind == "edit_fix_continuity":
return PRESET_DETERMINISTIC
elif kind == "visualize":
return PRESET_SIMPLE_1
else:
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
def max_tokens_for_kind(kind: str, total_budget: int):
if kind == "conversation":
return 75 # Example value, adjust as needed
return 75
elif kind == "conversation_old":
return 75 # Example value, adjust as needed
return 75
elif kind == "conversation_long":
return 300 # Example value, adjust as needed
return 300
elif kind == "conversation_select_talking_actor":
return 30 # Example value, adjust as needed
return 30
elif kind == "summarize":
return 500 # Example value, adjust as needed
return 500
elif kind == "analyze":
return 500 # Example value, adjust as needed
return 500
elif kind == "analyze_creative":
return 1024 # Example value, adjust as needed
return 1024
elif kind == "analyze_long":
return 2048 # Example value, adjust as needed
return 2048
elif kind == "analyze_freeform":
return 500 # Example value, adjust as needed
return 500
elif kind == "analyze_freeform_medium":
return 192
elif kind == "analyze_freeform_medium_short":
return 128
elif kind == "analyze_freeform_short":
return 10 # Example value, adjust as needed
return 10
elif kind == "narrate":
return 500 # Example value, adjust as needed
return 500
elif kind == "story":
return 300 # Example value, adjust as needed
return 300
elif kind == "create":
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
return min(1024, int(total_budget * 0.35))
elif kind == "create_concise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
return min(400, int(total_budget * 0.25))
elif kind == "create_precise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
return min(400, int(total_budget * 0.25))
elif kind == "create_short":
return 25
elif kind == "director":
return min(192, int(total_budget * 0.25)) # Example calculation, adjust as needed
return min(192, int(total_budget * 0.25))
elif kind == "director_short":
return 25 # Example value, adjust as needed
return 25
elif kind == "director_yesno":
return 2 # Example value, adjust as needed
return 2
elif kind == "edit_dialogue":
return 100 # Example value, adjust as needed
return 100
elif kind == "edit_add_detail":
return 200 # Example value, adjust as needed
return 200
elif kind == "edit_fix_exposition":
return 1024 # Example value, adjust as needed
return 1024
elif kind == "edit_fix_continuity":
return 512
elif kind == "visualize":
return 150
# tag based
elif "extensive" in kind:
return 2048
elif "long" in kind:
return 1024
elif "medium2" in kind:
return 512
elif "medium" in kind:
return 192
elif "short2" in kind:
return 128
elif "short" in kind:
return 75
elif "tiny2" in kind:
return 25
elif "tiny" in kind:
return 10
elif "yesno" in kind:
return 2
else:
return 150 # Default value if none of the kinds match
return 150 # Default value if none of the kinds match

Some files were not shown because too many files have changed in this diff Show More